package cn.schoolwow.quickserver.handler;

import cn.schoolwow.quickserver.domain.Client;
import cn.schoolwow.quickserver.exception.HttpStatusException;
import cn.schoolwow.quickserver.response.HttpStatus;
import cn.schoolwow.quickserver.util.QuickServerUtil;
import cn.schoolwow.quickserver.websocket.WebSocketSession;
import cn.schoolwow.quickserver.websocket.WebSocketSessionImpl;
import cn.schoolwow.quickserver.websocket.domain.OPCode;
import cn.schoolwow.quickserver.websocket.domain.WebSocketFrame;
import cn.schoolwow.quickserver.websocket.stream.WebSocketFrameStream;
import cn.schoolwow.quickserver.websocket.stream.WebSocketFrameStreamImpl;
import cn.schoolwow.quickserver.websocket.stream.WebSocketStream;
import cn.schoolwow.quickserver.websocket.stream.WebSocketStreamImpl;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;

/**WebSocket协议处理器*/
public class WebSocketHandler implements Handler{
    private Logger logger = LoggerFactory.getLogger(WebSocketHandler.class);

    /**WebSocket连接帧*/
    private WebSocketFrameStream webSocketFrameStream;

    @Override
    public Handler handle(Client client) throws Exception {
        WebSocketStream webSocketStream = new WebSocketStreamImpl(client.socket);
        webSocketFrameStream = new WebSocketFrameStreamImpl(webSocketStream);
        try {
            upgrade(client);
        }catch (HttpStatusException e){
            logger.error("websocket握手失败", e);
            sendHandshakeFailureResponse(e);
            return null;
        }
        handleWebSocketFrame(client);
        return null;
    }

    /**协议升级*/
    private void upgrade(Client client) throws IOException {
        checkRequestLine(client);
        String secWebSocketKey = checkRequestHeader(client);
        //计算Sec-WebSocket-Accept
        String secWebSocketAccept = null;
        try {
            secWebSocketAccept = QuickServerUtil.calculateSecWebSocketAccept(secWebSocketKey);
        } catch (NoSuchAlgorithmException e) {
            logger.error("服务端错误", e);
            throw new HttpStatusException(500, "服务器内部错误!");
        }
        logger.trace("计算Sec-WebSocket-Accept头部的值,请求头Sec-Web-Socket-Key:{},响应头Sec-WebSocket-Accept:{}", secWebSocketKey, secWebSocketAccept);
        //发送响应报文
        StringBuilder responseBuilder = new StringBuilder();
        responseBuilder.append("HTTP/1.1 101 Web Socket Protocol Handshake\r\n");
        responseBuilder.append("Upgrade: websocket\r\n");
        responseBuilder.append("Connection: Upgrade\r\n");
        responseBuilder.append("Sec-WebSocket-Accept: "+secWebSocketAccept+"\r\n");
        responseBuilder.append("\r\n");

        WebSocketStream webSocketStream = webSocketFrameStream.getWebSocketStream();
        webSocketStream.write(responseBuilder.toString().getBytes(StandardCharsets.UTF_8));
        webSocketStream.flush();
    }

    /**检查请求行*/
    private void checkRequestLine(Client client) throws IOException {
        if(!"GET".equalsIgnoreCase(client.httpRequestMeta.method)){
            throw new HttpStatusException(400, "请求方法必须为get!客户端方法:"+client.httpRequestMeta.method);
        }
        if(!"HTTP/1.1".equalsIgnoreCase(client.httpRequestMeta.protocol)){
            throw new HttpStatusException(400, "当前仅支持HTTP/1.1!客户端版本:"+client.httpRequestMeta.protocol);
        }
    }

    /**检查请求头部*/
    private String checkRequestHeader(Client client) throws IOException {
        if(!client.httpRequestMeta.headers.containsKey("host")){
            throw new HttpStatusException(400, "请求头部未包含host!");
        }
        if(!client.httpRequestMeta.headers.containsKey("upgrade")||!"websocket".equalsIgnoreCase(client.httpRequestMeta.headers.get("upgrade").get(0))){
            throw new HttpStatusException(400, "请求头部未包含upgrade或者upgrade的值不为websocket!");
        }
        if(!client.httpRequestMeta.headers.containsKey("connection")||!"upgrade".equalsIgnoreCase(client.httpRequestMeta.headers.get("connection").get(0))){
            throw new HttpStatusException(400, "请求头部未包含connection或者connection的值不为upgrade!");
        }
        if(!client.httpRequestMeta.headers.containsKey("Sec-WebSocket-Key")){
            throw new HttpStatusException(400, "请求头部未包含Sec-WebSocket-Key!");
        }
        if(!client.httpRequestMeta.headers.containsKey("Sec-WebSocket-Version")||!"13".equalsIgnoreCase(client.httpRequestMeta.headers.get("Sec-WebSocket-Version").get(0))){
            throw new HttpStatusException(400, "请求头部未包含Sec-WebSocket-Version或者Sec-WebSocket-Version的值不为13!");
        }
        return client.httpRequestMeta.headers.get("Sec-WebSocket-Key").get(0);
    }

    /**发送握手协议失败报文*/
    private void sendHandshakeFailureResponse(HttpStatusException e){
        StringBuilder responseBuilder = new StringBuilder();
        responseBuilder.append("HTTP/1.1 " + e.getStatus() + ' ' + HttpStatus.getStatus(e.getStatus()).statusMessage + "\r\n");
        responseBuilder.append("Connection: close\r\n");
        responseBuilder.append("Content-Length: "+e.getDescription().getBytes(StandardCharsets.UTF_8).length+"\r\n");
        responseBuilder.append("\r\n");
        responseBuilder.append(e.getDescription());
        logger.trace("服务端响应错误报文:\r\n{}", responseBuilder.toString());

        try {
            webSocketFrameStream.getWebSocketStream().write(responseBuilder.toString().getBytes(StandardCharsets.UTF_8));
            webSocketFrameStream.getWebSocketStream().flush();
        }catch (IOException ex){
            logger.error("写入服务端响应报文失败", ex);
        }
    }

    /**处理数据帧逻辑*/
    private void handleWebSocketFrame(Client client) {
        WebSocketSession webSocketSession = new WebSocketSessionImpl(webSocketFrameStream);
        try {
            client.webSocketServerListener.onOpen(webSocketSession);
            WebSocketFrame clientWebSocketFrame = webSocketFrameStream.getClientWebSocketFrame();
            //当未收到客户端关闭帧时一直运行
            while(!OPCode.Close.equals(clientWebSocketFrame.opCode)){
                switch (clientWebSocketFrame.opCode){
                    case PingFrame:{
                        logger.trace("客户端发送了Ping帧,服务端回应Pong帧");
                        WebSocketFrame webSocketFrame = new WebSocketFrame();
                        webSocketFrame.fin = true;
                        webSocketFrame.opCode = OPCode.PongFrame;
                        webSocketFrame.mask = false;
                        webSocketFrameStream.writeWebSocketFrame(webSocketFrame);
                    }break;
                    case PongFrame:{
                        logger.trace("客户端发送了PongFrame,服务端不做任何回应");
                    }break;
                    case TextFrame:{
                        String text = new String(clientWebSocketFrame.payload, StandardCharsets.UTF_8);
                        client.webSocketServerListener.onTextMessage(text, webSocketSession);
                    }break;
                    case BinaryFrame:{
                        client.webSocketServerListener.onBinaryMessage(clientWebSocketFrame.payload, webSocketSession);
                    }break;
                    default:{
                        logger.warn("当前无法处理的帧类型:{}", clientWebSocketFrame.opCode);
                    }break;
                }
                clientWebSocketFrame = webSocketFrameStream.getClientWebSocketFrame();
            }
            if(clientWebSocketFrame.payloadLength>0){
                try {
                    String description = new String(clientWebSocketFrame.payload,2,clientWebSocketFrame.payload.length-2, StandardCharsets.UTF_8);
                    logger.debug("链接关闭!关闭原因:{}, 描述:{}", clientWebSocketFrame.closeCode, description);
                }catch (Exception e){
                    logger.error("获取客户端关闭原因时异常", e);
                }
            }
            client.webSocketServerListener.onClose(webSocketSession);
        }catch (Exception e){
            try {
                client.webSocketServerListener.onError(e, webSocketSession);
            } catch (IOException ex) {
                logger.error("websocket处理onError回调时发生异常", ex);
            }
        }finally {
            try {
                client.socket.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}
