/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.genemapper.classification;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Guice;
import com.google.inject.Inject;
import com.google.inject.Injector;
import de.julielab.geneexpbase.ioc.ServicesShutdownHub;
import de.julielab.geneexpbase.services.CacheService;
import de.julielab.geneexpbase.services.ShutdownRequiring;
import de.julielab.genemapper.Configuration;
import de.julielab.genemapper.hpo.GnHpoServer;
import de.julielab.genemapper.ioc.GeneMappingModule;
import de.julielab.genemapper.utils.GeneMapperRuntimeException;
import de.julielab.ipc.javabridge.Options;
import de.julielab.ipc.javabridge.StdioBridge;
import de.julielab.java.utilities.IOStreamUtilities;
import java.io.IOException;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import javax.cache.Cache;
import org.apache.commons.codec.digest.DigestUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TransformerClassifier
implements ShutdownRequiring {
    private static final Logger log = LoggerFactory.getLogger(TransformerClassifier.class);
    private final BlockingQueue<StdioBridge<String>> bridges;
    private final ObjectMapper objectMapper;
    private final Cache<String, Float> taggingCache;
    private final JsonFactory jsonFactory;
    private final String modelPath;

    @Inject
    public TransformerClassifier(CacheService cacheService, String modelPath, de.julielab.geneexpbase.configuration.Configuration configuration) throws IOException {
        this.modelPath = modelPath;
        String pythonExecutable = "python";
        String modelType = "bert";
        int gpuNum = 0;
        Options<String> params = new Options<String>(String.class);
        params.setExecutable(pythonExecutable);
        params.setExternalProgramReadySignal("Ready for tagging.");
        params.setExternalProgramTerminationSignal("exit");
        params.setTerminationSignalFromErrorStream("SyntaxError");
        params.setResultLineIndicator(l -> l.startsWith("result:"));
        params.setResultReshaper(l -> l.substring(7));
        String script = IOStreamUtilities.getStringFromInputStream(this.getClass().getResourceAsStream("/transformerCandidateRanking.py"));
        int numBridges = Integer.parseInt((String)configuration.getOrDefault((Object)"python_process_limit", "1"));
        this.bridges = new ArrayBlockingQueue<StdioBridge<String>>(numBridges);
        for (int i = 0; i < numBridges; ++i) {
            this.bridges.add(new StdioBridge<String>(params, "-u", "-c", script, modelType, modelPath, String.valueOf(gpuNum)));
        }
        this.jsonFactory = new JsonFactory();
        this.objectMapper = new ObjectMapper();
        this.taggingCache = cacheService.getCacheManager().getCache("transformer-candidate-ranking-cache");
    }

    public static void main(String[] args) throws IOException {
        Injector injector = Guice.createInjector(new GeneMappingModule(new Configuration(GnHpoServer.CONFIG_WITH_SA_AND_LUCENE_SETTINGS)));
        TransformerClassifier c = new TransformerClassifier(injector.getInstance(CacheService.class), "../built-resources/models/transformer-v23-allMatches-includeFpMentions", new Configuration());
        List<Float> floats = c.classifySentencePairs(List.of("symbol mtor synonym mechanistic target of rapamycin", "symbol akt1 synonym akt", "symbol il2 synonym interleukin-2"), List.of("<< mtor >> frap akt1", "mtor frap << akt1 >>", "mtor << frap >> akt1"));
        System.out.println(floats);
        c.shutdown();
        injector.getInstance(ServicesShutdownHub.class).shutdown();
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public List<Float> classifySentencePairs(List<String> l1, List<String> l2) throws IOException {
        ArrayList<Float> labels = new ArrayList<Float>(l1.size());
        StringWriter sw = new StringWriter();
        JsonGenerator generator = this.jsonFactory.createGenerator(sw);
        ArrayList<Integer> indicesNotInCache = new ArrayList<Integer>(l1.size());
        HashMap<Integer, String> cacheKeys = new HashMap<Integer, String>();
        generator.writeStartArray();
        for (int i = 0; i < l1.size(); ++i) {
            labels.add(Float.valueOf(-1.0f));
            String s1 = l1.get(i);
            String s2 = l2.get(i);
            String cacheKey = DigestUtils.md5Hex(this.modelPath + s1 + "SEP" + s2);
            cacheKeys.put(i, cacheKey);
            Float label = this.taggingCache.get(cacheKey);
            if (label == null) {
                indicesNotInCache.add(i);
                generator.writeStartObject();
                generator.writeStringField("left", s1);
                generator.writeStringField("right", s2);
                generator.writeEndObject();
                continue;
            }
            labels.set(i, label);
        }
        generator.writeEndArray();
        generator.close();
        if (!indicesNotInCache.isEmpty()) {
            try {
                StdioBridge<String> bridge = this.bridges.take();
                if (!bridge.isRunning()) {
                    bridge.start();
                }
                Iterator response = bridge.sendAndReceive(sw.toString()).iterator();
                List classificationResults = null;
                if (response.hasNext()) {
                    String responseString = (String)response.next();
                    classificationResults = (List)this.objectMapper.readValue(responseString, (TypeReference)new TypeReference<List<String>>(){});
                }
                if (!this.bridges.offer(bridge)) {
                    throw new IllegalStateException("Could not put back an IO bridge.");
                }
                int indexNotInCacheIndex = 0;
                for (String result : classificationResults) {
                    int indexNotInCache = (Integer)indicesNotInCache.get(indexNotInCacheIndex++);
                    float label = Float.parseFloat(result);
                    labels.set(indexNotInCache, Float.valueOf(label));
                    String cacheKey = (String)cacheKeys.get(indexNotInCache);
                    this.taggingCache.put(cacheKey, Float.valueOf(label));
                }
            }
            catch (JsonMappingException e) {
                log.error("Error with transformer-based classification. Model path is {}", (Object)this.modelPath);
                throw new GeneMapperRuntimeException(e);
            }
            catch (InterruptedException e) {
                throw new GeneMapperRuntimeException(e);
            }
        }
        return labels;
    }

    @Override
    public void shutdown() {
        try {
            for (StdioBridge stdioBridge : this.bridges) {
                if (!stdioBridge.isRunning()) continue;
                stdioBridge.stop();
            }
        }
        catch (IOException | InterruptedException e) {
            log.error("Exception while stopping external process", e);
        }
    }
}

