/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import java.util.logging.Logger;

public class KLGain
extends RankedFeatureVector {
    private static Logger logger = MalletLogger.getLogger(KLGain.class.getName());

    private static double[] calcKLGains(InstanceList ilist, LabelVector[] classifications) {
        int fli;
        int fl;
        double modelLabelWeight;
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        assert (ilist.size() > 0);
        double[][] p = new double[numClasses][numFeatures];
        double[][] q = new double[numClasses][numFeatures];
        double[][] alphas = new double[numClasses][numFeatures];
        logger.info("Starting klgains, #instances=" + numInstances);
        double trueLabelWeightSum = 0.0;
        double modelLabelWeightSum = 0.0;
        boolean doingSmoothing = true;
        double numInExpectation = (double)numInstances + 1.0;
        int i = 0;
        while (i < numClasses) {
            int j = 0;
            while (j < numFeatures) {
                double d = 1.0 / (numInExpectation * (double)numFeatures * (double)numClasses);
                q[i][j] = d;
                p[i][j] = d;
                trueLabelWeightSum += p[i][j];
                modelLabelWeightSum += q[i][j];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < numInstances) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            int li = 0;
            while (li < numClasses) {
                double trueLabelWeight = labeling.value(li) / numInExpectation;
                modelLabelWeight = classifications[i].value(li) / numInExpectation;
                trueLabelWeightSum += trueLabelWeight;
                modelLabelWeightSum += modelLabelWeight;
                if (trueLabelWeight != 0.0 || modelLabelWeight != 0.0) {
                    fl = 0;
                    while (fl < fv.numLocations()) {
                        fli = fv.indexAtLocation(fl);
                        assert (fv.valueAtLocation(fl) == 1.0);
                        double[] dArray = p[li];
                        int n = fli;
                        dArray[n] = dArray[n] + trueLabelWeight;
                        double[] dArray2 = q[li];
                        int n2 = fli;
                        dArray2[n2] = dArray2[n2] + modelLabelWeight;
                        ++fl;
                    }
                }
                ++li;
            }
            ++i;
        }
        assert (Math.abs(trueLabelWeightSum - 1.0) < 0.001) : "trueLabelWeightSum should be 1.0, it was " + trueLabelWeightSum;
        assert (Math.abs(modelLabelWeightSum - 1.0) < 0.001) : "modelLabelWeightSum should be 1.0, it was " + modelLabelWeightSum;
        i = 0;
        while (i < numClasses) {
            int j = 0;
            while (j < numFeatures) {
                alphas[i][j] = Math.log(p[i][j] * (1.0 - q[i][j]) / (q[i][j] * (1.0 - p[i][j])));
                ++j;
            }
            ++i;
        }
        double[][] qeag = new double[numClasses][numFeatures];
        modelLabelWeightSum = 0.0;
        int i2 = 0;
        while (i2 < ilist.size()) {
            assert (classifications[i2].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i2);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            int fvMaxLocation = fv.numLocations() - 1;
            int li = 0;
            while (li < numClasses) {
                modelLabelWeight = classifications[i2].value(li) / (double)numInstances;
                modelLabelWeightSum += modelLabelWeight;
                fl = 0;
                while (fl < fv.numLocations()) {
                    fli = fv.indexAtLocation(fl);
                    double[] dArray = qeag[li];
                    int n = fli;
                    dArray[n] = dArray[n] + (modelLabelWeight * Math.exp(alphas[li][fli]) - modelLabelWeight);
                    ++fl;
                }
                ++li;
            }
            ++i2;
        }
        int li = 0;
        while (li < numClasses) {
            int fi = 0;
            while (fi < numFeatures) {
                double[] dArray = qeag[li];
                int n = fi++;
                dArray[n] = dArray[n] + modelLabelWeightSum;
            }
            ++li;
        }
        double[] klgains = new double[numFeatures];
        int i3 = 0;
        while (i3 < numClasses) {
            int j = 0;
            while (j < numFeatures) {
                if (alphas[i3][j] > 0.0 && !Double.isInfinite(alphas[i3][j])) {
                    int n = j;
                    klgains[n] = klgains[n] + (alphas[i3][j] * p[i3][j] - Math.log(qeag[i3][j]));
                }
                ++j;
            }
            ++i3;
        }
        logger.info("klgains.length=" + klgains.length);
        int j = 0;
        while (j < numFeatures) {
            if (j % (numFeatures / 100) == 0) {
                int i4 = 0;
                while (i4 < numClasses) {
                    logger.info("c=" + i4 + " p[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + p[i4][j]);
                    logger.info("c=" + i4 + " q[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + q[i4][j]);
                    logger.info("c=" + i4 + " alphas[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + alphas[i4][j]);
                    logger.info("c=" + i4 + " qeag[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + qeag[i4][j]);
                    ++i4;
                }
                logger.info("klgains[" + ilist.getDataAlphabet().lookupObject(j) + "] = " + klgains[j]);
            }
            ++j;
        }
        return klgains;
    }

    public KLGain(InstanceList ilist, LabelVector[] classifications) {
        super(ilist.getDataAlphabet(), KLGain.calcKLGains(ilist, classifications));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] c) {
        LabelVector[] ret = new LabelVector[c.length];
        int i = 0;
        while (i < c.length) {
            ret[i] = c[i].getLabelVector();
            ++i;
        }
        return ret;
    }

    public KLGain(InstanceList ilist, Classification[] classifications) {
        super(ilist.getDataAlphabet(), KLGain.calcKLGains(ilist, KLGain.getLabelVectorsFromClassifications(classifications)));
    }
}

