package cn.schoolwow.quickserver.websocket.stream;

import cn.schoolwow.quickserver.exception.HttpStatusException;
import cn.schoolwow.quickserver.util.BitUtil;
import cn.schoolwow.quickserver.util.QuickServerUtil;
import cn.schoolwow.quickserver.websocket.domain.CloseCode;
import cn.schoolwow.quickserver.websocket.domain.OPCode;
import cn.schoolwow.quickserver.websocket.domain.WebSocketFrame;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

public class WebSocketFrameStreamImpl implements WebSocketFrameStream{
    private Logger logger = LoggerFactory.getLogger(WebSocketFrameStreamImpl.class);

    /**分段大小,默认为1MB*/
    private static final int FRAGMENT_SIZE = 1024*1024;

    /**输入输出流*/
    private WebSocketStream webSocketStream;

    /**读锁*/
    private final Object readLock = new Object();

    /**写锁*/
    private final Object writeLock = new Object();

    public WebSocketFrameStreamImpl(WebSocketStream webSocketStream) {
        this.webSocketStream = webSocketStream;
    }

    @Override
    public WebSocketFrame getClientWebSocketFrame() throws IOException {
        WebSocketFrame webSocketFrame = getCompleteDataFrame();
        if(!webSocketFrame.mask){
            WebSocketFrame closeWebSocketFrame = WebSocketFrame.newInstance()
                    .opCode(OPCode.Close)
                    .closeCode(CloseCode.PROTOCOL_ERROR)
                    .mask(true);
            writeWebSocketFrame(closeWebSocketFrame);
            webSocketStream.close();
            throw new IOException("客户端数据帧未设置掩码!");
        }
        return webSocketFrame;
    }

    @Override
    public WebSocketFrame getServerWebSocketFrame() throws IOException {
        WebSocketFrame webSocketFrame = getCompleteDataFrame();
        if(webSocketFrame.mask){
            WebSocketFrame closeWebSocketFrame = WebSocketFrame.newInstance()
                    .opCode(OPCode.Close)
                    .closeCode(CloseCode.PROTOCOL_ERROR)
                    .mask(true);
            writeWebSocketFrame(closeWebSocketFrame);
            webSocketStream.close();
            throw new IOException("服务端数据帧设置了掩码!");
        }
        return webSocketFrame;
    }

    @Override
    public void writeWebSocketFrame(WebSocketFrame webSocketFrame) throws IOException {
        synchronized (writeLock){
            //判断负载大小,超时指定大小则分段发送
            if(null==webSocketFrame.payload||webSocketFrame.payload.length<FRAGMENT_SIZE){
                doWriteWebSocketFrame(webSocketFrame);
                return;
            }
            ByteArrayInputStream baos = new ByteArrayInputStream(webSocketFrame.payload);
            //分段数组
            byte[] fragmentBytes = new byte[FRAGMENT_SIZE];
            //记录已发送帧大小
            int total = 0;
            //发送起始帧
            int length = baos.read(fragmentBytes,0,fragmentBytes.length);
            WebSocketFrame startWebSocketFrame = WebSocketFrame.newInstance()
                    .fin(false)
                    .opCode(webSocketFrame.opCode)
                    .mask(webSocketFrame.mask)
                    .payload(fragmentBytes);
            total += length;
            logger.trace("发送起始帧,数据长度范围:{}-{}", 0, total);
            doWriteWebSocketFrame(startWebSocketFrame);
            //发送持续帧
            while((length = baos.read(fragmentBytes,0,fragmentBytes.length))==fragmentBytes.length){
                WebSocketFrame continueWebSocketFrame = WebSocketFrame.newInstance()
                        .fin(false)
                        .opCode(OPCode.ContinueFrame)
                        .mask(webSocketFrame.mask)
                        .payload(fragmentBytes);
                logger.trace("发送持续帧,数据长度范围:{}-{}", total, total+length);
                total += length;
                doWriteWebSocketFrame(continueWebSocketFrame);
            }
            //发送结尾帧
            byte[] endFragmentBytes = new byte[length];
            System.arraycopy(fragmentBytes, 0, endFragmentBytes, 0, endFragmentBytes.length);
            WebSocketFrame endWebSocketFrame = WebSocketFrame.newInstance()
                    .fin(true)
                    .opCode(OPCode.ContinueFrame)
                    .mask(webSocketFrame.mask)
                    .payload(endFragmentBytes);
            logger.trace("发送结尾帧,数据长度范围:{}-{}",total,total+length);
            doWriteWebSocketFrame(endWebSocketFrame);
            baos.close();
        }
    }

    @Override
    public WebSocketStream getWebSocketStream() {
        return webSocketStream;
    }

    private WebSocketFrame getWebSocketFrame() throws IOException {
        WebSocketFrame webSocketFrame = new WebSocketFrame();

        int[] bits = webSocketStream.readBitByte();
        webSocketFrame.fin = bits[0]==1;
        webSocketFrame.opCode = OPCode.getOPCode(BitUtil.getBitValue(bits,4,7));
        logger.trace("读取websocket帧第1个字节:{}", Arrays.toString(bits));

        bits = webSocketStream.readBitByte();
        webSocketFrame.mask = bits[0]==1;
        webSocketFrame.payloadLength = BitUtil.getBitValue(bits,1,7);
        if(webSocketFrame.payloadLength==126){
            logger.trace("负载长度为126,读取接下来2个字节作为实际长度!");
            webSocketFrame.payloadLength = webSocketStream.readShort();
        }else if(webSocketFrame.payloadLength==127){
            logger.trace("负载长度为127,读取接下来8个字节作为实际长度!");
            webSocketFrame.payloadLength = webSocketStream.readLong();
        }
        logger.trace("负载数据长度为:{}字节", webSocketFrame.payloadLength);
        if(webSocketFrame.mask){
            webSocketFrame.maskKey = new byte[4];
            int length = webSocketStream.read(webSocketFrame.maskKey,0,webSocketFrame.maskKey.length);
            logger.trace("读取掩码!key长度:{},值:{}", length, Arrays.toString(webSocketFrame.maskKey));
            if(length!=webSocketFrame.maskKey.length){
                throw new HttpStatusException(400, "读取掩码失败!预期字节长度:4,实际长度:"+length);
            }
        }
        webSocketFrame.payload = new byte[(int) webSocketFrame.payloadLength];
        webSocketStream.read(webSocketFrame.payload, 0, webSocketFrame.payload.length);
        if(webSocketFrame.mask&&webSocketFrame.payloadLength>0){
            logger.trace("服务端进行反掩码操作,掩码key:{},负载数据长度:{}", Arrays.toString(webSocketFrame.maskKey),webSocketFrame.payload.length);
            mask(webSocketFrame.payload, webSocketFrame.maskKey);
        }
        if(OPCode.Close.equals(webSocketFrame.opCode)){
            webSocketFrame.closeCode = CloseCode.getCloseCodeByCode(webSocketFrame.payload);
        }
        return webSocketFrame;
    }

    private void doWriteWebSocketFrame(WebSocketFrame webSocketFrame) throws IOException {
        try (WebSocketStream cacheStream = new WebSocketStreamImpl();){
            int[] bits = new int[8];
            bits[0] = webSocketFrame.fin?1:0;
            BitUtil.setBitValue(bits,4,7,webSocketFrame.opCode.value);
            logger.trace("写入websocket帧第1个字节:{}", Arrays.toString(bits));
            cacheStream.writeBit(bits);

            bits = new int[8];
            bits[0] = webSocketFrame.mask?1:0;
            if(webSocketFrame.payloadLength<126){
                BitUtil.setBitValue(bits,1,7,(int) webSocketFrame.payloadLength);
                cacheStream.writeBit(bits);
                logger.trace("写入websocket帧第2个字节:{},负载数据长度:{}", Arrays.toString(bits), webSocketFrame.payloadLength);
            }else if(webSocketFrame.payloadLength<Short.MAX_VALUE){
                BitUtil.setBitValue(bits,1,7,126);
                cacheStream.writeBit(bits);
                cacheStream.writeShort((int) webSocketFrame.payloadLength);
                logger.trace("写入websocket帧第2个字节:{},126+{}(大小占2字节)", Arrays.toString(bits), webSocketFrame.payloadLength);
            }else {
                BitUtil.setBitValue(bits,1,7, 127);
                cacheStream.writeBit(bits);
                cacheStream.writeLong(webSocketFrame.payloadLength);
                logger.trace("写入websocket帧第2个字节:{},127+{}(大小占8字节)", Arrays.toString(bits), webSocketFrame.payloadLength);
            }
            if(webSocketFrame.mask){
                cacheStream.write(webSocketFrame.maskKey);
                logger.trace("写入掩码key:{}", Arrays.toString(webSocketFrame.maskKey));
                if(OPCode.Close.equals(webSocketFrame.opCode)){
                    logger.trace("写入关闭状态码:{}", webSocketFrame.closeCode);
                }else if(null!=webSocketFrame.payload){
                    logger.trace("写入负载数据,长度:{}, 文本:{}, 十六进制:{}",
                            webSocketFrame.payloadLength,
                            new String(webSocketFrame.payload,StandardCharsets.UTF_8),
                            QuickServerUtil.byteArrayToHex(webSocketFrame.payload)
                    );
                }
                if(webSocketFrame.payloadLength>0&&null!=webSocketFrame.payload){
                    mask(webSocketFrame.payload, webSocketFrame.maskKey);
                }
            }
            if(null!=webSocketFrame.payload){
                cacheStream.write(webSocketFrame.payload);
            }
            webSocketStream.write(cacheStream.toByteArray());
            webSocketStream.flush();
            //还原数据
            if(webSocketFrame.mask&&webSocketFrame.payloadLength>0&&null!=webSocketFrame.payload){
                mask(webSocketFrame.payload, webSocketFrame.maskKey);
            }
        }
    }

    private WebSocketFrame getCompleteDataFrame() throws IOException {
        synchronized (readLock){
            WebSocketFrame webSocketFrame = getWebSocketFrame();
            boolean fin = webSocketFrame.fin;
            while(!fin){
                logger.trace("当前帧不是最后一个片段!继续读取下一个帧!");
                logger.trace("当前负载长度:{}", webSocketFrame.payload.length);
                WebSocketFrame nextFrame = getWebSocketFrame();
                byte[] data = new byte[nextFrame.payload.length+nextFrame.payload.length];
                System.arraycopy(webSocketFrame.payload, 0, data, 0, webSocketFrame.payload.length);
                System.arraycopy(nextFrame.payload, 0, data, webSocketFrame.payload.length, nextFrame.payload.length);
                webSocketFrame.payload = data;
                logger.trace("待复制帧长度:{},复制以后负载长度:{}", nextFrame.payload.length, webSocketFrame.payload.length);
                fin = nextFrame.fin;
            }
            return webSocketFrame;
        }
    }

    /**进行掩码操作*/
    private void mask(byte[] payload, byte[] maskKey){
        for(int i=0;i<payload.length;i++){
            payload[i] = (byte) (payload[i] ^ maskKey[i%4]);
        }
    }

}