/**
 * Copyright (c) 2019, Sinlmao (888@1st.com).
 * <p>
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package cn.sinlmao.commons.network.ssl;

import javax.net.ssl.*;
import java.io.IOException;
import java.net.InetAddress;
import java.net.Socket;
import java.security.*;
import java.security.cert.TrustAnchor;
import java.security.cert.X509Certificate;
import java.util.HashSet;
import java.util.Set;

/**
 * program Sinlmao Commons Network Utils
 * description
 * create 2020-06-18 01:28
 *
 * @author Sinlmao
 */
public class ProtocolSSLFactory extends SSLSocketFactory {

    public final static String PROTOCOL_DEFAULT = "TLS";

    public final static String CERT_FORMAT_DEFAULT = "PKCS12";

    public final static String CERT_STANDARD_DEFAULT = "X509";

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

    private final String[] protocols;
    private final SSLContext context;
    private final SSLSocketFactory factory;

    public static ProtocolSSLFactory getInstance(KeyManager[] keyManagers) throws Exception {
        return getInstance(keyManagers, null, new SecureRandom(), PROTOCOL_DEFAULT, new String[0]);
    }

    public static ProtocolSSLFactory getInstance(KeyManager[] keyManagers, TrustManager[] trustManagers) throws Exception {
        return getInstance(keyManagers, trustManagers, new SecureRandom(), PROTOCOL_DEFAULT, new String[0]);
    }

    public static ProtocolSSLFactory getInstance(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom) throws Exception {
        return getInstance(keyManagers, trustManagers, secureRandom, PROTOCOL_DEFAULT, new String[0]);
    }

    public static ProtocolSSLFactory getInstance(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom, String... protocols) throws Exception {
        return getInstance(keyManagers, trustManagers, secureRandom, PROTOCOL_DEFAULT, protocols);
    }

    public static ProtocolSSLFactory getInstance(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom,
                                                 String contextProtocol, String... protocols) throws Exception {
        return new ProtocolSSLFactory(keyManagers, trustManagers, secureRandom, contextProtocol, protocols);
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

    private ProtocolSSLFactory(KeyManager[] keyManagers, TrustManager[] trustManagers, SecureRandom secureRandom,
                               String contextProtocol, String... protocols) throws Exception {

        super();

        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
        TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());

        if (keyManagers == null || keyManagers.length == 0) {
            keyManagerFactory.init(null, null);
            keyManagers = keyManagerFactory.getKeyManagers();
        }
        if (trustManagers == null || trustManagers.length == 0) {
            trustManagerFactory.init((KeyStore) null);
            trustManagers = trustManagerFactory.getTrustManagers();
        }

        SSLContext context = SSLContext.getInstance(contextProtocol);
        context.init(keyManagers, trustManagers, secureRandom);

        this.protocols = protocols;
        this.context = context;
        this.factory = context.getSocketFactory();
    }

    @Override
    public String[] getDefaultCipherSuites() {
        return factory.getDefaultCipherSuites();
    }

    @Override
    public String[] getSupportedCipherSuites() {
        return factory.getSupportedCipherSuites();
    }

    @Override
    public Socket createSocket() throws IOException {
        final SSLSocket sslSocket = (SSLSocket) factory.createSocket();
        setProtocols(sslSocket);
        return sslSocket;
    }

    @Override
    public SSLSocket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
        final SSLSocket socket = (SSLSocket) factory.createSocket(s, host, port, autoClose);
        setProtocols(socket);
        return socket;
    }

    @Override
    public Socket createSocket(String host, int port) throws IOException {
        final SSLSocket socket = (SSLSocket) factory.createSocket(host, port);
        setProtocols(socket);
        return socket;
    }

    @Override
    public Socket createSocket(String host, int port, InetAddress localHost, int localPort) throws IOException {
        final SSLSocket socket = (SSLSocket) factory.createSocket(host, port, localHost, localPort);
        setProtocols(socket);
        return socket;
    }

    @Override
    public Socket createSocket(InetAddress host, int port) throws IOException {
        final SSLSocket socket = (SSLSocket) factory.createSocket(host, port);
        setProtocols(socket);
        return socket;
    }

    @Override
    public Socket createSocket(InetAddress address, int port, InetAddress localAddress, int localPort) throws IOException {
        final SSLSocket socket = (SSLSocket) factory.createSocket(address, port, localAddress, localPort);
        setProtocols(socket);
        return socket;
    }

    private void setProtocols(SSLSocket socket) {
        if (protocols != null && protocols.length > 0) {
            socket.setEnabledProtocols(protocols);
        }
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

    public static Set<TrustAnchor> getDefaultRootCAs()
            throws NoSuchAlgorithmException, KeyStoreException {
        X509TrustManager x509tm = getDefaultX509TrustManager();
        Set<TrustAnchor> rootCAs = new HashSet<>();
        for (X509Certificate c : x509tm.getAcceptedIssuers()) {
            rootCAs.add(new TrustAnchor(c, null));
        }
        return rootCAs;
    }

    public static X509TrustManager getDefaultX509TrustManager()
            throws NoSuchAlgorithmException, KeyStoreException {
        TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
        tmf.init((KeyStore) null);
        for (TrustManager tm : tmf.getTrustManagers()) {
            if (tm instanceof X509TrustManager) {
                return (X509TrustManager) tm;
            }
        }
        throw new IllegalStateException("X509TrustManager is not found");
    }
}