/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.classify;

import com.aliasi.classify.ConditionalClassifierEvaluator;
import com.aliasi.classify.JointClassification;
import com.aliasi.classify.JointClassifier;
import com.aliasi.util.Math;

public class JointClassifierEvaluator<E>
extends ConditionalClassifierEvaluator<E> {
    public JointClassifierEvaluator(JointClassifier<E> classifier, String[] categories, boolean storeInputs) {
        super(classifier, categories, storeInputs);
    }

    @Override
    public void setClassifier(JointClassifier<E> classifier) {
        this.setClassifier(classifier, JointClassifierEvaluator.class);
    }

    @Override
    public JointClassifier<E> classifier() {
        JointClassifier result = (JointClassifier)super.classifier();
        return result;
    }

    public double averageLog2JointProbability(String refCategory, String responseCategory) {
        this.validateCategory(refCategory);
        this.validateCategory(responseCategory);
        double sum = 0.0;
        int count = 0;
        int i = 0;
        while (i < this.mReferenceCategories.size()) {
            if (((String)this.mReferenceCategories.get(i)).equals(refCategory)) {
                JointClassification c = (JointClassification)this.mClassifications.get(i);
                int rank = 0;
                while (rank < c.size()) {
                    if (c.category(rank).equals(responseCategory)) {
                        sum += c.jointLog2Probability(rank);
                        ++count;
                        break;
                    }
                    ++rank;
                }
            }
            ++i;
        }
        return sum / (double)count;
    }

    public double averageLog2JointProbabilityReference() {
        double sum = 0.0;
        int i = 0;
        while (i < this.mReferenceCategories.size()) {
            String refCategory = ((String)this.mReferenceCategories.get(i)).toString();
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            int rank = 0;
            while (rank < c.size()) {
                if (c.category(rank).equals(refCategory)) {
                    sum += c.jointLog2Probability(rank);
                    break;
                }
                ++rank;
            }
            ++i;
        }
        return sum / (double)this.mReferenceCategories.size();
    }

    public double corpusLog2JointProbability() {
        double total = 0.0;
        int i = 0;
        while (i < this.mClassifications.size()) {
            JointClassification c = (JointClassification)this.mClassifications.get(i);
            double maxJointLog2P = Double.NEGATIVE_INFINITY;
            int rank = 0;
            while (rank < c.size()) {
                double jointLog2P = c.jointLog2Probability(rank);
                if (jointLog2P > maxJointLog2P) {
                    maxJointLog2P = jointLog2P;
                }
                ++rank;
            }
            double sum = 0.0;
            int rank2 = 0;
            while (rank2 < c.size()) {
                sum += java.lang.Math.pow(2.0, c.jointLog2Probability(rank2) - maxJointLog2P);
                ++rank2;
            }
            total += maxJointLog2P + Math.log2(sum);
            ++i;
        }
        return total;
    }

    @Override
    void baseToString(StringBuilder sb) {
        super.baseToString(sb);
        sb.append("Average Log2 Joint Probability Reference=" + this.averageLog2JointProbabilityReference() + "\n");
    }

    @Override
    void oneVsAllToString(StringBuilder sb, String category, int i) {
        super.oneVsAllToString(sb, category, i);
        sb.append("Average Joint Probability Histogram=\n");
        this.appendCategoryLine(sb);
        int j = 0;
        while (j < this.numCategories()) {
            if (j > 0) {
                sb.append(',');
            }
            sb.append(this.averageLog2JointProbability(category, this.categories()[j]));
            ++j;
        }
        sb.append("\n");
    }
}

