package framework.crypto;

import framework.config.AESCryptoConfig;
import framework.exceptions.ConfigurationException;
import lombok.Getter;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;

import javax.crypto.Cipher;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

/**
 * AES实现，允许多配置项
 */
public class AESCryptoPlus implements AESCrypto {

    /**
     * 加密算法
     */
    @Getter
    private final String algorithm;

    /**
     * 加密解密算法/加密模式/填充方式
     */
    @Getter
    private final String cipherAlgorithm;

    /**
     * 密钥
     */
    @Getter
    private final String secretKey;

    /**
     * 加密密钥索引项
     */
    private Integer secretIndex = 0;

    @Getter
    private final List<AESCrypto> cryptoList = new CopyOnWriteArrayList<>();

    @SneakyThrows
    public AESCryptoPlus(AESCryptoConfig config) {
        if (StringUtils.isBlank(config.getAlgorithm()))
            throw new ConfigurationException("Not config algorithm");
        if (StringUtils.isBlank(config.getCipherAlgorithm()))
            throw new ConfigurationException("Not config cipher algorithm");
        if (StringUtils.isBlank(config.getSecretKey()))
            throw new ConfigurationException("Not config secret key");
        if (config.getSecretIndex() == null)
            throw new ConfigurationException("Not config secret index, please use AESCryptoImpl instead of AESCryptoPlus");
        if (config.getSecretIndex() < 0)
            throw new ConfigurationException("Secret index min allow 0");
        if (config.getSecretIndex() > 99)
            throw new ConfigurationException("Secret index max allow 99");

        //
        String[] keys = config.getSecretKey().split(",");
        if (config.getSecretIndex() > keys.length - 1) {
            throw new ConfigurationException("Secret index exceeds Secret key length");
        }

        this.secretIndex = config.getSecretIndex();
        //
        this.algorithm = config.getAlgorithm();
        this.cipherAlgorithm = config.getCipherAlgorithm();
        this.secretKey = config.getSecretKey();
        //
        for (String k : keys) {
            AESCryptoConfig cryptoConfig = new AESCryptoConfig();
            cryptoConfig.setAlgorithm(this.algorithm);
            cryptoConfig.setCipherAlgorithm(this.cipherAlgorithm);
            cryptoConfig.setSecretKey(k);
            this.cryptoList.add(this.createAESCrypto(cryptoConfig));
        }
    }

    /**
     * 构建加密组件
     *
     * @param cryptoConfig
     * @return
     */
    protected AESCrypto createAESCrypto(AESCryptoConfig cryptoConfig) {
        return new AESCryptoImpl(cryptoConfig);
    }

    @SneakyThrows
    public Cipher getEncodeCipher() {
        List<AESCrypto> cryptoList = getCryptoList();
        AESCrypto crypto = cryptoList.get(this.secretIndex);
        return crypto.getEncodeCipher();
    }

    @SneakyThrows
    public Cipher getDecodeCipher(byte[] encryptedData) {
        if (encryptedData.length < 2)
            throw new IllegalArgumentException("encryptedData invalid");
        List<AESCrypto> cryptoList = getCryptoList();
        int index = encryptedData[encryptedData.length - 1];
        if (index < 0 || index > 99)
            throw new IllegalArgumentException("encryptedData invalid");
        AESCrypto crypto = cryptoList.get(index);
        return crypto.getDecodeCipher(encryptedData);
    }


    /**
     * 加密
     */
    @Override
    @SneakyThrows
    public byte[] encode(byte[] content) {
        Cipher cipher = this.getEncodeCipher();
        byte[] d = cipher.doFinal(content);
        byte[] iv = cipher.getIV();
        byte[] buf = new byte[d.length + iv.length + 1];
        System.arraycopy(iv, 0, buf, 0, iv.length);
        System.arraycopy(d, 0, buf, iv.length, d.length);
        buf[buf.length - 1] = this.secretIndex.byteValue();
        return buf;
    }

    /**
     * 解密
     */
    @Override
    @SneakyThrows
    public byte[] decode(byte[] content) {
        Cipher cipher = this.getDecodeCipher(content);
        //
        byte[] iv = cipher.getIV();
        byte[] data = new byte[content.length - iv.length - 1];
        System.arraycopy(content, iv.length, data, 0, data.length);
        //
        return cipher.doFinal(data);
    }
}
