package cn.benma666.sm.sm2;

import cn.benma666.myutils.ByteUtil;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import org.bouncycastle.crypto.params.ECPublicKeyParameters;
import org.bouncycastle.math.ec.ECPoint;

import java.io.IOException;
import java.math.BigInteger;
import java.util.HashMap;
import java.util.Map;

public class SM2EncDecUtils {

    public static final String public_key = "public_key";
    public static final String private_key = "private_key";

    // 生成随机秘钥对
    public static Map<String, String> generateKeyPair() {
        SM2 sm2 = SM2.Instance();
        AsymmetricCipherKeyPair key;
        while (true) {
            key = sm2.ecc_key_pair_generator.generateKeyPair();
            if (((ECPrivateKeyParameters) key.getPrivate()).getD().toByteArray().length == 32) {
                break;
            }
        }
        ECPrivateKeyParameters ecpriv = (ECPrivateKeyParameters) key.getPrivate();
        ECPublicKeyParameters ecpub = (ECPublicKeyParameters) key.getPublic();
        BigInteger privateKey = ecpriv.getD();
        ECPoint publicKey = ecpub.getQ();
        String pubk = ByteUtil.byteToHex(publicKey.getEncoded());
        String prik = ByteUtil.byteToHex(privateKey.toByteArray());
        System.out.println("公钥: " + pubk);
        System.out.println("私钥: " + prik);
        Map<String, String> result = new HashMap<>();

        result.put(public_key, pubk);
        result.put(private_key, prik);

        return result;
    }

    // 数据加密
    public static String encrypt(byte[] publicKey, byte[] data) throws IOException {
        if (publicKey == null || publicKey.length == 0) {
            return null;
        }

        if (data == null || data.length == 0) {
            return null;
        }

        byte[] source = new byte[data.length];
        System.arraycopy(data, 0, source, 0, data.length);
        Cipher cipher = new Cipher();
        SM2 sm2 = SM2.Instance();
        ECPoint userKey = sm2.ecc_curve.decodePoint(publicKey);
        ECPoint c1 = cipher.Init_enc(sm2, userKey);
        cipher.Encrypt(source);
        byte[] c3 = new byte[32];
        cipher.Dofinal(c3);
        return new StringBuffer(ByteUtil.byteToHex(c1.getEncoded())).append(ByteUtil.byteToHex(c3)).append(ByteUtil.byteToHex(source)).toString();
    }

    // 数据解密
    public static byte[] decrypt(byte[] privateKey, byte[] encryptedData) throws IOException {
        if (privateKey == null || privateKey.length == 0) {
            return null;
        }

        if (encryptedData == null || encryptedData.length == 0) {
            return null;
        }
        // 加密字节数组转换为十六进制的字符串 长度变为encryptedData.length * 2
        String data = ByteUtil.byteToHex(encryptedData);

        byte[] c1Bytes = ByteUtil.hexToByte(data.substring(0, 130));
        int c2Len = encryptedData.length - 97;
        byte[] c3 = ByteUtil.hexToByte(data.substring(130, 130 + 64));
        byte[] c2 = ByteUtil.hexToByte(data.substring(194, 194 + 2 * c2Len));

        SM2 sm2 = SM2.Instance();
        BigInteger userD = new BigInteger(1, privateKey);

        // 通过C1实体字节来生成ECPoint
        ECPoint c1 = sm2.ecc_curve.decodePoint(c1Bytes);
        Cipher cipher = new Cipher();
        cipher.Init_dec(userD, c1);
        cipher.Decrypt(c2);
        cipher.Dofinal(c3);

        // 返回解密结果
        return c2;
    }

    public static void main(String[] args) throws Exception {
//        System.out.println(new String(SM2EncDecUtils.decrypt(ByteUtil.hexToByte(""), ByteUtil.hexToByte(cipherText))));
//        singleThreadTest();
//        mutiThreadTest();
    }

    private static void singleThreadTest() throws Exception {
        String plainText = "sourceText";
        byte[] sourceData = plainText.getBytes();
        Map<String, String> keymap = generateKeyPair();

        long start = System.currentTimeMillis();
        int counts = 100;
        for (int j = 0; j < counts; j++) {
            String cipherText = SM2EncDecUtils.encrypt(ByteUtil.hexToByte(keymap.get(public_key)), sourceData);
            System.out.println("加密前长度: " + plainText.length() + ";加密后长度: " + cipherText.length());
            String plainTextEncripted = new String(SM2EncDecUtils.decrypt(ByteUtil.hexToByte(keymap.get(private_key)), ByteUtil.hexToByte(cipherText)));
            if (plainText.equals(plainTextEncripted)) {
                System.out.println("------解密后同原文是否一致: " + plainText.equals(plainTextEncripted) + "----------------------");
            }
        }
        long end = System.currentTimeMillis();
        System.out.println("平均耗时:" + (end - start) / counts + "ms。");
    }

    private static void mutiThreadTest() {
        String plainText = "sourceText";
        byte[] sourceData = plainText.getBytes();

        Map<String, String> keymap = generateKeyPair();
        int counts = 10;
        for (int i = 0; i < counts; i++) {
            new Thread(() -> {
                try {
                    for (int j = 0; j < counts; j++) {
                        String cipherText = SM2EncDecUtils.encrypt(ByteUtil.hexToByte(keymap.get(public_key)), sourceData);
                        if (!plainText.equals(new String(SM2EncDecUtils.decrypt(ByteUtil.hexToByte(keymap.get(private_key)), ByteUtil.hexToByte(cipherText))))) {
                            System.out.println("------解密后同原文不一致:" + Thread.currentThread().getName() + "--------------");
                        }
                    }
                } catch (IOException e) {
                    e.printStackTrace();
                }
                System.out.println(" --------------->线程" + Thread.currentThread().getName() + "执行完成.---------------------");
            }
            ).start();
        }
    }
}
