package com.dyadicsec.pkcs11;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import static com.dyadicsec.cryptoki.CK.*;

/**
 * Created by valery.osheter on 25-Jun-17.
 */
public final class CKPRFKey extends CKKey
{
    protected CKPRFKey()
    {
        keyType = DYCKK_ADV_PRF;
        clazz = CKO_PRIVATE_KEY;
    }

    void prepareReadTemplate(Map<Integer, CK_ATTRIBUTE> template)
    {
        super.prepareReadTemplate(template);
        addReadTemplate(template, CKA_DECRYPT);
    }

    void saveReadTemplate(Map<Integer, CK_ATTRIBUTE> template) throws CKException
    {
        super.saveReadTemplate(template);
        policy.cka_decrypt = template.get(CKA_DECRYPT).toBool();
        policy.cka_encrypt = policy.cka_private = policy.cka_sensitive = true;
        policy.cka_sign = policy.cka_verify = policy.cka_wrap = policy.cka_unwrap = policy.cka_extractable = false;
    }

    public static CKPRFKey find(Slot slot, String name)
    {
        return (CKPRFKey) CKObject.find(slot, CKO_PRIVATE_KEY, DYCKK_ADV_PRF, name);
    }

    public static CKPRFKey find(Slot slot, long uid)
    {
        return CKObject.find(slot, CKPRFKey.class, uid);
    }

    public static ArrayList<CKPRFKey> list(Slot slot)
    {
        return CKObject.list(slot, CKPRFKey.class, CKO_PRIVATE_KEY, DYCKK_ADV_PRF);
    }

    public static CKPRFKey generate(Slot slot, String name, Policy policy) throws CKException
    {
        CKPRFKey key = new CKPRFKey();
        if (policy==null) policy = new Policy();

        CK_ATTRIBUTE[] t =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, policy.cka_token),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_PRIVATE_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, DYCKK_ADV_PRF),
                        new CK_ATTRIBUTE(CKA_DECRYPT, policy.cka_decrypt),
                        new CK_ATTRIBUTE(CKA_DERIVE, policy.cka_derive),
                        new CK_ATTRIBUTE(CKA_ID, Utils.name2id(name)),
                };

        key.generateKey(slot, DYCKM_PRF_KEY_GEN, t);
        key.policy = policy;
        return key;
    }

    public byte[] prf(byte[] tweak, int secretSize) throws CKException
    {
        return prf(0, tweak, secretSize);
    }

    public CKSecretKey derive(int purpose, byte[] tweak, int keyType, int bitSize) throws CKException
    {
        CK_ATTRIBUTE[] t =
                {
                        new CK_ATTRIBUTE(CKA_TOKEN, false),
                        new CK_ATTRIBUTE(CKA_CLASS, CKO_SECRET_KEY),
                        new CK_ATTRIBUTE(CKA_KEY_TYPE, keyType),
                        new CK_ATTRIBUTE(CKA_SENSITIVE, false),
                        new CK_ATTRIBUTE(CKA_VALUE_LEN, bitSize/8),
                };

        CK_MECHANISM m = new DYCK_PRF_PARAMS(purpose, tweak, bitSize/8);
        return derive(CKSecretKey.class, m, t);
    }

    public byte[] prf(int purpose, byte[] tweak, int outLen) throws CKException
    {
        CKSecretKey temp = derive(purpose, tweak, CKK_GENERIC_SECRET, outLen*8);
        try { return temp.getValue(); }
        finally { temp.destroy(); }
    }


    static final int PRF_TWEAK_LEN = 16;
    static final int PRF_GCM_TAG_LEN = 12;
    static final int AES_BLOCK_LEN = 16;

    public byte[] encrypt(byte[] aad, byte[] in) throws CKException
    {
        byte[] tweak = slot.generateRandom(PRF_TWEAK_LEN);

        CKSecretKey KM = derive(0, tweak, CKK_AES, 256);
        CKSecretKey K = null;
        Session session = null;
        try
        {
            byte[] blocks = new byte[AES_BLOCK_LEN*3]; // zero
            blocks[AES_BLOCK_LEN*1-1] = 0;
            blocks[AES_BLOCK_LEN*2-1] = 1;
            blocks[AES_BLOCK_LEN*2-1] = 2;
            byte[] e = KM.encrypt(new CK_MECHANISM(CKM_AES_ECB), blocks, blocks.length);
            byte[] KValue = Arrays.copyOfRange(e, 0, AES_BLOCK_LEN*2);
            byte[] IV = Arrays.copyOfRange(e, AES_BLOCK_LEN*2, AES_BLOCK_LEN*3);

            K = CKSecretKey.create(slot, null, new Policy().setToken(false), CKK_AES, KValue);

            byte[] out = new byte[PRF_TWEAK_LEN+in.length+PRF_GCM_TAG_LEN];
            System.arraycopy(tweak, 0, out, 0, PRF_TWEAK_LEN);

            CK_MECHANISM m = new CK_GCM_PARAMS(IV, aad, PRF_GCM_TAG_LEN*8);
            session = K.encryptInit(m);
            session.encrypt(in, 0, in.length, out, PRF_TWEAK_LEN);
            return out;
        }
        finally
        {
            KM.destroy();
            if (K!=null) K.destroy();
            slot.releaseSession(session);
        }
    }

    public byte[] decrypt(byte[] aad, byte[] in) throws CKException
    {
        int outLen = in.length - (PRF_TWEAK_LEN+PRF_GCM_TAG_LEN);
        if (outLen<0) throw new CKException("Decrypt using PRF", CKR_ENCRYPTED_DATA_LEN_RANGE);
        byte[] tweak = Arrays.copyOfRange(in, 0, PRF_TWEAK_LEN);

        CKSecretKey KM = derive(0, tweak, CKK_AES, 256);
        CKSecretKey K = null;
        Session session = null;
        try
        {
            byte[] blocks = new byte[AES_BLOCK_LEN*3]; // zero
            blocks[AES_BLOCK_LEN*1-1] = 0;
            blocks[AES_BLOCK_LEN*2-1] = 1;
            blocks[AES_BLOCK_LEN*2-1] = 2;
            byte[] e = KM.encrypt(new CK_MECHANISM(CKM_AES_ECB), blocks, blocks.length);
            byte[] KValue = Arrays.copyOfRange(e, 0, AES_BLOCK_LEN*2);
            byte[] IV = Arrays.copyOfRange(e, AES_BLOCK_LEN*2, AES_BLOCK_LEN*3);

            K = CKSecretKey.create(slot, null, new Policy().setToken(false), CKK_AES, KValue);

            byte[] out = new byte[outLen];
            CK_MECHANISM m = new CK_GCM_PARAMS(IV, aad, PRF_GCM_TAG_LEN*8);
            session = K.decryptInit(m);
            session.decrypt(in, PRF_TWEAK_LEN, in.length - PRF_TWEAK_LEN, out, 0);
            return out;
        }
        finally
        {
            KM.destroy();
            if (K!=null) K.destroy();
            slot.releaseSession(session);
        }
    }


    public CKPRFKey rekey() throws CKException
    {
        return (CKPRFKey)super.rekey();
    }

}
