package com.dyadicsec.pkcs11;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Map;
import static com.dyadicsec.cryptoki.CK.*;

import static com.dyadicsec.pkcs11.Utils.bigInt2Bytes;

/**
 * Created by valery.osheter on 22-Jun-17.
 */
public final class CKRSAPrivateKey extends CKPrivateKey
{
    int bitSize = 0;
    BigInteger N = null;
    BigInteger E = null;
    CKRSAPublicKey pubKey = null;

    CKRSAPrivateKey()
    {
        keyType = CKK_RSA;
    }

    void prepareReadTemplate(Map<Integer, CK_ATTRIBUTE> template)
    {
        super.prepareReadTemplate(template);
        addReadTemplate(template, CKA_PUBLIC_EXPONENT);
        addReadTemplate(template, CKA_MODULUS);
    }

    void saveReadTemplate(Map<Integer, CK_ATTRIBUTE> template) throws CKException
    {
        super.saveReadTemplate(template);
        byte[] modulus = template.get(CKA_MODULUS).getValue();
        bitSize = modulus.length*8;
        E = new BigInteger(1,template.get(CKA_PUBLIC_EXPONENT).getValue());
        N = new BigInteger(1,modulus);
    }

    public BigInteger getE() throws CKException
    {
        if (E==null) read();
        return E;
    }

    public BigInteger getN() throws CKException
    {
        if (N==null) read();
        return N;
    }

    public int getBitSize() throws CKException
    {
        if (bitSize==0) read();
        return bitSize;
    }

    public CKRSAPublicKey getPublicKey() throws CKException
    {
        if (pubKey==null) pubKey = CKRSAPublicKey.create(slot, null, null, getN(), getE());
        return pubKey;
    }

    public static CK_ATTRIBUTE[] getUnwrapTemplate(String name, Policy policy)
    {
        if (policy==null) policy = new Policy();
        return new CK_ATTRIBUTE[]
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_RSA),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SENSITIVE, policy.cka_sensitive),
                        new CK_ATTRIBUTE(CKA_DECRYPT, policy.cka_decrypt),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_UNWRAP, policy.cka_unwrap),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };
    }

    public static CKRSAPrivateKey generate(Slot slot, String name, Policy policy, int bitsSize) throws CKException
    {
        if (policy==null) policy = new Policy();
        CKRSAPrivateKey key = new CKRSAPrivateKey();

        CK_ATTRIBUTE[] tPrv =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_RSA),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SENSITIVE, policy.cka_sensitive),
                        new CK_ATTRIBUTE(CKA_DECRYPT, policy.cka_decrypt),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_UNWRAP, policy.cka_unwrap),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };
        CK_ATTRIBUTE[] tPub =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, false),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PUBLIC_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_RSA),
                        new CK_ATTRIBUTE(CKA_MODULUS_BITS, bitsSize),
                };

        key.generateKeyPair(slot, CKM_RSA_PKCS_KEY_PAIR_GEN, tPub, tPrv);
        key.bitSize = bitsSize;
        key.policy = policy;
        key.name = name;
        return key;
    }

    public static CKRSAPrivateKey create(Slot slot, String name, Policy policy,
                                         BigInteger N, BigInteger E, BigInteger D, BigInteger P, BigInteger Q, BigInteger DP, BigInteger DQ, BigInteger QINV) throws CKException
    {
        if (policy==null) policy = new Policy();
        CKRSAPrivateKey key = new CKRSAPrivateKey();

        byte[] NBuf = bigInt2Bytes(N, 0);
        int keySize = NBuf.length;

        CK_ATTRIBUTE[] t =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_RSA),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SENSITIVE, policy.cka_sensitive),
                        new CK_ATTRIBUTE(CKA_DECRYPT, policy.cka_decrypt),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_UNWRAP, policy.cka_unwrap),
                        new CK_ATTRIBUTE(CKA_MODULUS, NBuf),
                        new CK_ATTRIBUTE(CKA_PUBLIC_EXPONENT, E.toByteArray()),
                        new CK_ATTRIBUTE(CKA_PRIVATE_EXPONENT, bigInt2Bytes(D, keySize)),
                        new CK_ATTRIBUTE(CKA_PRIME_1, bigInt2Bytes(P, keySize/2)),
                        new CK_ATTRIBUTE(CKA_PRIME_2, bigInt2Bytes(Q, keySize/2)),
                        new CK_ATTRIBUTE(CKA_EXPONENT_1, bigInt2Bytes(DP, keySize/2)),
                        new CK_ATTRIBUTE(CKA_EXPONENT_2, bigInt2Bytes(DQ, keySize/2)),
                        new CK_ATTRIBUTE(CKA_COEFFICIENT, bigInt2Bytes(QINV, keySize/2)),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };

        key.create(slot, t);
        key.bitSize = keySize*8;
        key.policy = policy;
        key.name = name;
        key.E = E;
        key.N = N;
        return key;
    }

    public static CKRSAPrivateKey find(Slot slot, String name)
    {
        return (CKRSAPrivateKey) CKObject.find(slot, CKO_PRIVATE_KEY, CKK_RSA, name);
    }

    public static CKRSAPrivateKey find(Slot slot, long uid)
    {
        return CKObject.find(slot, CKRSAPrivateKey.class, uid);
    }

    public static ArrayList<CKRSAPrivateKey> list(Slot slot)
    {
        return CKObject.list(slot, CKRSAPrivateKey.class, CKO_PRIVATE_KEY, CKK_RSA);
    }

    public byte[] sign(int mechanism, byte[] in) throws CKException
    {
        return sign(new CK_MECHANISM(mechanism), in);
    }

    public byte[] sign(CK_MECHANISM mechanism, byte[] in) throws CKException
    {
        return sign(mechanism, in, getBitSize()/8);
    }

    public byte[] decrypt(int mechanism, byte[] in) throws CKException
    {
        return decrypt(new CK_MECHANISM(mechanism), in, getBitSize()/8);
    }

    public byte[] decrypt(CK_MECHANISM mechanism, byte[] in) throws CKException
    {
        return decrypt(mechanism, in, getBitSize()/8);
    }
}
