package net.maizegenetics.pangenome.hapCalling;

import com.google.common.collect.Multiset;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.maizegenetics.analysis.imputation.BackwardForwardVariableStateNumber;
import net.maizegenetics.analysis.imputation.ViterbiAlgorithmVariableStateNumber;
import net.maizegenetics.dna.map.Chromosome;
import net.maizegenetics.pangenome.api.CreateGraphUtils;
import net.maizegenetics.pangenome.api.FilterGraphPlugin;
import net.maizegenetics.pangenome.api.HaplotypeGraph;
import net.maizegenetics.pangenome.api.HaplotypeNode;
import net.maizegenetics.pangenome.api.ReferenceRange;
import net.maizegenetics.pangenome.api.ReferenceRangeEmissionProbability;
import net.maizegenetics.pangenome.api.ReferenceRangeTransitionProbability;
import net.maizegenetics.taxa.TaxaList;
import org.apache.log4j.Logger;

/* loaded from: input_file:net/maizegenetics/pangenome/hapCalling/ConvertGBSToSNPs.class */
public class ConvertGBSToSNPs {
    private static Logger myLogger = Logger.getLogger(ConvertGBSToSNPs.class);
    private HaplotypeGraph myGraph;
    private List<HaplotypeNode> nodesOnPath;
    private Multiset<Integer> myHapidCounts = null;
    private Multiset<HaplotypeNode> myHapNodeCounts = null;
    private Map<Integer, Integer> myHapidCountMap = null;
    private Map<Integer, Integer> myHapidExclusionCountMap = null;
    private List<List<HaplotypeNode>> pathNodeLists = null;
    private List<double[]> pathGammas = null;
    private int minTaxaPerRefRange = 1;
    private int minReadsPerRefRange = 0;
    private int maxReadsPerRefRangeKB = 10000;
    private String myTaxaListString = null;
    private TaxaList myTaxaList = null;
    private double minTransitionProb = 0.001d;
    private double probReadMappedCorrectly = 0.99d;
    private double maxNodesPerRange = 10000.0d;
    private boolean splitTaxa = false;
    private ReferenceRangeEmissionProbability.METHOD myEmissionMethod = ReferenceRangeEmissionProbability.METHOD.inclusionOnly;
    private double transitionProbSameTaxon = 0.99d;
    private String targetTaxon = null;

    public ConvertGBSToSNPs countHaplotypeNodesFromFastQ() {
        this.myHapNodeCounts = (Multiset) new FastqToHapCountPlugin(null, false).configFile(ConvertGBSUtils.configFileName).readFile("/workdir/pjb39/phg.test/B73W22Cross_W22-Brink-std_D0DW9ACXX_4_CCACTCA.fastq").haplotypesGenomeFile("/workdir/pjb39/phg.test/").refGenomeFile("/workdir/pjb39/phg.test/Zea_mays.AGPv4.dna.toplevel.fa.gz").processData(null).getData(0).getData();
        this.myHapidCountMap = (Map) this.myHapNodeCounts.entrySet().stream().collect(Collectors.toMap(entry -> {
            return Integer.valueOf(((HaplotypeNode) entry.getElement()).id());
        }, entry2 -> {
            return Integer.valueOf(entry2.getCount());
        }));
        myLogger.debug("number of elements in myHapNodeCounts = " + this.myHapNodeCounts.size());
        myLogger.debug("size of myHapidCountMap = " + this.myHapidCountMap.size());
        return this;
    }

    public void testHapidCounts(HaplotypeGraph haplotypeGraph, String str) {
        if (this.myHapidCountMap == null) {
            throw new IllegalArgumentException("myHapidCountMap is null.");
        }
        System.out.println("------------------------------");
        System.out.println("taxon chr nodesWithTaxon nodesWithTaxonMax nodesNearMax pMax pNearMax");
        for (Chromosome chromosome : haplotypeGraph.chromosomes()) {
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            for (List<HaplotypeNode> list : haplotypeGraph.tree(chromosome).values()) {
                Map map = (Map) list.stream().collect(Collectors.toMap(haplotypeNode -> {
                    return haplotypeNode;
                }, haplotypeNode2 -> {
                    return this.myHapidCountMap.getOrDefault(Integer.valueOf(haplotypeNode2.id()), 0);
                }));
                if (map.values().stream().mapToInt((v0) -> {
                    return v0.intValue();
                }).sum() > 0) {
                    Optional<HaplotypeNode> findAny = list.stream().filter(haplotypeNode3 -> {
                        return haplotypeNode3.taxaList().indexOf(str) > -1;
                    }).findAny();
                    if (findAny.isPresent()) {
                        int orElse = map.values().stream().mapToInt((v0) -> {
                            return v0.intValue();
                        }).max().orElse(0);
                        int intValue = ((Integer) map.get(findAny.get())).intValue();
                        i++;
                        if (intValue == orElse) {
                            i2++;
                        }
                        if (intValue / orElse > 0.8d) {
                            i3++;
                        }
                    }
                }
            }
            System.out.printf("%s %s %d %d %d %1.2f %1.2f%n", str, chromosome.getName(), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3), Double.valueOf(i2 / i), Double.valueOf(i3 / i));
        }
    }

    public ConvertGBSToSNPs filterHaplotypeGraph(HaplotypeGraph haplotypeGraph) {
        FilterGraphPlugin filterGraphPlugin = new FilterGraphPlugin(null, false);
        if (this.minReadsPerRefRange > 0) {
            ArrayList arrayList = new ArrayList();
            for (ReferenceRange referenceRange : haplotypeGraph.referenceRanges()) {
                int[] array = haplotypeGraph.nodes(referenceRange).stream().mapToInt((v0) -> {
                    return v0.id();
                }).map(i -> {
                    return this.myHapidCountMap.getOrDefault(Integer.valueOf(i), 0).intValue();
                }).toArray();
                int sum = Arrays.stream(array).sum();
                int orElse = Arrays.stream(array).max().orElse(0);
                int size = haplotypeGraph.nodes(referenceRange).size();
                if (sum < this.minReadsPerRefRange || orElse > this.maxReadsPerRefRangeKB || size > this.maxNodesPerRange) {
                    arrayList.add(referenceRange);
                } else {
                    boolean z = true;
                    int length = array.length;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= length) {
                            break;
                        }
                        if (array[i2] != array[0]) {
                            z = false;
                            break;
                        }
                        i2++;
                    }
                    if (z) {
                        arrayList.add(referenceRange);
                    }
                }
            }
            filterGraphPlugin.refRanges(arrayList);
        }
        if (this.minTaxaPerRefRange > 0) {
            filterGraphPlugin.minCountTaxa(Integer.valueOf(this.minTaxaPerRefRange));
        }
        if (this.myTaxaList != null) {
            filterGraphPlugin.taxaList(this.myTaxaList);
        } else if (this.myTaxaListString != null) {
            filterGraphPlugin.taxaList(this.myTaxaListString);
        }
        myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", Integer.valueOf(haplotypeGraph.numberOfNodes())));
        this.myGraph = filterGraphPlugin.filter(haplotypeGraph);
        myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", Integer.valueOf(this.myGraph.numberOfNodes())));
        this.myGraph = CreateGraphUtils.addMissingSequenceNodes(this.myGraph);
        if (this.splitTaxa) {
            long currentTimeMillis = System.currentTimeMillis();
            System.out.println("starting split nodes by taxa");
            myLogger.debug("starting split nodes by taxa");
            this.myGraph = CreateGraphUtils.nodesSplitByIndividualTaxa(this.myGraph, this.transitionProbSameTaxon);
            myLogger.debug(String.format("split by taxa took %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
        }
        return this;
    }

    public ConvertGBSToSNPs filterHaplotypeGraph(HaplotypeGraph haplotypeGraph, List<Integer> list) {
        myLogger.debug("Filtering haplotype graph prior to path finding");
        FilterGraphPlugin filterGraphPlugin = new FilterGraphPlugin(null, false);
        ArrayList arrayList = new ArrayList();
        for (ReferenceRange referenceRange : haplotypeGraph.referenceRanges()) {
            if (list == null || list.contains(Integer.valueOf(referenceRange.id()))) {
                int[] array = haplotypeGraph.nodes(referenceRange).stream().mapToInt((v0) -> {
                    return v0.id();
                }).map(i -> {
                    return this.myHapidCountMap.getOrDefault(Integer.valueOf(i), 0).intValue();
                }).toArray();
                int sum = Arrays.stream(array).sum();
                int orElse = Arrays.stream(array).max().orElse(0);
                int size = haplotypeGraph.nodes(referenceRange).size();
                int ceil = (int) Math.ceil(this.maxReadsPerRefRangeKB * (((referenceRange.end() - referenceRange.start()) + 1) / 1000.0d));
                if (sum < this.minReadsPerRefRange || orElse > ceil || size > this.maxNodesPerRange) {
                    arrayList.add(referenceRange);
                }
            } else {
                arrayList.add(referenceRange);
            }
        }
        filterGraphPlugin.refRanges(arrayList);
        if (this.minTaxaPerRefRange > 0) {
            filterGraphPlugin.minCountTaxa(Integer.valueOf(this.minTaxaPerRefRange));
        }
        if (this.myTaxaList != null) {
            filterGraphPlugin.taxaList(this.myTaxaList);
        } else if (this.myTaxaListString != null) {
            filterGraphPlugin.taxaList(this.myTaxaListString);
        }
        this.myGraph = filterGraphPlugin.filter(haplotypeGraph);
        System.out.println("adding missing sequence nodes");
        this.myGraph = CreateGraphUtils.addMissingSequenceNodes(this.myGraph);
        System.out.println("splitTaxa = " + this.splitTaxa);
        if (this.splitTaxa) {
            long currentTimeMillis = System.currentTimeMillis();
            System.out.println("starting split nodes by taxa");
            this.myGraph = CreateGraphUtils.nodesSplitByIndividualTaxa(this.myGraph, this.transitionProbSameTaxon);
            System.out.printf("split by taxa took %d ms.%n", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
        }
        return this;
    }

    public void listTaxa() {
        System.out.println("taxa in graph:");
        Stream stream = this.myGraph.taxaInGraph().stream();
        PrintStream printStream = System.out;
        printStream.getClass();
        stream.forEach((v1) -> {
            r1.println(v1);
        });
    }

    public void showNodeCounts() {
        int i = 0;
        for (ReferenceRange referenceRange : this.myGraph.referenceRanges()) {
            int i2 = 0;
            for (HaplotypeNode haplotypeNode : this.myGraph.nodes(referenceRange)) {
                int i3 = i2;
                i2++;
                System.out.printf("%s, %d, node %d, id %d: %d%n", referenceRange.chromosome().getName(), Integer.valueOf(referenceRange.start()), Integer.valueOf(i3), Integer.valueOf(haplotypeNode.id()), Integer.valueOf(this.myHapidCountMap.get(Integer.valueOf(haplotypeNode.id())).intValue()));
            }
            i++;
            if (i > 20) {
                System.exit(0);
            }
        }
    }

    public List<HaplotypeNode> haplotypeCountsToPath() {
        ArrayList arrayList = new ArrayList();
        for (Chromosome chromosome : this.myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            TreeMap<ReferenceRange, List<HaplotypeNode>> tree = this.myGraph.tree(chromosome);
            myLogger.info("Extracted graph tree for chromosome " + chromosome.getName());
            long currentTimeMillis = System.currentTimeMillis();
            ReferenceRangeEmissionProbability.Builder method = new ReferenceRangeEmissionProbability.Builder().nodeMap(tree).inclusionCountMap(this.myHapidCountMap).probabilityCorrect(this.probReadMappedCorrectly).method(this.myEmissionMethod);
            if (this.myEmissionMethod != ReferenceRangeEmissionProbability.METHOD.inclusionOnly) {
                method.exclusionCountMap(this.myHapidExclusionCountMap);
            }
            ReferenceRangeEmissionProbability build = method.build();
            myLogger.info(String.format("emission probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
            myLogger.info(build.toString());
            long currentTimeMillis2 = System.currentTimeMillis();
            ArrayList arrayList2 = new ArrayList(tree.values());
            ReferenceRangeTransitionProbability referenceRangeTransitionProbability = new ReferenceRangeTransitionProbability(arrayList2, this.myGraph, this.minTransitionProb);
            myLogger.info(String.format("transition probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis2)));
            System.currentTimeMillis();
            int size = arrayList2.size();
            byte[] bArr = new byte[size];
            double[] startProbabilities = startProbabilities((List) arrayList2.get(0));
            long currentTimeMillis3 = System.currentTimeMillis();
            ViterbiAlgorithmVariableStateNumber viterbiAlgorithmVariableStateNumber = new ViterbiAlgorithmVariableStateNumber(bArr, referenceRangeTransitionProbability, build, startProbabilities);
            viterbiAlgorithmVariableStateNumber.initialize();
            viterbiAlgorithmVariableStateNumber.calculate();
            byte[] mostProbableStateSequence = viterbiAlgorithmVariableStateNumber.getMostProbableStateSequence();
            myLogger.info(String.format("Viterbi algorithm calculated in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis3)));
            for (int i = 0; i < size; i++) {
                arrayList.add(((List) arrayList2.get(i)).get(mostProbableStateSequence[i]));
            }
        }
        return arrayList;
    }

    public List<double[]> haplotypeCountsToPathProbability() {
        this.pathGammas = new ArrayList();
        this.pathNodeLists = new ArrayList();
        for (Chromosome chromosome : this.myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            TreeMap<ReferenceRange, List<HaplotypeNode>> tree = this.myGraph.tree(chromosome);
            this.pathNodeLists.addAll(tree.values());
            myLogger.info("Extracted graph tree for chromosome " + chromosome.getName());
            long currentTimeMillis = System.currentTimeMillis();
            ReferenceRangeEmissionProbability.Builder method = new ReferenceRangeEmissionProbability.Builder().nodeMap(tree).inclusionCountMap(this.myHapidCountMap).probabilityCorrect(this.probReadMappedCorrectly).method(this.myEmissionMethod);
            if (this.myEmissionMethod != ReferenceRangeEmissionProbability.METHOD.inclusionOnly) {
                method.exclusionCountMap(this.myHapidExclusionCountMap);
            }
            ReferenceRangeEmissionProbability build = method.build();
            myLogger.info(String.format("emission probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
            myLogger.info(build.toString());
            long currentTimeMillis2 = System.currentTimeMillis();
            ArrayList arrayList = new ArrayList(tree.values());
            ReferenceRangeTransitionProbability referenceRangeTransitionProbability = new ReferenceRangeTransitionProbability(arrayList, this.myGraph, this.minTransitionProb);
            myLogger.info(String.format("transition probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis2)));
            System.currentTimeMillis();
            int[] iArr = new int[arrayList.size()];
            double[] startProbabilities = startProbabilities((List) arrayList.get(0));
            System.currentTimeMillis();
            BackwardForwardVariableStateNumber backwardForwardVariableStateNumber = new BackwardForwardVariableStateNumber();
            backwardForwardVariableStateNumber.emission(build).transition(referenceRangeTransitionProbability).initialStateProbability(startProbabilities).observations(iArr).calculateAlpha().calculateBeta();
            this.pathGammas.addAll(backwardForwardVariableStateNumber.gamma());
        }
        return this.pathGammas;
    }

    public List<HaplotypeNode> nodeListFromProbabilities(double d, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<ReferenceRange> it = this.myGraph.referenceRangeList().iterator();
        for (double[] dArr : this.pathGammas) {
            ReferenceRange next = it.next();
            int i = 0;
            for (int i2 = 1; i2 < dArr.length; i2++) {
                if (dArr[i2] > dArr[i]) {
                    i = i2;
                }
            }
            if (dArr[i] >= d) {
                arrayList.add(this.myGraph.nodes(next).get(i));
            }
        }
        if (str != null) {
            try {
                PrintWriter printWriter = new PrintWriter(str);
                printWriter.println("chr\tstart\thasTarget\tprob\tnTaxa\tincludeCount\texcludeCount");
                Iterator<ReferenceRange> it2 = this.myGraph.referenceRangeList().iterator();
                for (double[] dArr2 : this.pathGammas) {
                    ReferenceRange next2 = it2.next();
                    int i3 = 0;
                    for (HaplotypeNode haplotypeNode : this.myGraph.nodes(next2)) {
                        printWriter.print(next2.chromosome().getName() + "\t");
                        printWriter.print(Integer.toString(next2.start()) + "\t");
                        printWriter.print(Boolean.toString(haplotypeNode.taxaList().indexOf(this.targetTaxon) >= 0) + "\t");
                        int i4 = i3;
                        i3++;
                        printWriter.print(dArr2[i4] + "\t");
                        printWriter.print(haplotypeNode.numTaxa() + "\t");
                        printWriter.print(this.myHapidCountMap.get(Integer.valueOf(haplotypeNode.id())).toString() + "\t");
                        printWriter.println(this.myHapidExclusionCountMap.get(Integer.valueOf(haplotypeNode.id())).toString());
                    }
                }
                printWriter.close();
            } catch (FileNotFoundException e) {
                myLogger.error(e.getMessage());
                myLogger.error(String.format("Unable to open %s for output in ConvertGBSToSNPs.nodeListFromProbabilities", str));
            }
        }
        return arrayList;
    }

    public List<HaplotypeNode> nodeListFromProbabilities(double d) {
        return nodeListFromProbabilities(d, null);
    }

    public double[] startProbabilities(List<HaplotypeNode> list) {
        int[] array = list.stream().mapToInt(haplotypeNode -> {
            return this.myHapidCountMap.getOrDefault(Integer.valueOf(haplotypeNode.id()), 0).intValue();
        }).toArray();
        int sum = Arrays.stream(array).sum();
        double[] dArr = new double[array.length];
        for (int i = 0; i < array.length; i++) {
            dArr[i] = array[i] / sum;
        }
        return dArr;
    }

    public double[] probabilityOfBeingCorrect(HaplotypeGraph haplotypeGraph, Multiset<Integer> multiset) {
        return haplotypeGraph.referenceRanges().stream().mapToDouble(referenceRange -> {
            return nodeCorrectProbability(haplotypeGraph.nodes(referenceRange).stream().map(haplotypeNode -> {
                return Integer.valueOf(multiset.count(Integer.valueOf(haplotypeNode.id())));
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }).toArray();
    }

    public double[] probabilityOfBeingCorrect(Multiset<Integer> multiset, TreeMap<ReferenceRange, List<HaplotypeNode>> treeMap) {
        double[] dArr = new double[treeMap.size()];
        int i = 0;
        Iterator<Map.Entry<ReferenceRange, List<HaplotypeNode>>> it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = nodeCorrectProbability(it.next().getValue().stream().map(haplotypeNode -> {
                return Integer.valueOf(multiset.count(Integer.valueOf(haplotypeNode.id())));
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }
        return dArr;
    }

    public double[] probabilityOfBeingCorrect(Map<Integer, Integer> map, TreeMap<ReferenceRange, List<HaplotypeNode>> treeMap) {
        double[] dArr = new double[treeMap.size()];
        int i = 0;
        Iterator<Map.Entry<ReferenceRange, List<HaplotypeNode>>> it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = nodeCorrectProbability(it.next().getValue().stream().map(haplotypeNode -> {
                return (Integer) map.getOrDefault(haplotypeNode, 0);
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }
        return dArr;
    }

    private double nodeCorrectProbability(int[] iArr) {
        int i = 0;
        int i2 = 0;
        for (int i3 : iArr) {
            i += i3;
            i2 = Math.max(i2, i3);
        }
        return i2 / i;
    }

    public List<HaplotypeNode> nodesOnPath() {
        return this.nodesOnPath;
    }

    public HaplotypeGraph filteredGraph() {
        return this.myGraph;
    }

    public Map<Integer, Integer> hapidCountMap() {
        return this.myHapidCountMap;
    }

    public ConvertGBSToSNPs hapidCounts(Multiset<Integer> multiset) {
        this.myHapidCounts = multiset;
        this.myHapidCountMap = (Map) this.myHapidCounts.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getElement();
        }, (v0) -> {
            return v0.getCount();
        }));
        return this;
    }

    public ConvertGBSToSNPs hapidCountMap(Map<Integer, Integer> map) {
        this.myHapidCountMap = map;
        return this;
    }

    public ConvertGBSToSNPs hapidExclusionCountMap(Map<Integer, Integer> map) {
        this.myHapidExclusionCountMap = map;
        return this;
    }

    public ConvertGBSToSNPs minReadsPerRange(int i) {
        this.minReadsPerRefRange = i;
        return this;
    }

    public ConvertGBSToSNPs maxReadsPerRangeKB(int i) {
        this.maxReadsPerRefRangeKB = i;
        return this;
    }

    public ConvertGBSToSNPs taxaFilterList(String str) {
        this.myTaxaListString = str;
        return this;
    }

    public ConvertGBSToSNPs taxaFilterList(TaxaList taxaList) {
        this.myTaxaList = taxaList;
        return this;
    }

    public ConvertGBSToSNPs minTaxaPerRange(int i) {
        this.minTaxaPerRefRange = i;
        return this;
    }

    public ConvertGBSToSNPs maxNodesPerRange(int i) {
        this.maxNodesPerRange = i;
        return this;
    }

    public ConvertGBSToSNPs probabilityReadMappingCorrect(double d) {
        this.probReadMappedCorrectly = d;
        return this;
    }

    public ConvertGBSToSNPs minTransitionProbability(double d) {
        this.minTransitionProb = d;
        return this;
    }

    public ConvertGBSToSNPs emissionProbabilityMethod(ReferenceRangeEmissionProbability.METHOD method) {
        this.myEmissionMethod = method;
        return this;
    }

    public ConvertGBSToSNPs splitTaxa(boolean z) {
        this.splitTaxa = z;
        return this;
    }

    public ConvertGBSToSNPs transitionProbabilitySameTaxon(double d) {
        this.transitionProbSameTaxon = d;
        return this;
    }

    public ConvertGBSToSNPs targetTaxon(String str) {
        this.targetTaxon = str;
        return this;
    }
}
