package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.*;
import java.math.BigInteger;
import java.security.KeyStoreException;
import java.security.KeyFactory;
import java.security.NoSuchProviderException;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.spec.*;

/**
 * Created by valery.osheter on 19-Apr-16.
 */
public final class RSAPrivateKey extends DYKey implements RSAPrivateCrtKey
{
    private static final long serialVersionUID = 1L;
    private RSAPrivateCrtKey sw = null;
    CKRSAPrivateKey pkcs11Key = null;
    private UnwrapInfo unwrapInfo = null;
    KeyParameters keyParams = null;
    private RSAPublicKey genPublicKey = null;
    private int bitSize = 0;

    RSAPrivateKey()
    {
    }

    RSAPrivateKey(CKRSAPrivateKey pkcs11Key)
    {
        this.pkcs11Key = pkcs11Key;
    }

    @Override
    protected CKPrivateKey getPkcs11Key() { return pkcs11Key;}

    @Override
    protected void unwrap(String alias) throws KeyStoreException
    {
        CK_ATTRIBUTE[] t = CKRSAPrivateKey.getUnwrapTemplate(alias, KeyParameters.toPolicy(keyParams));
        pkcs11Key = unwrapInfo.unwrap(CKRSAPrivateKey.class, t);
        unwrapInfo = null;
    }

    @Override
    protected void create(KeyStore store, String alias) throws KeyStoreException
    {
        try { pkcs11Key = CKRSAPrivateKey.create(store.slot, alias, KeyParameters.toPolicy(keyParams),
                sw.getModulus(),
                sw.getPublicExponent(),
                sw.getPrivateExponent(),
                sw.getPrimeP(),
                sw.getPrimeQ(),
                sw.getPrimeExponentP(),
                sw.getPrimeExponentQ(),
                sw.getCrtCoefficient()); }
        catch (CKException e) { throw new KeyStoreException(e); }
    }

    @Override
    protected void generate(KeyStore store, String alias) throws KeyStoreException
    {
        try { pkcs11Key = CKRSAPrivateKey.generate(store.slot, alias, KeyParameters.toPolicy(keyParams), bitSize); }
        catch (CKException e) { throw new KeyStoreException(e); }

        if (genPublicKey!=null)
        {
            try { genPublicKey.init(pkcs11Key.getN(), pkcs11Key.getE()); }
            catch (Throwable e) { throw new KeyStoreException(e); }
        }
    }

    @Override
    protected boolean swKeyPresent() {return sw != null;}
    @Override
    protected boolean unwrapInfoPresent() {return unwrapInfo != null;}

    RSAPrivateKey initForUnwrap(UnwrapInfo unwrapInfo)
    {
        this.unwrapInfo = unwrapInfo;
        return this;
    }

    RSAPrivateKey initForImport(KeyParameters keyParams, KeySpec keySpec) throws NoSuchAlgorithmException, InvalidKeySpecException, NoSuchProviderException
    {
        this.keyParams = keyParams;
        KeyFactory kf;

        try { kf = KeyFactory.getInstance("RSA", "SunRsaSign"); }
        catch (NoSuchProviderException e) { kf = null; }
        catch (NoSuchAlgorithmException e) { kf = null; }

        if (kf==null) try { kf = KeyFactory.getInstance("RSA", "IBMJCE"); }
        catch (NoSuchProviderException e) { kf = null; }
        catch (NoSuchAlgorithmException e) { kf = null; }

        if (kf==null) kf = KeyFactory.getInstance("RSA", "IBMJSSE2");

        this.sw = (RSAPrivateCrtKey)kf.generatePrivate(keySpec);
        return this;
    }

    RSAPrivateKey initForImport(RSAPrivateCrtKey key)
    {
        this.sw = key;
        return this;
    }

    RSAPrivateKey initForGenerate(KeyParameters keyParams, RSAPublicKey genPublicKey, int genBitSize)
    {
        this.keyParams = keyParams;
        this.genPublicKey = genPublicKey;
        bitSize = genBitSize;
        genPublicKey.prvKey = this;
        return this;
    }

    @Override
    public BigInteger getPrivateExponent() { return (sw!=null) ? sw.getPrivateExponent() : null; }

    @Override
    public String getAlgorithm() {
        return "RSA";
    }

    @Override
    public String getFormat()  { return (sw!=null) ? sw.getFormat() : "PKCS#8"; }

    @Override
    public byte[] getEncoded() { return (sw!=null) ? sw.getEncoded() : null; }

    @Override
    public BigInteger getPrimeP()  { return (sw!=null) ? sw.getPrimeP() : null; }

    @Override
    public BigInteger getPrimeQ()  { return (sw!=null) ? sw.getPrimeQ() : null; }

    @Override
    public BigInteger getPrimeExponentP()  { return (sw!=null) ? sw.getPrimeExponentP() : null; }

    @Override
    public BigInteger getPrimeExponentQ() { return (sw!=null) ? sw.getPrimeExponentQ() : null; }

    @Override
    public BigInteger getCrtCoefficient() { return (sw!=null) ? sw.getCrtCoefficient() : null; }

    @Override
    public BigInteger getModulus()
    {
        if (sw!=null) return sw.getModulus();
        if (pkcs11Key!=null)
        {
            try { return pkcs11Key.getN(); }
            catch (CKException e) { return null; }
        }
        return null;
    }

    @Override
    public BigInteger getPublicExponent()
    {
        if (sw!=null) return sw.getPublicExponent();
        if (pkcs11Key!=null)
        {
            try { return pkcs11Key.getE(); }
            catch (CKException e) { return null; }
        }
        return null;
    }

    int getBitSize() throws KeyStoreException
    {
        if (bitSize==0)
        {
            if (sw!=null) bitSize = Utils.bigIntByteSize(sw.getModulus())*8;
            else if (pkcs11Key!=null)
            {
                try { bitSize = pkcs11Key.getBitSize(); }
                catch (CKException e) { throw new KeyStoreException(e); }
            }
        }
        return bitSize;
    }

}
