package com.dyadicsec.provider;

import com.dyadicsec.pkcs11.CKException;

import javax.crypto.KeyAgreementSpi;
import javax.crypto.SecretKey;
import javax.crypto.ShortBufferException;
import javax.crypto.spec.SecretKeySpec;
import java.security.Key;
import java.security.SecureRandom;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.ProviderException;
import java.security.KeyStoreException;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.ECPoint;

/**
 * Created by valery.osheter on 19-Apr-16.
 */
public class ECDHKeyAgreement extends KeyAgreementSpi
{
    private ECPrivateKey prvKey = null;
    private ECPoint pub = null;

    @Override
    protected void engineInit(Key key, SecureRandom secureRandom) throws InvalidKeyException
    {
        if (key instanceof ECPrivateKey == false) throw new InvalidKeyException("CKKey must be instance of CKECPrivateKey");
        prvKey = (ECPrivateKey)key;

        try { prvKey.save(); }
        catch (KeyStoreException e) { throw new InvalidKeyException(e); }
    }

    @Override
    protected void engineInit(Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
    {
        if (algorithmParameterSpec != null) throw new InvalidAlgorithmParameterException("Parameters not supported");
        if (key instanceof ECPrivateKey == false) throw new InvalidKeyException("CKKey must be instance of CKECPrivateKey");
        prvKey = (ECPrivateKey)key;

        try { prvKey.save(); }
        catch (KeyStoreException e) { throw new InvalidKeyException(e); }
    }

    @Override
    protected Key engineDoPhase(Key key, boolean lastPhase) throws InvalidKeyException, IllegalStateException
    {
        if (prvKey == null) throw new IllegalStateException("Not initialized");
        if (!lastPhase) throw new IllegalStateException("Only two party agreement supported, lastPhase must be true");
        if (pub != null) throw new IllegalStateException("Phase already executed");
        if ((key instanceof java.security.interfaces.ECPublicKey)==false) throw new InvalidKeyException("CKKey must be a CKPublicKey with algorithm EC");

        java.security.interfaces.ECPublicKey pubKey = (java.security.interfaces.ECPublicKey)key;
        if (!pubKey.getParams().equals(prvKey.getParams()) &&
            !pubKey.getParams().getOrder().equals(prvKey.getParams().getOrder())) // IBM bug
        {
          throw new InvalidKeyException("EC curve doesn't match");
        }
        pub = pubKey.getW();

        return null;
    }

    @Override
    protected byte[] engineGenerateSecret() throws IllegalStateException
    {
        if ((prvKey == null) || (pub == null)) throw new IllegalStateException("Not initialized correctly");
        try { return prvKey.pkcs11Key.ecdh(pub); }
        catch (CKException e) { throw new ProviderException(e); }
    }

    @Override
    protected int engineGenerateSecret(byte[] out, int outOffset) throws IllegalStateException, ShortBufferException
    {
        int secretLen = prvKey.curve.getSize();
        if (outOffset + secretLen > out.length) throw new ShortBufferException("Need " + secretLen + " bytes, only " + (out.length - outOffset) + " available");
        byte[] secret = engineGenerateSecret();
        System.arraycopy(secret, 0, out, outOffset, secret.length);
        return secret.length;
    }

    @Override
    protected SecretKey engineGenerateSecret(String algorithm) throws IllegalStateException, NoSuchAlgorithmException, InvalidKeyException
    {
        if (algorithm == null) throw new NoSuchAlgorithmException("Algorithm must not be null");
        if (!(algorithm.equals("TlsPremasterSecret"))) throw new NoSuchAlgorithmException("Only supported for algorithm TlsPremasterSecret");
        return new SecretKeySpec(engineGenerateSecret(), "TlsPremasterSecret");
    }
}
