/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.ae.flairner;

import de.julielab.jcore.ae.annotationadder.AnnotationAdderAnnotator;
import de.julielab.jcore.ae.annotationadder.AnnotationAdderConfiguration;
import de.julielab.jcore.ae.annotationadder.AnnotationAdderHelper;
import de.julielab.jcore.ae.annotationadder.AnnotationOffsetException;
import de.julielab.jcore.ae.annotationadder.annotationrepresentations.TextAnnotation;
import de.julielab.jcore.ae.flairner.NerTaggingResponse;
import de.julielab.jcore.ae.flairner.PythonConnector;
import de.julielab.jcore.ae.flairner.StdioPythonConnector;
import de.julielab.jcore.ae.flairner.TaggedEntity;
import de.julielab.jcore.ae.flairner.TokenEmbedding;
import de.julielab.jcore.types.EmbeddingVector;
import de.julielab.jcore.types.EntityMention;
import de.julielab.jcore.types.Sentence;
import de.julielab.jcore.types.Token;
import de.julielab.jcore.utility.JCoReAnnotationTools;
import de.julielab.jcore.utility.JCoReTools;
import de.julielab.jcore.utility.index.Comparators;
import de.julielab.jcore.utility.index.IndexTermGenerator;
import de.julielab.jcore.utility.index.JCoReTreeMapAnnotationIndex;
import de.julielab.jcore.utility.index.TermGenerators;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_component.JCasAnnotator_ImplBase;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.FeatureStructure;
import org.apache.uima.cas.text.AnnotationIndex;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ResourceMetaData;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.jcas.cas.FSArray;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ResourceMetaData(name="JCoRe Flair Named Entity Recognizer", description="This component starts a child process to a python interpreter and loads a Flair sequence tagging model. Sentences are taken from the CAS, sent to Flair for tagging and the results are written into the CAS. The annotation type to use can be configured. It must be a subtype of de.julielab.jcore.types.EntityMention. The tag of each entity is written to the specificType feature.")
@TypeCapability(inputs={"de.julielab.jcore.types.Sentence", "de.julielab.jcore.types.Token"})
public class FlairNerAnnotator
extends JCasAnnotator_ImplBase {
    public static final String PARAM_ANNOTATION_TYPE = "AnnotationType";
    public static final String PARAM_FLAIR_MODEL = "FlairModel";
    public static final String PARAM_PYTHON_EXECUTABLE = "PythonExecutable";
    public static final String PARAM_STORE_EMBEDDINGS = "StoreEmbeddings";
    public static final String PARAM_GPU_NUM = "GpuNumber";
    public static final String PARAM_COMPONENT_ID = "ComponentId";
    public static final String GPU_NUM_SYS_PROP = "flairner.device";
    private static final Logger log = LoggerFactory.getLogger(FlairNerAnnotator.class);
    private PythonConnector connector;
    @ConfigurationParameter(name="AnnotationType", description="The UIMA type of which annotations should be created, e.g. de.julielab.jcore.types.EntityMention, of which the given type must be a subclass of. The tag of the entities is written to the specificType feature.")
    private String entityClass;
    @ConfigurationParameter(name="FlairModel", description="Path to the Flair sequence tagger model.")
    private String flairModel;
    @ConfigurationParameter(name="PythonExecutable", mandatory=false, description="The path to the python executable. Required is a python verion >=3.6. Defaults to 'python'.")
    private String pythonExecutable;
    @ConfigurationParameter(name="StoreEmbeddings", mandatory=false, description="Optional. Possible values: ALL, ENTITIES, NONE. The FLAIR SequenceTagger first computes the embeddings for each sentence and uses those as input for the actual NER algorithm. By default, the embeddings are not stored. By setting this parameter to ALL, the embeddings of all tokens of the sentence are retrieved from flair and stored in the embeddingVectors feature of each token. Setting the parameter to ENTITIES will restrict the embedding storage to those tokens which overlap with an entity recognized by FLAIR.")
    private StoreEmbeddings storeEmbeddings;
    @ConfigurationParameter(name="GpuNumber", mandatory=false, defaultValue={"0"}, description="Specifies the GPU device number to be used for FLAIR. This setting can be overwritten by the Java system property 'flairner.device'.")
    private int gpuNum;
    @ConfigurationParameter(name="ComponentId", mandatory=false, description="Specifies the componentId feature value given to the created annotations. Defaults to 'FlairNerAnnotator'.")
    private String componentId;
    private AnnotationAdderConfiguration adderConfig;

    public void initialize(UimaContext aContext) throws ResourceInitializationException {
        Optional<String> pythonExecutableOpt;
        this.entityClass = (String)aContext.getConfigParameterValue(PARAM_ANNOTATION_TYPE);
        this.flairModel = (String)aContext.getConfigParameterValue(PARAM_FLAIR_MODEL);
        this.storeEmbeddings = StoreEmbeddings.valueOf(Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_STORE_EMBEDDINGS)).orElse(StoreEmbeddings.NONE.name()));
        this.gpuNum = Optional.ofNullable((Integer)aContext.getConfigParameterValue(PARAM_GPU_NUM)).orElse(0);
        this.componentId = Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_COMPONENT_ID)).orElse(((Object)((Object)this)).getClass().getSimpleName());
        if (System.getProperty(GPU_NUM_SYS_PROP) != null) {
            try {
                this.gpuNum = Integer.valueOf(System.getProperty(GPU_NUM_SYS_PROP));
                log.info("The GPU device number is set to '" + this.gpuNum + "' by the system property 'flairner.device'. This causes the setting in the UIMA descriptor to be ignored.");
            }
            catch (NumberFormatException e) {
                log.error("The system property 'flairner.device' is set to '" + System.getProperty(GPU_NUM_SYS_PROP) + "' which cannot be parsed to an integer. Please provide the device number of the GPU to use.", (Throwable)e);
            }
        }
        if (!(pythonExecutableOpt = Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_PYTHON_EXECUTABLE))).isPresent()) {
            log.debug("No python executable given in the component descriptor, trying to read PYTHON environment variable.");
            String pythonExecutableEnv = System.getenv("PYTHON");
            if (pythonExecutableEnv != null) {
                this.pythonExecutable = pythonExecutableEnv;
                log.info("Python executable: {} (from environment variable PYTHON).", (Object)this.pythonExecutable);
            }
        } else {
            this.pythonExecutable = pythonExecutableOpt.get();
            log.info("Python executable: {} (from descriptor)", (Object)this.pythonExecutable);
        }
        if (this.pythonExecutable == null) {
            this.pythonExecutable = "python";
            log.info("Python executable: {} (default)", (Object)this.pythonExecutable);
        }
        try {
            this.connector = new StdioPythonConnector(this.flairModel, this.pythonExecutable, this.storeEmbeddings, this.gpuNum);
            this.connector.start();
        }
        catch (IOException e) {
            log.error("Could not start the python connector", (Throwable)e);
            throw new ResourceInitializationException((Throwable)e);
        }
        this.adderConfig = new AnnotationAdderConfiguration();
        this.adderConfig.setOffsetMode(AnnotationAdderAnnotator.OffsetMode.TOKEN);
        this.adderConfig.setSplitTokensAtWhitespace(true);
        this.adderConfig.setDefaultUimaType(this.entityClass);
        log.info("{}: {}", (Object)PARAM_ANNOTATION_TYPE, (Object)this.entityClass);
        log.info("{}: {}", (Object)PARAM_FLAIR_MODEL, (Object)this.flairModel);
        log.info("{}: {}", (Object)PARAM_STORE_EMBEDDINGS, (Object)this.storeEmbeddings);
        log.info("{}: {}", (Object)PARAM_GPU_NUM, (Object)this.gpuNum);
    }

    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        int i = 0;
        AnnotationIndex sentIndex = aJCas.getAnnotationIndex(Sentence.class);
        HashMap<String, Sentence> sentenceMap = new HashMap<String, Sentence>();
        for (Sentence sentence : sentIndex) {
            if (sentence.getId() == null) {
                sentence.setId("s" + i++);
            }
            sentenceMap.put(sentence.getId(), sentence);
        }
        try {
            AnnotationAdderHelper helper = new AnnotationAdderHelper();
            NerTaggingResponse taggingResponse = this.connector.tagSentences(StreamSupport.stream(sentIndex.spliterator(), false));
            List<TaggedEntity> taggedEntities = taggingResponse.getTaggedEntities();
            for (TaggedEntity entity : taggedEntities) {
                Sentence sentence = (Sentence)sentenceMap.get(entity.getDocumentId());
                EntityMention em = (EntityMention)JCoReAnnotationTools.getAnnotationByClassName((JCas)aJCas, (String)this.entityClass);
                helper.setAnnotationOffsetsRelativeToSentence(sentence, (Annotation)em, (TextAnnotation)entity, this.adderConfig);
                em.setSpecificType(entity.getTag());
                em.setConfidence(String.valueOf(entity.getLabelConfidence()));
                em.setComponentId(this.componentId);
                em.addToIndexes();
            }
            this.addTokenEmbeddings(aJCas, sentenceMap, helper, taggingResponse);
        }
        catch (IOException e) {
            log.error("Could not tag entities", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
            log.error("Could not create an instance of the entity class {}", (Object)this.entityClass);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
        catch (CASException e) {
            log.error("Could not set the entity offsets", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
        catch (AnnotationOffsetException e) {
            String docId = JCoReTools.getDocId((JCas)aJCas);
            log.error("Could not set the offsets of an annotation in document {}", (Object)docId);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
    }

    private void addTokenEmbeddings(JCas aJCas, Map<String, Sentence> sentenceMap, AnnotationAdderHelper helper, NerTaggingResponse taggingResponse) throws CASException {
        List<TokenEmbedding> tokenEmbeddings = taggingResponse.getTokenEmbeddings();
        JCoReTreeMapAnnotationIndex tokenIndex = null;
        if (!tokenEmbeddings.isEmpty()) {
            tokenIndex = new JCoReTreeMapAnnotationIndex(Comparators.longOverlapComparator(), (IndexTermGenerator)TermGenerators.longOffsetTermGenerator(), (IndexTermGenerator)TermGenerators.longOffsetTermGenerator(), aJCas, Token.type);
        }
        HashMap<Token, List> originalTokenEmbeddings = new HashMap<Token, List>();
        for (TokenEmbedding tokenEmbedding : tokenEmbeddings) {
            Sentence sentence = sentenceMap.get(tokenEmbedding.getSentenceId());
            List tokens = (List)helper.createSentenceTokenMap(sentence, this.adderConfig).get(sentence);
            Token subtoken = (Token)tokens.get(tokenEmbedding.getTokenId() - 1);
            List overlappingOriginalTokens = tokenIndex.searchFuzzy((Annotation)subtoken).collect(Collectors.toList());
            for (Token originalToken : overlappingOriginalTokens) {
                List embeddingsOfToken = originalTokenEmbeddings.compute(originalToken, (t, l) -> {
                    if (l != null) {
                        return l;
                    }
                    return new ArrayList();
                });
                embeddingsOfToken.add(tokenEmbedding.getVector());
            }
        }
        for (Token token : originalTokenEmbeddings.keySet()) {
            int j;
            List subTokenEmbeddings = (List)originalTokenEmbeddings.get(token);
            double[] avgEmbedding = (double[])subTokenEmbeddings.get(0);
            for (j = 1; j < subTokenEmbeddings.size(); ++j) {
                for (int k = 0; k < avgEmbedding.length; ++k) {
                    int n = k;
                    avgEmbedding[n] = avgEmbedding[n] + ((double[])subTokenEmbeddings.get(j))[k];
                }
            }
            if (subTokenEmbeddings.size() > 1) {
                j = 0;
                while (j < avgEmbedding.length) {
                    int n = j++;
                    avgEmbedding[n] = avgEmbedding[n] / (double)subTokenEmbeddings.size();
                }
            }
            EmbeddingVector embeddingVector = new EmbeddingVector(aJCas, token.getBegin(), token.getEnd());
            DoubleArray uimaVector = new DoubleArray(aJCas, avgEmbedding.length);
            uimaVector.copyFromArray(avgEmbedding, 0, 0, avgEmbedding.length);
            embeddingVector.setVector(uimaVector);
            embeddingVector.setSource(this.flairModel);
            embeddingVector.setComponentId(this.componentId);
            token.setEmbeddingVectors(JCoReTools.addToFSArray((FSArray)token.getEmbeddingVectors(), (FeatureStructure)embeddingVector));
        }
    }

    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        try {
            this.connector.shutdown();
        }
        catch (InterruptedException e) {
            log.error("Could not shutdown the python connector", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
    }

    public static enum StoreEmbeddings {
        ALL,
        ENTITIES,
        NONE;

    }
}

