package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import static com.dyadicsec.cryptoki.CK.*;

import java.io.ByteArrayOutputStream;
import java.lang.Object;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Signature;
import java.security.SignatureSpi;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.ProviderException;
import java.security.KeyStoreException;
import java.security.SignatureException;
import java.security.InvalidParameterException;

/**
 * Created by valery.osheter on 19-Apr-16.
 */
@SuppressWarnings("deprecation")
public class RSASignature extends SignatureSpi
{
    private int mechanismType;
    private RSAPrivateKey prvKey = null;
    private Signature pubSignature = null;
    private Session session = null;
    private ByteArrayOutputStream buffer = null;

    static String getSunProvider(int mechanismType) throws InvalidAlgorithmParameterException
    {
        switch (mechanismType)
        {
            case CKM_RSA_PKCS:           return "SunJCE";
            case CKM_SHA1_RSA_PKCS:      return "SunRsaSign";
            case CKM_SHA256_RSA_PKCS:    return "SunRsaSign";
            case CKM_SHA384_RSA_PKCS:    return "SunRsaSign";
            case CKM_SHA512_RSA_PKCS:    return "SunRsaSign";
        }
        throw new InvalidAlgorithmParameterException("Unsupported hash algorithm: " + mechanismType);
    }

    static String mechanismTypeToHashName(int mechanismType) throws InvalidAlgorithmParameterException
    {
        switch (mechanismType)
        {
            case CKM_RSA_PKCS:           return "NONE";
            case CKM_SHA1_RSA_PKCS:      return "SHA1";
            case CKM_SHA256_RSA_PKCS:    return "SHA256";
            case CKM_SHA384_RSA_PKCS:    return "SHA384";
            case CKM_SHA512_RSA_PKCS:    return "SHA512";
        }
        throw new InvalidAlgorithmParameterException("Unsupported hash algorithm: " + mechanismType);
    }

    RSASignature(int mechanismType)
    {
        this.mechanismType = mechanismType;
    }

    @Override
    protected void engineInitVerify(PublicKey publicKey) throws InvalidKeyException
    {
        if (publicKey instanceof RSAPublicKey) publicKey = ((RSAPublicKey)publicKey).getSoftwareKey();
        else if (publicKey instanceof java.security.interfaces.RSAPublicKey);
        else throw new InvalidKeyException("Invalid key type");

        try { pubSignature = Signature.getInstance(mechanismTypeToHashName(mechanismType)+"WithRSA", getSunProvider(mechanismType)); }
        catch (Throwable e) { throw new InvalidKeyException("engineInitVerify failed"); }
        pubSignature.initVerify(publicKey);
    }


    private void checkInit()
    {
        if (session!=null) return;
        try { session = prvKey.pkcs11Key.signInit(new CK_MECHANISM(mechanismType)); }
        catch (CKException e) { throw new ProviderException(e); }

        if (mechanismType== CKM_RSA_PKCS)
        {
            if (buffer==null) buffer = new ByteArrayOutputStream();
            buffer.reset();
        }
    }

    @Override
    protected void engineInitSign(PrivateKey privateKey) throws InvalidKeyException
    {
        if (privateKey instanceof RSAPrivateKey) prvKey = (RSAPrivateKey) privateKey;
        else throw new InvalidKeyException("Invalid key type");

        try { prvKey.save(); }
        catch (KeyStoreException e) { throw new InvalidKeyException(e); }

        closeSession();
        checkInit();
    }

    @Override
    protected void engineUpdate(byte b) throws SignatureException
    {
        if (pubSignature!=null)
        {
            pubSignature.update(b);
            return;
        }

        byte[] in = {b};
        engineUpdate(in, 0, 1);
    }

    @Override
    protected void engineUpdate(byte[] in, int inOffset, int inLen) throws SignatureException
    {
        if (pubSignature!=null)
        {
            pubSignature.update(in, inOffset, inLen);
            return;
        }

        checkInit();
        if (mechanismType== CKM_RSA_PKCS) buffer.write(in, inOffset, inLen);
        else
        {
            try { session.signUpdate(in, inOffset, inLen); }
            catch (CKException e) { throw new SignatureException(e); }
        }
    }

    @Override
    protected byte[] engineSign() throws SignatureException
    {
        checkInit();

        try
        {
            int size = prvKey.pkcs11Key.getBitSize();
            byte[] result;
            if (mechanismType== CKM_RSA_PKCS) result = session.sign(buffer.toByteArray(), size);
            else result = session.signFinal(size);
            return result;
        }
        catch (CKException e) { throw new SignatureException(e); }
        finally { closeSession(); }
    }

    @Override
    protected boolean engineVerify(byte[] sigBytes) throws SignatureException
    {
        return pubSignature.verify(sigBytes);
    }

    @Override
    protected void engineSetParameter(String param, Object value) throws InvalidParameterException
    {
        throw new UnsupportedOperationException("setParameter() not supported");
    }

    @Override
    protected Object engineGetParameter(String param) throws InvalidParameterException
    {
        throw new UnsupportedOperationException("getParameter() not supported");
    }

    public static final class NONEwithRSA extends RSASignature
    {
        public NONEwithRSA() { super(CKM_RSA_PKCS); }
    }

    public static final class SHA1withRSA extends RSASignature
    {
        public SHA1withRSA() { super(CKM_SHA1_RSA_PKCS); }
    }

    public static final class SHA256withRSA extends RSASignature
    {
        public SHA256withRSA() { super(CKM_SHA256_RSA_PKCS); }
    }

    public static final class SHA384withRSA extends RSASignature
    {
        public SHA384withRSA() { super(CKM_SHA384_RSA_PKCS); }
    }

    public static final class SHA512withRSA extends RSASignature
    {
        public SHA512withRSA() { super(CKM_SHA512_RSA_PKCS); }
    }

    private void closeSession()
    {
        if (session!=null) session.close();
        session = null;
    }

}
