package net.sparkworks.cargo.client.config;

import lombok.extern.slf4j.Slf4j;
import net.sparkworks.common.client.InternalCommunicationRestTemplateFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.Resource;
import org.springframework.http.MediaType;
import org.springframework.http.converter.FormHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.ResourceHttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.security.oauth2.client.DefaultOAuth2ClientContext;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
import org.springframework.security.oauth2.client.token.grant.password.ResourceOwnerPasswordResourceDetails;
import org.springframework.security.oauth2.config.annotation.web.configuration.EnableOAuth2Client;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.security.KeyManagementException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;
import java.util.Arrays;

@Slf4j
@Configuration
@EnableOAuth2Client
public class CargoRestTemplateConfig {
    
    public static final String SW_CARGO_REST_TEMPLATE_NAME = "cargoRestTemplate";

    private final CargoClientAuthConfig sparkworksCargoClientAuthConfig;
    
    private final CargoSslConfiguration cargoSslConfiguration;
    
    private final CargoInternalCommunicationConfiguration cargoInternalCommunicationConfiguration;

    public CargoRestTemplateConfig(@Autowired(required = false) CargoClientAuthConfig sparkworksCargoClientAuthConfig,
                                   @Autowired(required = false) CargoSslConfiguration cargoSslConfiguration,
                                   CargoInternalCommunicationConfiguration cargoInternalCommunicationConfiguration) {
        this.sparkworksCargoClientAuthConfig = sparkworksCargoClientAuthConfig;
        this.cargoSslConfiguration = cargoSslConfiguration;
        this.cargoInternalCommunicationConfiguration = cargoInternalCommunicationConfiguration;
    }

    public OAuth2ProtectedResourceDetails oAuth2ProtectedResourceDetails() {
        ResourceOwnerPasswordResourceDetails resource = new ResourceOwnerPasswordResourceDetails();
        resource.setAccessTokenUri(sparkworksCargoClientAuthConfig.getAccessTokenUrl());
        resource.setClientId(sparkworksCargoClientAuthConfig.getClientId());
        resource.setClientSecret(sparkworksCargoClientAuthConfig.getClientSecret());
        resource.setGrantType(sparkworksCargoClientAuthConfig.getGrantType());
        resource.setScope(sparkworksCargoClientAuthConfig.getScope());
        resource.setUsername(sparkworksCargoClientAuthConfig.getUsername());
        resource.setPassword(sparkworksCargoClientAuthConfig.getPassword());
        return resource;
    }
    
    @Bean(name = SW_CARGO_REST_TEMPLATE_NAME)
    public RestOperations restTemplate() {
        if (!cargoInternalCommunicationConfiguration.isEnabled()) {
            configureSslOnDisabledInternalCommunication();
        }
        AccessTokenRequest atr = new DefaultAccessTokenRequest();
        MappingJackson2HttpMessageConverter jackson = new MappingJackson2HttpMessageConverter();
        jackson.setSupportedMediaTypes(Arrays.asList(new MediaType[]{MediaType.ALL}));
        HttpMessageConverter<Resource> resource = new ResourceHttpMessageConverter();
        FormHttpMessageConverter formHttpMessageConverter = new FormHttpMessageConverter();
        formHttpMessageConverter.addPartConverter(jackson);
        formHttpMessageConverter.addPartConverter(resource); // This is hope driven programming
        final RestTemplate restTemplate =
                cargoInternalCommunicationConfiguration.isEnabled() ?
                        InternalCommunicationRestTemplateFactory.internalCommunicationRestTemplate() :
                        new OAuth2RestTemplate(oAuth2ProtectedResourceDetails(), new DefaultOAuth2ClientContext(atr));
        restTemplate.setMessageConverters(Arrays.asList(new StringHttpMessageConverter(),
                new MappingJackson2HttpMessageConverter(), jackson, resource, formHttpMessageConverter));
        return restTemplate;
    }
    
    private void configureSslOnDisabledInternalCommunication() {
        if (cargoSslConfiguration.isDisableCnCheck()) {
            HttpsURLConnection.setDefaultHostnameVerifier((hostname, session) -> true);
        }
        if (cargoSslConfiguration.isDisableCertValidation()) {
            disableSSLCertificateChecking();
        }
    }
    
    private void disableSSLCertificateChecking() {
        TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }
            
            @Override
            public void checkClientTrusted(X509Certificate[] arg0, String arg1) {}
            
            @Override
            public void checkServerTrusted(X509Certificate[] arg0, String arg1) {}
        } };
        
        try {
            SSLContext sc = SSLContext.getInstance("TLS");
            
            sc.init(null, trustAllCerts, new java.security.SecureRandom());
            
            HttpsURLConnection.setDefaultSSLSocketFactory(sc.getSocketFactory());
        } catch (KeyManagementException | NoSuchAlgorithmException e) {
            log.error(e.getMessage(), e);
        }
    }
    
}


