/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Trial;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.util.Arrays;

public class ROCData
implements AlphabetCarrying,
Serializable {
    private static final long serialVersionUID = -2060194953037720640L;
    public static final int TRUE_POSITIVE = 0;
    public static final int FALSE_POSITIVE = 1;
    public static final int FALSE_NEGATIVE = 2;
    public static final int TRUE_NEGATIVE = 3;
    private final LabelAlphabet labelAlphabet;
    private final int[][][] counts;
    private final double[] thresholds;

    public ROCData(double[] thresholds, LabelAlphabet labelAlphabet) {
        Arrays.sort(thresholds);
        this.counts = new int[labelAlphabet.size()][thresholds.length][4];
        this.labelAlphabet = labelAlphabet;
        this.thresholds = thresholds;
    }

    public void add(Classification classification) {
        int correctIndex = classification.getInstance().getLabeling().getBestIndex();
        LabelVector lv = classification.getLabelVector();
        double[] values = lv.getValues();
        if (!Alphabet.alphabetsMatch(this, lv)) {
            throw new IllegalArgumentException("Alphabets do not match");
        }
        int numLabels = this.labelAlphabet.size();
        int label = 0;
        while (label < numLabels) {
            double labelValue = values[label];
            int[][] thresholdCounts = this.counts[label];
            int threshold = 0;
            while (threshold < this.thresholds.length && labelValue >= this.thresholds[threshold]) {
                if (correctIndex == label) {
                    int[] nArray = thresholdCounts[threshold];
                    nArray[0] = nArray[0] + 1;
                } else {
                    int[] nArray = thresholdCounts[threshold];
                    nArray[1] = nArray[1] + 1;
                }
                ++threshold;
            }
            while (threshold < this.thresholds.length) {
                if (correctIndex == label) {
                    int[] nArray = thresholdCounts[threshold];
                    nArray[2] = nArray[2] + 1;
                } else {
                    int[] nArray = thresholdCounts[threshold];
                    nArray[3] = nArray[3] + 1;
                }
                ++threshold;
            }
            ++label;
        }
    }

    public void add(Trial trial) {
        for (Classification classification : trial) {
            this.add(classification);
        }
    }

    public void add(ROCData rocData) {
        if (!Alphabet.alphabetsMatch(this, rocData)) {
            throw new IllegalArgumentException("Alphabets do not match");
        }
        if (!Arrays.equals(this.thresholds, rocData.thresholds)) {
            throw new IllegalArgumentException("Thresholds do not match");
        }
        int countsLength = this.counts.length;
        int c = 0;
        while (c < countsLength) {
            int[][] thisClassCounts = this.counts[c];
            int[][] otherClassCounts = rocData.counts[c];
            int classLength = thisClassCounts.length;
            int t = 0;
            while (t < classLength) {
                int[] thisThrCounts = thisClassCounts[t];
                int[] otherThrCounts = otherClassCounts[t];
                int thrLength = thisThrCounts.length;
                int s = 0;
                while (s < thrLength) {
                    int n = s;
                    thisThrCounts[n] = thisThrCounts[n] + otherThrCounts[s];
                    ++s;
                }
                ++t;
            }
            ++c;
        }
    }

    @Override
    public Alphabet getAlphabet() {
        return this.labelAlphabet;
    }

    @Override
    public Alphabet[] getAlphabets() {
        return new Alphabet[]{this.labelAlphabet};
    }

    public int[][] getCounts(Label label) {
        return this.counts[label.getIndex()];
    }

    public int[] getCounts(Label label, double threshold) {
        int index = Arrays.binarySearch(this.thresholds, threshold);
        if (index < 0) {
            index = -index - 2;
        }
        return this.counts[label.getIndex()][index];
    }

    public LabelAlphabet getLabelAlphabet() {
        return this.labelAlphabet;
    }

    public double getPrecision(Label label, double threshold) {
        int[] counts = this.getCounts(label, threshold);
        return (double)counts[0] / (double)(counts[0] + counts[1]);
    }

    public double getPrecisionForScore(Label label, double score) {
        double fp;
        double tp;
        int[][] buckets = this.counts[label.getIndex()];
        int index = Arrays.binarySearch(this.thresholds, score);
        if (index < 0) {
            index = -index - 2;
        }
        if (index == this.thresholds.length - 1) {
            tp = buckets[index][0];
            fp = buckets[index][1];
        } else {
            tp = buckets[index][0] - buckets[index + 1][0];
            fp = buckets[index][1] - buckets[index + 1][1];
        }
        return tp / (tp + fp);
    }

    public double getPositivePercent(Label label, double threshold) {
        int[] counts = this.getCounts(label, threshold);
        int positive = counts[0] + counts[1];
        return (double)positive / (double)(positive + counts[2] + counts[3]) * 100.0;
    }

    public double getRecall(Label label, double threshold) {
        int[] counts = this.getCounts(label, threshold);
        return (double)counts[0] / (double)(counts[0] + counts[2]);
    }

    public double[] getThresholds() {
        return this.thresholds;
    }

    public void setCounts(Label label, double threshold, int[] newCounts) {
        int[] oldCounts;
        int index = Arrays.binarySearch(this.thresholds, threshold);
        if (index < 0) {
            index = -index - 2;
        }
        if (newCounts.length != (oldCounts = this.counts[label.getIndex()][index]).length) {
            throw new IllegalArgumentException("Array of counts must contain " + oldCounts.length + " elements.");
        }
        int i = 0;
        while (i < oldCounts.length) {
            oldCounts[i] = newCounts[i];
            ++i;
        }
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        DecimalFormat format = new DecimalFormat("0.####");
        int i = 0;
        while (i < this.labelAlphabet.size()) {
            int[][] labelData = this.counts[i];
            buf.append("ROC data for ");
            buf.append(this.labelAlphabet.lookupObject(i).toString());
            buf.append('\n');
            buf.append("THR\tTP\tFP\tFN\tTN\tPrecis\tRecall\n");
            int t = 0;
            while (t < this.thresholds.length) {
                buf.append(this.thresholds[t]);
                int[] nArray = labelData[t];
                int n = nArray.length;
                int n2 = 0;
                while (n2 < n) {
                    int res = nArray[n2];
                    buf.append('\t').append(res);
                    ++n2;
                }
                double tp = labelData[t][0];
                double sum = tp + (double)labelData[t][1];
                double precision = 0.0;
                if (sum != 0.0) {
                    precision = tp / sum;
                }
                sum = tp + (double)labelData[t][2];
                double recall = 0.0;
                if (sum != 0.0) {
                    recall = tp / sum;
                }
                buf.append('\t').append(format.format(precision));
                buf.append('\t').append(format.format(recall));
                buf.append('\n');
                ++t;
            }
            buf.append('\n');
            ++i;
        }
        return buf.toString();
    }
}

