/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.logging.Logger;

public class CRFOptimizableByLabelLikelihood
implements Optimizable.ByGradientValue,
Serializable {
    private static Logger logger = MalletLogger.getLogger(CRFOptimizableByLabelLikelihood.class.getName());
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
    protected InstanceList trainingSet;
    protected double cachedValue = -1.23456789E8;
    protected double[] cachedGradient;
    protected BitSet infiniteValues = null;
    protected CRF crf;
    protected CRF.Factors constraints;
    protected CRF.Factors expectations;
    private int cachedValueWeightsStamp = -1;
    private int cachedGradientWeightsStamp = -1;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = 1.0;
    double hyperbolicPriorSlope = 0.2;
    double hyperbolicPriorSharpness = 10.0;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;

    public CRFOptimizableByLabelLikelihood(CRF crf, InstanceList ilist) {
        this.crf = crf;
        this.trainingSet = ilist;
        this.cachedGradient = new double[crf.parameters.getNumFactors()];
        this.constraints = new CRF.Factors(crf.parameters);
        this.expectations = new CRF.Factors(crf.parameters);
        this.cachedValueWeightsStamp = -1;
        this.cachedGradientWeightsStamp = -1;
        this.gatherConstraints(ilist);
    }

    protected void gatherConstraints(InstanceList ilist) {
        assert (this.constraints.structureMatches(this.crf.parameters));
        this.constraints.zero();
        for (Instance instance : ilist) {
            Transducer.Incrementor incrementor;
            FeatureVectorSequence input = (FeatureVectorSequence)instance.getData();
            FeatureSequence output = (FeatureSequence)instance.getTarget();
            double instanceWeight = ilist.getInstanceWeight(instance);
            if (instanceWeight == 1.0) {
                incrementor = new CRF.Factors.Incrementor(this.constraints);
            } else {
                CRF.Factors factors = this.constraints;
                factors.getClass();
                incrementor = new CRF.Factors.WeightedIncrementor(factors, instanceWeight);
            }
            Transducer.Incrementor incrementor2 = incrementor;
            new SumLatticeDefault(this.crf, input, output, incrementor2);
        }
    }

    @Override
    public int getNumParameters() {
        return this.crf.parameters.getNumFactors();
    }

    @Override
    public void getParameters(double[] buffer) {
        this.crf.parameters.getParameters(buffer);
    }

    @Override
    public double getParameter(int index) {
        return this.crf.parameters.getParameter(index);
    }

    @Override
    public void setParameters(double[] buff) {
        this.crf.parameters.setParameters(buff);
        this.crf.weightsValueChanged();
    }

    @Override
    public void setParameter(int index, double value) {
        this.crf.parameters.setParameter(index, value);
        this.crf.weightsValueChanged();
    }

    protected double getExpectationValue() {
        boolean initializingInfiniteValues = false;
        double value = 0.0;
        if (this.infiniteValues == null) {
            this.infiniteValues = new BitSet();
            initializingInfiniteValues = true;
        }
        assert (this.expectations.structureMatches(this.crf.parameters));
        this.expectations.zero();
        int numInfLabeledWeight = 0;
        int numInfUnlabeledWeight = 0;
        int numInfWeight = 0;
        int ii = 0;
        while (ii < this.trainingSet.size()) {
            double weight;
            Transducer.Incrementor incrementor;
            String instanceName;
            Instance instance = (Instance)this.trainingSet.get(ii);
            double instanceWeight = this.trainingSet.getInstanceWeight(instance);
            FeatureVectorSequence input = (FeatureVectorSequence)instance.getData();
            FeatureSequence output = (FeatureSequence)instance.getTarget();
            double labeledWeight = new SumLatticeDefault(this.crf, input, output, null).getTotalWeight();
            String string = instanceName = instance.getName() == null ? "instance#" + ii : instance.getName().toString();
            if (Double.isInfinite(labeledWeight)) {
                ++numInfLabeledWeight;
                logger.warning(String.valueOf(instanceName) + " has -infinite labeled weight.\n" + (instance.getSource() != null ? instance.getSource() : ""));
            }
            if (instanceWeight == 1.0) {
                incrementor = new CRF.Factors.Incrementor(this.expectations);
            } else {
                CRF.Factors factors = this.expectations;
                factors.getClass();
                incrementor = new CRF.Factors.WeightedIncrementor(factors, instanceWeight);
            }
            Transducer.Incrementor incrementor2 = incrementor;
            double unlabeledWeight = new SumLatticeDefault(this.crf, input, null, incrementor2).getTotalWeight();
            if (Double.isInfinite(unlabeledWeight)) {
                ++numInfUnlabeledWeight;
                logger.warning(String.valueOf(instance.getName().toString()) + " has -infinite unlabeled weight.\n" + (instance.getSource() != null ? instance.getSource() : ""));
            }
            if (Double.isInfinite(weight = labeledWeight - unlabeledWeight)) {
                ++numInfWeight;
                logger.warning(String.valueOf(instanceName) + " has -infinite weight; skipping.");
                if (initializingInfiniteValues) {
                    this.infiniteValues.set(ii);
                } else if (!this.infiniteValues.get(ii)) {
                    throw new IllegalStateException("Instance i used to have non-infinite value, but now it has infinite value.");
                }
            } else {
                value += weight * instanceWeight;
            }
            ++ii;
        }
        if (numInfLabeledWeight > 0 || numInfUnlabeledWeight > 0 || numInfWeight > 0) {
            logger.warning("Number of instances with:\n\t -infinite labeled weight: " + numInfLabeledWeight + "\n" + "\t -infinite unlabeled weight: " + numInfUnlabeledWeight + "\n" + "\t -infinite weight: " + numInfWeight);
        }
        return value;
    }

    @Override
    public double getValue() {
        if (this.crf.weightsValueChangeStamp != this.cachedValueWeightsStamp) {
            this.cachedValueWeightsStamp = this.crf.weightsValueChangeStamp;
            long startingTime = System.currentTimeMillis();
            this.cachedValue = this.getExpectationValue();
            this.cachedValue = this.usingHyperbolicPrior ? (this.cachedValue += this.crf.parameters.hyberbolicPrior(this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness)) : (this.cachedValue += this.crf.parameters.gaussianPrior(this.gaussianPriorVariance));
            assert (!Double.isNaN(this.cachedValue) && !Double.isInfinite(this.cachedValue)) : "Label likelihood is NaN/Infinite";
            logger.info("getValue() (loglikelihood, optimizable by label likelihood) = " + this.cachedValue);
            long endingTime = System.currentTimeMillis();
            logger.fine("Inference milliseconds = " + (endingTime - startingTime));
        }
        return this.cachedValue;
    }

    private void assertNotNaNOrInfinite() {
        this.crf.parameters.assertNotNaN();
        this.expectations.assertNotNaNOrInfinite();
        this.constraints.assertNotNaNOrInfinite();
    }

    @Override
    public void getValueGradient(double[] buffer) {
        if (this.cachedGradientWeightsStamp != this.crf.weightsValueChangeStamp) {
            this.cachedGradientWeightsStamp = this.crf.weightsValueChangeStamp;
            this.getValue();
            this.assertNotNaNOrInfinite();
            this.expectations.plusEquals(this.constraints, -1.0);
            if (this.usingHyperbolicPrior) {
                this.expectations.plusEqualsHyperbolicPriorGradient(this.crf.parameters, -this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness);
            } else {
                this.expectations.plusEqualsGaussianPriorGradient(this.crf.parameters, -this.gaussianPriorVariance);
            }
            this.expectations.assertNotNaNOrInfinite();
            this.expectations.getParameters(this.cachedGradient);
            MatrixOps.timesEquals(this.cachedGradient, -1.0);
        }
        System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
    }

    public void setUseHyperbolicPrior(boolean f) {
        this.usingHyperbolicPrior = f;
    }

    public void setHyperbolicPriorSlope(double p) {
        this.hyperbolicPriorSlope = p;
    }

    public void setHyperbolicPriorSharpness(double p) {
        this.hyperbolicPriorSharpness = p;
    }

    public double getUseHyperbolicPriorSlope() {
        return this.hyperbolicPriorSlope;
    }

    public double getUseHyperbolicPriorSharpness() {
        return this.hyperbolicPriorSharpness;
    }

    public void setGaussianPriorVariance(double p) {
        this.gaussianPriorVariance = p;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.trainingSet);
        out.writeDouble(this.cachedValue);
        out.writeObject(this.cachedGradient);
        out.writeObject(this.infiniteValues);
        out.writeObject(this.crf);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version2 = in.readInt();
        this.trainingSet = (InstanceList)in.readObject();
        this.cachedValue = in.readDouble();
        this.cachedGradient = (double[])in.readObject();
        this.infiniteValues = (BitSet)in.readObject();
        this.crf = (CRF)in.readObject();
    }

    public static class Factory {
        public Optimizable.ByGradientValue newCRFOptimizable(CRF crf, InstanceList trainingData) {
            return new CRFOptimizableByLabelLikelihood(crf, trainingData);
        }
    }
}

