package com.unbound.provider;

import com.unbound.common.crypto.SystemProvider;
import com.unbound.provider.kmip.KMIP;

import javax.crypto.*;
import javax.crypto.spec.OAEPParameterSpec;
import javax.crypto.spec.PSource;
import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.security.*;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidParameterSpecException;
import java.security.spec.MGF1ParameterSpec;
import java.util.Arrays;


public final class RSACipher extends CipherSpi
{
  private final static byte[] B0 = new byte[0];

  private UBRSAPrivateKey prvKey = null;
  private RSAPublicKey pubKey = null;
  private SecureRandom secureRandom = null;
  private OAEPParameterSpec oaepSpec = null;
  private byte[] buffer = new byte[512]; // max rsa size
  private int bufferSize = 0;
  private int bufferOffset = 0;
  private int opmode = 0;
  private boolean isOaep = false;
  private boolean isRaw = false;
  private int oaepHashBitSize = 0;
  private int oaepMgfBitSize = 0;
  private byte[] oaepSource = null;

  private static MGF1ParameterSpec mgfBitSizeToSpec(int bitSize) throws InvalidAlgorithmParameterException
  {
    switch (bitSize)
    {
      case 160: return MGF1ParameterSpec.SHA1;
      case 256: return MGF1ParameterSpec.SHA256;
      case 384: return MGF1ParameterSpec.SHA384;
      case 512: return MGF1ParameterSpec.SHA512;
    }
    throw new InvalidAlgorithmParameterException("Unsupported OAEP MGF hash algorithm");
  }

  private static String hashBitSizeToName(int bitSize) throws InvalidAlgorithmParameterException
  {
    switch (bitSize)
    {
        case 160:  return "SHA-1";
        case 256:  return "SHA-256";
        case 384:  return "SHA-384";
        case 512:  return "SHA-512";
    }
    throw new InvalidAlgorithmParameterException("Unsupported OAEP hash algorithm");
  }

  private static int hashBitSizeToKmipHashAlg(int bitSize)
  {
    switch (bitSize)
    {
        case 160:  return KMIP.HashingAlgorithm.SHA_1;
        case 256:  return KMIP.HashingAlgorithm.SHA_256;
        case 384:  return KMIP.HashingAlgorithm.SHA_384;
        case 512:  return KMIP.HashingAlgorithm.SHA_512;
    }
    return 0;
  }

  private static String paddingTypeToName(boolean isRaw, boolean isOaep, int oaepHashBitSize) throws InvalidAlgorithmParameterException
  {
    if (isRaw) return "NOPadding";
    if (!isOaep) return "PKCS1Padding";
    return "OAEPWith" + hashBitSizeToName(oaepHashBitSize) + "AndMGF1Padding";
  }

  private static int hashNameToBitSize(String hashName) throws InvalidAlgorithmParameterException
  {
      hashName = hashName.toUpperCase();
      if (hashName.equals("SHA1")) return 160;
      if (hashName.equals("SHA-1")) return 160;
      if (hashName.equals("SHA-256")) return 256;
      if (hashName.equals("SHA-384")) return 384;
      if (hashName.equals("SHA-512")) return 512;
      throw new InvalidAlgorithmParameterException("OAEP hash algorithm not supported: " + hashName);
  }

  private static int oaepPaddingToHashBitSize(String padding) throws NoSuchPaddingException
  {
    padding = padding.toUpperCase();
    if (padding.equals("OAEPPADDING")) return 160; // SHA1
    if (padding.startsWith("OAEPWITH") && padding.endsWith("ANDMGF1PADDING"))
    {
        String hashName = padding.substring(8, padding.length() - 14);
        try { return hashNameToBitSize(hashName); }
        catch (InvalidAlgorithmParameterException e) { throw new NoSuchPaddingException("padding not supported: " + padding); }
    }
    throw new NoSuchPaddingException("padding not supported: " + padding);
  }

  private AlgorithmParameterSpec getParameterSpec() throws InvalidAlgorithmParameterException
  {
    if (oaepSpec == null)
    {
      if (!isOaep) return null;
      String oaepHashName = hashBitSizeToName(oaepHashBitSize);
      MGF1ParameterSpec mgfSpec = mgfBitSizeToSpec(oaepMgfBitSize);
      oaepSpec = new OAEPParameterSpec(oaepHashName, "MGF1", mgfSpec, PSource.PSpecified.DEFAULT);
    }
    return oaepSpec;
  }

  private void init(int opmode, Key key, AlgorithmParameterSpec paramSpec) throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    this.opmode = opmode;
    bufferOffset = 0;

    switch (opmode)
    {
      case Cipher.ENCRYPT_MODE:
      case Cipher.WRAP_MODE:
        prvKey = null;
        if (key instanceof RSAPublicKey) pubKey = (RSAPublicKey) key;
        else throw new InvalidKeyException("Invalid key type");
        bufferSize = (pubKey.getModulus().bitLength() + 7)/8;
        break;

      case Cipher.DECRYPT_MODE:
      case Cipher.UNWRAP_MODE:
        pubKey = null;
        if (key instanceof UBRSAPrivateKey) prvKey = (UBRSAPrivateKey) key;
        else throw new InvalidKeyException("Invalid key type");
        bufferSize = (prvKey.getBitSize()+7)/8;
        break;

      default:
        throw new InvalidKeyException("Unknown mode: " + opmode);
    }

    if (paramSpec != null)
    {
      if (!isOaep) throw new InvalidAlgorithmParameterException("Wrong padding parameter");
      if (!(paramSpec instanceof OAEPParameterSpec)) throw new InvalidAlgorithmParameterException("Wrong Parameters for OAEP Padding");
      oaepSpec = (OAEPParameterSpec) paramSpec;
      oaepHashBitSize = hashNameToBitSize(oaepSpec.getDigestAlgorithm());
      String mgfAlgName = oaepSpec.getMGFAlgorithm();
      if (!mgfAlgName.toUpperCase().equals("MGF1")) throw new InvalidAlgorithmParameterException("Unsupported MGF algorithm: " + mgfAlgName);
      AlgorithmParameterSpec mgfParam = oaepSpec.getMGFParameters();
      if (mgfParam instanceof MGF1ParameterSpec)
      {
        String mgfHashName = ((MGF1ParameterSpec)mgfParam).getDigestAlgorithm();
        oaepMgfBitSize = hashNameToBitSize(mgfHashName);
      }
      else throw new InvalidAlgorithmParameterException("Unsupported MGF hash");
      PSource s = oaepSpec.getPSource();
      if (s.getAlgorithm().equals("PSpecified")) oaepSource = ((PSource.PSpecified) s).getValue();
      else throw new InvalidAlgorithmParameterException("Unsupported pSource " + s.getAlgorithm() + "; PSpecified only");
    }
  }

  private void update(byte[] in, int inOffset, int inLen)
  {
    if ((inLen == 0) || (in == null)) return;
    if (bufferOffset + inLen <= buffer.length) System.arraycopy(in, inOffset, buffer, bufferOffset, inLen);
    bufferOffset += inLen;
  }

  private byte[] doFinal(Key wrappedKey) throws BadPaddingException, IllegalBlockSizeException, InvalidAlgorithmParameterException, InvalidKeyException, IOException
  {
    if (pubKey!=null)
    {
      Cipher cipher = SystemProvider.Cipher.getInstance("RSA/ECB/"+paddingTypeToName(isRaw, isOaep, oaepHashBitSize));
      cipher.init(opmode, pubKey, getParameterSpec(), secureRandom);
      if (opmode==Cipher.WRAP_MODE) return cipher.wrap(wrappedKey);
      return cipher.doFinal(buffer, 0, bufferOffset);
    }

    byte[] in = Arrays.copyOfRange(buffer, 0, bufferOffset);
    int kmipPadding, kmipHashAlg = 0, kmipMgfAlg = 0;
    kmipPadding = isOaep ? KMIP.PaddingMethod.OAEP : KMIP.PaddingMethod.PKCS1_V1_5;
    if (isOaep)
    {
      kmipHashAlg = hashBitSizeToKmipHashAlg(oaepHashBitSize);
      kmipMgfAlg = hashBitSizeToKmipHashAlg(oaepMgfBitSize);
    }

    return prvKey.decrypt(in, kmipPadding, kmipHashAlg, kmipMgfAlg, oaepSource);
  }

  // ------------------------- interface --------------------------
  @Override
  protected void engineSetMode(String mode) throws NoSuchAlgorithmException
  {
    mode = mode.toUpperCase();
    if (!mode.equals("NONE") && !mode.equals("ECB")) throw new NoSuchAlgorithmException("Mode not supported: " + mode);
  }

  @Override
  protected void engineSetPadding(String padding) throws NoSuchPaddingException
  {
    //mechanismType = paddingToMechanismType(padding);
    padding = padding.toUpperCase();
    if (padding.equals("NOPADDING")) { isRaw = true; isOaep = false; }
    else if (padding.equals("PKCS1PADDING")) { isRaw = false; isOaep = false; }
    else if (padding.equals("OAEPPADDING")) { isRaw = false; isOaep = true; }
    else if (padding.startsWith("OAEPWITH") && padding.endsWith("ANDMGF1PADDING")) { isRaw = false; isOaep = true; }
    else throw new NoSuchPaddingException("Unsupported padding: " + padding);

    if (isOaep)
    {
      oaepHashBitSize = oaepPaddingToHashBitSize(padding);
      oaepMgfBitSize = 160; // SHA1
    }
  }

  @Override
  protected int engineGetBlockSize()
  {
    return 0;
  }

  @Override
  protected int engineGetOutputSize(int inputLen)
  {
    return bufferSize;
  }

  @Override
  protected byte[] engineGetIV()
  {
    return null;
  }

  @Override
  protected AlgorithmParameters engineGetParameters()
  {
    try
    {
      AlgorithmParameterSpec spec = getParameterSpec();
      if (spec==null) return null;
      AlgorithmParameters params = AlgorithmParameters.getInstance("OAEP");
      params.init(spec);
      return params;
    }
    catch (Throwable e) { throw new RuntimeException("Invalid algorithm parameters not supported");  }
  }

  @Override
  protected void engineInit(int opmode, Key key, SecureRandom secureRandom) throws InvalidKeyException
  {
    this.secureRandom = secureRandom;
    try { init(opmode, key, null); }
    catch (InvalidAlgorithmParameterException e) { throw new InvalidKeyException("Wrong parameters", e); }
  }

  @Override
  protected void engineInit(int opmode, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    this.secureRandom = secureRandom;
    init(opmode, key, algorithmParameterSpec);
  }

  @Override
  protected void engineInit(int opmode, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom) throws InvalidKeyException, InvalidAlgorithmParameterException
  {
    this.secureRandom = secureRandom;
    OAEPParameterSpec spec = null;

    if (algorithmParameters != null)
    {
      try { spec = algorithmParameters.getParameterSpec(OAEPParameterSpec.class); }
      catch (InvalidParameterSpecException e) { throw new InvalidKeyException("Wrong parameters", e); }
    }

    init(opmode, key, spec);
  }

  @Override
  protected byte[] engineUpdate(byte[] in, int inOffset, int inLen)
  {
    update(in, inOffset, inLen);
    return B0;
  }

  @Override
  protected int engineUpdate(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException
  {
    update(in, inOffset, inLen);
    return 0;
  }

  @Override
  protected byte[] engineDoFinal(byte[] in, int inOffset, int inLen) throws IllegalBlockSizeException, BadPaddingException
  {
    update(in, inOffset, inLen);
    if (bufferOffset>buffer.length)  throw new IllegalBlockSizeException("Input must be under " + buffer.length + " bytes");
    try { return doFinal(null); }
    catch (Exception e) { throw new BadPaddingException("engineDoFinal failed"); }
  }

  @Override
  protected int engineDoFinal(byte[] in, int inOffset, int inLen, byte[] out, int outOffset) throws ShortBufferException, IllegalBlockSizeException, BadPaddingException
  {
    byte[] b = engineDoFinal(in, inOffset, inLen);
    if (outOffset + b.length > out.length) throw new ShortBufferException("Output buffer is too small");
    System.arraycopy(b, 0, out, outOffset, b.length);
    return b.length;
  }

  @Override
  protected byte[] engineWrap(Key key) throws InvalidKeyException,  IllegalBlockSizeException
  {
    byte[] encoded = key.getEncoded();
    if ((encoded == null) || (encoded.length == 0)) throw new InvalidKeyException("Could not obtain encoded key");
    if (encoded.length > buffer.length) throw new InvalidKeyException("CKKey is too long for wrapping");

    try { return doFinal(key); }
    catch (Exception e) { throw new InvalidKeyException("Wrapping failed", e); }
  }

  @Override
  protected Key engineUnwrap(byte[] wrappedKey, String algorithm, int wrappedKeyType) throws InvalidKeyException, NoSuchAlgorithmException
  {
    if (wrappedKeyType != Cipher.SECRET_KEY) throw new UnsupportedOperationException("wrappedKeyType == " + wrappedKeyType);
    if (wrappedKey.length > buffer.length) throw new InvalidKeyException("Key is too long for unwrapping");

    byte[] key;
    try
    {
      key = doFinal(null);
    }
    catch (Exception e)
    {
      throw new ProviderException(e);
    }

    return new SecretKeySpec(key, algorithm);
  }
}
