/*
 *      Copyright (C) 2012-2017 DataStax Inc.
 *
 *      This software can be used solely with DataStax Enterprise. Please consult the license at
 *      http://www.datastax.com/terms/datastax-dse-driver-license-terms
 */
package com.datastax.driver.core;

import com.datastax.driver.core.exceptions.DriverInternalError;
import com.datastax.driver.core.exceptions.UnsupportedFeatureException;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.util.AttributeKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.nio.ByteBuffer;
import java.util.*;

/**
 * A message from the CQL binary protocol.
 */
abstract class Message {

    protected static final Logger logger = LoggerFactory.getLogger(Message.class);

    static AttributeKey<CodecRegistry> CODEC_REGISTRY_ATTRIBUTE_KEY = AttributeKey.valueOf("com.datastax.driver.core.CodecRegistry");

    interface Coder<R extends Request> {
        void encode(R request, ByteBuf dest, ProtocolVersion version);

        int encodedSize(R request, ProtocolVersion version);
    }

    interface Decoder<R extends Response> {
        R decode(ByteBuf body, ProtocolVersion version, CodecRegistry codecRegistry);
    }

    private volatile int streamId = -1;

    /**
     * A generic key-value custom payload. Custom payloads are simply
     * ignored by the default QueryHandler implementation server-side.
     *
     * @since Protocol V4
     */
    private volatile Map<String, ByteBuffer> customPayload;

    protected Message() {
    }

    Message setStreamId(int streamId) {
        this.streamId = streamId;
        return this;
    }

    int getStreamId() {
        return streamId;
    }

    Map<String, ByteBuffer> getCustomPayload() {
        return customPayload;
    }

    Message setCustomPayload(Map<String, ByteBuffer> customPayload) {
        this.customPayload = customPayload;
        return this;
    }

    static abstract class Request extends Message {

        enum Type {
            // public requests
            STARTUP(1, Requests.Startup.coder),
            CREDENTIALS(4, Requests.Credentials.coder),
            OPTIONS(5, Requests.Options.coder),
            QUERY(7, Requests.Query.coder),
            PREPARE(9, Requests.Prepare.coder),
            EXECUTE(10, Requests.Execute.coder),
            REGISTER(11, Requests.Register.coder),
            BATCH(13, Requests.Batch.coder),
            AUTH_RESPONSE(15, Requests.AuthResponse.coder),

            // private requests
            CANCEL(255, Requests.Cancel.coder);

            final int opcode;
            final Coder<?> coder;

            Type(int opcode, Coder<?> coder) {
                this.opcode = opcode;
                this.coder = coder;
            }
        }

        final Type type;
        private final boolean tracingRequested;

        protected Request(Type type) {
            this(type, false);
        }

        protected Request(Type type, boolean tracingRequested) {
            this.type = type;
            this.tracingRequested = tracingRequested;
        }

        @Override
        Request setStreamId(int streamId) {
            // JAVA-1179: defensively guard against reusing the same Request object twice.
            // If no streamId was ever set we can use this object directly, otherwise make a copy.
            if (getStreamId() < 0)
                return (Request) super.setStreamId(streamId);
            else {
                Request copy = this.copy();
                copy.setStreamId(streamId);
                return copy;
            }
        }

        boolean isTracingRequested() {
            return tracingRequested;
        }

        ConsistencyLevel consistency() {
            switch (this.type) {
                case QUERY:
                    return ((Requests.Query) this).options.consistency;
                case EXECUTE:
                    return ((Requests.Execute) this).options.consistency;
                case BATCH:
                    return ((Requests.Batch) this).options.consistency;
                default:
                    return null;
            }
        }

        ConsistencyLevel serialConsistency() {
            switch (this.type) {
                case QUERY:
                    return ((Requests.Query) this).options.serialConsistency;
                case EXECUTE:
                    return ((Requests.Execute) this).options.serialConsistency;
                case BATCH:
                    return ((Requests.Batch) this).options.serialConsistency;
                default:
                    return null;
            }
        }

        long defaultTimestamp() {
            switch (this.type) {
                case QUERY:
                    return ((Requests.Query) this).options.defaultTimestamp;
                case EXECUTE:
                    return ((Requests.Execute) this).options.defaultTimestamp;
                case BATCH:
                    return ((Requests.Batch) this).options.defaultTimestamp;
                default:
                    return 0;
            }
        }

        ByteBuffer pagingState() {
            switch (this.type) {
                case QUERY:
                    return ((Requests.Query) this).options.pagingState;
                case EXECUTE:
                    return ((Requests.Execute) this).options.pagingState;
                default:
                    return null;
            }
        }

        Request copy() {
            Request request = copyInternal();
            request.setCustomPayload(this.getCustomPayload());
            return request;
        }

        protected abstract Request copyInternal();

        Request copy(ConsistencyLevel newConsistencyLevel) {
            Request request = copyInternal(newConsistencyLevel);
            request.setCustomPayload(this.getCustomPayload());
            return request;
        }

        protected Request copyInternal(ConsistencyLevel newConsistencyLevel) {
            throw new UnsupportedOperationException();
        }
    }

    static abstract class Response extends Message {

        enum Type {
            ERROR(0, Responses.Error.decoder),
            READY(2, Responses.Ready.decoder),
            AUTHENTICATE(3, Responses.Authenticate.decoder),
            SUPPORTED(6, Responses.Supported.decoder),
            RESULT(8, Responses.Result.decoder),
            EVENT(12, Responses.Event.decoder),
            AUTH_CHALLENGE(14, Responses.AuthChallenge.decoder),
            AUTH_SUCCESS(16, Responses.AuthSuccess.decoder);

            final int opcode;
            final Decoder<?> decoder;

            private static final Type[] opcodeIdx;

            static {
                int maxOpcode = -1;
                for (Type type : Type.values())
                    maxOpcode = Math.max(maxOpcode, type.opcode);
                opcodeIdx = new Type[maxOpcode + 1];
                for (Type type : Type.values()) {
                    if (opcodeIdx[type.opcode] != null)
                        throw new IllegalStateException("Duplicate opcode");
                    opcodeIdx[type.opcode] = type;
                }
            }

            Type(int opcode, Decoder<?> decoder) {
                this.opcode = opcode;
                this.decoder = decoder;
            }

            static Type fromOpcode(int opcode) {
                if (opcode < 0 || opcode >= opcodeIdx.length)
                    throw new DriverInternalError(String.format("Unknown response opcode %d", opcode));
                Type t = opcodeIdx[opcode];
                if (t == null)
                    throw new DriverInternalError(String.format("Unknown response opcode %d", opcode));
                return t;
            }
        }

        final Type type;
        protected volatile UUID tracingId;
        protected volatile List<String> warnings;

        protected Response(Type type) {
            this.type = type;
        }

        Response setTracingId(UUID tracingId) {
            this.tracingId = tracingId;
            return this;
        }

        UUID getTracingId() {
            return tracingId;
        }

        Response setWarnings(List<String> warnings) {
            this.warnings = warnings;
            return this;
        }
    }

    @ChannelHandler.Sharable
    static class ProtocolDecoder extends MessageToMessageDecoder<Frame> {

        @Override
        protected void decode(ChannelHandlerContext ctx, Frame frame, List<Object> out) throws Exception {
            boolean isTracing = frame.header.flags.contains(Frame.Header.Flag.TRACING);
            boolean isCustomPayload = frame.header.flags.contains(Frame.Header.Flag.CUSTOM_PAYLOAD);
            UUID tracingId = isTracing ? CBUtil.readUUID(frame.body) : null;
            Map<String, ByteBuffer> customPayload = isCustomPayload ? CBUtil.readBytesMap(frame.body) : null;

            if (customPayload != null && logger.isTraceEnabled()) {
                logger.trace("Received payload: {} ({} bytes total)", printPayload(customPayload), CBUtil.sizeOfBytesMap(customPayload));
            }

            boolean hasWarnings = frame.header.flags.contains(Frame.Header.Flag.WARNING);
            List<String> warnings = hasWarnings ? CBUtil.readStringList(frame.body) : Collections.<String>emptyList();

            try {
                CodecRegistry codecRegistry = ctx.channel().attr(CODEC_REGISTRY_ATTRIBUTE_KEY).get();
                assert codecRegistry != null;
                Response response = Response.Type.fromOpcode(frame.header.opcode).decoder.decode(frame.body, frame.header.version, codecRegistry);
                response
                        .setTracingId(tracingId)
                        .setWarnings(warnings)
                        .setCustomPayload(customPayload)
                        .setStreamId(frame.header.streamId);
                out.add(response);
            } finally {
                frame.body.release();
            }
        }

    }

    @ChannelHandler.Sharable
    static class ProtocolEncoder extends MessageToMessageEncoder<Request> {

        private final ProtocolVersion protocolVersion;

        ProtocolEncoder(ProtocolVersion version) {
            this.protocolVersion = version;
        }

        @Override
        protected void encode(ChannelHandlerContext ctx, Request request, List<Object> out) throws Exception {
            EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class);
            if (request.isTracingRequested())
                flags.add(Frame.Header.Flag.TRACING);
            if (protocolVersion == ProtocolVersion.NEWEST_BETA)
                flags.add(Frame.Header.Flag.USE_BETA);
            Map<String, ByteBuffer> customPayload = request.getCustomPayload();
            if (customPayload != null) {
                if (protocolVersion.compareTo(ProtocolVersion.V4) < 0)
                    throw new UnsupportedFeatureException(
                            protocolVersion,
                            "Custom payloads are only supported since native protocol V4");
                flags.add(Frame.Header.Flag.CUSTOM_PAYLOAD);
            }

            @SuppressWarnings("unchecked")
            Coder<Request> coder = (Coder<Request>) request.type.coder;
            int messageSize = coder.encodedSize(request, protocolVersion);
            int payloadLength = -1;
            if (customPayload != null) {
                payloadLength = CBUtil.sizeOfBytesMap(customPayload);
                messageSize += payloadLength;
            }
            ByteBuf body = ctx.alloc().buffer(messageSize);
            if (customPayload != null) {
                CBUtil.writeBytesMap(customPayload, body);
                if (logger.isTraceEnabled()) {
                    logger.trace("Sending payload: {} ({} bytes total)", printPayload(customPayload), payloadLength);
                }
            }

            coder.encode(request, body, protocolVersion);
            out.add(Frame.create(protocolVersion, request.type.opcode, request.getStreamId(), flags, body));
        }
    }

    // private stuff to debug custom payloads

    private static final char[] hexArray = "0123456789ABCDEF".toCharArray();

    static String printPayload(Map<String, ByteBuffer> customPayload) {
        if (customPayload == null)
            return "null";
        if (customPayload.isEmpty())
            return "{}";
        StringBuilder sb = new StringBuilder("{");
        Iterator<Map.Entry<String, ByteBuffer>> iterator = customPayload.entrySet().iterator();
        while (iterator.hasNext()) {
            Map.Entry<String, ByteBuffer> entry = iterator.next();
            sb.append(entry.getKey());
            sb.append(":");
            if (entry.getValue() == null)
                sb.append("null");
            else
                bytesToHex(entry.getValue(), sb);
            if (iterator.hasNext())
                sb.append(", ");
        }
        sb.append("}");
        return sb.toString();
    }

    // this method doesn't modify the given ByteBuffer
    static void bytesToHex(ByteBuffer bytes, StringBuilder sb) {
        int length = Math.min(bytes.remaining(), 50);
        sb.append("0x");
        for (int i = 0; i < length; i++) {
            int v = bytes.get(i) & 0xFF;
            sb.append(hexArray[v >>> 4]);
            sb.append(hexArray[v & 0x0F]);
        }
        if (bytes.remaining() > 50)
            sb.append("... [TRUNCATED]");
    }
}
