/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.ner;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import java.util.Iterator;

public final class ConditionalLikelihoodCostFunction
implements CostFunction {
    private static final double SIGMA_SQUARED = 100.0;
    private final DoubleMatrix features;
    private final DoubleMatrix outcome;
    private final int m;
    private final int classes;

    public ConditionalLikelihoodCostFunction(DoubleMatrix features, DoubleMatrix outcome) {
        this.features = features;
        this.outcome = outcome;
        this.m = outcome.getRowCount();
        this.classes = outcome.getColumnCount() == 1 ? 2 : outcome.getColumnCount();
    }

    @Override
    public CostGradientTuple evaluateCost(DoubleVector input) {
        DoubleMatrix theta = DenseMatrixFolder.unfoldMatrix(input, this.classes, (int)((double)input.getLength() / (double)this.classes));
        DenseDoubleMatrix gradient = new DenseDoubleMatrix(theta.getRowCount(), theta.getColumnCount());
        double cost = 0.0;
        for (int row = 0; row < this.m; ++row) {
            DoubleVector rowVector = this.features.getRowVector(row);
            double[] logProbabilities = new double[this.classes];
            Iterator iterateNonZero = rowVector.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                for (int i = 0; i < this.classes; ++i) {
                    int n = i;
                    logProbabilities[n] = logProbabilities[n] + theta.get(i, next.getIndex());
                }
            }
            double z = ConditionalLikelihoodCostFunction.logSum(logProbabilities);
            for (int i = 0; i < this.classes; ++i) {
                double prob = Math.exp(logProbabilities[i] - z);
                iterateNonZero = rowVector.iterateNonZero();
                while (iterateNonZero.hasNext()) {
                    DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                    gradient.set(i, next.getIndex(), gradient.get(i, next.getIndex()) + prob);
                    if (!ConditionalLikelihoodCostFunction.correctPrediction(i, this.outcome.getRowVector(row))) continue;
                    gradient.set(i, next.getIndex(), gradient.get(i, next.getIndex()) - 1.0);
                }
                if (!ConditionalLikelihoodCostFunction.correctPrediction(i, this.outcome.getRowVector(row))) continue;
                cost -= Math.log(prob);
            }
        }
        DoubleVector foldGradient = DenseMatrixFolder.foldMatrix((DoubleMatrix)gradient);
        return new CostGradientTuple(cost += ConditionalLikelihoodCostFunction.computeLogPrior(input, foldGradient), foldGradient);
    }

    static boolean correctPrediction(int classIndex, DoubleVector outcome) {
        return outcome.getLength() == 1 ? (int)outcome.get(0) == classIndex : outcome.maxIndex() == classIndex;
    }

    static double computeLogPrior(DoubleVector theta, DoubleVector gradient) {
        double prior = 0.0;
        for (int i = 0; i < theta.getLength(); ++i) {
            prior += theta.get(i) * theta.get(i) / 2.0 / 100.0;
            gradient.set(i, gradient.get(i) + theta.get(i) / 100.0);
        }
        return prior;
    }

    static double logSum(double[] logInputs) {
        int maxIdx = 0;
        double max = logInputs[0];
        for (int i = 1; i < logInputs.length; ++i) {
            if (!(logInputs[i] > max)) continue;
            maxIdx = i;
            max = logInputs[i];
        }
        boolean haveTerms = false;
        double intermediate = 0.0;
        double cutoff = max - 30.0;
        for (int i = 0; i < logInputs.length; ++i) {
            if (i == maxIdx || !(logInputs[i] > cutoff)) continue;
            haveTerms = true;
            intermediate += Math.exp(logInputs[i] - max);
        }
        if (haveTerms) {
            return max + Math.log(1.0 + intermediate);
        }
        return max;
    }
}

