package cn.allbs.websocket.util;

import cn.hutool.core.date.DateUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author ChenQi
 */
@Slf4j
public class WebSocketSessionManager {

    /**
     * 保存连接 session 的地方
     */
    private static final Map<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>();

    /**
     * 静态变量，用来记录当前在线连接数。应该把它设计成线程安全的。
     */
    private static int onlineCount = 0;

    /**
     * 添加 session
     *
     * @param userName 用户名称
     */
    public static void add(String userName, WebSocketSession session) {
        log.info(DateUtil.now() + "新增连接,连接用户为" + userName);
        // 增加连接数
        addOnlineCount();
        // 添加 session
        SESSION_POOL.put(userName, session);
    }

    /**
     * 删除 session,会返回删除的 session
     *
     * @param userName 用户名
     * @return websocket session
     */
    public static WebSocketSession remove(String userName) {
        // 减少连接数
        subOnlineCount();
        // 删除 session
        return SESSION_POOL.remove(userName);
    }

    /**
     * 删除并同步关闭连接
     *
     * @param userName 指定用户
     */
    public static void removeAndClose(String userName) {
        WebSocketSession session = remove(userName);
        if (session != null) {
            try {
                // 关闭连接
                session.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 获得 session
     *
     * @param userName 用户名称
     * @return WebSocketSession
     */
    public static WebSocketSession get(String userName) {
        // 获得 session
        return SESSION_POOL.get(userName);
    }

    /**
     * 查询所有在线用户session
     *
     * @return 在线用户session集合
     */
    public static List<WebSocketSession> getAll() {
        List<WebSocketSession> list = new ArrayList<>();
        // 获取所有session
        SESSION_POOL.forEach((k, v) -> list.add(v));
        return list;
    }

    /**
     * 查询所有在线用户名称
     *
     * @return 在线用户名称list
     */
    public static List<String> getAllUserName() {
        List<String> userList = new ArrayList<>();
        SESSION_POOL.forEach((k, v) -> userList.add(k));
        return userList;
    }

    /**
     * 统计在线用户数量
     *
     * @return 用户数量
     */
    public static synchronized int getOnlineCount() {
        return onlineCount;
    }

    /**
     * 在线用户加1
     */
    public static synchronized void addOnlineCount() {
        onlineCount++;
    }

    /**
     * 在线用户减1
     */
    public static synchronized void subOnlineCount() {
        onlineCount--;
    }
}
