package cn.blankcat.websocket;

import cn.blankcat.config.BotConfig;
import cn.blankcat.dto.websocket.WSEvent;
import cn.blankcat.dto.websocket.WSPayload;
import cn.blankcat.dto.websocket.WSUrl;
import cn.blankcat.openapi.GatewayService;
import cn.blankcat.openapi.RetrofitManager;
import cn.blankcat.websocket.handler.WebsocketHandler;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.*;
import okio.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.TimeUnit;

public class WebsocketManager {

    private static final Logger logger = LoggerFactory.getLogger("websocketManager");

    private WebsocketManager(){};

    private static class MyWebSocketListener extends WebSocketListener{
        @Override
        public void onOpen(WebSocket webSocket, Response response) {
            super.onOpen(webSocket, response);
        }

        @Override
        public void onMessage(WebSocket webSocket, String text) {
            websocketEventDispatch(text);
            super.onMessage(webSocket, text);
        }

        @Override
        public void onMessage(WebSocket webSocket, ByteString bytes) {
            super.onMessage(webSocket, bytes);
        }

        @Override
        public void onClosing(WebSocket webSocket, int code, String reason) {
            super.onClosing(webSocket, code, reason);
        }

        @Override
        public void onClosed(WebSocket webSocket, int code, String reason) {
            super.onClosed(webSocket, code, reason);
        }

        @Override
        public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
            if (response != null && response.body() != null) {
                MediaType mediaType = response.body().contentType();
                try {
                    String responseString = response.body().string();
                    if (response.code() > 209) {
                        logger.warn("请求错误, 错误内容为:{}", responseString);
                    }
                    super.onFailure(webSocket, t, response.newBuilder().body(ResponseBody.create(mediaType, responseString)).build());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private enum Singleton {
        INSTANCE;

        private WebSocket webSocket;

        Singleton() {
            webSocket = getOkHttpClient().newWebSocket(getRequest(), new MyWebSocketListener());
            webSocket.send("hello");
        }

        public static WebSocket getInstance() {
            return INSTANCE.webSocket;
        }

        public static void setInstance(WebSocket webSocket) {
            INSTANCE.webSocket = webSocket;
        }
    }

    /**
     * 获得一个全局单例的WebSocket对象
     * @return 获得一个全局单例的WebSocket对象
     */
    public static WebSocket getInstance() {
        return Singleton.getInstance();
    }

    /**
     * 根据用户提供的WebSocket对象创建一个新的WebSocket连接
     * @param webSocket 要设置的WebSocket对象
     */
    public static void setInstance(WebSocket webSocket) {
        Singleton.setInstance(webSocket);
    }

    /**
     * 根据默认的配置创建一个新的WebSocket连接
     */
    public static void setInstance() {
        setInstance(getOkHttpClient().newWebSocket(getRequest(), new MyWebSocketListener()));
    }

    private static OkHttpClient getOkHttpClient() {
        return new OkHttpClient.Builder()
                .readTimeout(3, TimeUnit.SECONDS)
                .writeTimeout(3, TimeUnit.SECONDS)
                .connectTimeout(3, TimeUnit.SECONDS)
                .build();
    }

    private static Request getRequest(){
        String url = "wss://api.sgroup.qq.com/websocket/";
        try {
            WSUrl wsUrl = RetrofitManager.getInstance()
                    .create(GatewayService.class)
                    .getWebsocketGateway()
                    .execute()
                    .body();
            url = wsUrl == null ? url : wsUrl.getUrl();
            return new Request.Builder().get().url(url).build();
        } catch (IOException e) {
            e.printStackTrace();

        }
        return new Request.Builder().get().url(url).build();
    }

    private static void websocketEventDispatch(String text) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            WSPayload wsPayload = mapper.readValue(text, WSPayload.class);
            WSEvent.WSEventType wsEventType = WSEvent.WSEventType.ofValue(wsPayload.getType());
            if (wsPayload.getSeq() > 0){
                BotConfig.DEFAULT.setSeq(wsPayload.getSeq());
            }
            if (wsEventType == null) {
                // 有三类不是官方定义的type, 这里加上去方便自定义handler处理
                if (!text.contains("heartbeat_interval")
                        && !text.contains("\"op\":9")
                        && !text.contains("\"op\":11")) {
                    return;
                }
                wsEventType = WSEvent.WSEventType.ofValue((wsPayload.getOpCode() == WSPayload.OPCode.WSHello.getValue()
                            ? "heartbeat_interval"
                            : wsPayload.getOpCode() == WSPayload.OPCode.WSInvalidSession.getValue()
                                ? "\"op\":9"
                                : "\"op\":11"
                        )
                );
                if (wsEventType == null) {
                    return;
                }
            }
            Class<?> clazz = WSEvent.eventClassMap.get(wsEventType);
            for (WebsocketHandler websocketHandler : WebsocketService.CLASS_HANDLER_MAP.get(clazz)) {
                websocketHandler.handle(text);
            }
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
    }
}
