/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLatticeDefault;
import cc.mallet.fst.Transducer;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.logging.Logger;

public class CRFOptimizableByBatchLabelLikelihood
implements Optimizable.ByCombiningBatchGradient,
Serializable {
    private static Logger logger = MalletLogger.getLogger(CRFOptimizableByBatchLabelLikelihood.class.getName());
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
    protected CRF crf;
    protected InstanceList trainingSet;
    protected int numBatches;
    protected List<CRF.Factors> expectations;
    protected CRF.Factors constraints;
    protected double[] cachedValue;
    protected List<double[]> cachedGradient;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = 1.0;
    double hyperbolicPriorSlope = 0.2;
    double hyperbolicPriorSharpness = 10.0;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;

    public CRFOptimizableByBatchLabelLikelihood(CRF crf, InstanceList ilist, int numBatches) {
        this.crf = crf;
        this.trainingSet = ilist;
        this.numBatches = numBatches;
        this.cachedValue = new double[this.numBatches];
        this.cachedGradient = new ArrayList<double[]>(this.numBatches);
        this.expectations = new ArrayList<CRF.Factors>(this.numBatches);
        int numFactors = crf.parameters.getNumFactors();
        for (int i = 0; i < this.numBatches; ++i) {
            this.cachedGradient.add(new double[numFactors]);
            this.expectations.add(new CRF.Factors(crf.parameters));
        }
        this.constraints = new CRF.Factors(crf.parameters);
        this.gatherConstraints(ilist);
    }

    protected void gatherConstraints(InstanceList ilist) {
        logger.info("Gathering constraints...");
        assert (this.constraints.structureMatches(this.crf.parameters));
        this.constraints.zero();
        for (Instance instance : ilist) {
            Transducer.Incrementor incrementor;
            FeatureVectorSequence input = (FeatureVectorSequence)instance.getData();
            FeatureSequence output = (FeatureSequence)instance.getTarget();
            double instanceWeight = ilist.getInstanceWeight(instance);
            if (instanceWeight == 1.0) {
                incrementor = new CRF.Factors.Incrementor(this.constraints);
            } else {
                CRF.Factors factors = this.constraints;
                factors.getClass();
                incrementor = new CRF.Factors.WeightedIncrementor(factors, instanceWeight);
            }
            Transducer.Incrementor incrementor2 = incrementor;
            new SumLatticeDefault(this.crf, input, output, incrementor2);
        }
        this.constraints.assertNotNaNOrInfinite();
    }

    protected double getExpectationValue(int batchIndex, int[] batchAssignments) {
        CRF.Factors batchExpectations = this.expectations.get(batchIndex);
        batchExpectations.zero();
        int numInfLabeledWeight = 0;
        int numInfUnlabeledWeight = 0;
        int numInfWeight = 0;
        double value = 0.0;
        for (int ii = batchAssignments[0]; ii < batchAssignments[1]; ++ii) {
            double weight;
            Transducer.Incrementor incrementor;
            FeatureSequence output;
            Instance instance = (Instance)this.trainingSet.get(ii);
            double instanceWeight = this.trainingSet.getInstanceWeight(instance);
            FeatureVectorSequence input = (FeatureVectorSequence)instance.getData();
            double labeledWeight = new SumLatticeDefault(this.crf, input, output = (FeatureSequence)instance.getTarget(), null).getTotalWeight();
            if (Double.isInfinite(labeledWeight)) {
                ++numInfLabeledWeight;
            }
            if (instanceWeight == 1.0) {
                incrementor = new CRF.Factors.Incrementor(batchExpectations);
            } else {
                CRF.Factors factors = batchExpectations;
                factors.getClass();
                incrementor = new CRF.Factors.WeightedIncrementor(factors, instanceWeight);
            }
            Transducer.Incrementor incrementor2 = incrementor;
            double unlabeledWeight = new SumLatticeDefault(this.crf, input, null, incrementor2).getTotalWeight();
            if (Double.isInfinite(unlabeledWeight)) {
                ++numInfUnlabeledWeight;
            }
            if (Double.isInfinite(weight = labeledWeight - unlabeledWeight)) {
                ++numInfWeight;
                continue;
            }
            value += weight * instanceWeight;
        }
        batchExpectations.assertNotNaNOrInfinite();
        if (numInfLabeledWeight > 0 || numInfUnlabeledWeight > 0 || numInfWeight > 0) {
            logger.warning("Batch: " + batchIndex + ", Number of instances with:\n\t -infinite labeled weight: " + numInfLabeledWeight + "\n\t -infinite unlabeled weight: " + numInfUnlabeledWeight + "\n\t -infinite weight: " + numInfWeight);
        }
        return value;
    }

    @Override
    public double getBatchValue(int batchIndex, int[] batchAssignments) {
        assert (batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " + this.numBatches + ")";
        assert (batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1]) : "Invalid batch assignments: " + Arrays.toString(batchAssignments);
        double value = this.getExpectationValue(batchIndex, batchAssignments);
        if (batchIndex == this.numBatches - 1) {
            value = this.usingHyperbolicPrior ? (value += this.crf.parameters.hyberbolicPrior(this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness)) : (value += this.crf.parameters.gaussianPrior(this.gaussianPriorVariance));
        }
        assert (!Double.isNaN(value) && !Double.isInfinite(value)) : "Label likelihood is NaN/Infinite, batchIndex: " + batchIndex + "batchAssignments: " + Arrays.toString(batchAssignments);
        this.cachedValue[batchIndex] = value;
        return value;
    }

    @Override
    public void getBatchValueGradient(double[] buffer, int batchIndex, int[] batchAssignments) {
        assert (batchIndex < this.numBatches) : "Incorrect batch index: " + batchIndex + ", range(0, " + this.numBatches + ")";
        assert (batchAssignments.length == 2 && batchAssignments[0] <= batchAssignments[1]) : "Invalid batch assignments: " + Arrays.toString(batchAssignments);
        CRF.Factors batchExpectations = this.expectations.get(batchIndex);
        if (batchIndex == this.numBatches - 1) {
            this.crf.parameters.assertNotNaN();
            batchExpectations.plusEquals(this.constraints, -1.0);
            if (this.usingHyperbolicPrior) {
                batchExpectations.plusEqualsHyperbolicPriorGradient(this.crf.parameters, -this.hyperbolicPriorSlope, this.hyperbolicPriorSharpness);
            } else {
                batchExpectations.plusEqualsGaussianPriorGradient(this.crf.parameters, -this.gaussianPriorVariance);
            }
            batchExpectations.assertNotNaNOrInfinite();
        }
        double[] gradient = this.cachedGradient.get(batchIndex);
        batchExpectations.getParameters(gradient);
        System.arraycopy(gradient, 0, buffer, 0, gradient.length);
    }

    @Override
    public void combineGradients(Collection<double[]> batchGradients, double[] buffer) {
        assert (buffer.length == this.crf.parameters.getNumFactors()) : "Incorrect buffer length: " + buffer.length + ", expected: " + this.crf.parameters.getNumFactors();
        Arrays.fill(buffer, 0.0);
        for (double[] gradient : batchGradients) {
            MatrixOps.plusEquals(buffer, gradient);
        }
        MatrixOps.timesEquals(buffer, -1.0);
    }

    @Override
    public int getNumBatches() {
        return this.numBatches;
    }

    public void setUseHyperbolicPrior(boolean f) {
        this.usingHyperbolicPrior = f;
    }

    public void setHyperbolicPriorSlope(double p) {
        this.hyperbolicPriorSlope = p;
    }

    public void setHyperbolicPriorSharpness(double p) {
        this.hyperbolicPriorSharpness = p;
    }

    public double getUseHyperbolicPriorSlope() {
        return this.hyperbolicPriorSlope;
    }

    public double getUseHyperbolicPriorSharpness() {
        return this.hyperbolicPriorSharpness;
    }

    public void setGaussianPriorVariance(double p) {
        this.gaussianPriorVariance = p;
    }

    public double getGaussianPriorVariance() {
        return this.gaussianPriorVariance;
    }

    @Override
    public int getNumParameters() {
        return this.crf.parameters.getNumFactors();
    }

    @Override
    public void getParameters(double[] buffer) {
        this.crf.parameters.getParameters(buffer);
    }

    @Override
    public double getParameter(int index) {
        return this.crf.parameters.getParameter(index);
    }

    @Override
    public void setParameters(double[] buff) {
        this.crf.parameters.setParameters(buff);
        this.crf.weightsValueChanged();
    }

    @Override
    public void setParameter(int index, double value) {
        this.crf.parameters.setParameter(index, value);
        this.crf.weightsValueChanged();
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.trainingSet);
        out.writeObject(this.crf);
        out.writeInt(this.numBatches);
        out.writeObject(this.cachedValue);
        for (double[] gradient : this.cachedGradient) {
            out.writeObject(gradient);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.readInt();
        this.trainingSet = (InstanceList)in.readObject();
        this.crf = (CRF)in.readObject();
        this.numBatches = in.readInt();
        this.cachedValue = (double[])in.readObject();
        this.cachedGradient = new ArrayList<double[]>(this.numBatches);
        for (int i = 0; i < this.numBatches; ++i) {
            this.cachedGradient.set(i, (double[])in.readObject());
        }
    }

    public static class Factory {
        public Optimizable.ByCombiningBatchGradient newCRFOptimizable(CRF crf, InstanceList trainingData, int numBatches) {
            return new CRFOptimizableByBatchLabelLikelihood(crf, trainingData, numBatches);
        }
    }
}

