package com.dyadicsec.pkcs11;

import java.io.IOException;
import java.math.BigInteger;
import java.security.spec.ECPoint;
import java.util.ArrayList;
import java.util.Map;
import static com.dyadicsec.cryptoki.CK.*;

/**
 * Created by valery.osheter on 22-Jun-17.
 */
public class CKECPrivateKey extends CKPrivateKey
{
    ECCurve curve = null;
    ECPoint point = null;
    CKECPublicKey pubKey = null;

    CKECPrivateKey()
    {
        keyType = CKK_EC;
    }

    void prepareReadTemplate(Map<Integer, CK_ATTRIBUTE> template)
    {
        super.prepareReadTemplate(template);
        addReadTemplate(template, CKA_EC_PARAMS);
        addReadTemplate(template, CKA_EC_POINT);
    }

    void saveReadTemplate(Map<Integer, CK_ATTRIBUTE> template) throws CKException
    {
        super.saveReadTemplate(template);

        curve = ECCurve.find(template.get(CKA_EC_PARAMS).getValue());
        if (curve==null) throw new CKException("Unsupported EC curve", 0);

        try { point = curve.derDecodePoint(template.get(CKA_EC_POINT).getValue()); }
        catch (IOException e) { throw new CKException(e, "Can't decode ECPoint", 0); }
    }

    public static CK_ATTRIBUTE[] getUnwrapTemplate(String name, Policy policy)
    {
        if (policy==null) policy = new Policy();
        return new CK_ATTRIBUTE[]
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_EC),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_DERIVE, policy.cka_derive && !policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };
    }


    public static CKECPrivateKey generate(Slot slot, String name, Policy policy, ECCurve curve) throws CKException
    {
        CKECPrivateKey key = new CKECPrivateKey();
        if (policy==null) policy = new Policy();

        CK_ATTRIBUTE[] tPrv =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_EC),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_DERIVE, policy.cka_derive && !policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };
        CK_ATTRIBUTE[] tPub =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, false),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PUBLIC_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_EC),
                        new CK_ATTRIBUTE(CKA_EC_PARAMS, curve.getOidBin()),
                };

        key.generateKeyPair(slot, CKM_EC_KEY_PAIR_GEN, tPub, tPrv);
        key.curve = curve;
        key.policy = policy;
        key.name = name;
        return key;
    }

    public static CKECPrivateKey create(Slot slot, String name, Policy policy, ECCurve curve, BigInteger x) throws CKException
    {
        CKECPrivateKey key = new CKECPrivateKey();
        if (policy==null) policy = new Policy();

        CK_ATTRIBUTE[] t =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_EC),
                        new CK_ATTRIBUTE(CKA_EXTRACTABLE, policy.cka_extractable),
                        new CK_ATTRIBUTE(CKA_SIGN, policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_DERIVE, policy.cka_derive && !policy.cka_sign),
                        new CK_ATTRIBUTE(CKA_EC_PARAMS, curve.getOidBin()),
                        new CK_ATTRIBUTE(CKA_VALUE, Utils.bigInt2Bytes(x, curve.getSize())),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };

        key.create(slot, t);
        key.curve = curve;
        key.policy = policy;
        key.name = name;
        return key;
    }

    public ECCurve getCurve() throws CKException
    {
        if (curve==null) read();
        return curve;
    }

    public ECPoint getPoint() throws CKException
    {
        if (point==null) read();
        return point;
    }

    public byte[] ecdh(ECPoint point) throws CKException
    {
        byte[] encodedPoint = getCurve().derEncodePoint(point);

        CK_ATTRIBUTE[] t =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, false),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_SECRET_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, CKK_GENERIC_SECRET),
                        new CK_ATTRIBUTE(CKA_SENSITIVE, false),
                        new CK_ATTRIBUTE(CKA_VALUE_LEN, curve.getSize()),
                };

        CK_MECHANISM m = new CK_ECDH1_DERIVE_PARAMS(CKD_NULL, encodedPoint, null);

        int derivedHandle = slot.deriveKey(m, handle, t);
        CK_ATTRIBUTE[] v = { new CK_ATTRIBUTE(CKA_VALUE) };
        try
        {
            slot.getAttributeValue(derivedHandle, v);
        }
        finally
        {
            slot.destroyObject(derivedHandle);
        }

        return v[0].getValue();
    }

    public byte[] sign(byte[] data) throws CKException
    {
        return sign(CKM_ECDSA, data);
    }

    public byte[] sign(int mechanism, byte[] data) throws CKException
    {
        return sign(new CK_MECHANISM(mechanism), data, getCurve().getSize()*2);
    }

    public CKECPublicKey getPublicKey() throws CKException
    {
        if (pubKey==null) pubKey = CKECPublicKey.create(slot, null, null, getCurve(), getPoint());
        return pubKey;
    }

    public static CKECPrivateKey find(Slot slot, String name)
    {
        return (CKECPrivateKey) CKObject.find(slot, CKO_PRIVATE_KEY, CKK_EC, name);
    }

    public static CKECPrivateKey find(Slot slot, long uid)
    {
        return CKObject.find(slot, CKECPrivateKey.class, uid);
    }

    public static ArrayList<CKECPrivateKey> list(Slot slot)
    {
        return CKObject.list(slot, CKECPrivateKey.class, CKO_PRIVATE_KEY, CKK_EC);
    }


}
