/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.ae.fte;

import com.google.gson.Gson;
import de.julielab.ipc.javabridge.Options;
import de.julielab.ipc.javabridge.ResultDecoders;
import de.julielab.ipc.javabridge.StdioBridge;
import de.julielab.jcore.types.EmbeddingVector;
import de.julielab.jcore.types.Sentence;
import de.julielab.jcore.types.Token;
import de.julielab.jcore.utility.JCoReTools;
import de.julielab.jcore.utility.index.Comparators;
import de.julielab.jcore.utility.index.JCoReSetAnnotationIndex;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_component.JCasAnnotator_ImplBase;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.FeatureStructure;
import org.apache.uima.cas.Type;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ResourceMetaData;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.DoubleArray;
import org.apache.uima.jcas.cas.FSArray;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ResourceMetaData(name="JCoRe Flair Token Embedding Annotator", description="Adds the Flair compatible embedding vectors to the token annotations.")
@TypeCapability(inputs={"de.julielab.jcore.types.Sentence", "de.julielab.jcore.types.Token"}, outputs={"de.julielab.jcore.types.EmbeddingVector"})
public class FlairTokenEmbeddingAnnotator
extends JCasAnnotator_ImplBase {
    public static final String PARAM_EMBEDDING_PATH = "EmbeddingPath";
    public static final String PARAM_COMPUTATION_FILTER = "ComputationFilter";
    public static final String PARAM_EMBEDDING_SOURCE = "EmbeddingSource";
    public static final String PARAM_PYTHON_EXECUTABLE = "PythonExecutable";
    private static final Logger log = LoggerFactory.getLogger(FlairTokenEmbeddingAnnotator.class);
    private static final int TIME_OUTPUT_INTERVAL = 1000;
    @ConfigurationParameter(name="EmbeddingPath", description="Path to a Flair compatible embedding file. Since flair supports a range of different embeddings, a type prefix is required. The syntax is 'prefix:<path or built-in flair embedding name>. The possible prefixes are 'word', 'char', 'bytepair', 'flair', 'bert', 'elmo'.")
    private String embeddingPath;
    @ConfigurationParameter(name="ComputationFilter", mandatory=false, description="This parameter may be set to a fully qualified annotation type. If given, only for documents containing at least one annotation of this type embeddings will be retrieved from the computing flair python script. However, for contextualized embeddings, all embedding vectors are computed anyway and the the I/O cost is minor in comparison to the embedding computation. Thus, setting this parameter will most probably only result in small time savings.")
    private String computationFilter;
    @ConfigurationParameter(name="EmbeddingSource", mandatory=false, description="The value of this parameter will be set to the source feature of the EmbeddingVector annotation instance created on the tokens. If left blank, the value of the EmbeddingPath will be used.")
    private String embeddingSource;
    @ConfigurationParameter(name="PythonExecutable", mandatory=false, description="The path to the python executable. Required is a python verion >=3.6.")
    private String pythonExecutable;
    private StdioBridge<byte[]> flairBridge;
    private Gson gson;
    private long embeddingRequestTime;
    private long embeddingRequestTimeForLastInterval;
    private int docsProcessed;

    public void initialize(UimaContext aContext) throws ResourceInitializationException {
        this.embeddingPath = (String)aContext.getConfigParameterValue(PARAM_EMBEDDING_PATH);
        this.computationFilter = (String)aContext.getConfigParameterValue(PARAM_COMPUTATION_FILTER);
        this.embeddingSource = Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_EMBEDDING_SOURCE)).orElse(this.embeddingPath);
        Optional<String> pythonExecutableOpt = Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_PYTHON_EXECUTABLE));
        if (!pythonExecutableOpt.isPresent()) {
            log.debug("No python executable given in the component descriptor, trying to read PYTHON environment variable.");
            String pythonExecutableEnv = System.getenv("PYTHON");
            if (pythonExecutableEnv != null) {
                this.pythonExecutable = pythonExecutableEnv;
                log.info("Python executable: {} (from environment variable PYTHON).", (Object)this.pythonExecutable);
            }
        } else {
            this.pythonExecutable = pythonExecutableOpt.get();
            log.info("Python executable: {} (from descriptor)", (Object)this.pythonExecutable);
        }
        if (this.pythonExecutable == null) {
            this.pythonExecutable = "python3.6";
            log.info("Python executable: {} (default)", (Object)this.pythonExecutable);
        }
        try {
            Options options = new Options(byte[].class);
            options.setExecutable(this.pythonExecutable);
            options.setExternalProgramTerminationSignal("exit");
            options.setExternalProgramReadySignal("Script is ready");
            options.setTerminationSignalFromErrorStream("SyntaxError");
            String script = IOUtils.toString((InputStream)((Object)((Object)this)).getClass().getResourceAsStream("/de/julielab/jcore/ae/fte/python/getEmbeddingScript.py"), (Charset)StandardCharsets.UTF_8);
            this.flairBridge = new StdioBridge(options, new String[]{"-u", "-c", script, this.embeddingPath});
            this.flairBridge.start();
        }
        catch (IOException e) {
            log.error("Could not create the IO bridge object.", (Throwable)e);
            throw new ResourceInitializationException((Throwable)e);
        }
        this.gson = new Gson();
        this.docsProcessed = 0;
        this.embeddingRequestTime = 0L;
        this.embeddingRequestTimeForLastInterval = 0L;
    }

    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        ArrayList<Token> tokenToAddEmbeddingsTo = new ArrayList<Token>();
        JCoReSetAnnotationIndex filterAnnotationIndex = null;
        if (!StringUtils.isBlank((String)this.computationFilter)) {
            Type type = aJCas.getTypeSystem().getType(this.computationFilter);
            if (type == null) {
                throw new AnalysisEngineProcessException((Throwable)new IllegalArgumentException("The type " + this.computationFilter + " was not found in the type system."));
            }
            if (!aJCas.getAnnotationIndex(type).iterator().hasNext()) {
                return;
            }
            filterAnnotationIndex = new JCoReSetAnnotationIndex(Comparators.overlapComparator(), aJCas, type);
        }
        String json = this.constructEmbeddingRequest(aJCas, tokenToAddEmbeddingsTo, filterAnnotationIndex);
        try {
            long time = System.currentTimeMillis();
            Optional<double[][]> any = this.flairBridge.sendAndReceive(json).map(ResultDecoders.decodeVectors).findAny();
            time = System.currentTimeMillis() - time;
            log.trace("Sending and receiving token embeddings took {} ms", (Object)time);
            this.embeddingRequestTime += time;
            this.embeddingRequestTimeForLastInterval += time;
            this.writeEmbeddingsToCas(aJCas, tokenToAddEmbeddingsTo, any);
        }
        catch (InterruptedException e) {
            log.error("Computation of embedding vectors was interrupted", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
        ++this.docsProcessed;
        if (this.docsProcessed % 1000 == 0) {
            if (log.isDebugEnabled()) {
                log.debug("Embedding computation for the last {} documents took {}ms (avg: {}ms). Total time for all {} processed documents until here: {}ms ({}s)", new Object[]{1000, this.embeddingRequestTimeForLastInterval, this.embeddingRequestTimeForLastInterval / 1000L, this.docsProcessed, this.embeddingRequestTime, this.embeddingRequestTime / 60L});
            }
            this.embeddingRequestTimeForLastInterval = 0L;
        }
    }

    private void writeEmbeddingsToCas(JCas aJCas, List<Token> tokenToAddEmbeddingsTo, Optional<double[][]> embeddingOptional) {
        if (embeddingOptional.isPresent()) {
            double[][] embeddingVectors = embeddingOptional.get();
            for (int i = 0; i < tokenToAddEmbeddingsTo.size(); ++i) {
                Token token = tokenToAddEmbeddingsTo.get(i);
                double[] embedding = embeddingVectors[i];
                DoubleArray embeddingArray = new DoubleArray(aJCas, embedding.length);
                embeddingArray.copyFromArray(embedding, 0, 0, embedding.length);
                EmbeddingVector casEmbedding = new EmbeddingVector(aJCas, token.getBegin(), token.getEnd());
                casEmbedding.setSource(this.embeddingSource);
                casEmbedding.setVector(embeddingArray);
                token.setEmbeddingVectors(JCoReTools.addToFSArray((FSArray)token.getEmbeddingVectors(), (FeatureStructure)casEmbedding));
            }
        }
    }

    private String constructEmbeddingRequest(JCas aJCas, List<Token> tokenToAddEmbeddingsTo, JCoReSetAnnotationIndex<Annotation> filterAnnotationIndex) {
        Map tokenBySentence = JCasUtil.indexCovered((JCas)aJCas, Sentence.class, Token.class);
        ArrayList sentencesAndIndices = new ArrayList();
        for (Annotation sentence : aJCas.getAnnotationIndex(Sentence.type)) {
            ArrayList<Integer> tokenIndicesToSet = filterAnnotationIndex != null ? new ArrayList<Integer>() : Collections.emptyList();
            int tokenIndex = 0;
            StringBuilder sentenceTextSb = new StringBuilder();
            for (Token token : (Collection)tokenBySentence.get(sentence)) {
                sentenceTextSb.append(token.getCoveredText()).append(" ");
                if (filterAnnotationIndex != null) {
                    if (!filterAnnotationIndex.searchSubset((Annotation)token).isEmpty()) {
                        tokenIndicesToSet.add(tokenIndex);
                        tokenToAddEmbeddingsTo.add(token);
                    }
                } else {
                    tokenToAddEmbeddingsTo.add(token);
                }
                ++tokenIndex;
            }
            sentenceTextSb.deleteCharAt(sentenceTextSb.length() - 1);
            HashMap<String, Object> sentenceAndIndices = new HashMap<String, Object>();
            sentenceAndIndices.put("sentence", sentenceTextSb.toString());
            sentenceAndIndices.put("tokenIndicesToReturn", tokenIndicesToSet);
            sentencesAndIndices.add(sentenceAndIndices);
        }
        return this.gson.toJson(sentencesAndIndices);
    }

    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        if (log.isDebugEnabled()) {
            log.debug("The total time for embedding computation, including I/O, was {}ms ({}s)", (Object)this.embeddingRequestTime, (Object)(this.embeddingRequestTime / 1000L));
        }
        try {
            this.flairBridge.stop();
        }
        catch (IOException | InterruptedException e) {
            log.error("Exception when trying shut down IO bridge to the python embedding computation script", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
    }
}

