package cn.caict.encryption.utils.sm2;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.util.Arrays;

import cn.caict.encryption.exception.EncException;
import cn.caict.encryption.key.PrivateKey;
import cn.caict.encryption.model.KeyMember;
import cn.caict.encryption.model.KeyType;
import cn.caict.encryption.utils.hash.SM3Digest;
import cn.caict.encryption.utils.hex.HexFormat;
import org.bouncycastle.crypto.DerivationFunction;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.ShortenedDigest;
import org.bouncycastle.crypto.generators.KDF1BytesGenerator;
import org.bouncycastle.crypto.params.ECDomainParameters;
import org.bouncycastle.crypto.params.ISO18033KDFParameters;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.math.ec.ECPoint;

/**
 * SM2公钥加密算法实现 包括 -密钥对生成 -签名,验签
 *
 */
public class SM2 {
    private static BigInteger n = new BigInteger(
            "FFFFFFFE" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFF" + "7203DF6B" + "21C6052B" + "53BBF409" + "39D54123", 16);
    private static BigInteger p = new BigInteger(
            "FFFFFFFE" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFF" + "00000000" + "FFFFFFFF" + "FFFFFFFF", 16);
    private static BigInteger a = new BigInteger(
            "FFFFFFFE" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFF" + "FFFFFFFF" + "00000000" + "FFFFFFFF" + "FFFFFFFC", 16);
    private static BigInteger b = new BigInteger(
            "28E9FA9E" + "9D9F5E34" + "4D5A9E4B" + "CF6509A7" + "F39789F5" + "15AB8F92" + "DDBCBD41" + "4D940E93", 16);
    private static BigInteger gx = new BigInteger(
            "32C4AE2C" + "1F198119" + "5F990446" + "6A39C994" + "8FE30BBF" + "F2660BE1" + "715A4589" + "334C74C7", 16);
    private static BigInteger gy = new BigInteger(
            "BC3736A2" + "F4F6779C" + "59BDCEE3" + "6B692153" + "D0A9877C" + "C62A4740" + "02DF32E5" + "2139F0A0", 16);
    private static ECDomainParameters ecc_bc_spec;
    private static int w = (int) Math.ceil(n.bitLength() * 1.0 / 2) - 1;
    private static BigInteger _2w = new BigInteger("2").pow(w);
    private static final int DIGEST_LENGTH = 32;
    private static SM3Digest SM3 = new SM3Digest();

    private static SecureRandom random = new SecureRandom();
    public static ECCurve.Fp curve = new ECCurve.Fp(p, // q
            a, // a
            b); // b;
    public static ECPoint G = curve.createPoint(gx, gy);

    /**
     * 随机数生成器
     *
     * @param max
     * @return
     */
    private static BigInteger random(BigInteger max) {

        BigInteger r = new BigInteger(256, random);

        while (r.compareTo(max) >= 0) {
            r = new BigInteger(128, random);
        }

        return r;
    }

    /**
     * 判断是否在范围内
     *
     * @param param
     * @param min
     * @param max
     * @return
     */
    private static boolean between(BigInteger param, BigInteger min, BigInteger max) {
        if (param.compareTo(min) >= 0 && param.compareTo(max) < 0) {
            return true;
        } else {
            return false;
        }
    }

    /**
     * 判断生成的公钥是否合法
     *
     * @param publicKey
     * @return
     */
    private boolean checkPublicKey(ECPoint publicKey) {

        if (!publicKey.isInfinity()) {

            BigInteger x = publicKey.getXCoord().toBigInteger();
            BigInteger y = publicKey.getYCoord().toBigInteger();

            if (between(x, new BigInteger("0"), p) && between(y, new BigInteger("0"), p)) {

                BigInteger xResult = x.pow(3).add(a.multiply(x)).add(b).mod(p);
                BigInteger yResult = y.pow(2).mod(p);
                if (yResult.equals(xResult) && publicKey.multiply(n).isInfinity()) {
                    return true;
                }
            }
        }
        return false;
    }

    /**
     * 生成密钥对
     *
     * @return
     */
    public SM2KeyPair generateKeyPair() {

        BigInteger d = random(n.subtract(new BigInteger("1")));
        SM2KeyPair keyPair = new SM2KeyPair(G.multiply(d).normalize(), d);

        if (checkPublicKey(keyPair.getPublicKey())) {
            return keyPair;
        } else {
            return null;
        }
    }

    public SM2() {
        ecc_bc_spec = new ECDomainParameters(curve, G, n);
    }

    /**
     * 字节数组拼接
     *
     * @param params
     * @return
     */
    private static byte[] join(byte[]... params) {
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        byte[] res = null;
        try {
            for (int i = 0; i < params.length; i++) {
                baos.write(params[i]);
            }
            res = baos.toByteArray();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return res;
    }

    /**
     * sm3摘要
     *
     * @param params
     * @return
     */
    private static byte[] sm3hash(byte[]... params) {
        byte[] res = null;
        try {
            res =SM3.Hash(join(params));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return res;
    }

    /**
     * 取得用户标识字节数组
     *
     * @param IDA
     * @param aPublicKey
     * @return
     */
    private static byte[] ZA(String IDA, ECPoint aPublicKey) {
        byte[] idaBytes = IDA.getBytes();
        int entlenA = idaBytes.length * 8;
        byte[] ENTLA = new byte[] { (byte) (entlenA & 0xFF00), (byte) (entlenA & 0x00FF) };
        byte[] ZA = sm3hash(ENTLA, idaBytes, bytesLenFrom33To32(a.toByteArray()), bytesLenFrom33To32(b.toByteArray()),
                bytesLenFrom33To32(gx.toByteArray()), bytesLenFrom33To32(gy.toByteArray()),
                bytesLenFrom33To32(aPublicKey.getXCoord().toBigInteger().toByteArray()),
                bytesLenFrom33To32(aPublicKey.getYCoord().toBigInteger().toByteArray()));
        return ZA;
    }

    /**
     * 签名
     *
     * @param M
     *            签名信息
     * @param IDA
     *            签名方唯一标识
     * @param keyPair
     *            签名方密钥对
     * @return 签名
     */
    public static Signature sign(String M, String IDA, SM2KeyPair keyPair) throws UnsupportedEncodingException {
        byte[] ZA = ZA(IDA, keyPair.getPublicKey());
        byte[] M_ = join(ZA,M.getBytes("ISO8859-1"));
        BigInteger e = new BigInteger(1, sm3hash(M_));
        BigInteger k;
        BigInteger r;
        do {
            k = random(n);
            ECPoint p1 = G.multiply(k).normalize();
            BigInteger x1 = p1.getXCoord().toBigInteger();
            r = e.add(x1);
            r = r.mod(n);
        } while (r.equals(BigInteger.ZERO) || r.add(k).equals(n));

        BigInteger s = ((keyPair.getPrivateKey().add(BigInteger.ONE).modInverse(n))
                .multiply((k.subtract(r.multiply(keyPair.getPrivateKey()))).mod(n))).mod(n);

        return new Signature(r, s);
    }

    /**
     * 签名
     *
     * @param M
     *            签名信息
     * @param IDA
     *            签名方唯一标识
     * @param keyPair
     *            签名方密钥对
     * @return 签名字节数组
     */
    public static byte[] signWithBytes(String M, String IDA, SM2KeyPair keyPair) throws UnsupportedEncodingException {
        Signature sign = sign(M,IDA,keyPair);
        return sign.toByte();
    }


    /**
     * 验签
     *
     * @param M
     *            签名信息
     * @param signature
     *            签名
     * @param IDA
     *            签名方唯一标识
     * @param aPublicKey
     *            签名方公钥
     * @return true or false
     */
    public static boolean verify(String M, Signature signature, String IDA, ECPoint aPublicKey) throws UnsupportedEncodingException {
        if (!between(signature.r, BigInteger.ONE, n))
            return false;
        if (!between(signature.s, BigInteger.ONE, n))
            return false;

        byte[] M_ = join(ZA(IDA, aPublicKey), M.getBytes("ISO8859-1"));
        BigInteger e = new BigInteger(1, sm3hash(M_));
        BigInteger t = signature.r.add(signature.s).mod(n);

        if (t.equals(BigInteger.ZERO))
            return false;

        ECPoint p1 = G.multiply(signature.s).normalize();
        ECPoint p2 = aPublicKey.multiply(t).normalize();
        BigInteger x1 = p1.add(p2).normalize().getXCoord().toBigInteger();
        BigInteger R = e.add(x1).mod(n);
        if (R.equals(signature.r))
            return true;
        return false;
    }

    /**
     * 验签
     *
     * @param M
     *            签名信息
     * @param signature
     *            签名
     * @param IDA
     *            签名方唯一标识
     * @param aPublicKey
     *            签名方公钥
     * @return true or false
     */
    public static boolean verify(String M, byte[] signature, String IDA, ECPoint aPublicKey) throws UnsupportedEncodingException {

        byte[] rBytes = Arrays.copyOfRange(signature, 0, 32);
        byte[] sBytes = Arrays.copyOfRange(signature, 32, 64);

        rBytes = bytesLenPreHandle(rBytes);
        sBytes = bytesLenPreHandle(sBytes);

        BigInteger r = new BigInteger(rBytes);
        BigInteger s = new BigInteger(sBytes);


        if (!between(r, BigInteger.ONE, n))
            return false;
        if (!between(s, BigInteger.ONE, n))
            return false;

        byte[] M_ = join(ZA(IDA, aPublicKey), M.getBytes("ISO8859-1"));
        BigInteger e = new BigInteger(1, sm3hash(M_));
        BigInteger t = r.add(s).mod(n);

        if (t.equals(BigInteger.ZERO))
            return false;

        ECPoint p1 = G.multiply(s).normalize();
        ECPoint p2 = aPublicKey.multiply(t).normalize();
        BigInteger x1 = p1.add(p2).normalize().getXCoord().toBigInteger();
        BigInteger R = e.add(x1).mod(n);
        if (R.equals(r))
            return true;
        return false;
    }

    public static boolean verify(byte[] msg, byte[] sign, KeyMember member) throws UnsupportedEncodingException {
        byte[] pubKeyXBytes = new byte[32];
        byte[] pubKeyYBytes = new byte[32];
        System.arraycopy(member.getRawPKey(),1,pubKeyXBytes,0,32);
        System.arraycopy(member.getRawPKey(),33,pubKeyYBytes,0,32);

        pubKeyXBytes = SM2.bytesLenPreHandle(pubKeyXBytes);
        pubKeyYBytes = SM2.bytesLenPreHandle(pubKeyYBytes);

        BigInteger x = new BigInteger(pubKeyXBytes);
        BigInteger y = new BigInteger(pubKeyYBytes);

        ECPoint publicKey = SM2.curve.createPoint(x,y);
        return SM2.verify(new String(msg,"ISO8859-1"),sign,"1234567812345678",publicKey);
    }

    private static byte[] bytesLenPreHandle(byte[] bytes){
        if(new BigInteger(bytes).signum() == -1) {
            byte[] rBytesWithLen33 = new byte[33];
            System.arraycopy(bytes,0,rBytesWithLen33,1,32);
            bytes = rBytesWithLen33;
        }
        return bytes;
    }

    private static byte[] bytesLenFrom33To32(byte[] bytes){
        if(bytes.length == 33) {
            byte[] rBytesWithLen32 = new byte[32];
            System.arraycopy(bytes,1,rBytesWithLen32,0,32);
            bytes = rBytesWithLen32;
        }
        return bytes;
    }

    public static BigInteger bigIntegerPreHandle(BigInteger bigInteger){
        if(bigInteger.signum() == -1) {
            byte[] rBytesWithLen33 = new byte[33];
            System.arraycopy(bigInteger.toByteArray(),0,rBytesWithLen33,1,32);
            bigInteger = new BigInteger(rBytesWithLen33);
        }
        return bigInteger;
    }

    public static byte[] getRawPubKey(ECPoint pubKey){
        byte[] pubKeyBytes = new byte[65];
        pubKeyBytes[0] = (byte)4;
        byte[] pubKeyXBytes = pubKey.getAffineXCoord().toBigInteger().toByteArray();
        byte[] pubKeyYBytes = pubKey.getAffineYCoord().toBigInteger().toByteArray();
        if(pubKeyXBytes.length == 33)
            System.arraycopy(pubKeyXBytes,1,pubKeyBytes,1,32);
        else
            System.arraycopy(pubKeyXBytes,0,pubKeyBytes,1,32);

        if(pubKeyYBytes.length == 33)
            System.arraycopy(pubKeyYBytes,1,pubKeyBytes,33,32);
        else
            System.arraycopy(pubKeyYBytes,0,pubKeyBytes,33,32);
        return pubKeyBytes;
    }

    public static SM2KeyPair getSM2KeyPair()
    {
        SM2 sm = new SM2();
        return sm.generateKeyPair();
    }

    public static byte[] getRawSkey(SM2KeyPair sm2KeyPair){
        BigInteger priKey=sm2KeyPair.getPrivateKey();

        byte[] priKeyBytes = priKey.toByteArray();
        if (priKey.toByteArray().length == 33) {
            priKeyBytes =  Arrays.copyOfRange(priKey.toByteArray(), 1, 33);
        }
        return priKeyBytes;
    }

    public static byte[] getRawPubKey(SM2KeyPair sm2KeyPair){
        ECPoint pubKey=sm2KeyPair.getPublicKey();
        return getRawPubKey(pubKey);
    }


    public static String printHexString(byte[] b) {
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < b.length; i++) {
            String hex = Integer.toHexString(b[i] & 0xFF);
            if (hex.length() == 1) {
                builder.append('0' + hex);
                hex = '0' + hex;
            }
            System.out.print(hex.toUpperCase());
            builder.append(hex);
        }
        System.out.println();
        return builder.toString();
    }

    private static void getPublicKey(String bPkey, KeyMember member) throws EncException {
        if (null == bPkey) {
            throw new EncException("public key cannot be null");
        }

        if (!HexFormat.isHexString(bPkey)) {
            throw new EncException("public key (" + bPkey + ") is invalid, please check");
        }

        KeyType type = null;
        byte[] buffPKey = HexFormat.hexToByte(bPkey);

        if (buffPKey.length < 3) {
            throw new EncException("public key (" + bPkey + ") is invalid, please check");
        }

        if (buffPKey[0] != (byte)0xB0) {
            throw new EncException("public key (" + bPkey + ") is invalid, please check");
        }

        // 判断算法类型
        if(buffPKey[1] == 0x65) {
            type = KeyType.ED25519;
        }
        else if(buffPKey[1] == 0x7A) {
            type = KeyType.SM2;
        }
        else{
            throw new EncException("public key (" + bPkey + ") is invalid, please check");
        }

        if (buffPKey[2] != 0x66) {
            throw new EncException("public key (" + bPkey + ") is invalid, please check");
        }

        byte[] rawPKey = new byte[buffPKey.length - 3];
        System.arraycopy(buffPKey, 3, rawPKey, 0, rawPKey.length);
        member.setRawPKey(rawPKey);
        member.setKeyType(type);
    }

    private static boolean allZero(byte[] buffer) {
        for (int i = 0; i < buffer.length; i++) {
            if (buffer[i] != 0)
                return false;
        }
        return true;
    }

    private static byte[] calculateHash(BigInteger x2, byte[] M, BigInteger y2) {
        ShortenedDigest digest = new ShortenedDigest(new SHA256Digest(), 20);
        byte[] buf = x2.toByteArray();
        digest.update(buf, 0, buf.length);
        digest.update(M, 0, M.length);
        buf = y2.toByteArray();
        digest.update(buf, 0, buf.length);
        buf = new byte[20];
        digest.doFinal(buf, 0);
        return buf;
    }


    public static byte[] encrypt(byte[] msg, String publickey) throws EncException {
        if(msg.length == 0){
            throw new EncException("msg is null, please check");
        }
        KeyMember member = new KeyMember();
        getPublicKey(publickey, member);

        byte[] pubKeyXBytes = new byte[32];
        byte[] pubKeyYBytes = new byte[32];
        System.arraycopy(member.getRawPKey(),1,pubKeyXBytes,0,32);
        System.arraycopy(member.getRawPKey(),33,pubKeyYBytes,0,32);

        pubKeyXBytes = SM2.bytesLenPreHandle(pubKeyXBytes);
        pubKeyYBytes = SM2.bytesLenPreHandle(pubKeyYBytes);

        BigInteger x = new BigInteger(pubKeyXBytes);
        BigInteger y = new BigInteger(pubKeyYBytes);

        ECPoint ecPublicKey = SM2.curve.createPoint(x,y);
        return encrypt(msg,ecPublicKey);
    }


    /**
     * * 加密 * @param input 待加密消息M * @param publicKey 公钥 * @return byte[] 加密后的字节数组
     */
    public static byte[] encrypt(byte[] msg, ECPoint publicKey)  throws EncException{
        //byte[] inputBuffer = input.getBytes();
        if(msg.length == 0){
            throw new EncException("msg is null, please check");
        }

        /* 1 产生随机数k，k属于[1, n-1] */
        BigInteger k = random(n);

        /* 2 计算椭圆曲线点C1 = [k]G = (x1, y1) */
        ECPoint C1 = G.multiply(k);
        byte[] C1Buffer = C1.getEncoded(false);

        /* 4 计算 [k]PB = (x2, y2) */
        ECPoint kpb = publicKey.multiply(k).normalize();
        /* 5 计算 t = KDF(x2||y2, klen) */
        byte[] kpbBytes = kpb.getEncoded(false);
        DerivationFunction kdf = new KDF1BytesGenerator(new ShortenedDigest(new SHA256Digest(), 20));
        byte[] t = new byte[msg.length];
        kdf.init(new ISO18033KDFParameters(kpbBytes));
        kdf.generateBytes(t, 0, t.length);
        if (allZero(t)) {
            //System.err.println("all zero");
            return null;
        }
        /* 6 计算C2=M^t */

        byte[] C2 = new byte[msg.length];
        for (int i = 0; i < msg.length; i++) {
            C2[i] = (byte) (msg[i] ^ t[i]);
        }
        /* 7 计算C3 = Hash(x2 || M || y2) */
        byte[] C3 = calculateHash(kpb.getXCoord().toBigInteger(), msg, kpb.getYCoord().toBigInteger());

        /* 8 输出密文 C=C1 || C2 || C3 */
        byte[] encryptResult = new byte[C1Buffer.length + C2.length + C3.length];
        System.arraycopy(C1Buffer, 0, encryptResult, 0, C1Buffer.length);
        System.arraycopy(C2, 0, encryptResult, C1Buffer.length, C2.length);
        System.arraycopy(C3, 0, encryptResult, C1Buffer.length + C2.length, C3.length);
        //System.out.print("密文: ");
        //printHexString(encryptResult);
        return encryptResult;
    }

    public static byte[] decrypt(byte[] encryptData, String encPrivateKey) throws EncException{
        return decrypt(encryptData,PrivateKey.getRawPrivateKey(encPrivateKey));
    }

    public static byte[] decrypt(byte[] encryptData, byte[] rawPrivateKey) throws EncException{
        BigInteger priKey = new BigInteger(rawPrivateKey);
        priKey = bigIntegerPreHandle(priKey);
        return decrypt(encryptData,priKey);
    }

    public static byte[] decrypt(byte[] encryptData, BigInteger privateKey) throws EncException{

        byte[] C1Byte = new byte[65];
        System.arraycopy(encryptData, 0, C1Byte, 0, C1Byte.length);
        ECPoint C1 = curve.decodePoint(C1Byte).normalize();
        /* 计算[dB]C1 = (x2, y2) */
        ECPoint dBC1 = C1.multiply(privateKey).normalize();
        /* 计算t = KDF(x2 || y2, klen) */
        byte[] dBC1Bytes = dBC1.getEncoded(false);
        DerivationFunction kdf = new KDF1BytesGenerator(new ShortenedDigest(new SHA256Digest(), 20));
        int klen = encryptData.length - 65 - 20;
        byte[] t = new byte[klen];
        kdf.init(new ISO18033KDFParameters(dBC1Bytes));
        kdf.generateBytes(t, 0, t.length);
        if (allZero(t)) {
            throw new EncException("allZero, (" + t + ")  please check");
        }
        /* 5 计算M'=C2^t */
        byte[] M = new byte[klen];
        for (int i = 0; i < M.length; i++) {
            M[i] = (byte) (encryptData[C1Byte.length + i] ^ t[i]);
        }
        /* 6 计算 u = Hash(x2 || M' || y2) 判断 u == C3是否成立 */
        byte[] C3 = new byte[20];
        System.arraycopy(encryptData, encryptData.length - 20, C3, 0, 20);
        byte[] u = calculateHash(dBC1.getXCoord().toBigInteger(), M, dBC1.getYCoord().toBigInteger());
        if (Arrays.equals(u, C3)) {
            return M;
        }
        return null;
    }

    public static class Signature {
        BigInteger r;
        BigInteger s;

        public Signature(BigInteger r, BigInteger s) {
            this.r = r;
            this.s = s;
        }

        public String toString() {
            return r.toString(16)  + s.toString(16);
        }

        public byte[] toByte(){
            byte[] sign = new byte[64];
            byte[] rBytes = r.toByteArray();
            byte[] sBytes = s.toByteArray();
            if(rBytes.length == 33)
                System.arraycopy(rBytes,1,sign,0,32);
            else
                System.arraycopy(rBytes,0,sign,0,32);

            if(sBytes.length == 33)
                System.arraycopy(sBytes,1,sign,32,32);
            else
                System.arraycopy(sBytes,0,sign,32,32);
            return sign;
        }

    }
}