package cn.xnatural.xnet;

import cn.xnatural.sched.Sched;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.*;
import java.nio.channels.*;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;


/**
 * TCP(AIO) 服务. 监听TCP端口, 处理端口接收到的数据
 */
public class XNet implements AutoCloseable {
    protected static final Logger log = LoggerFactory.getLogger(XNet.class);
    /**
     * 属性集
     */
    protected final Map<String, Object> attrs;
    protected final List<XNetBase> servers = new LinkedList<>();
    /**
     * 执行线程池
     */
    protected final ExecutorService exec;
    /**
     * 时间调度线程
     */
    protected final Sched sched;
    protected AsynchronousChannelGroup cg;
    protected AsynchronousServerSocketChannel ssc;

    /**
     * [host]:port
     */
    protected final Lazier<Hp> _hp = new Lazier<>(() -> Hp.parse(getStr("hp").orElse(":7001").trim()));
    /**
     * 新TCP连接 接受处理
     */
    protected final CompletionHandler<AsynchronousSocketChannel, XNet> acceptor = new CompletionHandler<AsynchronousSocketChannel, XNet>() {

        @Override
        public void completed(AsynchronousSocketChannel channel, XNet srv) { doAccept(channel); }

        @Override
        public void failed(Throwable ex, XNet srv) {
            if (!(ex instanceof ClosedChannelException)) {
                log.error(ex.getMessage() == null ? ex.getClass().getSimpleName() : ex.getMessage(), ex);
            }
        }
    };


    /**
     * 创建 {@link XNet}
     * @param attrs 属性集
     *              delimiter: 分隔符
     *              writeTimeout: 数据写入超时时间. 单位:毫秒
     *              backlog: 排队连接
     *              connection.maxIdle: 连接最大存活时间
     * @param sched 时间任务调度器
     * @param exec 线程池
     */
    public XNet(Map<String, Object> attrs, ExecutorService exec, Sched sched) {
        this.attrs = attrs == null ? new ConcurrentHashMap<>() : attrs;
        this.exec = exec == null ? new ThreadPoolExecutor(
                4,8, 4, TimeUnit.HOURS,
                new LinkedBlockingQueue<>(),
                new ThreadFactory() {
                    AtomicInteger i = new AtomicInteger(1);
                    @Override
                    public Thread newThread(Runnable r) { return new Thread(r, "xnet-" + i.getAndIncrement()); }
                }
        ) : exec;
        this.sched = sched == null ? new Sched(this.exec) : sched;
    }


    /**
     * {@link #XNet(Map, ExecutorService, Sched)}
     * @param hp host:port. eg: localhost:7001 or 127.0.0.1:7001 or :7001
     * @param exec 执行线程池
     */
    public XNet(String hp, ExecutorService exec) {
        this(Stream.of(Collections.singletonMap("hp", hp))
                .collect(Collectors.toMap(m -> m.keySet().iterator().next(), m -> m.values().iterator().next())), exec, null);
    }


    /**
     * {@link #XNet(Map, ExecutorService, Sched)}
     * @param hp host:port. eg: localhost:7001 or 127.0.0.1:7001 or :7001
     */
    public XNet(String hp) { this(hp, null); }


    /**
     * 启动
     */
    public XNet start() {
        if (ssc != null) throw new RuntimeException(XNet.class.getSimpleName() + " is already running");
        try {
            cg = AsynchronousChannelGroup.withThreadPool(exec);
            ssc = AsynchronousServerSocketChannel.open(cg);
            ssc.setOption(StandardSocketOptions.SO_REUSEADDR, true);
            // ssc.setOption(StandardSocketOptions.SO_RCVBUF, getInteger("so_revbuf").orElse(1024 * 1024 * 1));

            InetSocketAddress addr = (getHp().host != null && !getHp().host.isEmpty()) ?
                    new InetSocketAddress(getHp().host, getHp().port) : new InetSocketAddress(getHp().port);

            ssc.bind(addr, getInteger("backlog").orElse(128));
            ServiceLoader<XNetBase> plugins = ServiceLoader.load(XNetBase.class);
            plugins.iterator().forEachRemaining(servers::add);
            if (getStr("cluster.name").isPresent()) {
                cluster();
            }
            for (XNetBase server : servers) { server.start(); }
            log.info("Start listen {}", getHp());
            exec.execute(this::accept);
        } catch (IOException ex) {
            throw new RuntimeException(XNet.class.getSimpleName() + " starting error", ex);
        }
        return this;
    }


    public HttpServer http() {
        return (HttpServer) servers.stream().filter(o -> o instanceof HttpServer).findFirst().orElseGet(() -> {
            HttpServer srv = new HttpServer(this);
            servers.add(srv);
            return srv;
        });
    }

    public Cluster cluster() {
        http();
        return (Cluster) servers.stream().filter(o -> o instanceof Cluster).findFirst().orElseGet(() -> {
            Cluster srv = new Cluster(this);
            servers.add(srv);
            return srv;
        });
    }

    public AP ap() {
        cluster();
        return (AP) servers.stream().filter(o -> o instanceof AP).findFirst().orElseGet(() -> {
            AP srv = new AP(this);
            servers.add(srv);
            return srv;
        });
    }

//    public CP cp() {
//        cluster();
//        return (CP) servers.stream().filter(o -> o instanceof CP).findFirst().orElseGet(() -> {
//            CP srv = new CP(this);
//            servers.add(srv);
//            return srv;
//        });
//    }


    /**
     * 添加自定义协议服务
     */
    public XNet add(XNetBase plugin) {
        servers.add(plugin);
        return this;
    }


    /**
     * 所有的执行都到这来
     */
    protected void exec(Runnable fn) {
        exec.execute(fn);
    }


    /**
     * 执行一个调度
     */
    protected void sched(Duration time, Runnable fn) {
        sched.after(time, fn);
    }


    /**
     * 关闭
     */
    @Override
    public void close() {
        try { cg.shutdown(); cg.shutdownNow(); } catch (Exception e) {
            log.error("", e);
        }
        try { ssc.close(); } catch (IOException e) {
            log.error("", e);
        }
        for (XNetBase server : servers) {
            server.close();
        }
        sched.close();
        exec.shutdown();
    }


    /**
     * 接收新连接
     */
    protected void accept() { ssc.accept(this, acceptor); }


    /**
     * 新连接处理
     * @param channel {@link AsynchronousSocketChannel}
     */
    protected void doAccept(final AsynchronousSocketChannel channel) {
        if (servers.isEmpty()) {
            log.warn("No found protocol");
            try {
                channel.close();
            } catch (IOException e) {
                log.error("Not found protocol error: " + channel, e);
            }
            return;
        }
        new Runnable() {
            int i = 0;
            @Override
            public void run() {
                if (i >= servers.size()) { // 未知协议
                    try {
                        channel.close();
                    } catch (IOException e) {
                        log.error("Unknown protocol error: " + channel, e);
                    }
                    return;
                }
                AtomicBoolean flag = new AtomicBoolean(false);
                servers.get(i).accept(channel, f -> {
                    // 只允许调一次
                    if (!f && flag.compareAndSet(false, true)) {
                        i++; run();
                    }
                });
            }
        }.run();
        // 继续接入新连接
        accept();
    }


    /**
     * hp -> host:port
     */
    public Hp getHp() { return _hp.get(); }


    /**
     * 获取本机 ip 地址
     */
    public static String ipv4() {
        try {
            for (Enumeration<NetworkInterface> en = NetworkInterface.getNetworkInterfaces(); en.hasMoreElements(); ) {
                NetworkInterface current = en.nextElement();
                if (!current.isUp() || current.isLoopback() || current.isVirtual()) continue;
                Enumeration<InetAddress> addresses = current.getInetAddresses();
                while (addresses.hasMoreElements()) {
                    InetAddress addr = addresses.nextElement();
                    if (addr.isLoopbackAddress()) continue;
                    if (addr instanceof Inet4Address) {
                        return addr.getHostAddress();
                    }
                }
            }
        } catch (SocketException e) {
            log.error("", e);
        }
        return null;
    }


    /**
     * 是否是网络连接异常
     */
    public static boolean isConnectError(Throwable ex) {
        for (Throwable e = ex; e != null; ) {
            if (
                    e instanceof ConnectException
                    || (e instanceof SocketException && e.getMessage() != null && e.getMessage().contains("Connection reset"))
                    || (e instanceof SocketException && e.getMessage() != null && e.getMessage().contains("Write failed")) // 数据没发到对端
                    || (e instanceof SocketException && e.getMessage() != null && e.getMessage().contains("Broken pipe")) // 链接断开
                    || (e instanceof SocketTimeoutException && e.getMessage() != null && e.getMessage().contains("connect timed out"))
            ) { // 连接异常可以重试
                return true;
            }
            e = e.getCause();
        }
        return false;
    }


    protected static final char[] CS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".toCharArray();
    protected static final SecureRandom SR = new SecureRandom();

    public static String nanoId() { return nanoId(21, CS); }

    public static String nanoId(int len) { return nanoId(len, CS); }

    /**
     * nano id 生成
     * @param len 生成的长度
     * @param CS 字符集
     */
    public static String nanoId(int len, char[] CS) {
        if (len < 1) throw new IllegalArgumentException("Param len must >= 1");
        if (CS == null || CS.length < 1) throw new IllegalArgumentException("Param CS required");
        final int mask = (2 << (int)Math.floor(Math.log(CS.length - 1) / Math.log(2))) - 1;
        final int step = (int)Math.ceil(1.6 * mask * len / CS.length);
        final StringBuilder sb = new StringBuilder();
        while(true) {
            byte[] bytes = new byte[step];
            SR.nextBytes(bytes);
            for(int i = 0; i < step; ++i) {
                int idx = bytes[i] & mask;
                if (idx < CS.length) {
                    sb.append(CS[idx]);
                    if (sb.length() == len) {
                        return sb.toString();
                    }
                }
            }
        }
    }


    public Object getAttr(String key) { return attrs.get(key); }


    public XNet setAttr(String key, Object value) { attrs.put(key, value); return this; }


    protected Optional<String> getStr(String key) {
        return Optional.ofNullable(getAttr(key)).map(Object::toString);
    }

    protected Optional<Integer> getInteger(String key) {
        return Optional.ofNullable(getAttr(key)).map(Object::toString).map(Integer::valueOf);
    }

    protected Optional<Long> getLong(String key) {
        return Optional.ofNullable(getAttr(key)).map(Object::toString).map(Long::valueOf);
    }

    protected Optional<Boolean> getBoolean(String key) {
        return Optional.ofNullable(getAttr(key)).map(Object::toString).map(Boolean::valueOf);
    }

    protected Optional<Double> getDouble(String key) {
        return Optional.ofNullable(getAttr(key)).map(Object::toString).map(Double::valueOf);
    }
}
