package cn.blankcat.websocket.v1.service;

import cn.blankcat.dto.audio.AudioAction;
import cn.blankcat.dto.channel.Channel;
import cn.blankcat.dto.forum.ForumAuditResult;
import cn.blankcat.dto.forum.Post;
import cn.blankcat.dto.forum.Reply;
import cn.blankcat.dto.forum.Thread;
import cn.blankcat.dto.guild.Guild;
import cn.blankcat.dto.interaction.Interaction;
import cn.blankcat.dto.member.Member;
import cn.blankcat.dto.message.Message;
import cn.blankcat.dto.message.MessageAudit;
import cn.blankcat.dto.message.MessageDelete;
import cn.blankcat.dto.websocket.*;
import cn.blankcat.websocket.v1.WebsocketApi;
import cn.blankcat.websocket.v1.WebsocketApiData;
import cn.blankcat.websocket.v1.handler.AbstractWebsocketHandler;
import cn.blankcat.websocket.v1.handler.ReadyHandler;
import cn.blankcat.websocket.v1.handler.WebsocketHandler;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import okhttp3.*;
import okio.ByteString;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class WebsocketService implements WebsocketApi {

    private static final Logger logger = LoggerFactory.getLogger("WebsocketService");
    public static final Map<Class<?>, List<AbstractWebsocketHandler<?>>> handlers = new HashMap<>();
    private final WebsocketApiData websocketApiData = new WebsocketApiData();
    private final ExecutorService fixedThreadPool = Executors.newFixedThreadPool(3);

    public WebsocketApiData getWebsocketApiData() {
        return websocketApiData;
    }

    static {
        handlers.put(Guild.class, new ArrayList<>());
        handlers.put(Message.class, new ArrayList<>());
        handlers.put(MessageDelete.class, new ArrayList<>());
        handlers.put(MessageAudit.class, new ArrayList<>());
        handlers.put(Channel.class, new ArrayList<>());
        handlers.put(Member.class, new ArrayList<>());
        handlers.put(AudioAction.class, new ArrayList<>());
        handlers.put(Thread.class, new ArrayList<>());
        handlers.put(Post.class, new ArrayList<>());
        handlers.put(Reply.class, new ArrayList<>());
        handlers.put(ForumAuditResult.class, new ArrayList<>());
        handlers.put(Interaction.class, new ArrayList<>());
        handlers.put(WSReadyData.class, new ArrayList<>());
        handlers.put(WSHelloData.class, new ArrayList<>());
        handlers.put(WSPayload.class, new ArrayList<>());
        new ReadyHandler().registry();
    }

    private class MyWebSocketListener extends WebSocketListener {
        @Override
        public void onOpen(WebSocket webSocket, Response response) {
            super.onOpen(webSocket, response);
        }

        @Override
        public void onMessage(WebSocket webSocket, String text) {
            //logger.info("收到信息-->[{}]", text);
            websocketEventDispatcher(text);
            super.onMessage(webSocket, text);
        }

        @Override
        public void onMessage(WebSocket webSocket, ByteString bytes) {
            super.onMessage(webSocket, bytes);
        }

        @Override
        public void onClosing(WebSocket webSocket, int code, String reason) {
            super.onClosing(webSocket, code, reason);
        }

        @Override
        public void onClosed(WebSocket webSocket, int code, String reason) {
            super.onClosed(webSocket, code, reason);
        }

        @Override
        public void onFailure(WebSocket webSocket, Throwable t, @Nullable Response response) {
            if (response != null && response.body() != null) {
                MediaType mediaType = response.body().contentType();
                try {
                    String responseString = response.body().string();
                    if (response.code() > 209) {
                        logger.warn("请求错误, 错误内容为:{}", responseString);
                    }
                    super.onFailure(webSocket, t, response.newBuilder().body(ResponseBody.create(mediaType, responseString)).build());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
        }
    }

    private void websocketEventDispatcher(String text) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            WSPayload wsPayload = mapper.readValue(text, WSPayload.class);
            // 保存当前seq
            if (wsPayload.getSeq() > 0){
                session().setLastSeq(wsPayload.getSeq());
            }
            WSEventType wsEventType = WSEventType.ofValue(wsPayload.getType());
            Class<?> clazz = WSEvent.eventClassMap.get(wsEventType);
            if (clazz == null){
                buildInHandler(wsPayload);
            }else {
                for (WebsocketHandler handler : handlers.get(clazz)){
                    handler.handle(text, this);
                }
            }

        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
    }

    private void buildInHandler(WSPayload payload) throws JsonProcessingException {
        WSOPCode wsOPCode = WSOPCode.ofValue(payload.getOpCode());
        switch (wsOPCode){
            case WSHello: {
                // 接收到 hello 后需要开始发心跳
                fixedThreadPool.execute(()->{
                    Object obj = payload.getData();
                    WSHelloData helloData = mapper.convertValue(obj, WSHelloData.class);
                    this.keepHeart(helloData.getHeartbeatInterval());
                });
                break;
            }
            case WSHeartbeatAck:{
                // 心跳 ack 不需要业务处理
                break;
            }
            case WSReconnect:{
                // 达到连接时长，需要重新连接，此时可以通过 resume 续传原连接上的事件
                resume();
                break;
            }
            case WSInvalidSession:{
                // 无效的 sessionLog，需要重新鉴权
                identify();
                break;
            }
        }
    }

    private void keepHeart(int period) {
        logger.info("启动心跳进程成功,即将以{}ms的间隔发送心跳...", period * 4L / 5);
        websocketApiData.getHeartBeatTicker().schedule(new TimerTask() {
            @Override
            public void run() {
                // 持续发送心跳
                WSPayload<Long> wsPayload = new WSPayload<>();
                wsPayload.setOpCode(WSOPCode.WSHeartbeat.getValue());
                wsPayload.setData(session().getLastSeq());
                write(wsPayload);
            }
        }, 1000, period * 4L / 5);
    }

    private void setupWebsocketClient(Session session){
        OkHttpClient httpClient = new OkHttpClient.Builder()
                .readTimeout(3, TimeUnit.SECONDS)
                .writeTimeout(3, TimeUnit.SECONDS)
                .connectTimeout(3, TimeUnit.SECONDS)
                .build();
        WebSocket websocket = httpClient.newWebSocket(new Request.Builder().get().url(session.getUrl()).build(), new MyWebSocketListener());
        websocketApiData.setConn(websocket);
    }


    public WebsocketApi init(Session session) {
        websocketApiData.setSession(session);
        websocketApiData.setHeartBeatTicker(new Timer());
        setupWebsocketClient(session);
        return this;
    }

    @Override
    public void connect() {
        fixedThreadPool.execute(()->{
            this.keepHeart(41250);
        });
    }

    @Override
    public boolean identify() {
        Session session = session();
        if (session.getIntent() == Intent.IntentNone.getValue()){
            session.setIntent(Intent.IntentGuilds.getValue());
        }

        WSIdentityData data = new WSIdentityData();
        data.setToken(session.getToken().getRealString());
        data.setIntents(session.getIntent());
        Long[] shard = {session.getShards().getShardID(), session.getShards().getShardCount()};
        data.setShard(shard);

        WSPayload<WSIdentityData> payload = new WSPayload<>(data);
        payload.setData(data);
        payload.setOpCode(WSOPCode.WSIdentity.getValue());
        payload.setType("");
        logger.info("正在进行鉴权...");
        return write(payload);
    }

    @Override
    public Session session() {
        return websocketApiData.getSession();
    }

    @Override
    public boolean resume() {
        Session session = session();

        WSResumeData data = new WSResumeData();
        data.setToken(session.getToken().getRealString());
        data.setSessionId(session.getId());
        data.setSeq(session.getLastSeq());

        WSPayload<WSResumeData> payload = new WSPayload<>(data);
        payload.setData(data);
        payload.setOpCode(WSOPCode.WSResume.getValue());
        payload.setType("");
        logger.info("正在恢复连接...");
        return write(payload);
    }

    @Override
    public boolean write(WSPayload payload) {
        try {
            String json = mapper.writeValueAsString(payload);
            boolean send = websocketApiData.getConn().send(json);
            logger.info("{}(session)写入[{}]信息:{}", websocketApiData.getSession().getId(), WSOPCode.ofValue(payload.getOpCode()), json);
            return send;
        } catch (JsonProcessingException e) {
            e.printStackTrace();
        }
        return false;
    }

    @Override
    public boolean close() {
        return websocketApiData.getConn().close(100, null);
    }




}
