package com.dyadicsec.pkcs11;

import java.io.IOException;
import java.math.BigInteger;
import java.security.AlgorithmParameters;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.util.Arrays;

/**
 * Created by valery.osheter on 20-Apr-16.
 */
public final class ECCurve
{
    private ECParameterSpec spec;
    private String name;
    private int bits;
    private byte[] oidBin;

    private ECCurve(String name, int bits, byte[] oidBin)
    {
        this.name = name;
        this.bits = bits;
        this.oidBin = oidBin;

        try
        {
            AlgorithmParameters parameters = AlgorithmParameters.getInstance("EC");
            parameters.init(new ECGenParameterSpec(name));
            spec = parameters.getParameterSpec(ECParameterSpec.class);
        }
        catch (java.lang.Exception e) { throw new RuntimeException(e); }
    }

    public byte[] getOidBin() { return oidBin; }
    public ECParameterSpec getSpec() { return spec; }

    public int getSize() { return (bits+7)/8; }
    public int getBits() { return bits; }

    private static final ECCurve[] curves = {
            new ECCurve("secp256r1", 256, new byte[] { (byte)0x06, (byte)0x08, (byte)0x2a, (byte)0x86, (byte)0x48, (byte)0xce, (byte)0x3d, (byte)0x03, (byte)0x01, (byte)0x07 }), 	//"1.2.840.10045.3.1.7",
            new ECCurve("secp384r1", 384, new byte[] { (byte)0x06, (byte)0x05, (byte)0x2b, (byte)0x81, (byte)0x04, (byte)0x00, (byte)0x22 }), 										//"1.3.132.0.34",
            new ECCurve("secp521r1", 521, new byte[] { (byte)0x06, (byte)0x05, (byte)0x2b, (byte)0x81, (byte)0x04, (byte)0x00, (byte)0x23 }),										//"1.3.132.0.35",
            new ECCurve("secp256k1", 256, new byte[] { (byte)0x06, (byte)0x05, (byte)0x2b, (byte)0x81, (byte)0x04, (byte)0x00, (byte)0x0a }),										//"1.3.132.0.10",
    };

    public static ECCurve find(String name)
    {
        for (ECCurve curve : curves)
        {
            if (curve.name.equals(name)) return curve;
        }
        return null;
    }

    public static ECCurve find(ECParameterSpec spec)
    {
        for (ECCurve curve : curves)
        {
            if (curve.spec.equals(spec)) return curve;

            // dirty
            if (spec.getOrder().equals(curve.getSpec().getOrder())) return curve;
        }

        return null;
    }

    public static ECCurve find(int bits)
    {
        for (ECCurve curve : curves)
        {
            if (curve.getBits()==bits) return curve;
        }
        return null;
    }

    public static ECCurve find(byte[] oidBin)
    {
        for (ECCurve curve : curves)
        {
            if (Arrays.equals(oidBin, curve.oidBin)) return curve;
        }
        return null;
    }

    public ECPoint derDecodePoint(byte[] data) throws IOException
    {
        data = DER.decode(DER.TAG_OCTETSTRING, data);
        int n = getSize();
        if (data.length!=1+n*2) throw new IOException("Point does not match field size");
        if (data[0]!=4) throw new IOException("Only uncompressed point format supported");

        byte[] xb = Arrays.copyOfRange(data, 1, 1+n);
        byte[] yb = Arrays.copyOfRange(data, 1+n, 1+n*2);
        return new ECPoint(new BigInteger(1, xb), new BigInteger(1, yb));
    }

    static private byte[] bigIntToByteArray(BigInteger value, int size)
    {
        byte[] bytes = value.toByteArray();
        int offset = 0;
        while (offset<bytes.length && bytes[offset]==0) offset++;
        int count = bytes.length-offset;

        byte[] out = new byte[size];
        System.arraycopy(bytes, offset, out, size-count, count);
        return out;
    }

    public byte[] derEncodePoint(ECPoint point)
    {
        int s = getSize();
        return DER.encode(DER.TAG_OCTETSTRING, DER.cat(
                new byte[]{4},
                bigIntToByteArray(point.getAffineX(), s),
                bigIntToByteArray(point.getAffineY(), s)));
    }

}
