/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.speciesassignment.mlcandidateranker;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import de.julielab.geneexpbase.candidateretrieval.CandidateRetrieval;
import de.julielab.geneexpbase.candidateretrieval.GeneCandidateRetrievalException;
import de.julielab.geneexpbase.candidateretrieval.QueryGenerator;
import de.julielab.geneexpbase.configuration.Parameters;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.geneexpbase.genemodel.GeneSet;
import de.julielab.geneexpbase.genemodel.GeneSpeciesOccurrence;
import de.julielab.geneexpbase.genemodel.MentionMappingResult;
import de.julielab.java.utilities.FileUtilities;
import de.julielab.ml.RankLibRanker;
import de.julielab.speciesassignment.Configuration;
import de.julielab.speciesassignment.GeneSpeciesAssigner;
import de.julielab.speciesassignment.SpeciesAssignmentException;
import de.julielab.speciesassignment.SpeciesAssignmentRuntimeException;
import de.julielab.speciesassignment.candidateretrieval.CandidateSetterForSpeciesTagger;
import de.julielab.speciesassignment.cooccurrence.SingularToSetInference;
import de.julielab.speciesassignment.mlcandidateranker.FeatureNormalization;
import de.julielab.speciesassignment.mlcandidateranker.SpeciesInstanceTools;
import de.julielab.speciesassignment.services.SpeciesAssignmentSentenceSmoothing;
import de.julielab.speciesassignment.services.SpeciesHintSetter;
import de.julielab.speciesassignment.spi.SpeciesAssignmentFilter;
import de.julielab.speciesassignment.spi.SpeciesAssignmentSmoothing;
import de.julielab.speciesassignment.spi.SpeciesDocumentScoringService;
import de.julielab.speciesassignment.spi.SpeciesReferenceMapper;
import de.julielab.speciesassignment.spi.SynonymSpeciesCooccurrenceService;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.inject.Inject;
import javax.inject.Named;
import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MLSpeciesAssigner
implements GeneSpeciesAssigner {
    private static final Logger log = LoggerFactory.getLogger(MLSpeciesAssigner.class);
    private final CandidateSetterForSpeciesTagger candidateSetterForSpeciesTagger;
    private final SingularToSetInference toSetInference;
    private final SpeciesAssignmentSmoothing assignmentSmoothing;
    private final SynonymSpeciesCooccurrenceService synonymSpeciesCooccurrenceService;
    private final SpeciesDocumentScoringService speciesDocumentScoringService;
    private final SpeciesAssignmentFilter speciesAssignmentFilter;
    private final SpeciesReferenceMapper speciesReferenceMapper;
    private Classifier classifier;
    private RankLibRanker ranker;
    private final QueryGenerator queryGenerator;
    private SpeciesInstanceTools instanceTools;
    private FeatureNormalization featureNormalization;
    private Pipe instancePipe;

    @Inject
    public MLSpeciesAssigner(@Nullable @Named(value="mlSpeciesAssignerModelFile") File modelFile, CandidateSetterForSpeciesTagger candidateSetterForSpeciesTagger, QueryGenerator queryGenerator, SpeciesInstanceTools speciesInstanceTools, SingularToSetInference toSetInference, SpeciesAssignmentSentenceSmoothing assignmentSmoothing, SynonymSpeciesCooccurrenceService synonymSpeciesCooccurrenceService, SpeciesDocumentScoringService speciesDocumentScoringService, SpeciesAssignmentFilter speciesAssignmentFilter, SpeciesReferenceMapper speciesReferenceMapper) throws SpeciesAssignmentException {
        this.candidateSetterForSpeciesTagger = candidateSetterForSpeciesTagger;
        this.queryGenerator = queryGenerator;
        this.instanceTools = speciesInstanceTools;
        this.toSetInference = toSetInference;
        this.assignmentSmoothing = assignmentSmoothing;
        this.synonymSpeciesCooccurrenceService = synonymSpeciesCooccurrenceService;
        this.speciesDocumentScoringService = speciesDocumentScoringService;
        this.speciesAssignmentFilter = speciesAssignmentFilter;
        this.speciesReferenceMapper = speciesReferenceMapper;
        if (modelFile != null) {
            this.loadModel(modelFile);
        }
    }

    public void setClassifier(Classifier classifier) {
        this.classifier = classifier;
    }

    public void loadModel(File inputFile) throws SpeciesAssignmentException {
        log.debug("Loading species assignment model from {}", (Object)inputFile);
        try (BufferedInputStream input = FileUtilities.getInputStreamFromFile(inputFile);){
            this.loadModel(input);
        }
        catch (Exception e) {
            log.error("Could not load model from {}", (Object)inputFile, (Object)e);
            throw new SpeciesAssignmentException(e);
        }
    }

    public void loadModel(InputStream input) throws IOException, ClassNotFoundException, SpeciesAssignmentException {
        try (ObjectInputStream ois = new ObjectInputStream(input);){
            Object model = ois.readObject();
            if (model instanceof Classifier) {
                this.classifier = (Classifier)model;
                this.instancePipe = this.classifier.getInstancePipe();
            } else {
                this.ranker = (RankLibRanker)model;
                this.instancePipe = this.ranker.getInstancePipe();
            }
            this.featureNormalization = (FeatureNormalization)ois.readObject();
        }
        catch (Exception e) {
            log.error("Exception occurred while trying to load species assignment model", e);
            throw e;
        }
        this.instanceTools.injectServices(this.instancePipe);
        this.instancePipe.getDataAlphabet().stopGrowth();
    }

    public void saveModel(File outputFile, Parameters parameterMap) throws IOException {
        try (BufferedOutputStream output = FileUtilities.getOutputStreamToFile(outputFile);){
            this.saveModel(output, parameterMap);
        }
    }

    public void saveModel(OutputStream output, Parameters parameterMap) throws IOException {
        try (ObjectOutputStream ois = new ObjectOutputStream(output);){
            if (this.classifier != null) {
                ois.writeObject(this.classifier);
            }
            if (this.ranker != null) {
                ois.writeObject(this.ranker);
            }
            ois.writeObject(this.featureNormalization);
        }
    }

    public void setInstanceTools(SpeciesInstanceTools instanceTools) {
        this.instanceTools = instanceTools;
    }

    @Override
    public void assign(GeneDocument document, Parameters parameterMap) throws SpeciesAssignmentException {
        this.assign(document, null, null, parameterMap);
    }

    @Override
    public void setSpeciesHints(GeneMention gm, Parameters parameterMap) {
        Parameters effectiveParameters = parameterMap;
        if (gm.getMentionMappingResult() == null) {
            gm.setMentionMappingResult(new MentionMappingResult(gm));
        }
        CandidateRetrieval candidateRetrieval = this.candidateSetterForSpeciesTagger.getCandidateRetrieval();
        gm.getMentionMappingResult().candidatesNoTaxRestriction = candidateRetrieval.getCandidates(gm, Collections.emptyList(), this.queryGenerator);
        SpeciesHintSetter.setSpeciesHints(gm, effectiveParameters);
    }

    @Override
    public void shutdown() {
    }

    @Override
    public void assign(GeneDocument document) {
        try {
            this.assign(document, null);
        }
        catch (SpeciesAssignmentException e) {
            throw new SpeciesAssignmentRuntimeException(e);
        }
    }

    public void setFeatureNormalization(FeatureNormalization featureNormalization) {
        this.featureNormalization = featureNormalization;
    }

    public void assign(GeneDocument document, @Nullable Map<Range<Integer>, List<Classification>> classificationMap, @Deprecated @Nullable Map<Range<Integer>, GeneMention> instanceListMap, Parameters parameterMap) throws SpeciesAssignmentException {
        try {
            InstanceList instances;
            Pipe featurePipe;
            Parameters effectiveParameters = parameterMap;
            if (!document.hasState(GeneDocument.State.SYNONYM_CANDIDATES_ASSIGNED)) {
                this.candidateSetterForSpeciesTagger.setCandidates(document, (Map<String, Object>)effectiveParameters);
            }
            document.expectState(EnumSet.of(GeneDocument.State.GENES_SELECTED, GeneDocument.State.SPECIES_CANDIDATES_ASSIGNED, GeneDocument.State.SYNONYM_CANDIDATES_ASSIGNED, GeneDocument.State.REFERENCE_SPECIES_ADDED));
            this.speciesAssignmentFilter.filterAssignments(document);
            document.getGenes().map(GeneMention::getTaxonomyOccurrences).forEach(this.speciesReferenceMapper::addReferences);
            boolean predictionMode = this.classifier != null || this.ranker != null;
            String mlAlgorithmKey = de.julielab.geneexpbase.configuration.Configuration.dot("species_assignment", "ml", "algorithm");
            Pipe pipe = featurePipe = this.instancePipe != null ? this.instancePipe : (Pipe)effectiveParameters.getOrDefault((Object)"featurePipe", (Object)null);
            if (parameterMap != null) {
                parameterMap.startParameterUsageTracking(MLSpeciesAssigner.class.getSimpleName());
            }
            if (Set.of("svm", "maxent").contains(effectiveParameters.get(mlAlgorithmKey))) {
                instances = this.instanceTools.createInstanceListsForMaxEntPerGeneMention(document, featurePipe, effectiveParameters);
            } else if (effectiveParameters.get(mlAlgorithmKey).equals("ltr")) {
                instances = this.instanceTools.createInstanceListsForMaxEntPerGeneMention(document, featurePipe, effectiveParameters);
            } else {
                throw new IllegalArgumentException("Unsupported machine learning method " + effectiveParameters.get(mlAlgorithmKey));
            }
            if (parameterMap != null) {
                parameterMap.stopParameterUsageTracking(MLSpeciesAssigner.class.getSimpleName());
            }
            if (predictionMode) {
                assert (this.featureNormalization != null) : "Feature normalization function is not set";
                this.featureNormalization.applyFeatureNormalization(instances, effectiveParameters);
                this.assignWithTrainedModel(document, classificationMap, instanceListMap, effectiveParameters);
                if (effectiveParameters.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_ENABLE)) {
                    this.toSetInference.infer(document, effectiveParameters);
                } else if (effectiveParameters.getBoolean(Configuration.PARAM_ML_TAX_SCORE_INFERENCE_ENABLE)) {
                    this.delegateMaxScoreToGeneSets(document);
                }
                if (effectiveParameters.getString(Configuration.PARAM_SPECIES_ASSIGNMENT_SMOOTHING).equals("sentencewise")) {
                    this.assignmentSmoothing.smooth(document, effectiveParameters);
                }
            }
            document.addState(GeneDocument.State.SPECIES_SCORES_ASSIGNED);
            document.addState(GeneDocument.State.SPECIES_ASSIGNED_TO_GENES);
        }
        catch (GeneCandidateRetrievalException e) {
            throw new SpeciesAssignmentException(e);
        }
    }

    private void delegateMaxScoreToGeneSets(GeneDocument document) {
        for (GeneSet gs : document.getGeneSets()) {
            HashMap<String, Double> taxScoreSums = new HashMap<String, Double>();
            for (GeneMention gm : gs) {
                Map<String, Double> processedTaxonomyScores = gm.getProcessedTaxonomyScores();
                for (String tax : processedTaxonomyScores.keySet()) {
                    taxScoreSums.merge(tax, processedTaxonomyScores.get(tax), Double::sum);
                }
            }
            String bestTax = null;
            double bestScore = 0.0;
            for (String tax : taxScoreSums.keySet()) {
                Double score = (Double)taxScoreSums.get(tax);
                if (!(score > bestScore)) continue;
                bestScore = score;
                bestTax = tax;
            }
            if (bestTax == null) continue;
            for (GeneMention gm : gs) {
                if (gm.getTaxonomyOccurrences().values().contains((Object)GeneSpeciesOccurrence.COMPOUND) || gm.getTaxonomyOccurrences().values().contains((Object)GeneSpeciesOccurrence.SPECIES_PREFIX)) continue;
                gm.setTaxonomyIds(List.of(bestTax));
            }
        }
    }

    private void setTaxonomyScoresWithRanker(GeneMention gm) {
        InstanceList rankedList = this.ranker.rank(gm.getInstances());
        for (Instance instance : rankedList) {
            String taxId = (String)instance.getSource();
            double score = (Double)instance.getProperty("score");
            if (!gm.getTaxonomyOccurrences().get(taxId).contains((Object)GeneSpeciesOccurrence.COMPOUND_PRECED)) {
                gm.setTaxonomyScore(taxId, score);
                continue;
            }
            gm.setTaxonomyScore(taxId, 1.0);
        }
    }

    public void assignWithTrainedModel(GeneDocument document, @Nullable Map<Range<Integer>, List<Classification>> classificationMap, @Deprecated @Nullable Map<Range<Integer>, GeneMention> instanceListMap, Parameters parameterMap) {
        double acceptanceThreshold = parameterMap.getDouble(Configuration.PARAM_SPECIES_ML_ACCEPTANCE_THRESHOLD);
        Map<String, Double> docLevelTaxScores = this.speciesDocumentScoringService.computeTaxDocScores(document, parameterMap);
        for (GeneMention gm : document.getGenesIterable()) {
            String fallbackTaxonomyId;
            if (this.classifier != null) {
                this.setTaxonomyScoresWithClassifier(classificationMap, instanceListMap, gm);
            } else if (this.ranker != null) {
                this.setTaxonomyScoresWithRanker(gm);
            } else {
                throw new IllegalStateException("Either the classifier or the ranker need not to be null but both are.");
            }
            gm.setProcessedTaxonomyScores(new HashMap<String, Double>(gm.getTaxonomyScores()));
            Map<String, Double> processedTaxonomyScores = gm.getProcessedTaxonomyScores();
            if (parameterMap.getBoolean(Configuration.PARAM_SPECIES_REMOVE_TAX_WITHOUT_GENE_CANDIDATES)) {
                this.removeTaxonomyCandidatesWithoutGeneCandidates(gm.getMentionMappingResult(), processedTaxonomyScores);
            }
            if (parameterMap.getBoolean(de.julielab.geneexpbase.configuration.Configuration.dot("species_assignment", "ml", "add_doc_level_apriori_scores"))) {
                double docLevelAprioriFactor = parameterMap.getDouble(Configuration.PARAM_ML_DOC_LEVEL_APRIORI_FACTOR);
                docLevelTaxScores.keySet().stream().filter(processedTaxonomyScores::containsKey).forEach(tax -> processedTaxonomyScores.merge((String)tax, (Double)docLevelTaxScores.get(tax) * docLevelAprioriFactor, Double::sum));
            }
            OptionalDouble maxScoreOpt = processedTaxonomyScores.values().stream().mapToDouble(Double::valueOf).max();
            maxScoreOpt.ifPresent(max -> {
                String fallBackTaxId;
                if (max < acceptanceThreshold && (fallBackTaxId = this.getFallbackTaxonomyId(document, gm, gm.getTaxonomyScores().keySet(), this.synonymSpeciesCooccurrenceService, this.speciesDocumentScoringService, parameterMap)) != null) {
                    gm.setTaxonomyIds(List.of(fallBackTaxId));
                    gm.setTaxonomyCandidates(Set.of(fallBackTaxId));
                    processedTaxonomyScores.clear();
                    gm.setProcessedTaxonomyScore(fallBackTaxId, 1.0);
                }
            });
            if (!processedTaxonomyScores.isEmpty()) {
                String bestTax = null;
                double highestScore = -1.7976931348623157E308;
                for (String taxId : processedTaxonomyScores.keySet()) {
                    Double score = processedTaxonomyScores.get(taxId);
                    if (!(score > highestScore)) continue;
                    highestScore = score;
                    bestTax = taxId;
                }
                if (bestTax != null) {
                    gm.setTaxonomyIds(List.of(bestTax));
                    gm.setTaxonomyCandidates(Set.of(bestTax));
                }
            }
            if (gm.getTaxonomyIds().isEmpty() && (fallbackTaxonomyId = this.getFallbackTaxonomyId(document, gm, Collections.emptySet(), this.synonymSpeciesCooccurrenceService, this.speciesDocumentScoringService, parameterMap)) != null) {
                gm.setTaxonomyIds(List.of(fallbackTaxonomyId));
                gm.setTaxonomyCandidates(Set.of(fallbackTaxonomyId));
            }
            if (!gm.getTaxonomyIds().isEmpty()) continue;
            log.debug("Could not obtain any taxonomy ID for gene mention {} in document {} with acceptance threshold {} and fallback algorithm {}. Setting tax ID 9606", gm, document.getId(), acceptanceThreshold, parameterMap.getString(Configuration.PARAM_SPECIES_ML_FALLBACK_ALGORITHM));
            gm.setTaxonomyIds(List.of("9606"));
            gm.setTaxonomyCandidates(Set.of("9606"));
        }
    }

    private void removeTaxonomyCandidatesWithoutGeneCandidates(MentionMappingResult mmr, Map<String, Double> processedTaxonomyScores) {
        List<String> toRemove = processedTaxonomyScores.keySet().stream().filter(tax -> mmr.tax2originalCandidates.getOrDefault(tax, Collections.emptyList()).isEmpty()).collect(Collectors.toList());
        toRemove.forEach(processedTaxonomyScores::remove);
    }

    public void setTaxonomyScoresWithClassifier(@Nullable Map<Range<Integer>, List<Classification>> classificationMap, @Deprecated @Nullable Map<Range<Integer>, GeneMention> instanceListMap, GeneMention gm) {
        ArrayList<Classification> classifications = this.classifier.classify(gm.getInstances());
        if (classificationMap != null) {
            classificationMap.put(gm.getOffsets(), classifications);
        }
        if (instanceListMap != null) {
            instanceListMap.put(gm.getOffsets(), gm);
        }
        for (int j = 0; j < classifications.size(); ++j) {
            Classification classification = (Classification)classifications.get(j);
            Instance instance = (Instance)gm.getInstances().get(j);
            int posRank = classification.getLabeling().getRank(gm.getInstances().getPipe().getTargetAlphabet().lookupIndex(Float.valueOf(1.0f)));
            double value = classification.getLabeling().getValueAtRank(posRank);
            String taxId = (String)instance.getSource();
            if (!gm.getTaxonomyOccurrences().get(taxId).contains((Object)GeneSpeciesOccurrence.COMPOUND)) {
                gm.setTaxonomyScore(taxId, value);
                continue;
            }
            gm.setTaxonomyScore(taxId, 1.0);
        }
    }

    private String getFallbackTaxonomyId(GeneDocument document, GeneMention gm, Set<String> filterTaxIds, SynonymSpeciesCooccurrenceService speciesCooccurrenceService, SpeciesDocumentScoringService documentScoringService, Parameters parameterMap) {
        String fallbackAlgorithm;
        switch (fallbackAlgorithm = parameterMap.getString(Configuration.PARAM_SPECIES_ML_FALLBACK_ALGORITHM)) {
            case "apriori_all_syn_max": {
                Optional<Pair<String, Double>> bestDocumentLevelAPrioriTaxId = speciesCooccurrenceService.getBestDocumentLevelAPrioriTaxId(document, filterTaxIds, parameterMap);
                if (!bestDocumentLevelAPrioriTaxId.isPresent()) break;
                return bestDocumentLevelAPrioriTaxId.get().getLeft();
            }
            case "apriori_synlevel": {
                Set<Pair<String, Double>> bestAPrioriTaxIdsForBestGeneCandidates = speciesCooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates(gm, null, filterTaxIds, parameterMap);
                if (bestAPrioriTaxIdsForBestGeneCandidates.size() == 1) {
                    return bestAPrioriTaxIdsForBestGeneCandidates.iterator().next().getLeft();
                }
                String bestTax = this.applyTaxTieBreaker(document, gm, speciesCooccurrenceService, documentScoringService, parameterMap, bestAPrioriTaxIdsForBestGeneCandidates);
                if (bestTax == null) break;
                return bestTax;
            }
            case "apriori_doclevel": {
                String bestDocumentLevelTaxId = this.getBestDocumentLevelTaxId(document, documentScoringService, parameterMap, filterTaxIds);
                if (bestDocumentLevelTaxId == null) break;
                return bestDocumentLevelTaxId;
            }
        }
        log.trace("Could not obtain a taxonomy ID by falling back algorithm {} for gene mention {} with candidate taxonomy IDs {} in document {}", fallbackAlgorithm, gm, filterTaxIds, document.getId());
        return null;
    }

    private String applyTaxTieBreaker(GeneDocument document, GeneMention gm, SynonymSpeciesCooccurrenceService speciesCooccurrenceService, SpeciesDocumentScoringService documentScoringService, Parameters parameterMap, Set<Pair<String, Double>> taxIdCandidates) {
        String tieBreakerAlgorithm = parameterMap.getString(Configuration.PARAM_SPECIES_ML_TIE_BREAKER_ALGORITHM);
        Set<String> tiedTaxIds = taxIdCandidates.stream().map(Map.Entry::getKey).collect(Collectors.toSet());
        switch (tieBreakerAlgorithm) {
            case "apriori_doclevel": {
                String bestTiedTaxId = this.getBestDocumentLevelTaxId(document, documentScoringService, parameterMap, tiedTaxIds);
                if (bestTiedTaxId == null) break;
                return bestTiedTaxId;
            }
            case "apriori_all_syn_max": {
                Optional<Pair<String, Double>> bestDocumentLevelAPrioriTaxId = speciesCooccurrenceService.getBestDocumentLevelAPrioriTaxId(document, tiedTaxIds, parameterMap);
                if (!bestDocumentLevelAPrioriTaxId.isPresent()) break;
                return bestDocumentLevelAPrioriTaxId.get().getLeft();
            }
            case "apriori_synlevel": {
                Set<Pair<String, Double>> bestAPrioriTaxIdsForBestGeneCandidates = speciesCooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates(gm, null, tiedTaxIds, parameterMap);
                String bestTax = null;
                double bestScore = 0.0;
                for (Pair<String, Double> pair : bestAPrioriTaxIdsForBestGeneCandidates) {
                    if (!(pair.getRight() > bestScore)) continue;
                    bestScore = pair.getRight();
                    bestTax = pair.getLeft();
                }
                return bestTax;
            }
        }
        log.debug("Could not break a taxonomy ID tie with algorithm {} for gene mention {} between the taxonomy IDs {} in document {}", tieBreakerAlgorithm, gm.getText(), tiedTaxIds, document.getId());
        return null;
    }

    @Nullable
    private String getBestDocumentLevelTaxId(GeneDocument document, SpeciesDocumentScoringService documentScoringService, Parameters parameterMap, Set<String> tiedTaxIds) {
        Map<String, Double> docLevelTaxScores = documentScoringService.computeTaxDocScores(document, parameterMap);
        List tiedTaxonomyScores = docLevelTaxScores.entrySet().stream().filter(e -> tiedTaxIds.contains(e.getKey())).collect(Collectors.toList());
        String bestTiedTaxId = null;
        double bestScore = 0.0;
        for (Map.Entry e2 : tiedTaxonomyScores) {
            if (!((Double)e2.getValue() > bestScore)) continue;
            bestScore = (Double)e2.getValue();
            bestTiedTaxId = (String)e2.getKey();
        }
        return bestTiedTaxId;
    }

    public void setRanker(RankLibRanker ranker) {
        this.ranker = ranker;
    }
}

