package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import java.security.InvalidKeyException;
import java.security.ProviderException;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidAlgorithmParameterException;
import java.security.AlgorithmParameters;
import java.security.SecureRandom;
import java.security.GeneralSecurityException;
import java.security.KeyStoreException;
import java.security.Key;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.util.Arrays;
import java.util.LinkedList;


/**
 * Created by saar.peer on 29-Jun-16.
 */
public class SecretKeyCipher extends CipherSpi
{

    private SecretKey secretKey = null;

    private final int keyType;
    private boolean wrap = false;
    private boolean encrypt = true;
    private boolean singleOp = false;
    private AlgorithmParameterSpec paramSpec = null;
    private boolean aad = false;
    private boolean padding = false;
    private boolean initialized = false;
    private byte[] buffer = null;
    private byte[] auth = null;
    private LinkedList<byte[]> siv_headers = null;
    private Session session = null;
    private CK_MECHANISM mechanism = null;

    private int mode = 0;
    private static final int ECB = 1;
    private static final int CBC = 2;
    private static final int CCM = 3;
    private static final int GCM = 4;
    private static final int OFB = 5;
    private static final int CFB = 6;
    private static final int CTR = 7;
    private static final int XTS = 8;
    private static final int SIV = 9;
    private static final int NIST = 10;

    private final static byte[] B0 = new byte[0];

    SecretKeyCipher(int keyType)
    {
        this.keyType = keyType;
    }

    private void checkValidAlg(int keyType) throws NoSuchAlgorithmException
    {
        if (this.keyType!=keyType) throw new NoSuchAlgorithmException("Mode not supported");
    }


    @Override
    protected void engineSetMode(String mode) throws NoSuchAlgorithmException
    {
        mode = mode.toUpperCase();

        if (mode.equalsIgnoreCase("CCM"))    { this.mode = CCM;  checkValidAlg(CKK_AES);        aad = true; }
        else if (mode.equalsIgnoreCase("GCM"))    { this.mode = GCM;  checkValidAlg(CKK_AES);        aad = true; }
        else if (mode.equalsIgnoreCase("ECB"))      this.mode = ECB;
        else if (mode.equalsIgnoreCase("CBC"))      this.mode = CBC;
        else if (mode.equalsIgnoreCase("CTR"))    { this.mode = CTR;  checkValidAlg(CKK_AES);        }
        else if (mode.equalsIgnoreCase("OFB64"))  { this.mode = OFB;  checkValidAlg(CKK_DES3);       }
        else if (mode.equalsIgnoreCase("OFB128")) { this.mode = OFB;  checkValidAlg(CKK_AES);        }
        else if (mode.equalsIgnoreCase("CFB64"))  { this.mode = CFB;  checkValidAlg(CKK_DES3);       }
        else if (mode.equalsIgnoreCase("CFB128")) { this.mode = CFB;  checkValidAlg(CKK_AES);        }
        else if (mode.equalsIgnoreCase("XTS"))    { this.mode = XTS;  checkValidAlg(DYCKK_AES_XTS);  singleOp = true; }
        else if (mode.equalsIgnoreCase("SIV"))    { this.mode = SIV;  checkValidAlg(DYCKK_AES_SIV);  singleOp = true; aad = true; }
        else if (mode.equalsIgnoreCase("WRAP"))   { this.mode = NIST; checkValidAlg(CKK_AES);        singleOp = true; }
        else throw new NoSuchAlgorithmException("Mode not supported: " + mode);
    }

    @Override
    protected void engineSetPadding(String padding) throws NoSuchPaddingException
    {
        if (padding.equalsIgnoreCase("NOPADDING")) this.padding = false;
        else if (padding.equalsIgnoreCase("PKCS5PADDING"))
        {
            if (mode!=CBC && mode!=NIST) throw new NoSuchPaddingException("padding not supported");
            this.padding = true;
        }
        else throw new NoSuchPaddingException("padding not supported");
    }

    @Override
    protected int engineGetBlockSize() {
        return keyType== CKK_DES3 ? 8 : 16;
    }

    @Override
    protected int engineGetOutputSize(int inputLen)
    {
        byte[] temp = new byte[(singleOp ? buffer.length : 0) + inputLen];
        return encdecLen(temp, 0, temp.length);
    }

    @Override
    protected byte[] engineGetIV()
    {
        if (paramSpec==null) return null;
        if (paramSpec instanceof IvParameterSpec) return ((IvParameterSpec)paramSpec).getIV();
        return null;
    }

    @Override
    protected int engineGetKeySize(Key key) throws InvalidKeyException
    {
        if (key instanceof SecretKey)
        {
            SecretKey secretKey = (SecretKey)key;
            try { return secretKey.getBitSize(); }
            catch (KeyStoreException e) { throw new InvalidKeyException(e); }
        }

        byte[] encoded = key.getEncoded();
        if (encoded==null) throw new InvalidKeyException("Invalid key value");
        return encoded.length * 8;
    }

    @Override
    protected AlgorithmParameters engineGetParameters()
    {
        if (paramSpec==null) return null;

        try
        {
            AlgorithmParameters params = AlgorithmParameters.getInstance(keyType== CKK_DES3 ? "DESede" : "AES", "SunJCE");
            params.init(paramSpec);
            return params;
        }
        catch (GeneralSecurityException e) { throw new ProviderException("Could not encode parameters", e); }
    }

    private int getMechanismType()
    {
        switch (mode)
        {
            case CTR:  return CKM_AES_CTR;
            case GCM:  return CKM_AES_GCM;
            case CCM:  return CKM_AES_CCM;
            case XTS:  return DYCKM_AES_XTS;
            case SIV:  return DYCKM_AES_SIV;
            case NIST: return padding ? CKM_AES_KEY_WRAP_PAD : CKM_AES_KEY_WRAP;
            case OFB:  return keyType== CKK_AES ? CKM_AES_OFB    : CKM_DES_OFB64;
            case CFB:  return keyType== CKK_AES ? CKM_AES_CFB128 : CKM_DES_CFB64;
            case ECB:  return keyType== CKK_AES ? CKM_AES_ECB    : CKM_DES3_ECB;
            case CBC:
                return padding ?
                        (keyType== CKK_AES ? CKM_AES_CBC_PAD  : CKM_DES3_CBC_PAD) :
                        (keyType== CKK_AES ? CKM_AES_CBC      : CKM_DES3_CBC);
        }
        return -1;
    }

    private CK_MECHANISM getMechIV(IvParameterSpec spec)
    {
        int mechanismType = getMechanismType();
        if (mechanism != null && (mechanism.getType()==mechanismType))
        {
            mechanism.setBuffer(spec.getIV());
            return mechanism;
        }
        return new CK_MECHANISM(mechanismType, spec.getIV());
    }

    private CK_MECHANISM getMechSIV()
    {
        return new DYCK_AES_SIV_PARAMS(siv_headers==null ? null : (byte[][])siv_headers.toArray());
    }


    private CK_MECHANISM getMechGCM(GCMParameterSpec spec)
    {
        if (mechanism != null && (mechanism.getType()==CKM_AES_GCM))
        {
            ((CK_GCM_PARAMS)mechanism).init(spec.getIV(), auth, spec.getTLen());
            return mechanism;
        }

        return new CK_GCM_PARAMS(spec.getIV(), auth, spec.getTLen());
    }

    private CK_MECHANISM getMechCCM(CCMParameterSpec spec)
    {
        return new CK_CCM_PARAMS(spec.getDataSize(), spec.getIV(), auth, spec.getTagSize());
    }

    private void ensureInitOperation()
    {
        try { ensureInit(); }
        catch (InvalidKeyException e) { throw new ProviderException(e); }
        catch (InvalidAlgorithmParameterException e) { throw new ProviderException(e); }
    }

    private void initSession() throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        try {
            session = encrypt ? secretKey.pkcs11Key.encryptInit(mechanism) : secretKey.pkcs11Key.decryptInit(mechanism);
        }
        catch (CKException e)
        {
            if (e.getRV()== CKR_ARGUMENTS_BAD ||
                    e.getRV()== CKR_MECHANISM_INVALID ||
                    e.getRV()== CKR_MECHANISM_PARAM_INVALID) throw new InvalidAlgorithmParameterException(e);

            throw new InvalidKeyException(e);
        }
    }

    private void releaseSession()
    {
        if (session==null) return;
        secretKey.pkcs11Key.getSlot().releaseSession(session);
        session = null;
    }


    private void prepareMechanism() throws InvalidAlgorithmParameterException
    {
        mechanism = null;

        switch (mode)
        {
            case SIV:
                mechanism = getMechSIV();
                break;

            case GCM:
                mechanism = getMechGCM((GCMParameterSpec)paramSpec);
                break;

            case CCM:
                mechanism = getMechCCM((CCMParameterSpec)paramSpec);
                break;

            case CTR:
            case CBC:
            case OFB:
            case CFB:
            case XTS:
            case NIST:
                mechanism = getMechIV((IvParameterSpec)paramSpec);
                break;

            case ECB:
                mechanism = new CK_MECHANISM(getMechanismType());
                break;

            default:
                throw new InvalidAlgorithmParameterException("Invalid PKCS#11 mechanism");
        }
    }

    private void ensureInit() throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        if (initialized) return;
        buffer = null;
        if (session!=null) session.close();
        session = null;

        prepareMechanism();

        if (!wrap)
        {
            initSession();
        }

        if (singleOp) buffer = B0;
        if (aad) auth = B0;

        initialized = true;
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom)
            throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        initialized = false;

        if (!(key instanceof SecretKey)) throw new InvalidKeyException("Invalid key type");
        secretKey = (SecretKey)key;

        try { secretKey.save(); }
        catch (KeyStoreException e) {
            throw new InvalidKeyException(e);
        }

        try {
            if (secretKey.getKeyType()!=keyType) throw new InvalidKeyException("Invalid key type");
        } catch (KeyStoreException e) {
            throw new InvalidKeyException(e);
        }

        wrap = (opmode == Cipher.WRAP_MODE) || (opmode == Cipher.UNWRAP_MODE);
        encrypt = (opmode == Cipher.ENCRYPT_MODE) || (opmode == Cipher.WRAP_MODE);
        paramSpec = algorithmParameterSpec;

        switch (mode)
        {
            case SIV:
                if (!wrap) throw new InvalidAlgorithmParameterException("SIV doesn't support encrypt/decrypt");
                break;

            case GCM:
                if (!(paramSpec instanceof GCMParameterSpec)) throw new InvalidAlgorithmParameterException("GCMParameterSpec required");
                break;

            case CCM:
                if (!(paramSpec instanceof CCMParameterSpec)) throw new InvalidAlgorithmParameterException("CCMParameterSpec required");
                break;

            case NIST:
                if (paramSpec==null)
                {
                    final byte[] iv = new byte[0];
                    paramSpec = new IvParameterSpec(iv);
                }
                else
                {
                    if (!(paramSpec instanceof IvParameterSpec)) throw new InvalidAlgorithmParameterException("IvParameterSpec required");
                    int ivLen =((IvParameterSpec)paramSpec).getIV().length;
                    if (ivLen!=8) throw new InvalidAlgorithmParameterException("Invalid IV length");
                }
                break;

            case CTR:
            case CBC:
            case OFB:
            case CFB:
            case XTS:
                if (wrap && mode==XTS) throw new InvalidAlgorithmParameterException("XTS doesn't support wrap/unwrap");
                if (paramSpec==null && encrypt)
                {
                    if (secureRandom==null) throw new InvalidAlgorithmParameterException("Can't generate IV");
                    int size = engineGetBlockSize();
                    if (mode==NIST) size = 0;
                    byte[] iv = new byte[size];
                    secureRandom.nextBytes(iv);
                    paramSpec = new IvParameterSpec(iv);
                }
                if (paramSpec==null || !(paramSpec instanceof IvParameterSpec)) throw new InvalidAlgorithmParameterException("IvParameterSpec required");
                if (((IvParameterSpec)paramSpec).getIV().length!=engineGetBlockSize()) throw new InvalidAlgorithmParameterException("Invalid IV length");
                break;
        }

        if (aad) { auth = B0; siv_headers = null; }
        else ensureInit();
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        AlgorithmParameterSpec spec = null;

        Class<? extends AlgorithmParameterSpec> clazz = IvParameterSpec.class;
        switch (mode)
        {
            case CCM: clazz = CCMParameterSpec.class; break;
            case GCM: clazz = GCMParameterSpec.class;
        }

        if (algorithmParameters!=null)
        {
            try { spec = algorithmParameters.getParameterSpec(clazz); }
            catch (InvalidParameterSpecException ipse) { throw new InvalidAlgorithmParameterException("Wrong parameter"); }
        }

        engineInit(opmode, key, spec, secureRandom);
    }

    @Override
    protected void engineInit(int opmode, Key key, SecureRandom secureRandom) throws InvalidKeyException
    {
        try { engineInit(opmode, key, (AlgorithmParameterSpec)null, secureRandom); }
        catch (InvalidAlgorithmParameterException e) { throw new InvalidKeyException(e); }
    }

    private void updateSingleOp(byte[] in, int inOffset, int inLen)
    {
        int oldSize = buffer.length;
        byte[] newBuffer = new byte[oldSize+inLen];
        if (oldSize>0) System.arraycopy(buffer, 0, newBuffer, 0, oldSize);
        System.arraycopy(in, inOffset, newBuffer, oldSize, inLen);
        buffer = newBuffer;
    }

    @Override
    protected void engineUpdateAAD(byte[] src,
                                   int offset,
                                   int len) throws IllegalStateException, UnsupportedOperationException
    {
        if (!aad) throw new IllegalStateException("Cipher does not accept AAD");

        if (mode==SIV)
        {
            if (siv_headers==null) siv_headers = new LinkedList<byte[]>();
            siv_headers.add(Arrays.copyOfRange(src, offset, len));
            return;
        }

        int oldSize = auth.length;
        byte[] newBuffer = new byte[oldSize+len];
        if (oldSize>0) System.arraycopy(auth, 0, newBuffer, 0, oldSize);
        System.arraycopy(src, offset, newBuffer, oldSize, len);
        auth = newBuffer;
    }

    @Override
    protected byte[] engineUpdate(byte[] in, int inOffset, int inLen)
    {
        if ((inLen == 0) || (in == null)) return B0;

        ensureInitOperation();

        if (singleOp)
        {
            updateSingleOp(in, inOffset, inLen);
            return B0;
        }

        int outLen = encdecUpdateLen(in, inOffset, inLen);
        byte[] out = new byte[outLen];
        int realOutLen =  encdecUpdate(in, inOffset, inLen, out, 0);
        if (realOutLen==outLen) return out;
        return Arrays.copyOf(out, realOutLen);
    }

    @Override
    protected int engineUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException
    {
        if ((inLen == 0) || (in == null)) return 0;
        ensureInitOperation();

        if (singleOp)
        {
            updateSingleOp(in, inOffset, inLen);
            return 0;
        }

        int outLen = encdecUpdateLen(in, inOffset, inLen);
        if (outLen > out.length-outOffset) throw new ShortBufferException();
        return encdecUpdate(in, inOffset, inLen, out, outOffset);
    }

    private int getOutBufLen(int inLen)
    {
        int blockLen = (keyType==CKK_DES3) ? 8 : 16;
        return inLen + blockLen * (encrypt ? 3 : 1); // block + data + block + tag
    }


    @Override
    protected byte[] engineDoFinal(byte[] in, int inOffset, int inLen)
            throws IllegalBlockSizeException, BadPaddingException, AEADBadTagException
    {
        ensureInitOperation();

        if (singleOp)
        {
            engineUpdate(in, inOffset, inLen);
            in = buffer;
            inOffset = 0;
            inLen = buffer.length;
        }

        int outLen = getOutBufLen(inLen); //encdecLen(in, inOffset, inLen);
        byte[] out = new byte[outLen==0 ? 1 : outLen];

        int realOutLen = encdec(in, inOffset, inLen, out, 0);
        if (realOutLen==outLen) return out;
        return Arrays.copyOf(out, realOutLen);
    }

    @Override
    protected int engineDoFinal(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
            throws ShortBufferException, IllegalBlockSizeException, BadPaddingException, AEADBadTagException
    {
        ensureInitOperation();

        if (singleOp)
        {
            if (in!=null && inLen>0) engineUpdate(in, inOffset, inLen);
            in = buffer;
            inOffset = 0;
            inLen = buffer.length;
        }

        int outLen = encdecLen(in, inOffset, inLen);
        if (outLen > out.length-outOffset) throw new ShortBufferException();
        return encdec(in, inOffset, inLen, out, outOffset);
    }

    private int encdecLen(byte[] in, int inOffset, int inLen)
    {
        try { return encrypt ? session.encrypt(in, inOffset, inLen, null, 0) : session.decrypt(in, inOffset, inLen, null, 0); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    private int encdec(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
            throws IllegalBlockSizeException, BadPaddingException, AEADBadTagException
    {
        try { return encrypt ? session.encrypt(in, inOffset, inLen, out, outOffset) : session.decrypt(in, inOffset, inLen, out, outOffset); }
        catch (CKException e)
        {
            int rv = e.getRV();
            if (rv== CKR_DATA_LEN_RANGE || rv== CKR_ENCRYPTED_DATA_LEN_RANGE) throw new IllegalBlockSizeException();
            if (rv== CKR_DATA_INVALID || rv== CKR_ENCRYPTED_DATA_INVALID)
            {
                if (aad) throw new AEADBadTagException();
                throw new BadPaddingException();
            }
            throw new ProviderException(e);
        }
        finally { releaseSession(); }
    }

    private int encdecUpdateLen(byte[] in, int inOffset, int inLen)
    {
        try { return encrypt ? session.encryptUpdate(in, inOffset, inLen, null, 0) : session.decryptUpdate(in, inOffset, inLen, null, 0); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    private int encdecUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset)
    {
        try { return encrypt ? session.encryptUpdate(in, inOffset, inLen, out, outOffset) : session.decryptUpdate(in, inOffset, inLen, out, outOffset); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    @Override
    protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException
    {
        CKKey wrappedPkcs11Key = null;
        try
        {
            if (key instanceof SecretKey)
            {
                ((SecretKey)key).save();
                wrappedPkcs11Key = ((SecretKey)key).pkcs11Key;
            }
            else if (key instanceof RSAPrivateKey)
            {
                ((RSAPrivateKey)key).save();
                wrappedPkcs11Key = ((RSAPrivateKey)key).pkcs11Key;
            }
            else if (key instanceof ECPrivateKey)
            {
                ((ECPrivateKey)key).save();
                wrappedPkcs11Key = ((ECPrivateKey)key).pkcs11Key;
            }
        }
        catch (KeyStoreException e) { throw new InvalidKeyException(e); }

        try { prepareMechanism(); }
        catch (InvalidAlgorithmParameterException e) { throw new InvalidKeyException(e); }

        if (wrappedPkcs11Key!=null)
        {
            try { return secretKey.pkcs11Key.wrap(mechanism, wrappedPkcs11Key, 0); }
            catch (CKException e) { throw new InvalidKeyException(e); }
        }

        byte[] encodedKey = key.getEncoded();
        if ((encodedKey == null) || (encodedKey.length == 0)) throw new InvalidKeyException("Cannot get an encoding of the key to be wrapped");

        try { initSession(); }
        catch (InvalidAlgorithmParameterException e) { throw new ProviderException(e); }
        try { return engineDoFinal(encodedKey, 0, encodedKey.length); }
        catch (BadPaddingException e)  { throw new ProviderException(e); }
        finally { releaseSession(); }
    }

    @Override
    protected Key engineUnwrap(byte[] wrappedKey, String wrappedKeyAlgorithm, int wrappedKeyType) throws InvalidKeyException,  NoSuchAlgorithmException
    {
        try { prepareMechanism(); }
        catch (InvalidAlgorithmParameterException e) { throw new InvalidKeyException(e); }

        UnwrapInfo unwrapInfo = new UnwrapInfo(mechanism, secretKey.pkcs11Key, wrappedKey);

        switch (wrappedKeyType)
        {
            case Cipher.PRIVATE_KEY:
                if (wrappedKeyAlgorithm.equalsIgnoreCase("RSA")) return new RSAPrivateKey().initForUnwrap(unwrapInfo);
                else if (wrappedKeyAlgorithm.equalsIgnoreCase("EC")) return new ECPrivateKey().initForUnwrap(unwrapInfo);
                throw new InvalidKeyException("Unsupported wrappedKeyAlgorithm " + wrappedKeyAlgorithm);

            case Cipher.SECRET_KEY:
                return new SecretKey().initForUnwrap(unwrapInfo, SecretKey.algToKeyType(wrappedKeyAlgorithm));
        }
        throw new InvalidKeyException("Unsupported wrappedKeyType");
    }

    public static final class AES extends SecretKeyCipher
    {
        public AES() { super(CKK_AES); }
    }

    public static final class AESXTS extends SecretKeyCipher
    {
        public AESXTS() { super(DYCKK_AES_XTS); }
    }

    public static final class AESSIV extends SecretKeyCipher
    {
        public AESSIV() { super(DYCKK_AES_SIV); }
    }

    public static final class DES3 extends SecretKeyCipher
    {
        public DES3() { super(CKK_DES3); }
    }
}
