package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import javax.crypto.spec.SecretKeySpec;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.ProviderException;


/**
 * Created by saar.peer on 29-Jun-16.
 */
public class SecretKey extends DYKey implements javax.crypto.SecretKey
{
    private static final long serialVersionUID = 1L;

    CKSecretKey pkcs11Key = null;
    int bitSize = 0;
    private int keyType = -1;
    private javax.crypto.SecretKey sw = null;
    private KeyParameters keyParams = null;
    private UnwrapInfo unwrapInfo = null;

    SecretKey()
    {
    }

    SecretKey(CKSecretKey pkcs11Key)
    {
        this.pkcs11Key = pkcs11Key;
    }

    @Override
    protected CKSecretKey getPkcs11Key() { return pkcs11Key;}

    @Override
    protected void unwrap(String alias) throws KeyStoreException
    {
        CK_ATTRIBUTE[] t = CKSecretKey.getUnwrapTemplate(alias, KeyParameters.toPolicy(keyParams), keyType);
        pkcs11Key = unwrapInfo.unwrap(CKSecretKey.class, t);
        unwrapInfo = null;
    }

    @Override
    protected void create(KeyStore store, String alias) throws KeyStoreException
    {
        try { pkcs11Key = CKSecretKey.create(store.slot, alias, KeyParameters.toPolicy(keyParams), keyType, sw.getEncoded()); }
        catch (CKException e) { throw new KeyStoreException(e); }
    }

    @Override
    protected void generate(KeyStore store, String alias) throws KeyStoreException
    {
        try { pkcs11Key = CKSecretKey.generate(store.slot, alias, KeyParameters.toPolicy(keyParams), keyType, bitSize); }
        catch (CKException e) { throw new KeyStoreException(e); }
    }

    SecretKey initForUnwrap(UnwrapInfo unwrapInfo, int keyType)
    {
        this.unwrapInfo = unwrapInfo;
        this.keyType = keyType;
        return this;
    }

    SecretKey initForImport(KeyParameters keyParams, int keyType, SecretKeySpec keySpec) throws KeyStoreException
    {
        String alg = keySpec.getAlgorithm();

        int algKeyType = -1;
        if (alg.equals("AES")) algKeyType = CKK_AES;
        else if (alg.equals("DESede")) algKeyType = CKK_DES3;
        else if (alg.equals("Hmac")) algKeyType = CKK_GENERIC_SECRET;
        if (keyType==-1) keyType = algKeyType;
        else
        {
            if (algKeyType==-1) throw new KeyStoreException("Unsupported key type");
        }

        this.sw = keySpec;
        this.keyType = keyType;
        this.keyParams = keyParams;
        this.bitSize = keySpec.getEncoded().length*8;
        return this;
    }

    SecretKey initForGenerate(KeyParameters keyParams, int keyType, int genBitSize)
    {
        this.keyParams = keyParams;
        bitSize = genBitSize;
        this.keyType = keyType;
        return this;
    }

    int getBitSize() throws KeyStoreException
    {
        if (bitSize==0)
        {
            try { bitSize = pkcs11Key.getBitSize(); }
            catch (CKException e) { throw new KeyStoreException(e); }
        }
        return bitSize;
    }

    int getKeyType() throws KeyStoreException
    {
        if (keyType==-1)
        {
            try { keyType = pkcs11Key.getKeyType(); }
            catch (CKException e) { throw new KeyStoreException(e); }
        }
        return keyType;
    }


    @Override
    public String getAlgorithm()
    {
        switch (keyType)
        {
            case CKK_AES: return "AES";
            case CKK_DES3: return "DESede";
            case CKK_GENERIC_SECRET: return "Hmac";
            case DYCKK_AES_XTS: return "AESXTS";
            case DYCKK_AES_SIV: return "AESSIV";
        }
        return "Unknown";
    }

    @Override
    public String getFormat()
    {
        if (sw!=null) return sw.getFormat();
        return "RAW";
    }

    @Override
    public byte[] getEncoded()
    {
        if (sw!=null) return sw.getEncoded();
        if (pkcs11Key!=null)
        {
            try { return pkcs11Key.getValue(); }
            catch (CKException e) { throw new ProviderException(e); }
        }
        return null;
    }

    static int algToKeyType(String alg) throws NoSuchAlgorithmException
    {
        if (alg.equalsIgnoreCase("AES")) return CKK_AES;
        if (alg.equalsIgnoreCase("DESede")) return CKK_DES3;
        if (alg.equalsIgnoreCase("Hmac")) return CKK_GENERIC_SECRET;
        if (alg.equalsIgnoreCase("AESXTS")) return DYCKK_AES_XTS;
        if (alg.equalsIgnoreCase("AESSIV")) return DYCKK_AES_SIV;
        throw new NoSuchAlgorithmException("Unsupported algorithm " + alg);
    }

}
