/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.speciesassignment.cooccurrence;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import de.julielab.geneexpbase.configuration.Parameters;
import de.julielab.geneexpbase.genemodel.Acronym;
import de.julielab.geneexpbase.genemodel.AcronymLongform;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.geneexpbase.genemodel.GeneSet;
import de.julielab.geneexpbase.genemodel.GeneSpeciesOccurrence;
import de.julielab.speciesassignment.Configuration;
import de.julielab.speciesassignment.spi.SynonymSpeciesCooccurrenceService;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.inject.Inject;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SingularToSetInference {
    private static final Logger log = LoggerFactory.getLogger(SingularToSetInference.class);
    private final SynonymSpeciesCooccurrenceService cooccurrenceService;

    @Inject
    public SingularToSetInference(SynonymSpeciesCooccurrenceService cooccurrenceService) {
        this.cooccurrenceService = cooccurrenceService;
    }

    public void infer(GeneDocument document, Parameters parameterMap) {
        document.resetGeneSets();
        document.agglomerateByAcronyms();
        document.agglomerateByNames(false);
        List<BiFunction<GeneSet, Set<String>, Optional<String>>> tieBreakers = this.initTieBreakers(parameterMap);
        ArrayList<Collection<GeneMention>> newGeneSets = new ArrayList<Collection<GeneMention>>();
        for (GeneSet geneSet : document.getGeneSets()) {
            boolean multipleIDs;
            if (geneSet.size() <= 1) continue;
            Set<String> taxIdsInGs = geneSet.stream().map(GeneMention::getTaxonomyIds).flatMap(Collection::stream).collect(Collectors.toSet());
            String gsTaxId = null;
            boolean bl = multipleIDs = taxIdsInGs.size() > 1;
            if (taxIdsInGs.size() > 1) {
                GeneSpeciesOccurrence occurrenceType;
                Optional<String> longformTaxId;
                if (parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_FROM_LONGFORM) && (longformTaxId = this.getLongformTaxId(geneSet)).isPresent()) {
                    gsTaxId = longformTaxId.get();
                }
                if (gsTaxId == null && parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_FROM_COMPOUND)) {
                    occurrenceType = GeneSpeciesOccurrence.COMPOUND;
                    gsTaxId = this.inferFromOccurrenceType(geneSet, occurrenceType, taxIdsInGs, tieBreakers);
                }
                if (gsTaxId == null && parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_FROM_SENTENCE_PROCEED)) {
                    occurrenceType = GeneSpeciesOccurrence.SENTENCE_PRECED;
                    gsTaxId = this.inferFromOccurrenceType(geneSet, occurrenceType, taxIdsInGs, tieBreakers);
                }
                if (gsTaxId == null && parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_FROM_CANDIDATE_FREQUENCY)) {
                    gsTaxId = this.inferFromSingularAssignments(geneSet, taxIdsInGs, tieBreakers);
                }
                if (gsTaxId == null && parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_FROM_SENTENCE_SUCCEED)) {
                    occurrenceType = GeneSpeciesOccurrence.SENTENCE_SUCCED;
                    gsTaxId = this.inferFromOccurrenceType(geneSet, occurrenceType, taxIdsInGs, tieBreakers);
                }
                if (gsTaxId == null && parameterMap.getBoolean(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TO_HIGHEST_GENESET_SCORE)) {
                    gsTaxId = this.inferFromHighestGenesetScoreTax(parameterMap, taxIdsInGs, geneSet, gsTaxId);
                }
            } else if (!taxIdsInGs.isEmpty()) {
                gsTaxId = (String)taxIdsInGs.stream().findAny().get();
            }
            if (gsTaxId != null) {
                this.delegateGsTaxId(geneSet, gsTaxId, newGeneSets);
            } else {
                this.assignHighestScoringAPrioriCandidates(geneSet, parameterMap);
            }
            this.makeAbbreviationsConsistent(geneSet, document);
        }
        for (Collection collection : newGeneSets) {
            document.addGeneSet(collection);
        }
    }

    public String inferFromHighestGenesetScoreTax(Parameters parameterMap, Set<String> taxonomyCandidatesInGeneset, GeneSet gs, String gsTaxId) {
        Optional<Pair<String, Double>> taxIdHighestScoredOnGeneSet = this.getTaxIdHighestScoredOnGeneSet(gs, taxonomyCandidatesInGeneset, parameterMap);
        if (taxIdHighestScoredOnGeneSet.isPresent()) {
            gsTaxId = taxIdHighestScoredOnGeneSet.get().getLeft();
        }
        return gsTaxId;
    }

    private String inferFromSingularAssignments(GeneSet gs, Set<String> taxIdsInGs, List<BiFunction<GeneSet, Set<String>, Optional<String>>> tieBreakers) {
        String gsTaxId = null;
        Map<String, Integer> taxCounts = gs.stream().map(GeneMention::getTaxonomyCandidates).flatMap(Collection::stream).collect(Collectors.toMap(Function.identity(), x -> 1, (x, y) -> x + y));
        Set<String> majorityTaxIds = this.getMajorityTaxIds(taxCounts);
        if (majorityTaxIds.size() > 1) {
            for (int i = 0; i < tieBreakers.size() && gsTaxId == null; ++i) {
                BiFunction<GeneSet, Set<String>, Optional<String>> tiebreaker = tieBreakers.get(i);
                Optional<String> resolvedTaxIdOptional = tiebreaker.apply(gs, majorityTaxIds);
                if (!resolvedTaxIdOptional.isPresent() || !taxIdsInGs.contains(resolvedTaxIdOptional.get())) continue;
                gsTaxId = resolvedTaxIdOptional.get();
            }
        } else if (!majorityTaxIds.isEmpty()) {
            gsTaxId = majorityTaxIds.iterator().next();
        }
        return gsTaxId;
    }

    private void makeAbbreviationsConsistent(GeneSet gs, GeneDocument document) {
        ArrayList<GeneMention> mentionsToAddToGs = new ArrayList<GeneMention>();
        for (GeneMention gm : gs.stream().collect(Collectors.toList())) {
            Acronym acronym;
            Optional<GeneMention> overlappingGene;
            Optional<AcronymLongform> longform;
            if (!gm.isAbbreviationLongForm() || !(longform = document.getOverlappingAcronymLongforms(gm.getOffsets()).stream().findAny()).isPresent() || !(overlappingGene = document.getOverlappingGenes((acronym = longform.get().getAcronyms().get(0)).getOffsets()).findAny()).isPresent()) continue;
            GeneMention acroGene = overlappingGene.get();
            if (!acroGene.getTaxonomyIds().equals(gm.getTaxonomyIds())) {
                acroGene.getSingleGeneSet().remove(acroGene);
                acroGene.getGeneSets().remove(acroGene.getSingleGeneSet());
                mentionsToAddToGs.add(acroGene);
                acroGene.addGeneSet(gs);
                gs.add(acroGene);
            }
            acroGene.setTaxonomyIds(gm.getTaxonomyIds());
        }
    }

    private void assignHighestScoringAPrioriCandidates(GeneSet gs, Parameters parameterMap) {
        for (GeneMention gm : gs) {
            Set<Pair<String, Double>> bestCandidates = this.cooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates(gm, null, gm.getTaxonomyCandidates(), parameterMap);
            Optional<Pair> first = bestCandidates.stream().sorted(Comparator.comparingDouble(p -> (Double)p.getRight()).reversed()).findFirst();
            if (gm.getTaxonomyOccurrences().values().contains((Object)GeneSpeciesOccurrence.COMPOUND)) continue;
            if (first.isPresent()) {
                gm.setTaxonomyIds(List.of((String)first.get().getLeft()));
                continue;
            }
            if (!gm.getTaxonomyCandidates().isEmpty()) {
                if (gm.isTaxonomyCandidatesConjunctive()) {
                    gm.setTaxonomyIds(new ArrayList<String>(gm.getTaxonomyCandidates()));
                    continue;
                }
                gm.setTaxonomyIds(List.of(gm.getTaxonomyCandidates().iterator().next()));
                continue;
            }
            Set<Pair<String, Double>> bestAPrioriTax = this.cooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates(gm, null, null, parameterMap);
            Optional<Pair> singleBest = bestAPrioriTax.stream().sorted(Comparator.comparingDouble(p -> (Double)p.getRight()).reversed()).findFirst();
            if (!singleBest.isPresent()) continue;
            gm.setTaxonomyIds(List.of((String)singleBest.get().getLeft()));
        }
    }

    public void delegateGsTaxId(GeneSet gs, String gsTaxId, List<Collection<GeneMention>> newGeneSets) {
        HashMultimap<String, Object> extractedGeneSets = HashMultimap.create();
        Iterator iterator = gs.iterator();
        HashSet<Integer> toRemoveGeneHashes = new HashSet<Integer>();
        while (iterator.hasNext()) {
            GeneMention gm = (GeneMention)iterator.next();
            Multimap<String, GeneSpeciesOccurrence> taxonomyOcurrences = gm.getTaxonomyOccurrences();
            if (taxonomyOcurrences.get(gsTaxId).contains((Object)GeneSpeciesOccurrence.SPECIES_PREFIX)) {
                gm.setTaxonomyIds(List.of(gsTaxId));
                continue;
            }
            if (taxonomyOcurrences.get(gsTaxId).contains((Object)GeneSpeciesOccurrence.COMPOUND) && !taxonomyOcurrences.containsValue((Object)GeneSpeciesOccurrence.SPECIES_PREFIX) || !taxonomyOcurrences.values().contains((Object)GeneSpeciesOccurrence.SPECIES_PREFIX) && !taxonomyOcurrences.values().contains((Object)GeneSpeciesOccurrence.COMPOUND)) {
                if (gm.getTaxonomyCandidates().contains(gsTaxId) && gm.isTaxonomyCandidatesConjunctive()) {
                    gm.setTaxonomyIds(new ArrayList<String>(gm.getTaxonomyCandidates()));
                    continue;
                }
                gm.setTaxonomyIds(List.of(gsTaxId));
                continue;
            }
            for (String taxId : taxonomyOcurrences.keySet()) {
                if (!taxonomyOcurrences.get(taxId).contains((Object)GeneSpeciesOccurrence.SPECIES_PREFIX) && !taxonomyOcurrences.get(taxId).contains((Object)GeneSpeciesOccurrence.COMPOUND)) continue;
                gm.setTaxonomyIds(List.of(taxId));
                extractedGeneSets.put(taxId, gm);
            }
            toRemoveGeneHashes.add(System.identityHashCode(gm));
        }
        for (String taxId : extractedGeneSets.keySet()) {
            Collection newGs = extractedGeneSets.get(taxId);
            newGeneSets.add(newGs);
        }
        Iterator it = gs.iterator();
        while (it.hasNext()) {
            GeneMention gm = (GeneMention)it.next();
            if (!toRemoveGeneHashes.contains(System.identityHashCode(gm))) continue;
            it.remove();
        }
    }

    @Nullable
    public String inferFromOccurrenceType(GeneSet gs, GeneSpeciesOccurrence occurrenceType, Set<String> taxIdsInGs, List<BiFunction<GeneSet, Set<String>, Optional<String>>> tieBreakers) {
        String gsTaxId = null;
        Map<String, Integer> occurenceTypeTaxIdCounts = gs.getSpeciesOccurrenceCounts(occurrenceType);
        Set<String> majorityTaxIds = this.getMajorityTaxIds(occurenceTypeTaxIdCounts);
        if (majorityTaxIds.size() > 1) {
            for (int i = 0; i < tieBreakers.size() && gsTaxId == null; ++i) {
                BiFunction<GeneSet, Set<String>, Optional<String>> tiebreaker = tieBreakers.get(i);
                Optional<String> resolvedTaxIdOptional = tiebreaker.apply(gs, majorityTaxIds);
                if (!resolvedTaxIdOptional.isPresent() || !taxIdsInGs.contains(resolvedTaxIdOptional.get())) continue;
                gsTaxId = resolvedTaxIdOptional.get();
            }
        } else if (!majorityTaxIds.isEmpty()) {
            gsTaxId = majorityTaxIds.iterator().next();
        }
        return gsTaxId;
    }

    private List<BiFunction<GeneSet, Set<String>, Optional<String>>> initTieBreakers(Parameters parameterMap) {
        HashMap<Integer, String> orderMap = new HashMap<Integer, String>();
        orderMap.put(parameterMap.getInt(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAJORITYSYNONYM_RANK), Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAJORITYSYNONYM_RANK);
        orderMap.put(parameterMap.getInt(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAX_APRIORI_SYN_TAX_PAIR_RANK), Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAX_APRIORI_SYN_TAX_PAIR_RANK);
        orderMap.put(parameterMap.getInt(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_HIGHEST_GENESET_SCORE_RANK), Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_HIGHEST_GENESET_SCORE_RANK);
        orderMap.put(parameterMap.getInt(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_ACRONYM_LONGFORM_RANK), Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_ACRONYM_LONGFORM_RANK);
        log.debug("Using the following tie breaker order: {}", (Object)orderMap);
        int numTieBreakers = 4;
        ArrayList<BiFunction<GeneSet, Set<String>, Optional<String>>> tiebreakers = new ArrayList<BiFunction<GeneSet, Set<String>, Optional<String>>>(numTieBreakers);
        for (int i = 0; i < numTieBreakers; ++i) {
            BiFunction<GeneSet, Set, Optional> tieBreaker;
            String tiebreakerName = (String)orderMap.get(i + 1);
            if (tiebreakerName == null) continue;
            if (tiebreakerName.equals(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAJORITYSYNONYM_RANK)) {
                tieBreaker = (gs, tiedTaxIds) -> {
                    Optional<Pair<String, Double>> bestAPrioriTaxId;
                    Optional<String> majoritySynonym = this.getMajoritySynonym((GeneSet)gs);
                    if (majoritySynonym.isPresent() && (bestAPrioriTaxId = this.cooccurrenceService.getBestAPrioriTaxId(majoritySynonym.get(), null, (Set<String>)tiedTaxIds, parameterMap)).isPresent()) {
                        return Optional.of(bestAPrioriTaxId.get().getLeft());
                    }
                    return Optional.empty();
                };
                tiebreakers.add(tieBreaker);
                continue;
            }
            if (tiebreakerName.equals(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_MAX_APRIORI_SYN_TAX_PAIR_RANK)) {
                tieBreaker = (gs, tiedTaxIds) -> this.getMaxAPrioriTaxonomyAssignment((GeneSet)gs, (Set<String>)tiedTaxIds, GeneSpeciesOccurrence.COMPOUND, parameterMap);
                tiebreakers.add(tieBreaker);
                continue;
            }
            if (tiebreakerName.equals(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_HIGHEST_GENESET_SCORE_RANK)) {
                tieBreaker = (gs, tiedTaxIds) -> {
                    Optional<Pair<String, Double>> taxIdHighestScoredOnGeneSet = this.getTaxIdHighestScoredOnGeneSet((GeneSet)gs, (Set<String>)tiedTaxIds, parameterMap);
                    if (taxIdHighestScoredOnGeneSet.isPresent()) {
                        return Optional.of(taxIdHighestScoredOnGeneSet.get().getLeft());
                    }
                    return Optional.empty();
                };
                tiebreakers.add(tieBreaker);
                continue;
            }
            if (!tiebreakerName.equals(Configuration.PARAM_SYNONYM_APRIORI_INFERENCE_TIEBREAKER_ACRONYM_LONGFORM_RANK)) continue;
            tieBreaker = (gs, tiedTaxIds) -> this.getLongformTaxId((GeneSet)gs);
            tiebreakers.add(tieBreaker);
        }
        return tiebreakers;
    }

    private Optional<String> getLongformTaxId(GeneSet geneSet) {
        for (GeneMention gm : geneSet) {
            GeneDocument document = gm.getGeneDocument();
            Optional<AcronymLongform> longformOpt = document.getOverlappingAcronymLongforms(gm.getOffsets()).stream().findAny();
            if (!longformOpt.isPresent() || gm.getTaxonomyCandidates().size() != 1) continue;
            Optional<String> taxId = gm.getTaxonomyCandidates().stream().findAny();
            if (log.isTraceEnabled()) {
                log.trace("Got tax ID {} for long form {}", (Object)taxId.get(), (Object)longformOpt.get().getText());
            }
            return taxId;
        }
        return Optional.empty();
    }

    private Optional<Pair<String, Double>> getTaxIdHighestScoredOnGeneSet(GeneSet gs, Set<String> taxIds, Parameters parameterMap) {
        Optional<Pair<String, Double>> stringDoublePair;
        double highestScore = 0.0;
        String bestTaxId = null;
        for (String taxId : taxIds) {
            double sum = gs.stream().flatMap(gm -> this.cooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates((GeneMention)gm, null, Set.of(taxId), parameterMap).stream()).mapToDouble(Pair::getRight).sum();
            if (!(sum > highestScore)) continue;
            highestScore = sum;
            bestTaxId = taxId;
        }
        Optional<Pair<String, Double>> optional = stringDoublePair = bestTaxId != null ? Optional.of(new ImmutablePair<Object, Double>(bestTaxId, highestScore)) : Optional.empty();
        if (stringDoublePair.isPresent()) {
            log.trace("Max a priori geneset scored tax ID: {} ({})", (Object)bestTaxId, (Object)highestScore);
        }
        return stringDoublePair;
    }

    private Optional<String> getMaxAPrioriTaxonomyAssignment(GeneSet gs, Set<String> tiedTaxIds, GeneSpeciesOccurrence occurrenceType, Parameters parameterMap) {
        Optional<String> ret;
        Stream<GeneMention> gms = gs.getGeneMentionsWithSpeciesOccurrence(occurrenceType);
        Iterator it = gms.iterator();
        Pair<String, Double> highestScoredTaxAssignment = null;
        while (it.hasNext()) {
            GeneMention gm = (GeneMention)it.next();
            Optional<String> taxIdOpt = gm.getTaxonomyCandidateWithOccurrence(occurrenceType);
            if (taxIdOpt.isEmpty()) {
                throw new IllegalStateException("A gene mentioned that was picked for having a taxonomy ID with occurrence type " + occurrenceType + " did not return a taxonomy ID for this occurrence type.");
            }
            Set<Pair<String, Double>> bestAPrioriTaxIdsForBestGeneCandidates = this.cooccurrenceService.getBestAPrioriTaxIdsForBestGeneCandidates(gm, null, tiedTaxIds, parameterMap);
            for (Pair<String, Double> pair : bestAPrioriTaxIdsForBestGeneCandidates) {
                if (highestScoredTaxAssignment != null && !((Double)highestScoredTaxAssignment.getRight() < pair.getRight())) continue;
                highestScoredTaxAssignment = pair;
            }
        }
        Optional<String> optional = ret = highestScoredTaxAssignment != null ? Optional.of((String)highestScoredTaxAssignment.getLeft()) : Optional.empty();
        if (ret.isPresent()) {
            log.trace("Highest a priori scored tax-synonym pair has tax and score: {} ({})", highestScoredTaxAssignment.getLeft(), highestScoredTaxAssignment.getRight());
        }
        return ret;
    }

    private Optional<String> getMajoritySynonym(GeneSet geneSet, @Nullable GeneSpeciesOccurrence occurrenceType) {
        Set majoritySynonyms;
        Stream<Object> genes = occurrenceType != null ? geneSet.getGeneMentionsWithSpeciesOccurrence(occurrenceType) : geneSet.stream();
        Multiset countedSynonyms = genes.map(GeneMention::getAllBestCandidateSynonyms).flatMap(Collection::stream).collect(Collectors.toCollection(HashMultiset::create));
        OptionalInt max = countedSynonyms.stream().mapToInt(countedSynonyms::count).max();
        if (max.isPresent() && (majoritySynonyms = countedSynonyms.stream().filter(synonym -> countedSynonyms.count(synonym) == max.getAsInt()).collect(Collectors.toSet())).size() == 1) {
            return Optional.of((String)majoritySynonyms.stream().findAny().get());
        }
        return Optional.empty();
    }

    private Optional<String> getMajoritySynonym(GeneSet geneSet) {
        return this.getMajoritySynonym(geneSet, null);
    }

    private Set<String> getMajorityTaxIds(Map<String, Integer> taxCounts) {
        Set<String> ret = Collections.emptySet();
        Optional<Integer> max = taxCounts.values().stream().max(Integer::compareTo);
        if (max.isPresent()) {
            ret = taxCounts.keySet().stream().filter(tax -> taxCounts.get(tax) == max.get()).collect(Collectors.toSet());
        }
        return ret;
    }
}

