import json
import numpy as np
import sys
import torch
from simpletransformers.classification import ClassificationModel
from struct import *


def decodeString(buffer):
    """
    The input comes as a byte array. The bytes first contain
    the length of the message that follows. The message itself is
    just the sentence to tag (with tokens separated by whitespace)
    in UTF-8 encoding.
    :param buffer: The byte-encoded classification input. Two sentences, one label.
    :return: The original input string
    """
    lengthBuffer = bytearray(4)
    buffer.readinto(lengthBuffer)
    length = int.from_bytes(lengthBuffer, 'big')
    content = bytearray(length)
    buffer.readinto(content)
    return content.decode("utf-8")

modelType = sys.argv[1]
modelPath = sys.argv[2]
gpuNum = sys.argv[3]

useCuda = False
if torch.cuda.is_available():
    useCuda = True;

model_args = {
    "dataloader_num_workers": 1,
    "silent": True
}
model = ClassificationModel(modelType, modelPath, use_cuda=useCuda, cuda_device=gpuNum, args=model_args)

print("Ready for tagging.")
stdbuffer = sys.stdin.buffer
while True:
    # Sentence pair input
    line = decodeString(stdbuffer)
    if line.strip() == "exit":
        sys.exit(0)
    sentencePairClassificationRequests = json.loads(line)
    pairsList = []
    for sentencePair in sentencePairClassificationRequests:
        #pid      = sentencePair['pid']
        left     = sentencePair['left']
        right    = sentencePair['right']
        pairsList.append([left, right])
    labels, raw_outputs = model.predict(pairsList)
    # Compute softmax probabilities for the '1' label. The respective raw values are stored in the second column
    sum = np.sum(np.exp(raw_outputs[:,1]))
    probs = np.exp(raw_outputs[:,1])/sum
    print("result:" + json.dumps(probs.tolist()))