package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.CKException;
import com.dyadicsec.pkcs11.CK_MECHANISM;
import com.dyadicsec.pkcs11.Session;

import javax.crypto.*;
import javax.crypto.spec.SecretKeySpec;
import java.security.*;
import java.security.spec.AlgorithmParameterSpec;

import static com.dyadicsec.cryptoki.CK.DYCKM_LIMA;

/**
 * Created by valery.osheter on 21-Mar-17.
 */
public final class LIMACipher extends CipherSpi
{
    static final int N = 1024;
    static final int SEED_SIZE = 32;
    static final int MAX_MSG_SIZE = 96;
    static final int MAX_ENC_SIZE = getEncSize(MAX_MSG_SIZE);

    private int mode = Cipher.ENCRYPT_MODE;
    private byte[] buffer = null;
    private LIMAPrivateKey prvKey = null;
    private LIMAPublicKey pubKey = null;

    private static byte[] combine(byte[] a, byte[] b){
        if (a == null) return b;
        byte[] result = new byte[a.length + b.length];
        System.arraycopy(a, 0, result, 0, a.length);
        System.arraycopy(b, 0, result, a.length, b.length);
        return result;
    }

    @Override
    protected void engineSetMode(String mode)
    {
    }

    @Override
    protected void engineSetPadding(String mode)
    {
    }

    @Override
    protected int engineGetBlockSize()
    {
        return 0;
    }

    static int getEncSize(int inputLen) { return ((inputLen+ SEED_SIZE)  * 8 + N)*3 + 3; }
    static int getDecSize(int inputLen) { return ((inputLen - 3)/3 - N) / 8 - SEED_SIZE; }

    @Override
    protected int engineGetOutputSize(int inputLen)
    {
        switch (mode)
        {
            case Cipher.WRAP_MODE:
            case Cipher.ENCRYPT_MODE: return getEncSize(inputLen);
            case Cipher.UNWRAP_MODE:
            case Cipher.DECRYPT_MODE: return getDecSize(inputLen);
        }
        return 0;
    }

    @Override
    protected byte[] engineGetIV()
    {
        return null;
    }

    @Override
    protected AlgorithmParameters engineGetParameters()
    {
        return null;
    }

    @Override
    protected void engineInit(int mode, Key key, SecureRandom secureRandom) throws InvalidKeyException
    {
        this.mode = mode;
        buffer = null;

        switch (mode)
        {
            case Cipher.WRAP_MODE:
            case Cipher.ENCRYPT_MODE:
                if (!(key instanceof LIMAPublicKey)) throw new InvalidKeyException("CKKey type must be CKLIMAPublicKey");
                pubKey = (LIMAPublicKey)key;
                break;

            case Cipher.UNWRAP_MODE:
            case Cipher.DECRYPT_MODE:
                if (!(key instanceof LIMAPrivateKey)) throw new InvalidKeyException("CKKey type must be CKLIMAPrivateKey");
                prvKey = (LIMAPrivateKey)key;
                break;
        }
    }

    @Override
    protected void engineInit(int mode, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom)
            throws InvalidKeyException
    {
        engineInit(mode, key, null);
    }

    @Override
    protected void engineInit(int mode, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom)
            throws InvalidKeyException
    {
        engineInit(mode, key, null);
    }

    @Override
    protected byte[] engineUpdate(byte[] in, int inOffset, int inLen)
    {
        engineUpdate(in, inOffset, inLen, null, 0);
        return new byte[0];
    }

    @Override
    protected int engineUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ProviderException
    {
        buffer = combine(buffer, in);

        if ( (mode == Cipher.ENCRYPT_MODE && buffer.length > MAX_MSG_SIZE ) ||
                (mode == Cipher.DECRYPT_MODE && buffer.length > MAX_ENC_SIZE ))
            throw new ProviderException("Input is too long");
        return 0;
    }

    @Override
    protected byte[] engineDoFinal(byte[] in, int inOffset, int inLen) throws ProviderException
    {
        engineUpdate(in, inOffset, inLen);
        byte[] out = new byte[engineGetOutputSize(buffer.length)];
        try { doFinal(out, 0); }
        catch (ShortBufferException e) { throw new ProviderException(e); }
        return out;
    }

    @Override
    protected int engineDoFinal(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException {
        engineUpdate(in, inOffset, inLen);
        return doFinal(out, outOffset);
    }

    private int doFinal(byte[] out, int outOffset) throws ShortBufferException
    {
        if (outOffset+engineGetOutputSize(buffer.length) > out.length) throw new ShortBufferException();
        Session session = null;

        try
        {
            switch (mode)
            {
                case Cipher.ENCRYPT_MODE:
                    try {
                        pubKey.prvKey.save();
                        session = pubKey.pkcs11Key.encryptInit(new CK_MECHANISM(DYCKM_LIMA));
                        session.encrypt(buffer, 0, buffer.length, out, outOffset);
                    }
                    catch (CKException e ) { throw new ProviderException(e); }
                    catch (KeyStoreException e ) { throw new ProviderException(e); }
                    break;

                case Cipher.DECRYPT_MODE:
                    try {
                        prvKey.save();
                        session = prvKey.pkcs11Key.decryptInit(new CK_MECHANISM(DYCKM_LIMA));
                        session.decrypt(buffer, 0, buffer.length, out, outOffset);
                    }
                    catch (CKException e ) { throw new ProviderException(e); }
                    catch (KeyStoreException e ) { throw new ProviderException(e); }
                    break;
            }

            return engineGetOutputSize(buffer.length);
        }
        finally { if (session!=null) session.close(); }
    }

    @Override
    protected int engineGetKeySize(Key key)
    {
        return 128;
    }

    @Override
    protected byte[] engineWrap(Key key) throws InvalidKeyException, ProviderException
    {

        byte[] keyBuf = key.getEncoded();
        if ((keyBuf == null) || (keyBuf.length == 0)) throw new InvalidKeyException("Could not obtain encoded key");
        if (keyBuf.length > MAX_MSG_SIZE) throw new InvalidKeyException("CKKey is too long for wrapping");
        byte[] out = new byte[getEncSize(keyBuf.length)];
        try {
            pubKey.prvKey.save();
            Session session = pubKey.pkcs11Key.encryptInit(new CK_MECHANISM(DYCKM_LIMA));
            session.encrypt(keyBuf, 0, keyBuf.length, out, 0);
        }
        catch (CKException e ) { throw new ProviderException(e); }
        catch (KeyStoreException e ) { throw new ProviderException(e); }
        return out;
    }

    @Override
    protected Key engineUnwrap(byte[] wrappedKey, String algorithm, int wrappedKeyType) throws InvalidKeyException, ProviderException
    {
        if (wrappedKeyType != Cipher.SECRET_KEY) throw new UnsupportedOperationException("wrappedKeyType == " + wrappedKeyType);
        if (wrappedKey.length > MAX_ENC_SIZE) throw new InvalidKeyException("CKKey is too long for unwrapping");
        byte[] out = new byte[getDecSize(wrappedKey.length)];
        try {
            prvKey.save();
            Session session = prvKey.pkcs11Key.decryptInit(new CK_MECHANISM(DYCKM_LIMA));
            session.decrypt(wrappedKey, 0, wrappedKey.length, out, 0);
        }
        catch (CKException e ) { throw new ProviderException(e); }
        catch (KeyStoreException e ) { throw new ProviderException(e); }

        return new SecretKeySpec(out, "AES");
    }
}
