package cn.ac.caict.entity;

import cn.ac.caict.codec.crypto.asymmetric.AsymmetricCodec;
import cn.ac.caict.codec.crypto.asymmetric.AsymmetricPKCS8KeyWrapper;
import cn.ac.caict.codec.crypto.asymmetric.rsa.RSACodec;
import cn.ac.caict.codec.crypto.asymmetric.sm2.SM2Codec;
import cn.ac.caict.codec.crypto.symmetric.AESCodec;
import cn.ac.caict.codec.crypto.symmetric.SM4Codec;
import cn.ac.caict.codec.crypto.symmetric.SymmetricCodec;
import cn.ac.caict.codec.text.Base64Codec;
import cn.ac.caict.codec.text.HexCodec;
import cn.ac.caict.codec.text.TextCodec;
import cn.ac.caict.exception.CaictCryptoException;

import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.InvalidKeySpecException;
import java.util.function.Function;

import static cn.ac.caict.constants.SecureConstants.DataCryptoAlg.AES;
import static cn.ac.caict.constants.SecureConstants.DataCryptoAlg.SM4;
import static cn.ac.caict.constants.SecureConstants.KeyCryptoAlg.RSA;
import static cn.ac.caict.constants.SecureConstants.KeyCryptoAlg.SM2;

/**
 * 请求响应解码处理
 */
public class CaictEntityCipherCodec {

    /**
     * 秘钥 编码 - 解码
     */
    private static final RSACodec RSA_CODEC = new RSACodec();
    private static final SM2Codec SM2_CODEC = new SM2Codec();


    /**
     * 数据编码 - 解码
     */
    private static final AESCodec AES_CODEC = new AESCodec();
    private static final SM4Codec SM4_CODEC = new SM4Codec();


    /**
     * byte[] -> String
     */
    private static final HexCodec HEX_CODEC = new HexCodec();
    private static final Base64Codec BASE64_CODEC = new Base64Codec();

    private TextCodec<byte[], String> textCodec = BASE64_CODEC;

    private final AsymmetricCodec asymmetricCodec;
    private final SymmetricCodec symmetricCodec;

    /**
     * 签名
     * 数据解密
     * client - pri
     */
    private final PrivateKey privateKey;

    /**
     * server - pub
     * a> 请求加密 (key)
     * b> 响应验签名
     */
    private final PublicKey publicKey;


    private final Encryptor encryptor;
    private final Decryptor decryptor;
    private final Signer signer;


    private final String keyAlg;
    private final String dataAlg;


    public CaictEntityCipherCodec(AsymmetricCodec asymmetricCodec,
                                  PrivateKey privateKey,
                                  PublicKey publicKey,
                                  SymmetricCodec symmetricCodec,
                                  TextCodec<byte[], String> textCodec,
                                  String keyAlg,
                                  String dataAlg
    ) {
        this.asymmetricCodec = asymmetricCodec;
        this.privateKey = privateKey;
        this.publicKey = publicKey;

        this.symmetricCodec = symmetricCodec;
        if (textCodec != null) {
            this.textCodec = textCodec;
        }

        encryptor = new Encryptor(this);
        decryptor = new Decryptor(this);
        signer = new Signer(this);
        this.keyAlg = keyAlg;
        this.dataAlg = dataAlg;
    }


    public static CaictEntityCipherCodecBuilder keyAlg(String keyAlg) {
        return new CaictEntityCipherCodecBuilder(keyAlg);
    }


    public static CaictEntityCipherCodecBuilder rsa() {
        return keyAlg(RSA);
    }

    public static CaictEntityCipherCodecBuilder sm2() {
        return keyAlg(SM2);
    }


    /**
     * 随机生成一个key
     */
    public byte[] randomKey() {
        return symmetricCodec.randomKey();
    }

    /**
     * 把随机生成的key 进行加密处理
     */
    public String encryptRandomKey(byte[] key) {
        return HEX_CODEC.encode(encryptor.key(key));
    }

    public byte[] decryptRandomKey(String key) {
        return decryptor.key(HEX_CODEC.decode(key));
    }


    /**
     * 加密数据
     */
    public String encrypt(String data, byte[] key) {
        return encrypt(data, StandardCharsets.UTF_8, key);
    }

    public String encrypt(String data, Charset charset, byte[] key) {
        return cipher(this.textCodec::encode,
                this.encryptor, data.getBytes(charset), key);

    }


    /**
     * 解密数据
     * a> 对称秘钥对数据解密
     * b> new String(byte[])
     */
    public String decrypt(String data, byte[] key) {
        return decrypt(data, StandardCharsets.UTF_8, key);
    }

    public String decrypt(String data, Charset charset, byte[] key) {
        return cipher(
                (bytes) -> new String(bytes, charset),
                this.decryptor, textCodec.decode(data), key
        );
    }

    /**
     * 处理加解密数据
     */
    private String cipher(Function<byte[], String> stringCodec, Codec dataCodec, byte[] data, byte[] key) {
        return stringCodec.apply(
                dataCodec.data(data, key)
        );
    }


    /**
     * 签名 - 验证签名
     */
    public String sign(String signData) {
        return sign(signData, StandardCharsets.UTF_8);
    }

    public String sign(String signData, Charset charset) {
        return sign(signData.getBytes(charset));
    }

    public String sign(byte[] signData) {
        return this.signer.sign(signData);
    }

    public boolean verifySign(String signDataHex, String data) {
        return verifySign(signDataHex, data, StandardCharsets.UTF_8);
    }

    public boolean verifySign(String signDataHex, String data, Charset charset) {
        return verifySign(signDataHex, data.getBytes(charset));
    }

    public boolean verifySign(String signDataHex, byte[] data) {
        return this.signer.verifySign(signDataHex, data);
    }


    public String getKeyAlg() {
        return keyAlg;
    }

    public String getDataAlg() {
        return dataAlg;
    }

    public static class CaictEntityCipherCodecBuilder {

        private AsymmetricCodec asymmetricCodec;
        private SymmetricCodec symmetricCodec;

        /**
         * 签名
         * 数据解密
         * client - pri
         */
        private PrivateKey privateKey;

        /**
         * 加密
         * 数据签名校验
         * server - pub
         */
        private PublicKey publicKey;


        private boolean rsa;
        private boolean sm2;
        private boolean aes;
        private boolean sm4;

        private String publicKeyPKCS8;
        private String privateKeyPKCS8;


        private TextCodec<byte[], String> textCodec;

        public CaictEntityCipherCodecBuilder(String alg) {
            reset();
            switch (alg) {
                case RSA:
                    this.rsa = true;
                    break;
                case SM2:
                    this.sm2 = true;
                    break;
                default:
                    throw new IllegalStateException("Unexpected value: " + alg);
            }
        }

        private void reset() {
            this.rsa = false;
            this.sm2 = false;
        }


        public CaictEntityCipherCodecBuilder rsa() {
            reset();
            this.rsa = true;
            return this;
        }

        public CaictEntityCipherCodecBuilder sm2() {
            reset();
            this.sm2 = true;
            return this;
        }


        public CaictEntityCipherCodecBuilder dataAlg(String alg) {

            switch (alg) {
                case AES:
                    this.aes = true;
                    break;
                case SM4:
                    this.sm4 = true;
                    break;
                default:
                    throw new IllegalStateException("Unexpected value: " + alg);
            }
            return this;
        }

        public CaictEntityCipherCodecBuilder aes() {
            this.aes = true;
            this.sm4 = false;
            return this;
        }

        public CaictEntityCipherCodecBuilder sm4() {
            this.sm4 = true;
            this.aes = false;
            return this;
        }

        public CaictEntityCipherCodecBuilder textCodec(TextCodec<byte[], String> textCodec) {
            this.textCodec = textCodec;
            return this;
        }


        public CaictEntityCipherCodecBuilder publicKey(String otherPublicKeyPKCS8) {
            this.publicKeyPKCS8 = otherPublicKeyPKCS8;
            return this;
        }


        public CaictEntityCipherCodecBuilder privateKey(String selfPrivateKeyPKCS8) {
            this.privateKeyPKCS8 = selfPrivateKeyPKCS8;
            return this;
        }


        public CaictEntityCipherCodec build() {

            chooseAsymmetricCodec();
            chooseSymmetricCodec();


            return new CaictEntityCipherCodec(
                    asymmetricCodec, privateKey, publicKey,
                    symmetricCodec,
                    textCodec,
                    rsa ? RSA : SM2,
                    aes ? AES : SM4
            );
        }

        private void chooseSymmetricCodec() {
            if (aes) this.symmetricCodec = AES_CODEC;
            if (sm4) this.symmetricCodec = SM4_CODEC;
        }

        private void chooseAsymmetricCodec() {
            if (rsa) this.asymmetricCodec = RSA_CODEC;
            if (sm2) this.asymmetricCodec = SM2_CODEC;
            try {
                //初始化公钥 私钥
                this.publicKey = this.asymmetricCodec.keyCodec().getX509PublicKey(
                        AsymmetricPKCS8KeyWrapper.decode(publicKeyPKCS8, AsymmetricPKCS8KeyWrapper.PUBLIC)
                );

                this.privateKey = this.asymmetricCodec.keyCodec().getPrivateKeyFromPKCS8(
                        AsymmetricPKCS8KeyWrapper.decode(privateKeyPKCS8, AsymmetricPKCS8KeyWrapper.PRIVATE)
                );
            } catch (NoSuchProviderException | NoSuchAlgorithmException | InvalidKeySpecException e) {
                throw new CaictCryptoException(e.getMessage(), e);
            }
        }

    }


    public static class Signer {

        private final CaictEntityCipherCodec codec;

        public Signer(CaictEntityCipherCodec codec) {
            this.codec = codec;
        }


        public String sign(byte[] signData) {
            try {
                return HEX_CODEC.encode(this.codec.asymmetricCodec.signatureCodec().sign(signData, codec.privateKey));
            } catch (Exception e) {
                throw new CaictCryptoException(e.getMessage(), e);
            }
        }

        public boolean verifySign(String signDataHex, byte[] data) {
            try {
                return this.codec.asymmetricCodec.signatureCodec().verify(data, HEX_CODEC.decode(signDataHex), codec.publicKey);
            } catch (Exception e) {
                throw new CaictCryptoException(e.getMessage(), e);
            }
        }
    }

    /**
     * 数据 - key 处理
     */
    public interface Codec {


        /**
         * 针对key的编码解码工具
         * 加密: rsa/sm2 加
         * 解密: rsa/sm2 解密
         */
        byte[] key(byte[] key);

        /**
         * 数据加解密
         * 使用对称算法进行数据加解密
         * 加密: aes/sm4 加密
         * 解密: aes/sm4 解密
         */
        byte[] data(byte[] data, byte[] key);


    }


    /**
     * 加密
     */
    public static class Encryptor implements Codec {

        private final CaictEntityCipherCodec codec;

        public Encryptor(CaictEntityCipherCodec codec) {
            this.codec = codec;
        }

        /**
         * 对Key进行加密处理
         */
        public byte[] key(byte[] key) {
            try {
                return this.codec.asymmetricCodec.encrypt(
                        key,
                        this.codec.publicKey
                );
            } catch (Exception e) {
                throw new CaictCryptoException("encrypt random key error.", e);
            }
        }


        @Override
        public byte[] data(byte[] data, byte[] key) {
            try {
                return this.codec.symmetricCodec.encode(
                        data, key, null
                );
            } catch (BadPaddingException | IllegalBlockSizeException e) {
                throw new CaictCryptoException(e.getMessage(), e);
            }
        }
    }

    /**
     * base64 -> byte[]
     * <p>
     * [√] byte[] - decode -> byte[]
     */
    public static class Decryptor implements Codec {

        private final CaictEntityCipherCodec codec;

        public Decryptor(CaictEntityCipherCodec codec) {
            this.codec = codec;
        }

        /**
         * 对Key进行加密处理
         */
        public byte[] key(byte[] key) {
            try {
                return this.codec.asymmetricCodec.decrypt(
                        key,
                        this.codec.privateKey
                );
            } catch (Exception e) {
                throw new CaictCryptoException("encrypt random key error.", e);
            }
        }


        @Override
        public byte[] data(byte[] data, byte[] key) {
            try {
                return this.codec.symmetricCodec.decode(
                        data, key, null
                );
            } catch (BadPaddingException | IllegalBlockSizeException e) {
                throw new CaictCryptoException(e.getMessage(), e);
            }
        }
    }

}
