/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.speciesassignment.evaluation;

import com.google.common.collect.Multiset;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.speciesassignment.evaluation.SpeciesCountDocument;
import de.julielab.speciesassignment.evaluation.SpeciesCrossValPartition;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SpeciesCountDocumentPartitioning {
    private static final Logger log = LoggerFactory.getLogger(SpeciesCountDocumentPartitioning.class);

    public static List<SpeciesCrossValPartition> partitionDocuments(Stream<GeneDocument> documents, int numSplits) {
        ArrayList<SpeciesCrossValPartition> partitioning = new ArrayList<SpeciesCrossValPartition>(numSplits);
        for (int i = 0; i < numSplits; ++i) {
            partitioning.add(new SpeciesCrossValPartition());
        }
        List<SpeciesCountDocument> countDocs = documents.map(SpeciesCountDocument::new).collect(Collectors.toList());
        HashMap<String, TreeSet<SpeciesCountDocument>> tax2docs = new HashMap<String, TreeSet<SpeciesCountDocument>>();
        countDocs.forEach(doc -> doc.getDocument().getGoldGenes().values().stream().flatMap(Collection::stream).map(gm -> gm.getIds().get(0)).distinct().forEach(tax -> tax2docs.compute((String)tax, (k, v) -> v != null ? v : new TreeSet<SpeciesCountDocument>(Comparator.comparingInt(d -> d.getCount((String)tax)).thenComparing(d -> d.getId()).reversed())).add(doc)));
        HashSet<String> assignedDocIds = new HashSet<String>();
        int docNum = 0;
        Predicate<Map> notEmptyFunc = docs -> docs.values().stream().anyMatch(Predicate.not(Collection::isEmpty));
        Set<String> allTaxIds = countDocs.stream().map(SpeciesCountDocument::getTaxCounts).map(Multiset::elementSet).flatMap(Collection::stream).collect(Collectors.toSet());
        while (notEmptyFunc.test(tax2docs)) {
            int splitNum = docNum % numSplits;
            SpeciesCrossValPartition partition = (SpeciesCrossValPartition)partitioning.get(splitNum);
            if (partition.isEmpty()) {
                SpeciesCountDocumentPartitioning.addAnyDocumentToPartition(tax2docs, assignedDocIds, partition);
            } else {
                Deque<String> taxesOrderedByFrequency = partition.getTaxesOrderedByFrequency(allTaxIds);
                while (!taxesOrderedByFrequency.isEmpty() && !SpeciesCountDocumentPartitioning.assignNextAvailableDocument4Tax(tax2docs, assignedDocIds, partition, (String)taxesOrderedByFrequency.poll())) {
                }
            }
            ++docNum;
        }
        int i = 0;
        for (SpeciesCountDocument emptyDoc : () -> countDocs.stream().filter(cd2 -> cd2.getDocument().getGoldGenes().isEmpty()).iterator()) {
            int index = i % numSplits;
            ((SpeciesCrossValPartition)partitioning.get(index)).add(emptyDoc);
            ++i;
        }
        log.debug("Partitioned {} documents into {} partitions. Size distribution: {}", partitioning.stream().mapToInt(Collection::size).sum(), partitioning.size(), partitioning.stream().map(Collection::size).map(String::valueOf).collect(Collectors.joining(", ")));
        return partitioning;
    }

    private static void addAnyDocumentToPartition(Map<String, TreeSet<SpeciesCountDocument>> tax2docs, Set<String> assignedDocIds, SpeciesCrossValPartition partition) {
        ArrayDeque<String> allTax = new ArrayDeque<String>(tax2docs.keySet());
        while (!allTax.isEmpty() && !SpeciesCountDocumentPartitioning.assignNextAvailableDocument4Tax(tax2docs, assignedDocIds, partition, (String)allTax.poll())) {
        }
    }

    private static boolean assignNextAvailableDocument4Tax(Map<String, TreeSet<SpeciesCountDocument>> tax2docs, Set<String> assignedDocIds, SpeciesCrossValPartition partition, String tax) {
        SpeciesCountDocument document = tax2docs.get(tax).pollFirst();
        while (document != null && !assignedDocIds.add(document.getId())) {
            document = tax2docs.get(tax).pollFirst();
        }
        if (document != null) {
            partition.add(document);
            assignedDocIds.add(document.getId());
            return true;
        }
        return false;
    }
}

