package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import javax.crypto.*;
import javax.crypto.spec.OAEPParameterSpec;
import javax.crypto.spec.PSource;
import javax.crypto.spec.SecretKeySpec;
import java.lang.Exception;
import java.math.BigInteger;
import java.security.Key;
import java.security.SecureRandom;
import java.security.InvalidAlgorithmParameterException;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidKeyException;
import java.security.KeyStoreException;
import java.security.NoSuchProviderException;
import java.security.AlgorithmParameters;
import java.security.interfaces.RSAKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.MGF1ParameterSpec;
import java.util.Arrays;

/**
 * Created by valery.osheter on 19-Apr-16.
 */
public final class RSACipher extends CipherSpi
{
    private RSAPrivateKey prvKey = null;
    private RSAPublicKey pubKey = null;
    private int mechanismType = CKM_RSA_PKCS;
    private int oaepHashType = CKM_SHA_1;
    private int oaepMgfType = CKG_MGF1_SHA1;
    private byte[] oaepSource = null;
    private final static byte[] B0 = new byte[0];
    private byte[] buffer = null;
    private int bufferOffset = 0;
    private int opmode = 0;
    private SecureRandom secureRandom = null;
    private OAEPParameterSpec oaepSpec = null;
    private java.security.Key wrappedKey = null;

    public RSACipher()
    {
    }

    private static int hashNameToType(String hashName) throws InvalidAlgorithmParameterException
    {
        hashName = hashName.toUpperCase();
        if (hashName.equals("SHA1")) return CKM_SHA_1;
        if (hashName.equals("SHA-1")) return CKM_SHA_1;
        if (hashName.equals("SHA-256")) return CKM_SHA256;
        if (hashName.equals("SHA-384")) return CKM_SHA384;
        if (hashName.equals("SHA-512")) return CKM_SHA512;
        throw new InvalidAlgorithmParameterException("OAEP hash algorithm not supported: " + hashName);
    }

    private static int hashTypeToMgfType(int hashType) throws InvalidAlgorithmParameterException
    {
        switch (hashType)
        {
            case CKM_SHA_1:   return CKG_MGF1_SHA1;
            case CKM_SHA256:  return CKG_MGF1_SHA256;
            case CKM_SHA384:  return CKG_MGF1_SHA384;
            case CKM_SHA512:  return CKG_MGF1_SHA512;
        }
        throw new InvalidAlgorithmParameterException("OAEP MGF hash algorithm not supported: " + hashType);
    }

    private static String hashTypeToName(int mechanismType) throws InvalidAlgorithmParameterException
    {
        switch (mechanismType)
        {
            case CKM_SHA_1:   return "SHA-1";
            case CKM_SHA256:  return "SHA-256";
            case CKM_SHA384:  return "SHA-384";
            case CKM_SHA512:  return "SHA-512";
        }
        throw new InvalidAlgorithmParameterException("Unsupported OAEP hash algorithm: " + mechanismType);
    }

    private static int paddingToMechanismType(String padding) throws NoSuchPaddingException
    {
        padding = padding.toUpperCase();
        if (padding.equals("NOPADDING")) return CKM_RSA_X_509;
        if (padding.equals("PKCS1PADDING")) return CKM_RSA_PKCS;
        if (padding.equals("OAEPPADDING")) return CKM_RSA_PKCS_OAEP;
        if (padding.startsWith("OAEPWITH") && padding.endsWith("ANDMGF1PADDING")) return CKM_RSA_PKCS_OAEP;
        throw new NoSuchPaddingException("Unsupported padding: " + padding);
    }

    private static int oaepPaddingToHashType(String padding) throws NoSuchPaddingException
    {
        padding = padding.toUpperCase();
        if (padding.equals("OAEPPADDING")) return CKM_SHA_1;
        if (padding.startsWith("OAEPWITH") && padding.endsWith("ANDMGF1PADDING"))
        {
            String hashName = padding.substring(8, padding.length() - 14);
            try { return hashNameToType(hashName); }
            catch (InvalidAlgorithmParameterException e) { throw new NoSuchPaddingException("padding not supported: " + padding); }
        }
        throw new NoSuchPaddingException("padding not supported: " + padding);
    }

    private static MGF1ParameterSpec mgfTypeToSpec(int mgfType) throws InvalidAlgorithmParameterException
    {
        switch (mgfType)
        {
            case CKG_MGF1_SHA1: return MGF1ParameterSpec.SHA1;
            case CKG_MGF1_SHA256: return MGF1ParameterSpec.SHA256;
            case CKG_MGF1_SHA384: return MGF1ParameterSpec.SHA384;
            case CKG_MGF1_SHA512: return MGF1ParameterSpec.SHA512;
        }
        throw new InvalidAlgorithmParameterException("Unsupported OAEP MGF hash algorithm: " + mgfType);
    }

    private static String paddingTypeToName(int mechanismType, int oaepHashType) throws NoSuchPaddingException, InvalidAlgorithmParameterException
    {
        switch (mechanismType)
        {
            case CKM_RSA_X_509:     return "NOPadding";
            case CKM_RSA_PKCS:      return "PKCS1Padding";
            case CKM_RSA_PKCS_OAEP: return "OAEPWith" + hashTypeToName(oaepHashType) + "AndMGF1Padding";
        }
        throw new NoSuchPaddingException("padding not supported");
    }

    @Override
    protected void engineSetMode(String mode) throws NoSuchAlgorithmException
    {
        mode = mode.toUpperCase();
        if (!mode.equals("NONE") && !mode.equals("ECB")) throw new NoSuchAlgorithmException("Mode not supported: " + mode);
    }

    @Override
    protected void engineSetPadding(String padding) throws NoSuchPaddingException
    {
        mechanismType = paddingToMechanismType(padding);
        if (mechanismType== CKM_RSA_PKCS_OAEP)
        {
            oaepHashType = oaepPaddingToHashType(padding);
            oaepMgfType = CKG_MGF1_SHA1;
        }
    }

    @Override
    protected int engineGetBlockSize() { return 0; }

    @Override
    protected int engineGetOutputSize(int inputLen)
    {
        return (buffer==null) ? 0 : buffer.length;
    }

    @Override
    protected byte[] engineGetIV() { return null; }

    private AlgorithmParameterSpec getParameterSpec() throws InvalidAlgorithmParameterException
    {
        if (oaepSpec == null)
        {
            if (mechanismType != CKM_RSA_PKCS_OAEP) return null;
            String oaepHashName = hashTypeToName(oaepHashType);
            MGF1ParameterSpec mgfSpec = mgfTypeToSpec(oaepMgfType);
            oaepSpec = new OAEPParameterSpec(oaepHashName, "MGF1", mgfSpec, PSource.PSpecified.DEFAULT);
        }
        return oaepSpec;
    }

    @Override
    protected AlgorithmParameters engineGetParameters()
    {
        try
        {
            AlgorithmParameterSpec spec = getParameterSpec();
            if (spec==null) return null;
            AlgorithmParameters params = AlgorithmParameters.getInstance("OAEP");
            params.init(spec);
            return params;
        }
        catch (Throwable e) { throw new RuntimeException("Invalid algorithm parameters not supported");  }
    }

    void init(int opmode, Key key, AlgorithmParameterSpec paramSpec) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        int bufSize = 0;
        this.opmode = opmode;
        switch (opmode)
        {
            case Cipher.ENCRYPT_MODE:
            case Cipher.WRAP_MODE:
                prvKey = null;
                if (key instanceof RSAPublicKey) pubKey = (RSAPublicKey) key;
                else if (key instanceof java.security.interfaces.RSAPublicKey) pubKey = new RSAPublicKey((java.security.interfaces.RSAPublicKey) key);
                else throw new InvalidKeyException("Invalid key type");
                if (pubKey.prvKey!=null)
                {
                    try { pubKey.prvKey.save(); }
                    catch (KeyStoreException e) { throw new InvalidKeyException(e); }
                }
                bufSize = pubKey.getBitSize()/8;
                break;

            case Cipher.DECRYPT_MODE:
            case Cipher.UNWRAP_MODE:
                pubKey = null;
                if (key instanceof RSAPrivateKey) prvKey = (RSAPrivateKey) key;
                else throw new InvalidKeyException("Invalid key type");
                try { prvKey.save(); }
                catch (KeyStoreException e) { throw new InvalidKeyException(e); }
                try { bufSize = prvKey.getBitSize()/8; }
                catch (KeyStoreException e) { throw new InvalidKeyException(e); }
                break;

            default:
                throw new InvalidKeyException("Unknown mode: " + opmode);
        }

        if (paramSpec != null)
        {
            if (mechanismType != CKM_RSA_PKCS_OAEP) throw new InvalidAlgorithmParameterException("Wrong padding parameter");
            if (!(paramSpec instanceof OAEPParameterSpec)) throw new InvalidAlgorithmParameterException("Wrong Parameters for OAEP Padding");
            oaepSpec = (OAEPParameterSpec) paramSpec;
            oaepHashType = hashNameToType(oaepSpec.getDigestAlgorithm());
            String mgfAlgName = oaepSpec.getMGFAlgorithm();
            if (!mgfAlgName.toUpperCase().equals("MGF1")) throw new InvalidAlgorithmParameterException("Unsupported MGF algorithm: " + mgfAlgName);
            AlgorithmParameterSpec mgfParam = oaepSpec.getMGFParameters();
            if (mgfParam instanceof MGF1ParameterSpec)
            {
                String mgfHashName = ((MGF1ParameterSpec)mgfParam).getDigestAlgorithm();
                int mgfHashType = hashNameToType(mgfHashName);
                oaepMgfType = hashTypeToMgfType(mgfHashType);
            }
            else throw new InvalidAlgorithmParameterException("Unsupported MGF hash");
            PSource s = oaepSpec.getPSource();
            if (s.getAlgorithm().equals("PSpecified")) oaepSource = ((PSource.PSpecified) s).getValue();
            else throw new InvalidAlgorithmParameterException("Unsupported pSource " + s.getAlgorithm() + "; PSpecified only");
        }

        buffer = new byte[bufSize];
        bufferOffset = 0;
    }

    @Override
    protected void engineInit(int opmode, Key key, SecureRandom secureRandom) throws InvalidKeyException
    {
        this.secureRandom = secureRandom;
        try { init(opmode, key, null); }
        catch (InvalidAlgorithmParameterException e) { throw new InvalidKeyException("Wrong parameters", e); }
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        this.secureRandom = secureRandom;
        init(opmode, key, algorithmParameterSpec);
    }

    @Override
    protected void engineInit(int opmode, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        this.secureRandom = secureRandom;
        OAEPParameterSpec spec = null;

        if (algorithmParameters != null)
        {
            try { spec = algorithmParameters.getParameterSpec(OAEPParameterSpec.class); }
            catch (InvalidParameterSpecException e) { throw new InvalidKeyException("Wrong parameters", e); }
        }

        init(opmode, key, spec);
    }

    // internal update method
    private void update(byte[] in, int inOffset, int inLen)
    {
        if ((inLen == 0) || (in == null)) return;
        if (bufferOffset + inLen <= buffer.length) System.arraycopy(in, inOffset, buffer, bufferOffset, inLen);
        bufferOffset += inLen;
    }

    @Override
    protected byte[] engineUpdate(byte[] in, int inOffset, int inLen)
    {
        update(in, inOffset, inLen);
        return B0;
    }

    @Override
    protected int engineUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException
    {
        update(in, inOffset, inLen);
        return 0;
    }

    private byte[] doFinal() throws BadPaddingException, IllegalBlockSizeException, NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, NoSuchProviderException
    {
        if (pubKey!=null)
        {
            Cipher cipher = Cipher.getInstance("RSA/ECB/"+paddingTypeToName(mechanismType, oaepHashType), "SunJCE");
            cipher.init(opmode, pubKey.getSoftwareKey(), getParameterSpec(), secureRandom);
            if (opmode==Cipher.WRAP_MODE) return cipher.wrap(wrappedKey);
            else return cipher.doFinal(buffer, 0, bufferOffset);
        }
        else
        {
            CK_MECHANISM m = (mechanismType == CKM_RSA_PKCS_OAEP)
                    ? new CK_RSA_PKCS_OAEP_PARAMS(oaepHashType, oaepMgfType, oaepSource)
                    : new CK_MECHANISM(mechanismType);
            Session session = null;
            try
            {
                session = prvKey.pkcs11Key.decryptInit(m);
                byte[] out = new byte[buffer.length];
                int outLen = session.decrypt(buffer, 0, bufferOffset, out, 0);
                return outLen==buffer.length ? out : Arrays.copyOf(out, outLen);
            }
            catch (CKException e)
            {
                if (e.getRV()== CKR_ARGUMENTS_BAD ||
                        e.getRV()== CKR_MECHANISM_INVALID ||
                        e.getRV()== CKR_MECHANISM_PARAM_INVALID) throw new InvalidAlgorithmParameterException(e);
                throw new InvalidKeyException(e);
            }
            finally { if (session!=null) session.close(); }
        }
    }

    @Override
    protected byte[] engineDoFinal(byte[] in, int inOffset, int inLen) throws IllegalBlockSizeException, BadPaddingException
    {
        update(in, inOffset, inLen);
        if (bufferOffset>buffer.length)  throw new IllegalBlockSizeException("Input must be under " + buffer.length + " bytes");
        try { return doFinal(); }
        catch (Exception e) { throw new BadPaddingException("engineDoFinal failed"); }
    }

    @Override
    protected int engineDoFinal(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException
    {
        byte[] b = engineDoFinal(in, inOffset, inLen);
        if (outOffset + b.length > out.length) throw new ShortBufferException("Output buffer is too small");
        System.arraycopy(b, 0, out, outOffset, b.length);
        return b.length;
    }

    private long getPubKeyUID() throws InvalidKeyException
    {
        java.security.MessageDigest sha256;
        try {  sha256 = java.security.MessageDigest.getInstance("SHA-256", "SUN"); }
        catch (NoSuchAlgorithmException e) { throw new InvalidKeyException(e); }
        catch (NoSuchProviderException e) { throw new InvalidKeyException(e); }

        BigInteger N = pubKey.getModulus();
        int keyByteSize = Utils.bigIntByteSize(N);

        byte[] uid = Arrays.copyOf(sha256.digest(Utils.bigInt2Bytes(N,keyByteSize)), 8);
        return Utils.bytesToUID(uid) ^ 0x0f0f0f0f0f0f0f0fl;
    }

    @Override
    protected byte[] engineWrap(Key key) throws InvalidKeyException,  IllegalBlockSizeException
    {
        if (pubKey!=null && (key instanceof SecretKey) && (mechanismType== CKM_RSA_PKCS_OAEP))
        {
            try { ((SecretKey)key).save(); }
            catch (KeyStoreException e) { throw new InvalidKeyException(e); }

            long pubKeyUID = getPubKeyUID();
            Slot slot = ((SecretKey)key).pkcs11Key.getSlot();
            CKRSAPublicKey pkcs11Key = slot.findObject(CKRSAPublicKey.class, pubKeyUID);
            if (pkcs11Key==null) throw new InvalidKeyException("Public key not found");

            CK_MECHANISM m = new CK_RSA_PKCS_OAEP_PARAMS(oaepHashType, oaepMgfType, oaepSource);
            try { return pkcs11Key.wrap(m, ((SecretKey)key).pkcs11Key); }
            catch (CKException e)  { throw new InvalidKeyException(e); }
        }

        byte[] encoded = key.getEncoded();
        if ((encoded == null) || (encoded.length == 0)) throw new InvalidKeyException("Could not obtain encoded key");
        if (encoded.length > buffer.length) throw new InvalidKeyException("CKKey is too long for wrapping");

        wrappedKey = key;
        try { return doFinal(); }
        catch (Exception e) { throw new InvalidKeyException("Wrapping failed", e); }
    }

    @Override
    protected Key engineUnwrap(byte[] wrappedKey, String algorithm, int wrappedKeyType) throws InvalidKeyException, NoSuchAlgorithmException
    {
        if (wrappedKeyType != Cipher.SECRET_KEY) throw new UnsupportedOperationException("wrappedKeyType == " + wrappedKeyType);
        if (wrappedKey.length > buffer.length) throw new InvalidKeyException("CKKey is too long for unwrapping");

        CK_MECHANISM m = null;
        if (mechanismType==CKM_RSA_PKCS_OAEP) m = new CK_RSA_PKCS_OAEP_PARAMS(oaepHashType, oaepMgfType, oaepSource);
        else m = new CK_MECHANISM(mechanismType);

        boolean decrypt = false;
        try { decrypt = prvKey.pkcs11Key.getPolicy().getDecrypt(); } catch (CKException e) { }

        int keyType = SecretKey.algToKeyType(algorithm);
        if (decrypt)
        {
            try
            {
                byte[] unwrappedData = engineDoFinal(wrappedKey, 0, wrappedKey.length);
                return new SecretKeySpec(unwrappedData, algorithm);
            }
            catch (IllegalBlockSizeException e) { throw new InvalidKeyException(e); }
            catch (BadPaddingException e) { throw new InvalidKeyException(e); }
        }
        else
        {
            UnwrapInfo unwrapInfo = new UnwrapInfo(m, prvKey.pkcs11Key, wrappedKey);
            return new SecretKey().initForUnwrap(unwrapInfo, keyType);
        }
    }

    @Override
    protected int engineGetKeySize(Key key) throws InvalidKeyException
    {
        return Utils.bigIntByteSize(((RSAKey)key).getModulus())*8;
    }
}
