package cn.schoolwow.quickserver.websocket.stream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.TreeMap;

/**WebSocket流*/
public class WebSocketStreamImpl implements WebSocketStream {
    private Logger logger = LoggerFactory.getLogger(WebSocketStreamImpl.class);

    /**套接字*/
    private Socket socket;

    /**输入流*/
    private DataInputStream in;

    private ByteArrayOutputStream baos;

    /**输出流*/
    private DataOutputStream out;

    private Scanner scanner;

    public WebSocketStreamImpl() {
        this.baos = new ByteArrayOutputStream();
        this.out = new DataOutputStream(baos);
    }

    public WebSocketStreamImpl(Socket socket) throws IOException {
        this.socket = socket;
        this.in = new DataInputStream(socket.getInputStream());
        this.out = new DataOutputStream(socket.getOutputStream());
        this.scanner = new Scanner(this.in);
    }

    @Override
    public int read() throws IOException {
        return in.read();
    }

    @Override
    public int read(byte[] b) throws IOException {
        int length = in.read(b);
        if(length!=b.length){
            throw new IllegalArgumentException("读取字节数组失败！期望读取长度:" + b.length + ",实际读取长度:" + length);
        }
        return length;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        int length = in.read(b, off , len);
        if(length!=len){
            throw new IllegalArgumentException("读取字节数组失败！期望读取长度:" + len + ",实际读取长度:" + length);
        }
        return length;
    }

    @Override
    public byte[] readByteArray(int length) throws IOException {
        byte[] buffer = new byte[8192];
        ByteArrayOutputStream baos1 = new ByteArrayOutputStream();
        int actualLength = 0, totalLength = length;
        while((actualLength=in.read(buffer,0,Math.min(buffer.length,totalLength)))!=0){
            baos1.write(buffer,0,actualLength);
            totalLength = totalLength-actualLength;
        }
        if(baos1.size()!=length){
            throw new IllegalArgumentException("读取指定长度字节失败!期望字节长度:" + length + ",当前字节长度:" + baos1.size());
        }
        return baos1.toByteArray();
    }

    @Override
    public TreeMap<String, List<String>> readHeaders() throws IOException {
        String line = readLine();
        String lastPutName = null;
        TreeMap<String,List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
        while(!"".equals(line)){
            if(line.contains(":")&&Character.isLetter(line.charAt(0))){
                String name = line.substring(0,line.indexOf(':'));
                String value = line.substring(name.length()+1).trim();
                List<String> valueList = headers.get(name);
                if(null==valueList){
                    valueList = new ArrayList<>();
                }
                valueList.add(value);
                headers.put(name,valueList);
                lastPutName = name;
            }else if(null!=lastPutName){
                List<String> valueList = headers.get(lastPutName);
                String value = valueList.get(valueList.size()-1).trim();
                value += line;
                valueList.set(valueList.size()-1,value);
            }
            if(scanner.hasNextLine()){
                line = readLine();
            }else{
                break;
            }
        }
        return headers;
    }

    @Override
    public int[] readBitByte() throws IOException{
        return readBitByte(1);
    }

    @Override
    public int[] readBitByte(int nByte) throws IOException{
        int[] bits = new int[nByte*8];
        for(int i=0;i<nByte;i++){
            int b = in.read();
            if(b==-1){
                throw new IOException("输入流读取到末尾了!");
            }
            for(int j=0;j<8;j++){
                int n = 0x01<<(7-j);
                int result = b&n;
                bits[i*8+j] = (result==n?1:0);
            }
        }
        return bits;
    }

    @Override
    public void writeLine(String line) throws IOException {
        out.write((line+"\r\n").getBytes(StandardCharsets.UTF_8));
    }

    @Override
    public void writeBit(int[] bits) throws IOException{
        if(bits.length%8!=0){
            throw new IllegalArgumentException("bits参数长度必须为8的倍数!");
        }
        int byteCount = bits.length/8;
        for(int i=0;i<byteCount;i++){
            int b = 0;
            for(int j=0;j<8;j++){
                b = b | (bits[i*8+j]<<(7-j));
            }
            write(b);
        }
    }

    @Override
    public void flush() throws IOException {
        out.flush();
    }

    @Override
    public int size() {
        return this.out.size();
    }

    @Override
    public byte[] toByteArray() {
        return baos.toByteArray();
    }

    @Override
    public void readFully(byte[] b) throws IOException {
        in.readFully(b);
    }

    @Override
    public void readFully(byte[] b, int off, int len) throws IOException {
        in.readFully(b,off,len);
    }

    @Override
    public int skipBytes(int n) throws IOException {
        return in.skipBytes(n);
    }

    @Override
    public boolean readBoolean() throws IOException {
        return in.readBoolean();
    }

    @Override
    public byte readByte() throws IOException {
        return in.readByte();
    }

    @Override
    public int readUnsignedByte() throws IOException {
        return in.readUnsignedByte();
    }

    @Override
    public short readShort() throws IOException {
        return in.readShort();
    }

    @Override
    public int readUnsignedShort() throws IOException {
        return in.readUnsignedShort();
    }

    @Override
    public char readChar() throws IOException {
        return in.readChar();
    }

    @Override
    public int readInt() throws IOException {
        return in.readInt();
    }

    @Override
    public long readLong() throws IOException {
        return in.readLong();
    }

    @Override
    public float readFloat() throws IOException {
        return in.readFloat();
    }

    @Override
    public double readDouble() throws IOException {
        return in.readDouble();
    }

    @Override
    public String readLine() throws IOException {
        return scanner.nextLine();
    }

    @Override
    public String readUTF() throws IOException {
        return in.readUTF();
    }

    @Override
    public void write(int b) throws IOException {
        out.write(b);
    }

    @Override
    public void write(byte[] b) throws IOException {
        out.write(b);
    }

    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        out.write(b,off,len);
    }

    @Override
    public void writeBoolean(boolean v) throws IOException {
        out.writeBoolean(v);
    }

    @Override
    public void writeByte(int v) throws IOException {
        out.writeByte(v);
    }

    @Override
    public void writeShort(int v) throws IOException {
        out.writeShort(v);
    }

    @Override
    public void writeChar(int v) throws IOException {
        out.writeChar(v);
    }

    @Override
    public void writeInt(int v) throws IOException {
        out.writeInt(v);
    }

    @Override
    public void writeLong(long v) throws IOException {
        out.writeLong(v);
    }

    @Override
    public void writeFloat(float v) throws IOException {
        out.writeFloat(v);
    }

    @Override
    public void writeDouble(double v) throws IOException {
        out.writeDouble(v);
    }

    @Override
    public void writeBytes(String s) throws IOException {
        out.writeBytes(s);
    }

    @Override
    public void writeChars(String s) throws IOException {
        out.writeChars(s);
    }

    @Override
    public void writeUTF(String s) throws IOException {
        out.writeUTF(s);
    }

    @Override
    public void close(){
        if(null==socket){
            return;
        }
        try {
            in.close();
            out.close();
            scanner.close();
            socket.close();
        } catch (IOException e) {
            logger.error("关闭套接字失败", e);
        }
    }

    @Override
    public Socket getSocket() {
        return this.socket;
    }
}