import flair
import json
import sys
import torch
from flair.data import Sentence
from flair.models import SequenceTagger
from struct import *


# The input comes as a byte array. The bytes first contain
# the length of the message that follows. The message itself is
# just the sentence to tag (with tokens separated by whitespace)
# in UTF-8 encoding.
def decodeString(buffer):
    lengthBuffer = bytearray(4)
    buffer.readinto(lengthBuffer)
    length = int.from_bytes(lengthBuffer, 'big')
    content = bytearray(length)
    buffer.readinto(content)
    return content.decode("utf-8")

taggerPath = sys.argv[1]
gpuNum = sys.argv[2]

if torch.cuda.is_available():
    flair.device = torch.device("cuda:"+gpuNum)

tagger = SequenceTagger.load(taggerPath)

print("Ready for tagging.")
stdbuffer = sys.stdin.buffer
while True:
    # Sentence input
    line = decodeString(stdbuffer)
    if line.strip() == "exit":
        sys.exit(0)
    sentenceTaggingRequests = json.loads(line)
    taggedEntities = []
    for sentenceToTag in sentenceTaggingRequests:
        sid      = sentenceToTag['sid']
        sentence = Sentence(sentenceToTag['text'])
        tagger.predict(sentence)

        tags = [e.tag for e in sentence.get_spans("ner") for t in e.tokens]
        taggedEntities.append("".join(tags))
    print(json.dumps(taggedEntities))