/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.constraints.ge.MaxEntGEConstraint;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletProgressMessageLogger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;

public class MaxEntOptimizableByGE
implements Optimizable.ByGradientValue {
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(String.valueOf(MaxEntOptimizableByGE.class.getName()) + "-pl");
    protected boolean cacheStale = true;
    protected int defaultFeatureIndex;
    protected double temperature = 1.0;
    protected double objWeight = 1.0;
    protected double cachedValue;
    protected double gaussianPriorVariance = 1.0;
    protected double[] cachedGradient;
    protected double[] parameters;
    protected InstanceList trainingList;
    protected MaxEnt classifier;
    protected ArrayList<MaxEntGEConstraint> constraints;

    public MaxEntOptimizableByGE(InstanceList trainingList, ArrayList<MaxEntGEConstraint> constraints, MaxEnt initClassifier) {
        int numFeatures;
        this.trainingList = trainingList;
        this.defaultFeatureIndex = numFeatures = trainingList.getDataAlphabet().size();
        int numLabels = trainingList.getTargetAlphabet().size();
        this.cachedGradient = new double[(numFeatures + 1) * numLabels];
        this.cachedValue = 0.0;
        if (initClassifier != null) {
            this.parameters = initClassifier.parameters;
            this.classifier = initClassifier;
        } else {
            this.parameters = new double[(numFeatures + 1) * numLabels];
            this.classifier = new MaxEnt(trainingList.getPipe(), this.parameters);
        }
        this.constraints = constraints;
        for (MaxEntGEConstraint constraint : constraints) {
            constraint.preProcess(trainingList);
        }
    }

    public void setGaussianPriorVariance(double variance) {
        this.gaussianPriorVariance = variance;
    }

    public void setTemperature(double temp) {
        this.temperature = temp;
    }

    public void setWeight(double weight) {
        this.objWeight = weight;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override
    public double getValue() {
        if (!this.cacheStale) {
            return this.cachedValue;
        }
        if (this.objWeight == 0.0) {
            return 0.0;
        }
        for (MaxEntGEConstraint constraint : this.constraints) {
            constraint.zeroExpectations();
        }
        Arrays.fill(this.cachedGradient, 0.0);
        int numFeatures = this.trainingList.getDataAlphabet().size() + 1;
        int numLabels = this.trainingList.getTargetAlphabet().size();
        double[][] scores = new double[this.trainingList.size()][numLabels];
        double[] constraintValue = new double[numLabels];
        int ii = 0;
        while (ii < this.trainingList.size()) {
            Instance instance = (Instance)this.trainingList.get(ii);
            double instanceWeight = this.trainingList.getInstanceWeight(instance);
            if (instance.getTarget() == null) {
                FeatureVector fv = (FeatureVector)instance.getData();
                this.classifier.getClassificationScoresWithTemperature(instance, this.temperature, scores[ii]);
                for (MaxEntGEConstraint constraint : this.constraints) {
                    constraint.computeExpectations(fv, scores[ii], instanceWeight);
                }
            }
            ++ii;
        }
        double value = 0.0;
        for (MaxEntGEConstraint constraint : this.constraints) {
            value += constraint.getValue();
        }
        value *= this.objWeight;
        int ii2 = 0;
        while (ii2 < this.trainingList.size()) {
            Instance instance = (Instance)this.trainingList.get(ii2);
            if (instance.getTarget() == null) {
                Arrays.fill(constraintValue, 0.0);
                double instanceExpectation = 0.0;
                double instanceWeight = this.trainingList.getInstanceWeight(instance);
                FeatureVector fv = (FeatureVector)instance.getData();
                for (MaxEntGEConstraint constraint : this.constraints) {
                    constraint.preProcess(fv);
                    int label = 0;
                    while (label < numLabels) {
                        double val = constraint.getCompositeConstraintFeatureValue(fv, label);
                        int n = label;
                        constraintValue[n] = constraintValue[n] + val;
                        instanceExpectation += val * scores[ii2][label];
                        ++label;
                    }
                }
                int label = 0;
                while (label < numLabels) {
                    if (scores[ii2][label] != 0.0) {
                        assert (!Double.isInfinite(scores[ii2][label]));
                        double weight = this.objWeight * instanceWeight * scores[ii2][label] * (constraintValue[label] - instanceExpectation) / this.temperature;
                        assert (!Double.isNaN(weight));
                        MatrixOps.rowPlusEquals(this.cachedGradient, numFeatures, label, fv, weight);
                        int n = numFeatures * label + this.defaultFeatureIndex;
                        this.cachedGradient[n] = this.cachedGradient[n] + weight;
                    }
                    ++label;
                }
            }
            ++ii2;
        }
        this.cachedValue = value;
        this.cacheStale = false;
        double reg = this.getRegularization();
        progressLogger.info("Value (GE=" + value + " Gaussian prior= " + reg + ") = " + this.cachedValue);
        return this.cachedValue;
    }

    protected double getRegularization() {
        double regularization = 0.0;
        int pi = 0;
        while (pi < this.parameters.length) {
            double p = this.parameters[pi];
            regularization -= p * p / (2.0 * this.gaussianPriorVariance);
            int n = pi++;
            this.cachedGradient[n] = this.cachedGradient[n] - p / this.gaussianPriorVariance;
        }
        this.cachedValue += regularization;
        return regularization;
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cacheStale) {
            this.getValue();
        }
        assert (buffer.length == this.cachedGradient.length);
        System.arraycopy(this.cachedGradient, 0, buffer, 0, buffer.length);
    }

    @Override
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override
    public double getParameter(int index) {
        return this.parameters[index];
    }

    @Override
    public void getParameters(double[] buffer) {
        assert (buffer.length == this.parameters.length);
        System.arraycopy(this.parameters, 0, buffer, 0, buffer.length);
    }

    @Override
    public void setParameter(int index, double value) {
        this.cacheStale = true;
        this.parameters[index] = value;
    }

    @Override
    public void setParameters(double[] params) {
        assert (params.length == this.parameters.length);
        this.cacheStale = true;
        System.arraycopy(params, 0, this.parameters, 0, this.parameters.length);
    }
}

