/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.genemapper.hpo;

import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Module;
import de.julielab.geneexpbase.configuration.Parameters;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.geneexpbase.genemodel.GeneSet;
import de.julielab.geneexpbase.hpo.HpoCorpusRegistry;
import de.julielab.geneexpbase.hpo.HpoInstance;
import de.julielab.geneexpbase.hpo.HpoRoute;
import de.julielab.geneexpbase.hpo.InspectionFilePrinter;
import de.julielab.genemapper.Configuration;
import de.julielab.genemapper.GeneMapper;
import de.julielab.genemapper.evaluation.GeneIdCorrectnessRenderer;
import de.julielab.genemapper.ioc.GeneMappingModule;
import de.julielab.genemapper.mappingcores.DypsisMappingCore;
import de.julielab.genemapper.utils.GeneMapperException;
import de.julielab.genemapper.utils.GeneMapperInitializationException;
import de.julielab.genemapper.utils.GeneMapperRuntimeException;
import de.julielab.java.utilities.Color;
import de.julielab.java.utilities.FileUtilities;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GenesetsOptimizationRoute
extends HpoRoute {
    private static final Logger log = LoggerFactory.getLogger(GenesetsOptimizationRoute.class);
    public static final String GET_GENESET_SCORE = "get_geneset_score";
    private static final boolean PRINT_INSPECTION_FILES = true;
    private static final String INSEPECTION_FILE_SUFFIX = "";
    private static final Set<String> documentsFilter = Set.of();

    public GenesetsOptimizationRoute(Configuration configuration) {
        super(log, (de.julielab.geneexpbase.configuration.Configuration)configuration);
    }

    public static void main(String[] args) throws IOException, GeneMapperInitializationException {
        Configuration configuration = new Configuration(new File("smac/gene_mapper_configurations/genemapper_for_geneset_opt.properties"));
        GenesetsOptimizationRoute route = new GenesetsOptimizationRoute(configuration);
        DoubleStream.Builder scoreStreamBuilder = DoubleStream.builder();
        for (int i = 0; i < 10; ++i) {
            HpoInstance si = route.parseInstanceName("genesets-merged-testsplit-" + i, INSEPECTION_FILE_SUFFIX);
            String score = route.calculateScore(si, new Parameters((Properties)((Object)configuration)), 1, 0, Integer.MAX_VALUE, Integer.MAX_VALUE, HpoRoute.Metric.NDCG, 1);
            System.out.println(score);
            scoreStreamBuilder.accept(Double.valueOf(score));
        }
        System.out.println("Average: " + scoreStreamBuilder.build().average().getAsDouble());
    }

    protected Injector createGuiceInjector(de.julielab.geneexpbase.configuration.Configuration configuration) {
        return Guice.createInjector((Module[])new Module[]{new GeneMappingModule(configuration)});
    }

    public String getRouteEndpoint() {
        return GET_GENESET_SCORE;
    }

    public int getNumSplits() {
        return 10;
    }

    public int getDevSamplingFrequency() {
        return 0;
    }

    protected List<HpoInstance> getActiveCorpora() {
        return List.of(HpoCorpusRegistry.gnpBc2Train(), HpoCorpusRegistry.gnpNlmiat());
    }

    protected String getTaskName() {
        return "genesets";
    }

    protected HpoRoute.Metric getDefaultMetric() {
        return HpoRoute.Metric.NDCG;
    }

    protected String calculateScore(HpoInstance si, Parameters parameterMap, int seed, int cutoffTime, int resourceBudget, int maxResourceBudget, HpoRoute.Metric returnMetric, int runId) {
        GeneMapper geneMapper = (GeneMapper)this.injector.getInstance(GeneMapper.class);
        ((DypsisMappingCore)geneMapper.getMappingCore()).setOmitCandidateSetting(true);
        try {
            double result = this.evaluate(si, parameterMap, geneMapper, returnMetric);
            String string = String.valueOf(result * -1.0);
            return string;
        }
        catch (Exception e) {
            log.error("Exception in score calculation", (Throwable)e);
            throw new GeneMapperRuntimeException(e);
        }
        finally {
            ((DypsisMappingCore)geneMapper.getMappingCore()).setOmitCandidateSetting(false);
        }
    }

    private double evaluate(HpoInstance si, Parameters parameterMap, GeneMapper geneMapper, HpoRoute.Metric returnMetric) throws GeneMapperException {
        List testPartition = this.getDocuments4Instance(si.getCorpus(), si.getSubcorpus(), si.getSplitType(), si.isMergeCorpora(), si.getCrossvalRound());
        log.info("Evaluating instance {}", (Object)si);
        double result = this.evaluateParameters(si, parameterMap, geneMapper, testPartition, returnMetric);
        log.info("Got {} {} for instance {}.", new Object[]{returnMetric, result, si});
        return result;
    }

    private double evaluateParameters(HpoInstance si, Parameters parameterMap, GeneMapper geneMapper, List<GeneDocument> testPartition, HpoRoute.Metric returnMetric) throws GeneMapperException {
        ArrayList<GeneDocument> copies = new ArrayList<GeneDocument>(testPartition.size());
        int overallNumDifferentGoldIds = 0;
        int overallNumGenesets = 0;
        int overallNumGenesetsWithContradictingIds = 0;
        HashMap<String, Double> documentScores = new HashMap<String, Double>();
        for (GeneDocument testDoc : testPartition) {
            if (!documentsFilter.isEmpty() && !documentsFilter.contains(testDoc.getId())) continue;
            GeneDocument copy = new GeneDocument(testDoc);
            copies.add(copy);
            geneMapper.map(copy, parameterMap);
            int numDifferentGoldIds = (int)copy.getGenes().flatMap(GeneMention::getAllGoldIds).distinct().count();
            int numGenesets = (int)copy.getGeneSets().stream().filter(gs -> gs.stream().anyMatch(GeneMention::hasGoldMentions)).count();
            int numGenesetsWithContradictingIds = 0;
            for (GeneSet gs2 : copy.getGeneSets()) {
                if (gs2.getContradictingGoldGeneIds().isEmpty()) continue;
                numGenesetsWithContradictingIds += gs2.getContradictingGoldGeneIds().size();
            }
            overallNumDifferentGoldIds += numDifferentGoldIds;
            overallNumGenesets += numGenesets;
            overallNumGenesetsWithContradictingIds += numGenesetsWithContradictingIds;
            double score = this.calculateGenesetScore(numDifferentGoldIds, numGenesets, numGenesetsWithContradictingIds);
            documentScores.put(copy.getId(), score);
        }
        double score = documentScores.values().stream().filter(Predicate.not(d -> Double.isNaN(d))).mapToDouble(Double::valueOf).average().getAsDouble();
        this.printInspectionFile("genesets", score, documentScores, si, copies);
        return score;
    }

    private double calculateGenesetScore(int numDifferentGoldIds, int numGenesets, int numGenesetsWithContradictingIds) {
        return (double)numDifferentGoldIds / (double)(numDifferentGoldIds + Math.abs(numGenesets - numDifferentGoldIds) + numGenesetsWithContradictingIds);
    }

    protected void printInspectionFile(String idType, double fileMetricValue, Map<String, Double> documentScores, HpoInstance hpoInstance, List<GeneDocument> documents) {
        DecimalFormat df = new DecimalFormat("0.##");
        File output = Path.of("smac", "inspectionfiles-" + idType, hpoInstance.toString() + "-" + df.format(fileMetricValue) + ".txt").toFile();
        if (!output.getParentFile().exists()) {
            output.getParentFile().mkdirs();
        }
        Function<GeneMention, Stream> geneMentionGoldIdFunction = gm -> gm.getIds().stream();
        Function<GeneDocument, Stream> documentGoldIdFunction = d -> d.getGoldIds().stream();
        Function<GeneDocument, Set> goldWithOffsetIdFunction = d -> d.getGoldGenes().values().stream().flatMap(Collection::stream).flatMap(geneMentionGoldIdFunction).map(Integer::parseInt).sorted().collect(Collectors.toCollection(LinkedHashSet::new));
        Function<GeneDocument, Set> goldNoOffsetIdFunction = d -> ((Stream)documentGoldIdFunction.apply((GeneDocument)d)).map(Integer::parseInt).sorted().collect(Collectors.toCollection(LinkedHashSet::new));
        Function<GeneDocument, Set> goldIdFunction = d -> d.isGoldHasOffsets() ? (Set)goldWithOffsetIdFunction.apply((GeneDocument)d) : (Set)goldNoOffsetIdFunction.apply((GeneDocument)d);
        BiFunction<GeneMention, String, GeneDocument.MentionCorrectness> correctnessFunction = GeneMention::getGenesetCorrectnessLevel;
        Map<GeneDocument.MentionCorrectness, BiFunction<GeneMention, String, String>> renderMap = Map.of(GeneDocument.MentionCorrectness.CORRECT_ID, GeneIdCorrectnessRenderer::renderCorrectGeneSetMention, GeneDocument.MentionCorrectness.WRONG_ID, GeneIdCorrectnessRenderer::renderWrongGeneSetMention, GeneDocument.MentionCorrectness.CANT_FIND, GeneIdCorrectnessRenderer::renderNoCorrectMentionGeneSet);
        try (BufferedWriter bw = FileUtilities.getWriterToFile((File)output);){
            bw.write("Instance: " + hpoInstance);
            bw.newLine();
            bw.write(String.format("Overall score : %s%.2f%s", Color.RED, fileMetricValue, Color.RESET));
            bw.newLine();
            bw.newLine();
            for (GeneDocument d2 : documents) {
                Double docScore = documentScores.get(d2.getId());
                Color col = InspectionFilePrinter.getHighlightColor((double)docScore);
                bw.write(String.format("%s doc score: %2$s%3$.2f, %4$s", d2.getId(), col, docScore, Color.RESET));
                bw.newLine();
                Set goldIds = goldIdFunction.apply(d2);
                bw.write("Gold tax IDs in doc: " + goldIds.stream().map(String::valueOf).collect(Collectors.joining(" ")));
                bw.newLine();
                bw.write(d2.getGenesetInspectionText(correctnessFunction, renderMap));
                bw.newLine();
                bw.newLine();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }
}

