package cn.geminis.crypto.csp.soft.gm;

import cn.geminis.core.util.ByteUtils;
import cn.geminis.core.util.FileUtils;
import cn.geminis.crypto.core.key.PrivateKey;
import cn.geminis.crypto.core.key.PublicKey;
import cn.geminis.crypto.csp.*;
import cn.geminis.crypto.csp.soft.SoftRandomGenerator;
import cn.geminis.crypto.csp.soft.rsa.Sha1HMac;
import org.bouncycastle.asn1.gm.GMObjectIdentifiers;

/**
 * @author Allen
 */
public class Sm2CspFactory extends AbstractCspFactory {

    private static final String PRIVATE_KEY_FILE_NAME = "sm2PrivateKey.dat";
    private static final String PUBLIC_KEY_FILE_NAME = "sm2PublicKey.dat";

    private byte[] mainKey;
    private PrivateKey privateKey;
    private PublicKey publicKey;

    public Sm2CspFactory(String keyPath, String pin) {
        // 私钥加密密钥为口令的摘要值
        var digest = this.createDigest();
        var pinDigest = digest.digest(pin.getBytes());
        this.mainKey = ByteUtils.duplicate(pinDigest, 16);

        var cipher = this.createBlockCipher();

        String pkFileName = keyPath + PUBLIC_KEY_FILE_NAME;
        String skFileName = keyPath + PRIVATE_KEY_FILE_NAME;

        byte[] pkData = FileUtils.readFile(pkFileName);
        byte[] skData = FileUtils.readFile(skFileName);

        if (pkData == null || skData == null) {
            // 没有密钥，创建
            var generator = new Sm2KeyGenerator();
            var keypair = generator.generateKeyPair();
            pkData = keypair.getPublicKey().getEncoded();
            skData = keypair.getPrivateKey().getEncoded();

            skData = cipher.encrypt(skData);

            FileUtils.writeFile(pkFileName, pkData);
            FileUtils.writeFile(skFileName, skData);
        }

        try {
            skData = cipher.decrypt(skData);
        } catch (Exception e) {
            throw new RuntimeException("解密PIN码保护的私钥错误", e);
        }

        this.privateKey = new PrivateKey(skData);
        this.publicKey = new PublicKey(pkData);

        register();
    }

    @Override
    public RandomGenerator createRandomGenerator() {
        return new SoftRandomGenerator();
    }

    @Override
    public String getDigestAlgOid() {
        return GMObjectIdentifiers.sm3.getId();
    }

    @Override
    public String getDigestAlgName() {
        return "SM3";
    }

    @Override
    public AbstractDigest createDigest() {
        return new Sm3Digest();
    }

    @Override
    public String getSignerAlgOid() {
        return GMObjectIdentifiers.sm2sign_with_sm3.getId();
    }

    @Override
    public String getAsyncEncryptionAlgOid() {
        return GMObjectIdentifiers.sm2sign_with_sm3.getId();
    }

    @Override
    public String getSignerAlgName() {
        return "SM2";
    }

    @Override
    public AbstractSigner createSigner() {
        return new Sm2Signer(publicKey, privateKey);
    }

    @Override
    public AbstractBlockCipher createBlockCipher() {
        return new Sm4BlockCipher(this.mainKey);
    }

    @Override
    public AbstractAsymmetricBlockCipher createAsymmetricBlockCipher() {
        return new Sm2BlockCipher(this.publicKey, this.privateKey);
    }

    @Override
    public KeyPairGenerator createKeyPairGenerator() {
        return new Sm2KeyGenerator();
    }

    @Override
    public AbstractAgreement createAgreement() {
        return new Sm2Agreement(this.publicKey, this.privateKey);
    }

    @Override
    public AbstractMac createMac() {
        return new Sha1HMac();
    }

    @Override
    public String getBlockCipherAlgOid() {
        return GMObjectIdentifiers.sms4_ecb.getId();
    }

    @Override
    public String getKeyPairOid() {
        return "1.2.840.10045.2.1";
    }

    @Override
    public void close() {
    }

}
