package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.Iterator;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/topics/DMROptimizable.class */
public class DMROptimizable implements Optimizable.ByGradientValue {
    private static Logger logger;
    private static Logger progressLogger;
    MaxEnt classifier;
    InstanceList trainingList;
    int numGetValueCalls;
    int numGetValueGradientCalls;
    int numIterations;
    NumberFormat formatter;
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    static final double DEFAULT_LARGE_GAUSSIAN_PRIOR_VARIANCE = 100.0d;
    static final double DEFAULT_GAUSSIAN_PRIOR_MEAN = 0.0d;
    double gaussianPriorMean;
    double gaussianPriorVariance;
    double defaultFeatureGaussianPriorVariance;
    double[] parameters;
    double[] cachedGradient;
    double cachedValue;
    boolean cachedValueStale;
    boolean cachedGradientStale;
    int numLabels;
    int numFeatures;
    int defaultFeatureIndex;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DMROptimizable() {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = Integer.MAX_VALUE;
        this.formatter = null;
        this.gaussianPriorMean = 0.0d;
        this.gaussianPriorVariance = 1.0d;
        this.defaultFeatureGaussianPriorVariance = 100.0d;
    }

    public DMROptimizable(InstanceList instanceList, MaxEnt maxEnt) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = Integer.MAX_VALUE;
        this.formatter = null;
        this.gaussianPriorMean = 0.0d;
        this.gaussianPriorVariance = 1.0d;
        this.defaultFeatureGaussianPriorVariance = 100.0d;
        this.trainingList = instanceList;
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        this.numLabels = instanceList.getTargetAlphabet().size();
        this.numFeatures = dataAlphabet.size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        this.parameters = new double[this.numLabels * this.numFeatures];
        this.cachedGradient = new double[this.numLabels * this.numFeatures];
        if (maxEnt != null) {
            this.classifier = maxEnt;
            this.parameters = this.classifier.getParameters();
            this.defaultFeatureIndex = this.classifier.getDefaultFeatureIndex();
            if (!$assertionsDisabled && maxEnt.getInstancePipe() != instanceList.getPipe()) {
                throw new AssertionError();
            }
        } else if (this.classifier == null) {
            this.classifier = new MaxEnt(instanceList.getPipe(), this.parameters);
        }
        this.formatter = new DecimalFormat("0.###E0");
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        logger.fine("Number of instances in training list = " + this.trainingList.size());
        Iterator<Instance> it = this.trainingList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            if (((FeatureVector) next.getTarget()) != null) {
                FeatureVector featureVector = (FeatureVector) next.getData();
                if (!$assertionsDisabled && featureVector.getAlphabet() != dataAlphabet) {
                    throw new AssertionError();
                }
                boolean z = false;
                for (int i = 0; i < featureVector.numLocations(); i++) {
                    if (Double.isNaN(featureVector.valueAtLocation(i))) {
                        logger.info("NaN for feature " + dataAlphabet.lookupObject(featureVector.indexAtLocation(i)).toString());
                        z = true;
                    }
                }
                if (z) {
                    logger.info("NaN in instance: " + next.getName());
                }
            }
        }
    }

    public void setInterceptGaussianPriorVariance(double d) {
        this.defaultFeatureGaussianPriorVariance = d;
    }

    public void setRegularGaussianPriorVariance(double d) {
        this.gaussianPriorVariance = d;
    }

    public MaxEnt getClassifier() {
        return this.classifier;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        if (dArr == null || dArr.length != this.parameters.length) {
            dArr = new double[this.parameters.length];
        }
        System.arraycopy(this.parameters, 0, dArr, 0, this.parameters.length);
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        this.cachedValueStale = true;
        this.cachedGradientStale = true;
        if (dArr.length != this.parameters.length) {
            this.parameters = new double[dArr.length];
        }
        System.arraycopy(dArr, 0, this.parameters, 0, dArr.length);
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        if (!this.cachedValueStale) {
            return this.cachedValue;
        }
        this.numGetValueCalls++;
        this.cachedValue = 0.0d;
        double[] dArr = new double[this.trainingList.getTargetAlphabet().size()];
        double d = 0.0d;
        int i = 0;
        Iterator<Instance> it = this.trainingList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            FeatureVector featureVector = (FeatureVector) next.getTarget();
            if (featureVector != null) {
                this.classifier.getUnnormalizedClassificationScores(next, dArr);
                double d2 = 0.0d;
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    dArr[i2] = Math.exp(dArr[i2]);
                    d2 += dArr[i2];
                }
                double d3 = 0.0d;
                for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                    int indexAtLocation = featureVector.indexAtLocation(i3);
                    double valueAtLocation = featureVector.valueAtLocation(i3);
                    d += Dirichlet.logGammaStirling(dArr[indexAtLocation] + valueAtLocation) - Dirichlet.logGammaStirling(dArr[indexAtLocation]);
                    d3 += valueAtLocation;
                }
                d -= Dirichlet.logGammaStirling(d2 + d3) - Dirichlet.logGammaStirling(d2);
                if (Double.isNaN(d)) {
                    logger.fine("DCMMaxEntTrainer: Instance " + next.getName() + "has NaN value.");
                    for (int i4 : featureVector.getIndices()) {
                        logger.fine("log(scores)= " + Math.log(dArr[i4]) + " scores = " + dArr[i4]);
                    }
                }
                if (Double.isInfinite(d)) {
                    logger.warning("Instance " + next.getSource() + " has infinite value; skipping value and gradient");
                    this.cachedValue -= d;
                    this.cachedValueStale = false;
                    return -d;
                }
                this.cachedValue += d;
                i++;
            }
        }
        double d4 = 0.0d;
        for (int i5 = 0; i5 < this.numLabels; i5++) {
            for (int i6 = 0; i6 < this.numFeatures - 1; i6++) {
                double d5 = this.parameters[(i5 * this.numFeatures) + i6];
                d4 -= ((d5 - this.gaussianPriorMean) * (d5 - this.gaussianPriorMean)) / (2.0d * this.gaussianPriorVariance);
            }
            double d6 = this.parameters[(i5 * this.numFeatures) + this.defaultFeatureIndex];
            d4 -= ((d6 - this.gaussianPriorMean) * (d6 - this.gaussianPriorMean)) / (2.0d * this.defaultFeatureGaussianPriorVariance);
        }
        double d7 = this.cachedValue;
        this.cachedValue += d4;
        this.cachedValueStale = false;
        progressLogger.info("Value (likelihood=" + this.formatter.format(d7) + " prior=" + this.formatter.format(d4) + ") = " + this.formatter.format(this.cachedValue));
        return this.cachedValue;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        MatrixOps.setAll(this.cachedGradient, 0.0d);
        double[] dArr2 = new double[this.trainingList.getTargetAlphabet().size()];
        Iterator<Instance> it = this.trainingList.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            FeatureVector featureVector = (FeatureVector) next.getTarget();
            if (featureVector != null) {
                this.classifier.getUnnormalizedClassificationScores(next, dArr2);
                double d = 0.0d;
                for (int i = 0; i < dArr2.length; i++) {
                    dArr2[i] = Math.exp(dArr2[i]);
                    d += dArr2[i];
                }
                FeatureVector featureVector2 = (FeatureVector) next.getData();
                double d2 = 0.0d;
                for (double d3 : featureVector.getValues()) {
                    d2 += d3;
                }
                double digamma = Dirichlet.digamma(d + d2) - Dirichlet.digamma(d);
                for (int i2 = 0; i2 < featureVector2.numLocations(); i2++) {
                    int indexAtLocation = featureVector2.indexAtLocation(i2);
                    double valueAtLocation = featureVector2.valueAtLocation(i2);
                    if (valueAtLocation != 0.0d) {
                        for (int i3 = 0; i3 < this.numLabels; i3++) {
                            double[] dArr3 = this.cachedGradient;
                            int i4 = (i3 * this.numFeatures) + indexAtLocation;
                            dArr3[i4] = dArr3[i4] - ((valueAtLocation * dArr2[i3]) * digamma);
                        }
                        for (int i5 = 0; i5 < featureVector.numLocations(); i5++) {
                            int indexAtLocation2 = featureVector.indexAtLocation(i5);
                            double valueAtLocation2 = featureVector.valueAtLocation(i5);
                            double d4 = 0.0d;
                            if (valueAtLocation2 < 20.0d) {
                                for (int i6 = 0; i6 < valueAtLocation2; i6++) {
                                    d4 += 1.0d / (dArr2[indexAtLocation2] + i6);
                                }
                            } else {
                                d4 = Dirichlet.digamma(dArr2[indexAtLocation2] + valueAtLocation2) - Dirichlet.digamma(dArr2[indexAtLocation2]);
                            }
                            double[] dArr4 = this.cachedGradient;
                            int i7 = (indexAtLocation2 * this.numFeatures) + indexAtLocation;
                            dArr4[i7] = dArr4[i7] + (valueAtLocation * dArr2[indexAtLocation2] * d4);
                        }
                    }
                }
                for (int i8 = 0; i8 < this.numLabels; i8++) {
                    double[] dArr5 = this.cachedGradient;
                    int i9 = (i8 * this.numFeatures) + this.defaultFeatureIndex;
                    dArr5[i9] = dArr5[i9] - (dArr2[i8] * digamma);
                }
                for (int i10 = 0; i10 < featureVector.numLocations(); i10++) {
                    int indexAtLocation3 = featureVector.indexAtLocation(i10);
                    double valueAtLocation3 = featureVector.valueAtLocation(i10);
                    double d5 = 0.0d;
                    if (valueAtLocation3 < 20.0d) {
                        for (int i11 = 0; i11 < valueAtLocation3; i11++) {
                            d5 += 1.0d / (dArr2[indexAtLocation3] + i11);
                        }
                    } else {
                        d5 = Dirichlet.digamma(dArr2[indexAtLocation3] + valueAtLocation3) - Dirichlet.digamma(dArr2[indexAtLocation3]);
                    }
                    double[] dArr6 = this.cachedGradient;
                    int i12 = (indexAtLocation3 * this.numFeatures) + this.defaultFeatureIndex;
                    dArr6[i12] = dArr6[i12] + (dArr2[indexAtLocation3] * d5);
                }
            }
        }
        this.numGetValueGradientCalls++;
        for (int i13 = 0; i13 < this.numLabels; i13++) {
            for (int i14 = 0; i14 < this.numFeatures - 1; i14++) {
                double d6 = this.parameters[(i13 * this.numFeatures) + i14];
                double[] dArr7 = this.cachedGradient;
                int i15 = (i13 * this.numFeatures) + i14;
                dArr7[i15] = dArr7[i15] - ((d6 - this.gaussianPriorMean) / this.gaussianPriorVariance);
            }
            double d7 = this.parameters[(i13 * this.numFeatures) + this.defaultFeatureIndex];
            double[] dArr8 = this.cachedGradient;
            int i16 = (i13 * this.numFeatures) + this.defaultFeatureIndex;
            dArr8[i16] = dArr8[i16] - ((d7 - this.gaussianPriorMean) / this.defaultFeatureGaussianPriorVariance);
        }
        MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0d);
        if (!$assertionsDisabled && (dArr == null || dArr.length != this.parameters.length)) {
            throw new AssertionError();
        }
        System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
    }

    static {
        $assertionsDisabled = !DMROptimizable.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(DMROptimizable.class.getName());
        progressLogger = MalletProgressMessageLogger.getLogger(DMROptimizable.class.getName() + "-pl");
    }
}
