package cn.xnatural.xnet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.CompletionHandler;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
 * 一条AIO tcp连接会话, 数据流
 */
public class XioSession implements AutoCloseable {
    protected static final Logger log = LoggerFactory.getLogger(XioSession.class);
    /**
     * Aio 监听渠道
     */
    protected final AsynchronousSocketChannel channel;
    /**
     * {@link XioHandler}
     */
    public final    XioHandler                handler;
    /**
     * 每次接收消息的内存空间
     */
    protected final ByteBuffer                buf;
    /**
     * 上次读写时间
     */
    protected       Long                      lastUsed = System.currentTimeMillis();
    /**
     * 是否已关闭
     */
    protected final AtomicBoolean _closed = new AtomicBoolean(false);
    /**
     * 是否正在写入
     */
    protected final AtomicBoolean _writing = new AtomicBoolean(false);
    /**
     * Write 任务对列
     */
    protected final Queue<WriteRecord>        waiting  = new ConcurrentLinkedQueue<>();
    /**
     * 当前正在读的数据记录
     */
    protected ReadRecord rr;
    // 读处理器
    protected final CompletionHandler<Integer, ByteBuffer>  readHandler  = new CompletionHandler<Integer, ByteBuffer>() {
        @Override
        public void completed(Integer count, ByteBuffer buf) {
            lastUsed = System.currentTimeMillis();
            if (count > 0) {
                buf.flip();
                try { // 同一时间只有一个 read, 避免 ReadPendingException
                    doRead(buf);
                    read();
                } catch (Exception ex) {
                    log.error(handler.getClass().getName(), ex);
                    close();
                }
            } else {
                // 接收字节为空
                if (!channel.isOpen()) close();
            }
        }

        @Override
        public void failed(Throwable ex, ByteBuffer buf) {
            if (!(ex instanceof ClosedChannelException)) {
                log.error(channel.toString(), ex);
            }
            close();
        }
    };
    // 写处理器
    protected final CompletionHandler<Integer, WriteRecord> writeHandler = new CompletionHandler<Integer, WriteRecord>() {
        @Override
        public void completed(Integer result, WriteRecord record) {
            ByteBuffer data = record.poll();
            if (data == null) { // 证明当前Record已写完了
                if (record.okFn != null) handler.exec(record.okFn);
                record.close();
                _writing.set(false);
                trigger();
            } else {
                if (!_closed.get()) channel.write(data, record, writeHandler);
            }
        }

        @Override
        public void failed(Throwable ex, WriteRecord record) {
            if (record.failFn != null) handler.exec(() -> record.failFn.accept(ex, XioSession.this));
            else if (!(ex instanceof ClosedChannelException)) {
                log.error(ex.getClass().getName() + " " + getRemoteAddress() + " ->" + getLocalAddress(), ex);
            }
            close();
        }
    };


    /**
     * 创建 {@link XioSession}
     * @param channel {@link AsynchronousSocketChannel}
     * @param handler {@link XioHandler}
     */
    public XioSession(AsynchronousSocketChannel channel, XioHandler handler) {
        if (channel == null) throw new NullPointerException("Param channel required");
        if (handler == null) throw new NullPointerException("Param handler required");
        this.channel = channel;
        this.handler = handler;
        this.buf = ByteBuffer.allocate(handler.getAttr("receiveMsgBufferSize", Integer.class, 1024 * 1024));
    }


    /**
     * 开始数据接收处理
     */
    public XioSession start() { read(); return this; }


    /**
     * 关闭
     */
    @Override
    public void close() {
        if (_closed.compareAndSet(false, true)) {
            trigger();
            try { channel.shutdownInput(); } catch(Exception ex) {}
            try { channel.shutdownOutput(); } catch(Exception ex) {}
            try { channel.close(); } catch(Exception ex) {}
            doClose(this);
            log.trace("closed: {}", this);
        }
    }


    /**
     * 子类重写, 清除对当前{@link XioSession}的引用
     */
    protected void doClose(XioSession stream) {}


    /**
     * 数据读取
     * 格式:
     *  多少域[byte],[第一域长度[int], 第二域长度[int]...], [第一域内容,第二域内容...]
     */
    protected void doRead(ByteBuffer buf) {
        if (!buf.hasRemaining()) { buf.clear(); return; } // 是否有可以读的数据, 没有就清除状态
        // 新消息开始
        if (rr == null) {
            rr = createReadRecord(buf.get()); // 创建新的临时读数据记录并读取新消息有多少域
            if (buf.remaining() < rr.fieldCnt * 4) { buf.compact(); return; } // 不够读, 每个域的长度必须一起读
            // 读每个域的长度
            for (byte i = 0; i < rr.fieldCnt; i++) {
                rr.fields.add(new XioStream(buf.getInt()));
            }
        }

        // 取出当前正在读取的域
        XioStream field = rr.fields.stream().filter(f -> f.leftReceived() > 0).findFirst().orElse(null);
        if (field == null) { // 一个完整数据消息已读完(所有域已读完)
            rr = null;
            doRead(buf); return;
        }

        // 填充当前域
        byte[] bs = new byte[(int) Math.min(field.leftReceived(), buf.remaining())];
        buf.get(bs);
        field.addStream(bs);

        // 如果是第一个域, 让业务处理先(一边读一边处理)
        if (rr.fields.get(0) == field && field.received <= bs.length) { // 只能执行一次: 第一个域第一次读了数据
            handler.exec(() -> { // 先分派新的线程处理数据, 最后域XioStream会继续接收未完成的数据(一边接收一边处理)
                try {
                    handler.handle(new ArrayList<>(rr.fields), this); // 新创建ArrayList是为了避免,删减引起数据读不全错误
                } catch (Exception ex) {
                    log.error(getClass().getName() + " receive error", ex);
                }
            });
        }
        doRead(buf);
    }


    /**
     * 读一条新的数据的开始
     */
    protected ReadRecord createReadRecord(byte fieldCnt) { return new ReadRecord(fieldCnt); }


    /**
     * 写入消息到流
     * @param fields 多域数据
     * @param failFn 失败回调函数
     * @param okFn 成功回调函数
     */
    public void write(List<XioStream> fields, BiConsumer<Throwable, XioSession> failFn, Runnable okFn) {
        if (fields == null) throw new IllegalArgumentException("Param fields required");
        if (_closed.get() || !channel.isOpen()) {
            close(); // 执行是为了删除 client中 streamMap 中的引用
            if (failFn == null) {
                log.error("Already closed. " + this);
            } else {
                failFn.accept(new ClosedChannelException(), this);
            }
        }
        lastUsed = System.currentTimeMillis();
        waiting.offer(new WriteRecord(fields, failFn, okFn));
        trigger();
    }

    /**
     * 写入消息到流
     * @param fields 多域数据
     */
    public void write(XioStream... fields) {
        if (fields == null || fields.length < 1) throw new IllegalArgumentException("Param fields required");
        write(Arrays.asList(fields), null, null);
    }

    /**
     * 写入消息到流
     * @param data 要写入的数据
     * @param failFn 失败回调函数
     * @param okFn 成功回调函数
     */
    public void write(byte[] data, BiConsumer<Throwable, XioSession> failFn, Runnable okFn) {
        if (data == null || data.length < 1) throw new IllegalArgumentException("Param data required");
        write(Collections.singletonList(new XioStream(data)), failFn, okFn);
    }

    /**
     * {@link #write(byte[], BiConsumer, Runnable)}
     * @param fields 数据
     */
    public void write(byte[]... fields) {
        if (fields == null || fields.length < 1) throw new IllegalArgumentException("Param fields required");
        List<XioStream> x = new LinkedList<>();
        for (byte[] field : fields) x.add(new XioStream(field));
        write(x, null, null);
    }

    /**
     * {@link #write(List, BiConsumer, Runnable)}
     * @param fields 数据
     */
    public void write(List<byte[]> fields) {
        if (fields == null || fields.isEmpty()) throw new IllegalArgumentException("Param fields required");
        write(fields.stream().map(XioStream::new).collect(Collectors.toList()), null, null);
    }


    /**
     * 遍历消息对列发送
     */
    protected void trigger() { // 触发发送
        if (waiting.isEmpty()) return;
        if (!_writing.compareAndSet(false, true)) return; // 保证同时只一个在写入
        WriteRecord record = waiting.poll();
        if (record != null) {
            if (_closed.get()) { // 关闭时, 通知剩下没发送的record#failFn
                if (record.failFn != null) record.failFn.accept(new ClosedChannelException(), this);
                _writing.set(false);
                trigger();
            }
            else channel.write(record.poll(), record, writeHandler);
        } else {
            _writing.set(false);
            trigger();
        }
    }


    /**
     * 继续处理接收数据
     */
    protected void read() {
        if (_closed.get()) return;
        if (!channel.isOpen()) { close(); return; }
        channel.read(buf, buf, readHandler);
    }


    /**
     * 远程连接地址
     */
    public String getRemoteAddress() {
        try {
            return channel.getRemoteAddress().toString();
        } catch (IOException e) {
            log.error("",e);
        }
        return null;
    }


    /**
     * 本地连接地址
     */
    public String getLocalAddress() {
        try {
            return channel.getLocalAddress().toString();
        } catch (IOException e) {
            log.error("",e);
        }
        return null;
    }


    @Override
    public String toString() { return XioSession.class.getSimpleName() + "@" + Integer.toHexString(hashCode()) + "[" + channel.toString() + "]"; }


    /**
     * 每次写入创建一个
     */
    protected class WriteRecord implements AutoCloseable {
        protected final BiConsumer<Throwable, XioSession> failFn;
        protected final Runnable        okFn;
        // 每个域的数据内容
        protected final List<XioStream> fields;
        // 分批写时,每批写的大小
        protected final int             perSendLen = handler.getAttr("perSendLen", Integer.class, 1024 * 10);
        // 当前正在往通道中写的数据
        protected       ByteBuffer      data;

        public WriteRecord(List<XioStream> fields, BiConsumer<Throwable, XioSession> failFn, Runnable okFn) {
            this.failFn = failFn;
            this.okFn = okFn;
            this.fields = new LinkedList<>(fields);
            if (fields.isEmpty()) throw new IllegalArgumentException("Not found field");
            if (fields.size() > Byte.MAX_VALUE) throw new IllegalArgumentException("Too many field, must < " + Byte.MAX_VALUE);
        }

        /**
         * 弹出下一个需要写入通道的数据
         * @return null: 证明已没有可写的数据了
         */
        ByteBuffer poll() {
            XioStream field = fields.stream().filter(o -> !o.isEnd()).findFirst().orElse(null);
            if (field == null) return null; // 数据已发送完
            // linux aio 可能没把数据写完, 需要手动写完
            if (data != null && data.hasRemaining()) return data;
            try {
                if (fields.get(0).readCnt == 0) { // 首块数据
                    byte[] bs = new byte[field.length > perSendLen ? perSendLen : (int) field.length];
                    data = ByteBuffer.allocate(field.read(bs) + 1 + (fields.size() * 4));
                    // 多少域[byte],[第一域长度[int], 第二域长度[int]...], [第一域内容,第二域内容...]
                    data.put((byte) fields.size()); // 多少域
                    fields.forEach(f -> data.putInt((int) f.length)); // 每个域的长度
                    data.put(bs).flip();
                } else {
                    long left = field.available(); // 流剩下的没读完
                    byte[] bs = new byte[left > perSendLen ? perSendLen : (int) left];
                    field.read(bs);
                    data = ByteBuffer.wrap(bs);
                }
            } catch (IOException ex) {
                throw new RuntimeException(ex);
            }
            return data;
        }

        @Override
        public void close() {
            for (XioStream field : fields) field.close();
        }
    }


    /**
     * 数据读取过程临时变量
     */
    protected class ReadRecord implements AutoCloseable {
        // 消息有多少域
        protected final byte            fieldCnt;
        // 多域数据流
        protected final List<XioStream> fields;

        public ReadRecord(byte fieldCnt) {
            if (fieldCnt < 1) throw new RuntimeException("Msg no field");
            this.fieldCnt = fieldCnt;
            this.fields = new ArrayList<>(fieldCnt);
        }

        @Override
        public void close() {
            for (XioStream stream : fields) stream.close();
        }
    }
}
