package com.unbound.provider.kmip;

import com.unbound.common.HEX;
import com.unbound.provider.kmip.attribute.BigNumAttribute;
import com.unbound.provider.kmip.request.RequestItem;
import com.unbound.provider.kmip.request.RequestMessage;
import com.unbound.provider.kmip.response.ResponseItem;
import com.unbound.provider.kmip.response.ResponseMessage;

import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;

import static com.unbound.common.Converter.*;

/**
 * Created by valery.osheter on 18-Nov-15.
 */
public class KMIPConverter
{
  private boolean write = false;
  private int offset = 0;
  private int size = 0;
  private byte[] pointer = null;

  public KMIPConverter()
  {
  }

  public KMIPConverter(byte[] in)
  {
    pointer = in;
    size = in.length;
  }

  private KMIPConverter(boolean write, byte[] pointer, int offset, int size)
  {
    this.write = write;
    this.pointer = pointer;
    this.offset = offset;
    this.size = size;
  }

  private static int padLength(int length)
  {
    return (length + 7) & ~7;
  }

  public int getOffset()
  {
    return offset;
  }

  public int getNextTag()
  {
    if (offset + 8 > size) return 0;
    return getBE4(pointer, offset);
  }

  public static void setError(String message) throws KMIPConvertException
  {
    throw new KMIPConvertException(message);
  }

  private int readTag(int tag) throws KMIPConvertException
  {
    if (offset + 8 > size) setError("Invalid length");
    int tag2 = getBE4(pointer, offset);
    if (tag2 != tag) setError("Unexpected tag 0x" + HEX.toString(tag2));

    int length = getBE4(pointer, offset + 4);
    if (offset + 8 + padLength(length) > size) setError("Invalid length");

    byte tagMode = (byte) (tag >> 24);
    if (tagMode != 0x42 && tagMode != 0x54) setError("Unexpected tag 0x" + HEX.toString(tag2));
    byte tagType = (byte) tag;

    switch (tagType)
    {
      case KMIP.TagType.LongInteger:
        break;
      case KMIP.TagType.Boolean:
        break;
      case KMIP.TagType.DateTime:
        if (length != 8) setError("Invalid date length");
        break;

      case KMIP.TagType.Integer:
        break;
      case KMIP.TagType.Enumeration:
        break;
      case KMIP.TagType.Interval:
        if (length != 4) setError("Invalid interval length");
        break;

      case KMIP.TagType.Structure:
        break;
      case KMIP.TagType.TextString:
        break;
      case KMIP.TagType.ByteString:
        break;
      case KMIP.TagType.BigInteger:
        break;

      default:
        setError("Invalid tag type 0x" + HEX.toString(tagType));
    }

    offset += 8;
    return length;
  }

  public int convertTag(int tag, int length, boolean present) throws KMIPConvertException
  {
    if (write)
    {
      if (!present) return -1;
      if (pointer != null)
      {
        setBE4(pointer, offset, tag);
        setBE4(pointer, offset + 4, length);
      }
      offset += 8;
      return length;
    }

    int tag2 = getNextTag();
    present = tag2 == tag;
    if (!present) return -1;

    return readTag(tag);
  }

  public void skip(int tag) throws KMIPConvertException
  {
    assert (!write);

    while (tag == getNextTag())
    {
      int length = readTag(tag);
      if (length < 0) break;
      int pad = padLength(length);
      offset += pad;
    }
  }

  public Boolean convertOptional(int tag, Boolean value) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.Boolean);

    if (0 > convertTag(tag, 8, value != null)) return null;

    if (write)
    {
      if (pointer != null)
      {
        pointer[offset + 0] = pointer[offset + 1] = pointer[offset + 2] = pointer[offset + 3] = pointer[offset + 4] = pointer[offset + 5] = pointer[offset + 6] = (byte) 0;
        pointer[offset + 7] = (value!=null && value) ? (byte) 1 : (byte) 0;
      }
    }
    else
    {
      value = pointer[offset + 7] != 0;
    }

    offset += 8;
    return value;
  }

  public Boolean convert(int tag, Boolean value) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    Boolean newValue = convertOptional(tag, value);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public Integer convertOptional(int tag, Integer value) throws KMIPConvertException
  {
    return convertOptional(tag, value, 0);
  }

  public Integer convertOptional(int tag, Integer value, int actualTag) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.Integer || (byte) tag == KMIP.TagType.Enumeration || (byte) tag == KMIP.TagType.Interval);

    if (0 > convertTag(tag, 4, value != null)) return null;

    if (write)
    {
      if (pointer != null)
      {
        setBE4(pointer, offset, value==null ? 0 : value);
        pointer[offset + 4] = pointer[offset + 5] = pointer[offset + 6] = pointer[offset + 7] = (byte) 0;
      }
    }
    else
    {
      value = getBE4(pointer, offset);
    }

    offset += 8;
    return value;
  }

  public Integer convert(int tag, Integer value) throws KMIPConvertException
  {
    return convert(tag, value, 0);
  }

  public Integer convert(int tag, Integer value, int actualTag) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    Integer newValue = convertOptional(tag, value, actualTag);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public Long convertOptional(int tag, Long value) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.LongInteger || (byte) tag == KMIP.TagType.DateTime);

    if (0 > convertTag(tag, 8, value != null)) return null;

    if (write)
    {
      if (pointer != null) setBE8(pointer, offset, value==null ? 0 : value);
    }
    else
    {
      value = getBE8(pointer, offset);
    }

    offset += 8;
    return value;
  }

  public Long convert(int tag, Long value) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    Long newValue = convertOptional(tag, value);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public String convertOptional(int tag, String value) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.TextString);

    byte[] utf8 = null;
    if (value != null) utf8 = value.getBytes(StandardCharsets.UTF_8);

    int length = utf8 == null ? 0 : utf8.length;
    length = convertTag(tag, length, value != null);
    if (length < 0) return null;
    int pad = padLength(length);

    if (write)
    {
      if (pointer != null)
      {
        System.arraycopy(utf8, 0, pointer, offset, length);
        Arrays.fill(pointer, offset + length, offset + pad, (byte) 0);
      }
    }
    else
    {
      value = new String(pointer, offset, length, StandardCharsets.UTF_8);
    }

    offset += pad;
    return value;
  }

  public String convert(int tag, String value) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    String newValue = convertOptional(tag, value);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public byte[] convertOptional(int tag, byte[] value) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.ByteString);

    int length = value == null ? 0 : value.length;
    length = convertTag(tag, length, value != null);
    if (length < 0) return null;
    int pad = padLength(length);

    if (write)
    {
      if (pointer != null)
      {
        System.arraycopy(value, 0, pointer, offset, length);
        Arrays.fill(pointer, offset + length, offset + pad, (byte) 0);
      }
    }
    else
    {
      value = new byte[length];
      System.arraycopy(pointer, offset, value, 0, length);
    }

    offset += pad;
    return value;
  }

  public byte[] convert(int tag, byte[] value) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    byte[] newValue = convertOptional(tag, value);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public BigInteger convertOptional(int tag, BigInteger value) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.BigInteger);

    int length = 0;
    if (write && value != null)
    {
      int bits = value.bitLength();
      int num = (bits + 7) / 8;
      length = padLength(num + (((bits & 0x07) == 0) ? 1 : 0));
    }

    length = convertTag(tag, length, value != null);
    if (length < 0) return null;
    int pad = padLength(length);

    if (write)
    {
      if (pointer != null)
      {
        BigNumAttribute.toBin(value, pointer, offset, pad);
      }
    }
    else
    {
      value = BigNumAttribute.fromBin(pointer, offset, length);
    }

    offset += pad;
    return value;
  }

  public BigInteger convert(int tag, BigInteger value) throws KMIPConvertException
  {
    if (write && value == null) setError("Expected write for tag 0x" + HEX.toString(tag));
    BigInteger newValue = convertOptional(tag, value);
    if (!write && newValue == null) setError("Expected read for tag 0x" + HEX.toString(tag));
    return newValue;
  }

  public <T extends Enum<T>> T convert(int tag, Enum en, Class<T> enumType) throws KMIPConvertException
  {
    String eval = en == null ? null : en.toString();
    String enumToString = convert(tag, eval);
    if (enumToString != null && !enumToString.isEmpty())
      return Enum.valueOf(enumType, enumToString);
    return null;
  }

  int convertBeginOptional(int tag, boolean present) throws KMIPConvertException
  {
    assert ((byte) tag == KMIP.TagType.Structure);

    int length = convertTag(tag, 0, present);
    if (length < 0) return -1;
    return offset - 8;
  }

  public int convertBegin(int tag) throws KMIPConvertException
  {
    int result = convertBeginOptional(tag, true);
    if (result < 0) setError("Expected tag 0x" + HEX.toString(tag));
    return result;
  }

  public void convertEnd(int begin) throws KMIPConvertException
  {
    if (begin < 0) return;
    if (write)
    {
      int length = offset - (begin + 8);
      int pad = padLength(length);
      if (pointer != null)
      {
        setBE4(pointer, begin + 4, length);
        Arrays.fill(pointer, offset, offset + pad - length, (byte) 0);
      }
      offset += pad - length;
    }
    else
    {
      int length = getBE4(pointer, begin + 4);
      int pad = padLength(length);
      offset = begin + 8 + pad;
    }
  }

  public void convertIntList(int tag, List<Integer> list) throws KMIPConvertException
  {
    if (write) for (Integer value : list) convert(tag, value);
    else while (getNextTag() == tag) list.add(convert(tag, (Integer) null));
  }

  public void convertStrList(int tag, List<String> list) throws KMIPConvertException
  {
    if (write) for (String value : list) convert(tag, value);
    else while (getNextTag() == tag) list.add(convert(tag, (String) null));
  }

  public void convertBufList(int tag, List<byte[]> list) throws KMIPConvertException
  {
    if (write) for (byte[] value : list) convert(tag, value);
    else while (getNextTag() == tag) list.add(convert(tag, (byte[]) null));
  }

  public boolean isWrite()
  {
    return write;
  }

  public boolean isRead()
  {
    return !isWrite();
  }

  public RequestItem convertRequestItem(RequestItem requestItem) throws KMIPConvertException
  {
    boolean present = requestItem != null;
    present = convert(KMIP.Tag.DyPresent, present);

    if (present)
    {

      if (write)
      {
        RequestItem.convert(this, requestItem);
      }
      else
      {
        requestItem = RequestItem.convert(this, null);
        if (requestItem == null) throw new KMIPConvertException();
      }

    }
    return requestItem;
  }

  public ResponseItem convertResponseItem(ResponseItem responseItem) throws KMIPConvertException
  {
    boolean present = responseItem != null;
    present = convert(KMIP.Tag.DyPresent, present);

    if (present)
    {
      if (write)
      {
        ResponseItem.convert(this, responseItem);
      }
      else
      {
        responseItem = ResponseItem.convert(this, null);
        if (responseItem == null) throw new KMIPConvertException();
      }
    }
    return responseItem;
  }

  public static byte[] convert(RequestMessage req) throws KMIPConvertException
  {
    KMIPConverter converter = new KMIPConverter();
    converter.write = true;
    converter.offset = 0;
    converter.size = 0;
    converter.pointer = null;
    req.convert(converter);
    converter.pointer = new byte[converter.offset];
    converter.offset = 0;
    req.convert(converter);
    return converter.pointer;
  }

  public static byte[] convert(ResponseMessage resp) throws KMIPConvertException
  {
    KMIPConverter converter = new KMIPConverter();
    converter.write = true;
    converter.offset = 0;
    converter.size = 0;
    converter.pointer = null;
    resp.convert(converter);
    converter.pointer = new byte[converter.offset];
    converter.offset = 0;
    resp.convert(converter);
    return converter.pointer;
  }

  public byte[] convertCustomAttribute(byte[] value) throws KMIPConvertException
  {
    if (write)
    {
      if (pointer != null)
      {
        System.arraycopy(value, 0, pointer, offset, value.length);
      }
      offset += value.length;
    }
    else
    {
      int begin = offset;
      int tag = getNextTag();
      skip(tag);
      int length = offset - begin;
      value = new byte[length];
      System.arraycopy(pointer, begin, value, 0, length);
    }
    return value;
  }

  public static RequestMessage convertRequestMessage(byte[] in) throws KMIPConvertException
  {
    KMIPConverter converter = new KMIPConverter();
    converter.write = false;
    converter.pointer = in;
    converter.offset = 0;
    converter.size = in.length;
    RequestMessage req = new RequestMessage();
    req.convert(converter);
    return req;
  }

  public static ResponseMessage convertResponseMessage(byte[] in) throws KMIPConvertException
  {
    KMIPConverter converter = new KMIPConverter();
    converter.write = false;
    converter.pointer = in;
    converter.offset = 0;
    converter.size = in.length;
    ResponseMessage resp = new ResponseMessage();
    resp.convert(converter);
    return resp;
  }
}
