/**
 * WebSocketFactory.java
 */
package itez.plat.socket.websocket;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;

import javax.websocket.Session;

import itez.core.runtime.service.Ioc;
import itez.kit.EClass;
import itez.kit.ELog;
import itez.kit.EStr;
import itez.kit.log.ELogBase;
import itez.plat.socket.model.Channels;
import itez.plat.socket.model.Tokens;
import itez.plat.socket.service.ChannelsService;
import itez.plat.socket.service.TokensService;

/**
 * <p>
 * WebSocket统一处理工厂
 * </p>
 * 
 * <p>Copyright(C) 2017-2020 <a href="http://www.itez.com.cn">上游科技</a></p>
 * 
 * @author		<a href="mailto:netwild@qq.com">Z.Mingyu</a>
 * @date		2020年10月18日 下午1:57:19
 */
public class WebSocketFactory {
	
	private static ELogBase log = ELog.log(WebSocketFactory.class);

	/**
	 * 客户端表
	 * 所有已连接的客户端都在该表中进行维护
	 * key: sid
	 * val: SocketClient
	 */
	static ConcurrentHashMap<String, SocketClient> clients = new ConcurrentHashMap<String, SocketClient>();
	
	/**
	 * 通道处理服务缓存
	 * key: channel
	 * val: service
	 */
	static ConcurrentHashMap<String, ISocketService> sers = new ConcurrentHashMap<String, ISocketService>();
	
	/**
	 * 心跳检测线程对象
	 */
	static HeartCheck heartCheck = new HeartCheck();
	
	/**
	 * Token服务
	 */
	static TokensService tokenSer = Ioc.get(TokensService.class);
	/**
	 * 通道服务
	 */
	static ChannelsService channSer = Ioc.get(ChannelsService.class);

	/**
	 * 计数器
	 */
	private static int onlineCount = 0;
	
	/**
	 * 默认通道代码
	 */
	public static final String DEF_CHANNEL = "DefaultChannel";
	
	/**
	 * <p>
	 * 新用户上线
	 * </p>
	 * 
	 * @param session
	 * @param token
	 * @return
	 */
	public static int online(Session session, Tokens token){
		String sid = session.getId();
		tokenSer.online(token, sid);
		SocketClient client = new SocketClient(token, session);
		addOnlineCount();
		addClientCache(client);
		
		//通道服务逻辑
		ISocketService ser = getSocketService(token.getChannel());
		if(ser != null) ser.online(token, client);
		
		//开始进行心跳检测
		startHeartCheck();
		
		return getOnlineCount();
	}
	
	/**
	 * <p>
	 * 用户离线
	 * </p>
	 * 
	 * @param session
	 * @param token
	 * @return
	 */
	public static int offline(Session session, Tokens token){
		String sid = session.getId();
		SocketClient client = getClient(sid);
		if(client == null) return getOnlineCount();
		tokenSer.offline(token);

		//通道服务逻辑
		ISocketService ser = getSocketService(token.getChannel());
		if(ser != null) ser.offline(token, client);

		removeClientCache(sid);
		subOnlineCount();
		return getOnlineCount();
	}
	
	/**
	 * <p>
	 * 接收到新消息
	 * </p>
	 * 
	 * @param session
	 * @param message
	 */
	public static void onMessage(Session session, Tokens token, String message){
		String sid = session.getId();
		SocketClient client = getClient(sid);

		client.updatePing();
		SocketMsg msg = SocketMsg.parse(sid, new String(message.getBytes(EStr.UTF_8)));
		if(msg.isPingPong()){ //心跳包
			SocketMsg pong = new SocketMsg(null, sid, "").setPong();
			sendMessage(pong);
		}else{ //普通消息
			ISocketService ser = getSocketService(token.getChannel());
			if(ser != null) ser.onMessage(client, msg);
		}		
	}
	
	/**
	 * <p>
	 * 发生错误
	 * </p>
	 * 
	 * @param session
	 * @param message
	 */
	public static void onError(Session session, Tokens token, Throwable error){
		ISocketService ser = getSocketService(token.getChannel());
		if(ser != null) ser.onError(token, getClient(session.getId()), error);
	}
	
	/**
	 * <p>
	 * 发送新消息
	 * </p>
	 * 
	 * @param msg
	 */
	public static void sendMessage(SocketMsg msg){
		String toId = msg.getToId();
		if(EStr.isEmpty(toId)){
			log.error("未指定接收者");
		}else if(!toId.equals("*")){ //1vs1
			sendMessage(toId, msg);
		}else{ //群发
			String fromId = msg.getFromId();
			if(EStr.isEmpty(fromId)){
				log.error("未指定发送者");
				return;
			}
			SocketClient client = getClient(fromId);
			if(client == null){
				log.error("发送者不存在：{}", fromId);
			}else{
				List<Tokens> tokens = tokenSer.getTokenByChannel(client.getDomain(), client.getChannel());
				tokens.forEach(t -> sendMessage(t.getSid(), msg));
			}
		}
	}
	
	/**
	 * <p>
	 * 发送新消息（1vs1）
	 * </p>
	 * 
	 * @param toId
	 * @param msg
	 */
	private static void sendMessage(String toId, SocketMsg msg){
		SocketClient client = clients.get(toId);
		if(client == null){
			log.error("接收者不存在：{}", toId);
		}else if(!client.isOpen()){
			tokenSer.offline(client.getToken());
			removeClientCache(toId);
			subOnlineCount();
			log.error("接收者已掉线：{}", toId);
		}else{
			client.sendMessage(msg);
		}
	}
	
	/**
	 * <p>
	 * 返回通道处理服务
	 * </p>
	 * 
	 * @param channel
	 * @return
	 */
	private static ISocketService getSocketService(String channel){
		if(EStr.isEmpty(channel)) return null;
		ISocketService ser = sers.get(channel);
		if(ser != null) return ser;
		if(channel.equals(DEF_CHANNEL)){
			ser = new SocketServiceDef();
			sers.put(DEF_CHANNEL, ser);
			return ser;
		}
		Channels chann = channSer.findByCode(channel);
		if(chann == null) return null;
		String clazz = chann.getClazz();
		if(EStr.isEmpty(clazz)) return null;
		Object t = EClass.newInstance(clazz);
		if (t instanceof ISocketService){
			ser = (ISocketService) t;
			sers.put(DEF_CHANNEL, ser);
			return ser;
		}
		return null;
	}
	
	/**
	 * <p>
	 * 根据通信令牌ID返回令牌对象
	 * 令牌对象需要提前创建
	 * </p>
	 * 
	 * @param tokenId
	 * @return
	 */
	public static Tokens getToken(String tokenId){
		if(EStr.isEmpty(tokenId)) return null;
		Tokens tokens = tokenSer.getToken(tokenId);
		return tokens;
	}
	
	public static void startHeartCheck(){
		if(!heartCheck.isRunAble()){
			heartCheck.setRunAble(true);
			heartCheck.start();
		}
	}
	
	/**
	 * <p>
	 * 根据session.id返回对应的客户端对象
	 * </p>
	 * 
	 * @param sid
	 * @return
	 */
	public static SocketClient getClient(String sid){
		if(EStr.isEmpty(sid)) return null;
		SocketClient client = clients.get(sid);
		return client;
	}
	
	public static void addClientCache(SocketClient client){
		clients.put(client.getSid(), client);
	}
	
	public static void removeClientCache(String sid){
		if(EStr.isEmpty(sid)) return;
		clients.remove(sid);
	}
	
	public static synchronized int getOnlineCount() {
		return onlineCount;
	}
	
	static synchronized void addOnlineCount() {
		onlineCount++;
	}
	
	static synchronized void subOnlineCount() {
		onlineCount--;
	}
	
}
