package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import javax.crypto.MacSpi;
import javax.crypto.spec.IvParameterSpec;
import java.security.Key;
import java.security.InvalidKeyException;
import java.security.InvalidAlgorithmParameterException;
import java.security.KeyStoreException;
import java.security.ProviderException;
import java.security.spec.AlgorithmParameterSpec;


/**
 * Created by valery.osheter on 14-Mar-17.
 */
public class Mac extends MacSpi {

    private com.dyadicsec.provider.SecretKey secretKey = null;
    private final int mechanism;
    private Session session = null;
    private AlgorithmParameterSpec paramSpec = null;
    private byte[] oneByte = null;

    Mac(int mechanism)
    {
        this.mechanism = mechanism;
    }

    @Override
    protected int engineGetMacLength() {
        switch (mechanism)
        {
            case CKM_SHA_1_HMAC : return 20;
            case CKM_SHA256_HMAC : return 32;
            case CKM_SHA384_HMAC : return 48;
            case CKM_SHA512_HMAC : return 64;
            case CKM_AES_CMAC : return 16;
            case CKM_AES_GMAC : return 16;
        }
        return 0;
    }


    @Override
    protected void engineInit(Key key, AlgorithmParameterSpec paramSpec) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        if (key==null) throw new InvalidKeyException("Invalid key");

        this.paramSpec = paramSpec;
        if (mechanism== CKM_AES_GMAC)
        {
            if (paramSpec!=null && !(paramSpec instanceof IvParameterSpec)) throw new InvalidAlgorithmParameterException("IvParameterSpec required");
        }
        else
        {
            if (paramSpec!=null) throw new InvalidAlgorithmParameterException("Parameters not supported");
        }

        if (!(key instanceof SecretKey)) throw new InvalidKeyException("Invalid key type");
        secretKey = (SecretKey)key;

        CK_MECHANISM m = null;
        byte[] iv = null;
        int keyType = -1;

        try
        {
            secretKey.save();
            keyType = secretKey.getKeyType();
        }
        catch (KeyStoreException e) { throw new InvalidKeyException(e); }


        switch (mechanism)
        {
            case CKM_SHA_1_HMAC :
            case CKM_SHA256_HMAC :
            case CKM_SHA384_HMAC :
            case CKM_SHA512_HMAC :
                if (keyType != CKK_GENERIC_SECRET) throw new InvalidKeyException("Invalid key type");
                m = new CK_MECHANISM(mechanism);
                break;

            case CKM_AES_CMAC :
                if (keyType!= CKK_AES) throw new InvalidKeyException("Invalid key type");
                m = new CK_MECHANISM(mechanism);
                break;

            case CKM_AES_GMAC :
                if (keyType!= CKK_AES) throw new InvalidKeyException("Invalid key type");
                if (paramSpec==null) iv = new byte[12];
                else iv = ((IvParameterSpec)paramSpec).getIV();
                m = new CK_MECHANISM(mechanism, iv);
                break;
        }

        try { session = secretKey.pkcs11Key.signInit(m); }
        catch (CKException e)
        {
            engineReset();
            if (e.getRV()== CKR_ARGUMENTS_BAD ||
                    e.getRV()== CKR_MECHANISM_INVALID ||
                    e.getRV()== CKR_MECHANISM_PARAM_INVALID) throw new InvalidAlgorithmParameterException(e);
            throw new InvalidKeyException(e);
        }
    }

    @Override
    protected void engineUpdate(byte b)
    {
        if (oneByte == null) oneByte = new byte[1];
        oneByte[0] = b;
        engineUpdate(oneByte, 0, 1);
    }

    private void ensureInit()
    {
        if (session==null)
        {
            try { engineInit(secretKey, paramSpec); }
            catch (InvalidAlgorithmParameterException e) { throw new ProviderException(e); }
            catch (InvalidKeyException e) { throw new ProviderException(e); }
        }
    }

    @Override
    protected void engineUpdate(byte[] b, int ofs, int len)
    {
        ensureInit();
        try { session.signUpdate(b, ofs, len); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    @Override
    protected byte[] engineDoFinal()
    {
        ensureInit();

        int outLen = engineGetMacLength();
        try { return session.signFinal(outLen); }
        catch (CKException e) { throw new ProviderException(e); }
        finally { engineReset(); }
    }

    @Override
    protected void engineReset()
    {
        if (session!=null) session.close();
        session = null;
    }

    public static final class HmacSHA1 extends Mac
    {
        public HmacSHA1() { super(CKM_SHA_1_HMAC); }
    }

    public static final class HmacSHA256 extends Mac
    {
        public HmacSHA256() { super(CKM_SHA256_HMAC); }
    }

    public static final class HmacSHA384 extends Mac
    {
        public HmacSHA384() { super(CKM_SHA384_HMAC); }
    }

    public static final class HmacSHA512 extends Mac
    {
        public HmacSHA512() { super(CKM_SHA512_HMAC); }
    }

    public static final class CMAC extends Mac
    {
        public CMAC() { super(CKM_AES_CMAC); }
    }

    public static final class GMAC extends Mac
    {
        public GMAC() { super(CKM_AES_GMAC); }
    }
}
