package cn.xnatural.xnet;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.CompletionHandler;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

import static cn.xnatural.xnet.HttpServer.log;

/**
 * Http AIO 数据流
 */
public class HttpIOSession implements AutoCloseable {
    protected final AsynchronousSocketChannel channel;
    public final    HttpServer                server;
    // 当前解析的请求
    protected       HttpRequest               request;
    // 升级协议的解析器
    protected Protocol protocol;
    // 协议确认函数
    protected Consumer<Boolean> confirm;
    protected       Long                      lastUsed = System.currentTimeMillis();
    protected final AtomicBoolean             closed   = new AtomicBoolean(false);
    protected final AtomicBoolean             writing  = new AtomicBoolean(false);
    // Write 数据对列
    protected final Queue<ByteBuffer>      waiting     = new ConcurrentLinkedQueue<>();
    // 每次接收消息的内存空间
    protected final Lazier<ByteBuffer> _buf        = new Lazier<>(() -> ByteBuffer.allocate(HttpIOSession.this.server.getAttr("receiveMsgBufferSize", Integer.class,1024 * 512)));
    // 读处理器
    protected final CompletionHandler<Integer, ByteBuffer> readHandler = new CompletionHandler<Integer, ByteBuffer>() {
        @Override
        public void completed(Integer count, ByteBuffer buf) {
            if (count > 0) {
                lastUsed = System.currentTimeMillis();
                buf.flip();
                doRead(buf);
                buf.compact();
                // 避免 ReadPendingException
                read();
            } else {
                //1. 有可能文件上传一次大于 buf 的容量
                //2. 浏览器老发送空的字节
                // TODO 待研究
                // log.warn("接收字节为空. 关闭 " + session.sc.toString())
                if (!channel.isOpen()) close();
            }
        }

        @Override
        public void failed(Throwable ex, ByteBuffer buf) {
            if (!(ex instanceof ClosedChannelException)) {
                log.error(ex.getClass().getSimpleName() + " " + getRemoteAddress() + " ->" + getLocalAddress(), ex);
            }
            close();
        }
    };
    // 写处理器
    protected final CompletionHandler<Integer, ByteBuffer> writeHandler = new CompletionHandler<Integer, ByteBuffer>() {
        @Override
        public void completed(Integer result, ByteBuffer buf) {
            // linux aio 可能没把数据写完, 需要手动写完
            if (buf.hasRemaining()) {
                if (!closed.get()) channel.write(buf, buf, writeHandler);
            } else {
                writing.set(false);
                trigger();
            }
        }

        @Override
        public void failed(Throwable ex, ByteBuffer buf) {
            if (!(ex instanceof ClosedChannelException)) {
                log.error(ex.getClass().getName() + " " + getRemoteAddress() + " ->" + getLocalAddress(), ex);
            }
            close();
        }
    };


    protected HttpIOSession(AsynchronousSocketChannel channel, HttpServer server) {
        if (channel == null) throw new NullPointerException("Param channel required");
        if (server == null) throw new NullPointerException("Param server required");
        this.channel = channel;
        this.server = server;
    }


    /**
     * 开始数据接收处理
     */
    protected void start(Consumer<Boolean> confirm) {
        this.confirm = confirm;
        read();
    }


    /**
     * 关闭
     */
    @Override
    public void close() {
        // 尽可能等待数据处理结束再关
        for (
                long waitCount = 0, perWait = 50, waitLimit = server.getAttr("closeWaitLimit", Long.class,2000L);
                waitCount < waitLimit && (writing.get() || !waiting.isEmpty());
                waitCount += perWait
        ) {
            try {
                Thread.sleep(perWait);
            } catch (InterruptedException e) {
                log.error("", e);
            }
        }
        if (closed.compareAndSet(false, true)) {
            try { channel.shutdownOutput(); } catch(Exception ex) {}
            try { channel.shutdownInput(); } catch(Exception ex) {}
            try { channel.close(); } catch(Exception ex) {}
            _buf.clear(); // 释放
            doClose(this);
        }
    }


    /**
     * 子类重写, 清除对当前{@link HttpIOSession}的引用
     * @param session {@link HttpIOSession}
     */
    protected void doClose(HttpIOSession session) {}


    /**
     * 发送消息到客户端
     */
    public void write(ByteBuffer buf) {
        if (closed.get() || buf == null) return;
        lastUsed = System.currentTimeMillis();
        waiting.offer(buf);
        trigger();
    }


    /**
     * 触发发送, 遍历消息对列发送
     */
    protected void trigger() {
        if (closed.get() || waiting.isEmpty()) return;
        if (!writing.compareAndSet(false, true)) return;
        ByteBuffer buf = waiting.poll();
        if (buf != null) channel.write(buf, buf, writeHandler);
        else writing.set(false);
    }


    /**
     * 继续处理接收数据
     */
    protected void read() {
        if (closed.get()) return;
        if (!channel.isOpen()) close();
        else channel.read(_buf.get(), _buf.get(), readHandler);
    }


    /**
     * 读数据, 解析数据
     * @param buf 请求字节流
     */
    protected void doRead(ByteBuffer buf) {
        if (request == null) request = new HttpRequest(this);
        else if (request.decoder.complete) {
            if (protocol == null) request = new HttpRequest(this);
            else { // 证明已升级为其他协议了
                try {
                    protocol.decoder(this).decode(buf);
                } catch (Exception e) {
                    log.error(request.getUpgrade() + " decode error. from: " + getRemoteAddress(), e);
                    close();
                }
                return;
            }
        }
        try {
            request.decoder.decode(buf);
            if (confirm != null) {
                confirm.accept(true);
                confirm = null; // 只执行一次
                server.connections.offer(this);
                log.debug("New {} Connection from: {}, connected: {}", request.getUpgrade() == null ? "HTTP" : request.getUpgrade(), getRemoteAddress(), server.connections.size());
                if (server.connections.size() > server.getAttr("maxConnectionCountToClean", Integer.class, 10)) {
                    try { server.clean(); } catch (Throwable ex) {
                        log.error("clean error", ex);
                    }
                }
            }
        } catch (Throwable ex) {
            if (confirm != null) {
                confirm.accept(false);
                confirm = null; // 只执行一次
            } else {
                log.error("Http decode error. from: " + getRemoteAddress(), ex);
                close(); return;
            }
        }
        if (request.decoder.complete) {
            server.receive(request);
        }
    }


    /**
     * 远程连接地址
     */
    public String getRemoteAddress() {
        try {
            return channel.getRemoteAddress().toString();
        } catch (IOException e) {
            log.error("",e);
        }
        return null;
    }


    /**
     * 本地连接地址
     */
    public String getLocalAddress() {
        try {
            return channel.getLocalAddress().toString();
        } catch (IOException e) {
            log.error("",e);
        }
        return null;
    }


    @Override
    public String toString() {
        return HttpIOSession.class.getSimpleName() + "@" + Integer.toHexString(hashCode()) + "[" + channel + "]";
    }
}
