package cn.schoolwow.quickhttp.util;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;

public class WebSocketUtil {
    /**生成Web-Socket-Key随机值*/
    private static SecureRandom secWebSocketKeyRandom = new SecureRandom();

    /**
     * 进行掩码操作
     * @param payload 负载数据
     * @param maskKey 掩码key
     * */
    public static void mask(byte[] payload, byte[] maskKey){
        for(int i=0;i<payload.length;i++){
            payload[i] = (byte) (payload[i] ^ maskKey[i%4]);
        }
    }

    /**
     * 生成随机字节数组
     * @param bytes 字节数组
     * */
    public static void randomBytes(byte[] bytes){
        secWebSocketKeyRandom.nextBytes(bytes);
    }

    /**
     * 计算SecWebSocketAccept的值
     * @param secWebSocketKey 请求头部Sec-WebSocket-Key的值
     * @return 响应头部Sec-WebSocket-Accept的值
     * */
    public static String calculateSecWebSocketAccept(String secWebSocketKey) throws NoSuchAlgorithmException {
        //计算Sec-WebSocket-Accept
        byte[] plainText = (secWebSocketKey+"258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes();
        byte[] shaEncoded = MessageDigest.getInstance("SHA").digest(plainText);
        String secWebSocketAccept = Base64.getEncoder().encodeToString(shaEncoded);
        return secWebSocketAccept;
    }

    /**
     * 字节数组转十六进制
     */
    public static String byteArrayToHex(byte[] bytes) {
        if (null == bytes || bytes.length == 0) {
            return "null";
        }
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < bytes.length; i++) {
            String hex = Integer.toHexString(bytes[i] & 0xFF);
            if (hex.length() < 2) {
                builder.append(0);
            }
            builder.append(hex);
        }
        return builder.toString().toUpperCase();
    }

    /**
     * 十六进制字符串转字节数组
     */
    public static byte[] hexToByteArray(String hex) {
        if (null == hex || hex.isEmpty()) {
            return null;
        }
        byte[] result = new byte[hex.length() / 2];
        for (int i = 0, j = 0; j < result.length; i += 2, j++) {
            result[j] = (byte) Integer.parseInt(hex.substring(i, i + 2), 16);
        }
        return result;
    }

    private static byte hexToByte(String hex) {
        return (byte) Integer.parseInt(hex, 16);
    }

    /**byte转为bit数组*/
    public static int[] getBitByte(byte b){
        int[] bits = new int[8];
        for(int i=0;i<8;i++){
            int n = 0x01<<(7-i);
            int result = b&n;
            bits[i] = (result==n?1:0);
        }
        return bits;
    }

    /**byte数组转为bit数组*/
    public static int[] getBitBytes(byte[] bytes){
        int[] bits = new int[bytes.length*8];
        for(int i=0;i<bytes.length;i++){
            int[] aBits = getBitByte(bytes[i]);
            System.arraycopy(aBits,0,bits,i*8,aBits.length);
        }
        return bits;
    }

    /**获取二进制位值，前闭后闭*/
    public static int getBitValue(int[] bits, int startIndex, int endIndex){
        if(startIndex<0||endIndex<0){
            throw new IllegalArgumentException("startIndex和endIndex必须大于0!");
        }
        if(startIndex>endIndex){
            throw new IllegalArgumentException("startIndex必须小于等于end!");
        }
        if(startIndex==endIndex){
            return bits[startIndex];
        }
        int v = 0;
        int length = endIndex-startIndex;
        for(int i=startIndex;i<=endIndex;i++){
            v = v | (bits[i]<<(length-(i-startIndex)));
        }
        if(v>(0x01<<(length+1))){
            StringBuilder builder = new StringBuilder();
            for(int b:bits){
                builder.append(b);
            }
            throw new IllegalArgumentException("数据解析失败!当前计算值:"+v+",当前数组:"+builder.toString()+",开始索引:"+startIndex+",结束索引:"+endIndex);
        }
        return v;
    }

    /**bit数组转为byte数组*/
    public static byte[] bits2Bytes(int[] bits){
        if(bits.length%8!=0){
            throw new IllegalArgumentException("bit数组的长度必须为8的倍数!");
        }
        byte[] bytes = new byte[bits.length/8];
        for(int i=0;i<bytes.length;i++){
            byte b = 0;
            for(int j=0;j<8;j++){
                if(bits[i*8+j]==1){
                    b = (byte) (b&0xff | (0x01<<(7-j)));
                }
            }
            bytes[i] = b;
        }
        return bytes;
    }

    /**设置二进制位，前闭后闭*/
    public static void setBitValue(int[] bits, int startIndex, int endIndex, int value){
        if(startIndex<0||endIndex<0){
            throw new IllegalArgumentException("startIndex和endIndex必须大于0!startIndex:"+startIndex+",endIndex:"+endIndex);
        }
        if(startIndex>endIndex){
            throw new IllegalArgumentException("startIndex必须小于等于end!startIndex:"+startIndex+",endIndex:"+endIndex);
        }
        int maxValue = 0x01<<(endIndex-startIndex+1);
        if(value>=maxValue){
            throw new IllegalArgumentException("value值必须小于指定位数长度!当前value值:"+value+",允许的最大值:"+maxValue);
        }
        if(startIndex==endIndex){
            bits[startIndex] = value;
        }
        int length = endIndex-startIndex;
        for(int i=startIndex;i<=endIndex;i++){
            int a = value&(0x01<<(length-(i-startIndex)));
            if(a==0){
                continue;
            }
            bits[i] = (a&(a-1))==0?1:0;
        }
    }
}