/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.genemapper.hpo;

import cc.mallet.classify.Classifier;
import cc.mallet.types.Alphabet;
import cc.mallet.types.InstanceList;
import com.google.common.collect.Sets;
import com.google.inject.Guice;
import com.google.inject.Injector;
import de.julielab.evaluation.entities.EntityEvaluationResults;
import de.julielab.evaluation.entities.EntityEvaluator;
import de.julielab.evaluation.entities.EvaluationData;
import de.julielab.evaluation.entities.EvaluationDataEntry;
import de.julielab.geneexpbase.candidateretrieval.SynHit;
import de.julielab.geneexpbase.configuration.Parameters;
import de.julielab.geneexpbase.genemodel.DocumentMappingResult;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.geneexpbase.genemodel.MentionMappingResult;
import de.julielab.geneexpbase.hpo.HpoCorpusRegistry;
import de.julielab.geneexpbase.hpo.HpoInstance;
import de.julielab.geneexpbase.hpo.HpoRoute;
import de.julielab.geneexpbase.hpo.SplitType;
import de.julielab.genemapper.Configuration;
import de.julielab.genemapper.GeneMapper;
import de.julielab.genemapper.disambig.DypsisContextRanker;
import de.julielab.genemapper.evaluation.tools.Stats;
import de.julielab.genemapper.ioc.GeneMappingModule;
import de.julielab.genemapper.utils.GeneMapperException;
import de.julielab.genemapper.utils.GeneMapperRuntimeException;
import de.julielab.java.utilities.Color;
import de.julielab.ml.RankLibRanker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GeneMappingOptimizationRoute
extends HpoRoute {
    private static final Logger log = LoggerFactory.getLogger(GeneMappingOptimizationRoute.class);
    public static final String GET_GENE_MAPPING_SCORE = "get_gene_mapping_score";
    private static final Set<String> documentsFilter = Set.of();
    private final transient EntityEvaluator entityEvaluator;

    public GeneMappingOptimizationRoute(Configuration configuration) {
        super(log, configuration);
        Properties properties = new Properties();
        properties.setProperty("comparison-type", EvaluationDataEntry.ComparisonType.OVERLAP.name());
        properties.setProperty("overlap-type", EvaluationDataEntry.OverlapType.CHARS.name());
        properties.setProperty("overlap-size", "1");
        this.entityEvaluator = new EntityEvaluator(properties);
    }

    @Override
    protected Injector createGuiceInjector(de.julielab.geneexpbase.configuration.Configuration configuration) {
        return Guice.createInjector(new GeneMappingModule(configuration));
    }

    @Override
    public String getRouteEndpoint() {
        return GET_GENE_MAPPING_SCORE;
    }

    @Override
    public int getNumSplits() {
        return 10;
    }

    @Override
    public int getDevSamplingFrequency() {
        return 0;
    }

    @Override
    protected List<HpoInstance> getActiveCorpora() {
        return List.of(HpoCorpusRegistry.gnpBc2Train(), HpoCorpusRegistry.gnpNlmiat());
    }

    @Override
    protected String getTaskName() {
        return "semrank";
    }

    @Override
    protected HpoRoute.Metric getDefaultMetric() {
        return HpoRoute.Metric.F;
    }

    @Override
    protected String calculateScore(HpoInstance si, Parameters parameterMap, int seed, int cutoffTime, int resourceBudget, int maxResourceBudget, HpoRoute.Metric returnMetric, int runId) {
        log.info("Calculating disambiguation score for {}", (Object)si);
        GeneMapper geneMapper = this.injector.getInstance(GeneMapper.class);
        try {
            double returnedValue;
            if (parameterMap.getBoolean("do_disambiguation")) {
                this.trainDisambiguationModel(si, parameterMap, seed, geneMapper);
            }
            List<GeneDocument> testPartition = this.getDocuments4Instance(si);
            EntityEvaluationResults evaluationResults = this.evaluate(testPartition, geneMapper, parameterMap);
            DypsisContextRanker contextRanker = (DypsisContextRanker)geneMapper.getMappingCore().getContextualRanking();
            contextRanker.setClassifier(null);
            contextRanker.setRanker(null);
            double microRecall = evaluationResults.getMicroStatsMentionWise().getRecall();
            double microPrecision = evaluationResults.getMicroStatsMentionWise().getPrecision();
            double microF = evaluationResults.getMicroStatsMentionWise().getF();
            switch (returnMetric) {
                case RECALL: {
                    returnedValue = microRecall;
                    break;
                }
                case PRECISION: {
                    returnedValue = microPrecision;
                    break;
                }
                case F: {
                    returnedValue = microF;
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported result metric: " + returnMetric);
                }
            }
            log.info("Instance {}: Evaluated contextRanker on {} documents. Metrics for finding the correct gene Id {}[R/P/F]{}: {} / {} / {}. Returned metric is {} with a value of {}.", new Object[]{si, testPartition.size(), Color.CYAN, Color.RESET, microRecall, microPrecision, microF, returnMetric, returnedValue});
            return String.valueOf(returnedValue * -1.0);
        }
        catch (Exception e) {
            log.error("Score calculation failed for instance {}. Parameters were:\n{}", (Object)si, (Object)parameterMap);
            throw new GeneMapperRuntimeException(e);
        }
    }

    private void trainDisambiguationModel(HpoInstance si, Parameters parameterMap, int seed, GeneMapper geneMapper) throws GeneMapperException, IllegalAccessException {
        String prefix = "disambiguation";
        boolean useAllActiveCorporaForTraining = parameterMap.getBoolean(Configuration.dot(prefix, "use_all_corpus_trainsplits"), false);
        log.debug("Using all active corpora for training: {}", (Object)useAllActiveCorporaForTraining);
        List<GeneDocument> partition = si.getSplitType() == SplitType.DEV ? this.getAllCorporaTrainingDocuments4Instance(si) : this.getDocuments4Instance(si.getCorpus(), si.getSubcorpus(), SplitType.TRAINSPLIT, si.isMergeCorpora(), si.getCrossvalRound());
        ArrayList<Alphabet> alphabets = new ArrayList<Alphabet>();
        parameterMap.put(Configuration.dot("disambiguation", "train_mode"), true);
        InstanceList trainingInstances = null;
        for (GeneDocument trainDoc : partition) {
            if (!documentsFilter.isEmpty() && !documentsFilter.contains(trainDoc.getId())) continue;
            GeneDocument copy = new GeneDocument(trainDoc);
            DocumentMappingResult trainResult = geneMapper.map(copy, parameterMap, new Stats());
            if (parameterMap.get(Configuration.dot(prefix, "algorithm")).equals("linear_combination")) continue;
            if (trainingInstances == null) {
                trainingInstances = trainResult.contextualRankingTrainingInstances;
                continue;
            }
            trainingInstances.addAll(trainResult.contextualRankingTrainingInstances);
        }
        if (!parameterMap.get(Configuration.dot(prefix, "algorithm")).equals("linear_combination")) {
            this.train(trainingInstances, parameterMap, prefix, seed);
            alphabets.add((Alphabet)parameterMap.get(Configuration.dot(prefix, "data_alphabet")));
            DypsisContextRanker disambiguation = (DypsisContextRanker)geneMapper.getMappingCore().getContextualRanking();
            if (parameterMap.containsKey(Configuration.dot(prefix, "classifier"))) {
                disambiguation.setClassifier((Classifier)parameterMap.get(Configuration.dot(prefix, "classifier")));
            } else if (parameterMap.containsKey(Configuration.dot(prefix, "ranker"))) {
                disambiguation.setRanker((RankLibRanker)parameterMap.get(Configuration.dot(prefix, "ranker")));
            } else {
                throw new IllegalAccessException("Could not find a trained classifier or ranker to set to the disambiguation object.");
            }
        }
        alphabets.forEach(a -> a.stopGrowth());
        if (parameterMap.getBoolean(Configuration.dot(prefix, "scale_result_score"))) {
            BiFunction<MentionMappingResult, Boolean, Map<String, List<SynHit>>> candidateListFunction = (mmr, dorank) -> mmr.tax2semanticallyOrderedCandidates;
            Function<SynHit, Double> synHitScoreFunction = SynHit::getContextualScore;
            this.setResultScalingParameters(si, parameterMap, geneMapper, prefix, partition, candidateListFunction, synHitScoreFunction);
        }
        parameterMap.put(Configuration.dot("disambiguation", "train_mode"), false);
    }

    private void setResultScalingParameters(HpoInstance si, Parameters parameterMap, GeneMapper geneMapper, String prefix, List<GeneDocument> partition, BiFunction<MentionMappingResult, Boolean, Map<String, List<SynHit>>> candidateListFunction, Function<SynHit, Double> synHitScoreFunction) throws GeneMapperException {
        log.info("Applying the model trained on {} to the same data to obtain min and max mention score values for min-max scaling.", (Object)si);
        Stats stats = new Stats();
        double min2 = Double.MAX_VALUE;
        double max = -1.7976931348623157E308;
        double minFam = Double.MAX_VALUE;
        double maxFam = -1.7976931348623157E308;
        boolean scaleResultsValues = parameterMap.getBoolean(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "scale_result_score"));
        boolean scaleFamilyMatchValues = parameterMap.getBoolean(de.julielab.geneexpbase.configuration.Configuration.dot("candidate_retrieval", "scale_family_name_match_score"), false);
        parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "scale_result_score"), false);
        parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "scale_family_name_match_score"), false);
        for (GeneDocument trainDoc : partition) {
            GeneDocument copy = new GeneDocument(trainDoc);
            DocumentMappingResult result = geneMapper.map(copy, parameterMap, stats);
            for (MentionMappingResult mmr : result.mentionResults) {
                double[] mentionScores;
                if (mmr.mappedMention.isRejected()) continue;
                Map<String, List<SynHit>> candidatesMap = candidateListFunction.apply(mmr, true);
                for (double score : mentionScores = candidatesMap.keySet().stream().flatMap(tax -> ((List)candidatesMap.get(tax)).stream()).mapToDouble(synHitScoreFunction::apply).toArray()) {
                    if (score < min2) {
                        min2 = score;
                    }
                    if (!(score > max)) continue;
                    max = score;
                }
                if (!mmr.mappedMention.matchesFamilyName()) continue;
                double familyNameMatchScore = mmr.mappedMention.getFamilyNameMatchScore();
                if (familyNameMatchScore < minFam) {
                    minFam = familyNameMatchScore;
                }
                if (!(familyNameMatchScore > maxFam)) continue;
                maxFam = familyNameMatchScore;
            }
        }
        parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "scale_result_score"), scaleResultsValues);
        parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "scale_family_name_match_score"), scaleFamilyMatchValues);
        if (scaleResultsValues) {
            parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "min_result_score"), min2);
            parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "max_result_score"), max);
        }
        if (scaleFamilyMatchValues) {
            parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "min_family_match_score"), minFam);
            parameterMap.put(de.julielab.geneexpbase.configuration.Configuration.dot(prefix, "max_family_match_score"), maxFam);
        }
    }

    private EntityEvaluationResults evaluate(List<GeneDocument> testPartition, GeneMapper geneMapper, Parameters parameterMap) throws GeneMapperException {
        EvaluationData goldData = new EvaluationData();
        EvaluationData predData = new EvaluationData();
        ArrayList rejected = new ArrayList();
        int numPredGmWithGoldId = 0;
        int numPredGmWithGoldMention = 0;
        for (GeneDocument doc : testPartition) {
            if (!documentsFilter.isEmpty() && !documentsFilter.contains(doc.getId())) continue;
            GeneDocument copy = new GeneDocument(doc);
            geneMapper.map(copy, parameterMap, new Stats());
            if (copy.isGoldHasOffsets()) {
                copy.getGoldGenes().values().stream().flatMap(Collection::stream).flatMap(gm -> gm.getIds().stream().map(id -> {
                    EvaluationDataEntry goldEntry = new EvaluationDataEntry(gm.getDocId(), (String)id, gm.getBegin(), gm.getEnd(), gm.getText(), gm.getTagger().name(), "Gene");
                    goldEntry.setReferenceObject(gm);
                    return goldEntry;
                })).forEach(goldData::add);
            } else {
                copy.getGoldIds().stream().map(id -> new EvaluationDataEntry(copy.getId(), (String)id)).forEach(goldData::add);
            }
            copy.getGenes().filter(Predicate.not(GeneMention::isRejected)).flatMap(gm -> gm.getMentionMappingResult().getResultCandidates().map(sh -> {
                EvaluationDataEntry e = new EvaluationDataEntry(gm.getDocId(), sh.getId(), gm.getBegin(), gm.getEnd(), gm.getText(), gm.getTagger().name(), "Gene");
                e.setReferenceObject(gm);
                return e;
            })).forEach(predData::add);
            copy.getGenes().filter(GeneMention::isRejected).forEach(rejected::add);
            Iterator goldGenes = copy.getGoldGenes().values().stream().flatMap(Collection::stream).iterator();
            while (goldGenes.hasNext()) {
                Set candidateIds;
                GeneMention goldGene = (GeneMention)goldGenes.next();
                List predGenes = copy.getOverlappingGenes(goldGene.getOffsets()).collect(Collectors.toList());
                if (!predGenes.isEmpty()) {
                    ++numPredGmWithGoldMention;
                }
                if (Sets.intersection(candidateIds = predGenes.stream().filter(Predicate.not(GeneMention::isRejected)).flatMap(gm -> gm.getMentionMappingResult().tax2originalCandidates.values().stream().flatMap(Collection::stream)).map(SynHit::getId).collect(Collectors.toSet()), new HashSet<String>(goldGene.getIds())).isEmpty()) continue;
                ++numPredGmWithGoldId;
            }
        }
        goldData.forEach(e -> Objects.requireNonNull(e.getEntityId()));
        predData.forEach(e -> Objects.requireNonNull(e.getEntityId()));
        log.debug("Evaluating on {} gold and {} test instances.", (Object)goldData.size(), (Object)predData.size());
        EntityEvaluationResults evalResult = this.entityEvaluator.evaluate(goldData, predData);
        return evalResult;
    }
}

