package cn.geminis.crypto.csp.soft.gm;

import cn.geminis.crypto.core.key.PrivateKey;
import cn.geminis.crypto.core.key.PublicKey;
import cn.geminis.crypto.csp.AbstractAgreement;
import cn.geminis.crypto.csp.parameter.CalcAgreementCipherParameters;
import cn.geminis.crypto.csp.parameter.InitAgreementCipherParameters;
import org.bouncycastle.asn1.ASN1Sequence;
import org.bouncycastle.asn1.DERSequence;
import org.bouncycastle.asn1.DERSequenceGenerator;
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.crypto.agreement.SM2KeyExchange;
import org.bouncycastle.crypto.params.*;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;

/**
 * @author Allen
 */
public class Sm2Agreement extends AbstractAgreement {

    private SM2KeyExchange agreement = new SM2KeyExchange();

    private AsymmetricKeyParameter privateKey;
    private AsymmetricKeyParameter tempPrivateKey;
    private AsymmetricKeyParameter tempPublicKey;

    public Sm2Agreement(PublicKey publicKey, PrivateKey privateKey) {
        super(publicKey);
        this.privateKey = privateKey.getKeyParameter();

        try (var keyPairGenerator = new Sm2KeyGenerator()) {
            var keyPair = keyPairGenerator.generateKeyPair();
            this.tempPrivateKey = keyPair.getPrivateKey().getKeyParameter();
            this.tempPublicKey = keyPair.getPublicKey().getKeyParameter();
        }
    }

    @Override
    public int getFieldSize() {
        return 64;
    }

    @Override
    public void init(InitAgreementCipherParameters param) {
        agreement.init(new ParametersWithID(
                new SM2KeyExchangePrivateParameters(
                        param.isInitiator(),
                        (ECPrivateKeyParameters) this.privateKey,
                        (ECPrivateKeyParameters) this.tempPrivateKey),
                param.getId()
        ));
    }

    @Override
    public BigInteger calculateAgreement(CalcAgreementCipherParameters param) {
        return new BigInteger(
                agreement.calculateKey(
                        getFieldSize(),
                        new ParametersWithID(
                                new SM2KeyExchangePublicParameters(
                                        (ECPublicKeyParameters) param.getPublicKey().getKeyParameter(),
                                        (ECPublicKeyParameters) param.getTempPublicKey().getKeyParameter()),
                                param.getId())
                )
        );
    }

    @Override
    public byte[] getSession() {
        try (var stream = new ByteArrayOutputStream()) {
            var generator = new DERSequenceGenerator(stream);
            generator.addObject(new PublicKey(this.tempPublicKey).getSubjectPublicKeyInfo());
            generator.addObject(new PrivateKey(this.tempPrivateKey).getPrivateKeyInfo());
            generator.close();
            return stream.toByteArray();
        } catch (IOException e) {
            throw new RuntimeException("获取密钥协商器Session错误", e);
        }
    }

    @Override
    public void setSession(byte[] session) {
        try {
            var sequence = (ASN1Sequence) DERSequence.fromByteArray(session);
            var publicKeyInfo = SubjectPublicKeyInfo.getInstance(sequence.getObjectAt(0));
            var privateKeyInfo = PrivateKeyInfo.getInstance(sequence.getObjectAt(1));
            this.tempPublicKey = new PublicKey(publicKeyInfo).getKeyParameter();
            this.tempPrivateKey = new PrivateKey(privateKeyInfo).getKeyParameter();
        } catch (IOException e) {
            throw new RuntimeException("设置密钥协商器Session错误", e);
        }
    }

    @Override
    public PublicKey getTempPublicKey() {
        return new PublicKey(this.tempPublicKey);
    }

    @Override
    public void close() {

    }
}
