/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.ScoredClassification;
import com.aliasi.stats.Statistics;
import com.aliasi.util.Math;
import com.aliasi.util.Pair;
import com.aliasi.util.ScoredObject;
import java.util.Arrays;

public class ConditionalClassification
extends ScoredClassification {
    private final double[] mConditionalProbs;
    private static final double TOLERANCE = 0.01;

    public ConditionalClassification(String[] categories, double[] conditionalProbs) {
        this(categories, conditionalProbs, conditionalProbs, 0.01);
    }

    public ConditionalClassification(String[] categories, double[] scores, double[] conditionalProbs) {
        this(categories, scores, conditionalProbs, 0.01);
    }

    public ConditionalClassification(String[] categories, double[] conditionalProbs, double tolerance) {
        this(categories, conditionalProbs, conditionalProbs, tolerance);
    }

    public ConditionalClassification(String[] categories, double[] scores, double[] conditionalProbs, double tolerance) {
        super(categories, scores);
        this.mConditionalProbs = conditionalProbs;
        if (tolerance < 0.0 || Double.isNaN(tolerance)) {
            String msg = "Tolerance must be a positive number. Found tolerance=" + tolerance;
            throw new IllegalArgumentException(msg);
        }
        int i = 0;
        while (i < conditionalProbs.length) {
            if (conditionalProbs[i] < 0.0 || conditionalProbs[i] > 1.0) {
                String msg = "Conditional probabilities must be  between 0.0 and 1.0. Found conditionalProbs[" + i + "]=" + conditionalProbs[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        double sum = Math.sum(conditionalProbs);
        if (sum < 1.0 - tolerance || sum > 1.0 + tolerance) {
            String msg = "Conditional probabilities must sum to 1.0. Acceptable tolerance=" + tolerance + " Found sum=" + sum;
            throw new IllegalArgumentException(msg);
        }
    }

    public double conditionalProbability(int rank) {
        if (rank < 0 || rank > this.mConditionalProbs.length - 1) {
            String msg = "Require rank in range 0.." + (this.mConditionalProbs.length - 1) + " Found rank=" + rank;
            throw new IllegalArgumentException(msg);
        }
        return this.mConditionalProbs[rank];
    }

    public double conditionalProbability(String category) {
        int rank = 0;
        while (rank < this.size()) {
            if (this.category(rank).equals(category)) {
                return this.conditionalProbability(rank);
            }
            ++rank;
        }
        String msg = String.valueOf(category) + " is not a valid category for this classification.  Valid categories are:";
        int rank2 = 0;
        while (rank2 < this.size()) {
            msg = String.valueOf(msg) + " " + this.category(rank2) + ",";
            ++rank2;
        }
        msg = msg.substring(0, msg.length() - 1);
        throw new IllegalArgumentException(msg);
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Rank  Category  Score  P(Category|Input)\n");
        int i = 0;
        while (i < this.size()) {
            sb.append(String.valueOf(i) + "=" + this.category(i) + " " + this.score(i) + " " + this.conditionalProbability(i) + '\n');
            ++i;
        }
        return sb.toString();
    }

    public static ConditionalClassification createLogProbs(String[] categories, double[] logProbabilities) {
        ConditionalClassification.verifyLengths(categories, logProbabilities);
        ConditionalClassification.verifyLogProbs(logProbabilities);
        double[] linearProbs = ConditionalClassification.logJointToConditional(logProbabilities);
        Pair<String[], double[]> catsProbs = ConditionalClassification.sort(categories, linearProbs);
        return new ConditionalClassification(catsProbs.a(), catsProbs.b());
    }

    public static ConditionalClassification createProbs(String[] categories, double[] probabilityRatios) {
        int i = 0;
        while (i < probabilityRatios.length) {
            if (probabilityRatios[i] < 0.0 || Double.isInfinite(probabilityRatios[i]) || Double.isNaN(probabilityRatios[i])) {
                String msg = "Probability ratios must be non-negative and finite. Found probabilityRatios[" + i + "]=" + probabilityRatios[i];
                throw new IllegalArgumentException(msg);
            }
            ++i;
        }
        if (Math.sum(probabilityRatios) == 0.0) {
            double[] conditionalProbs = new double[probabilityRatios.length];
            Arrays.fill(conditionalProbs, 1.0 / (double)probabilityRatios.length);
            return new ConditionalClassification(categories, conditionalProbs);
        }
        double[] logProbs = new double[probabilityRatios.length];
        int i2 = 0;
        while (i2 < probabilityRatios.length) {
            logProbs[i2] = Math.log2(probabilityRatios[i2]);
            ++i2;
        }
        return ConditionalClassification.createLogProbs(categories, logProbs);
    }

    static void verifyLogProbs(double[] logProbabilities) {
        double[] dArray = logProbabilities;
        int n = logProbabilities.length;
        int n2 = 0;
        while (n2 < n) {
            double x = dArray[n2];
            if (Double.isNaN(x) || x > 0.0) {
                String msg = "Log probs must be non-positive numbers. Found x=" + x;
                throw new IllegalArgumentException(msg);
            }
            ++n2;
        }
    }

    static void verifyLengths(String[] categories, double[] logProbabilities) {
        if (categories.length != logProbabilities.length) {
            String msg = "Arrays must be same length. Found categories.length=" + categories.length + " logProbabilities.length=" + logProbabilities.length;
            throw new IllegalArgumentException(msg);
        }
    }

    static Pair<String[], double[]> sort(String[] categories, double[] vals) {
        ConditionalClassification.verifyLengths(categories, vals);
        ScoredObject[] scoredObjects = new ScoredObject[categories.length];
        int i = 0;
        while (i < categories.length) {
            scoredObjects[i] = new ScoredObject<String>(categories[i], vals[i]);
            ++i;
        }
        String[] categoriesSorted = new String[scoredObjects.length];
        double[] valsSorted = new double[categories.length];
        Arrays.sort(scoredObjects, ScoredObject.reverseComparator());
        int i2 = 0;
        while (i2 < scoredObjects.length) {
            categoriesSorted[i2] = (String)scoredObjects[i2].getObject();
            valsSorted[i2] = scoredObjects[i2].score();
            ++i2;
        }
        return new Pair<String[], double[]>(categoriesSorted, valsSorted);
    }

    static double[] logJointToConditional(double[] logJointProbs) {
        int i = 0;
        while (i < logJointProbs.length) {
            if (logJointProbs[i] > 0.0 && logJointProbs[i] < 1.0E-10) {
                logJointProbs[i] = 0.0;
            }
            if (logJointProbs[i] > 0.0 || Double.isNaN(logJointProbs[i])) {
                StringBuilder sb = new StringBuilder();
                sb.append("Joint probs must be zero or negative. Found log2JointProbs[" + i + "]=" + logJointProbs[i]);
                int k = 0;
                while (k < logJointProbs.length) {
                    sb.append("\nlogJointProbs[" + k + "]=" + logJointProbs[k]);
                    ++k;
                }
                throw new IllegalArgumentException(sb.toString());
            }
            ++i;
        }
        double max = Math.maximum(logJointProbs);
        double[] probRatios = new double[logJointProbs.length];
        int i2 = 0;
        while (i2 < logJointProbs.length) {
            probRatios[i2] = java.lang.Math.pow(2.0, logJointProbs[i2] - max);
            if (probRatios[i2] == Double.POSITIVE_INFINITY) {
                probRatios[i2] = 3.4028234663852886E38;
            } else if (probRatios[i2] == Double.NEGATIVE_INFINITY || Double.isNaN(probRatios[i2])) {
                probRatios[i2] = 0.0;
            }
            ++i2;
        }
        return Statistics.normalize(probRatios);
    }
}

