/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Boostable;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.MCMaxEnt;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.MalletProgressMessageLogger;
import cc.mallet.util.Maths;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Logger;

public class MCMaxEntTrainer
extends ClassifierTrainer<MCMaxEnt>
implements Boostable,
Serializable {
    private static Logger logger = MalletLogger.getLogger(MCMaxEntTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(String.valueOf(MCMaxEntTrainer.class.getName()) + "-pl");
    int numGetValueCalls = 0;
    int numGetValueGradientCalls = 0;
    int numIterations = 10;
    public static final String EXP_GAIN = "exp";
    public static final String GRADIENT_GAIN = "grad";
    public static final String INFORMATION_GAIN = "info";
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 0.1;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2;
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0;
    static final Class DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    boolean usingMultiConditionalTraining = true;
    boolean usingHyperbolicPrior = false;
    double gaussianPriorVariance = 0.1;
    double hyperbolicPriorSlope = 0.2;
    double hyperbolicPriorSharpness = 10.0;
    Class maximizerClass = DEFAULT_MAXIMIZER_CLASS;
    double generativeWeighting = 1.0;
    MaximizableTrainer mt;
    MCMaxEnt initialClassifier;
    static CommandOption.Boolean usingMultiConditionalTrainingOption = new CommandOption.Boolean(MCMaxEntTrainer.class, "useMCTraining", "true|false", true, true, "Use MultiConditional Training", null);
    static CommandOption.Boolean usingHyperbolicPriorOption = new CommandOption.Boolean(MCMaxEntTrainer.class, "useHyperbolicPrior", "true|false", false, false, "Use hyperbolic (close to L1 penalty) prior over parameters", null);
    static CommandOption.Double gaussianPriorVarianceOption = new CommandOption.Double(MCMaxEntTrainer.class, "gaussianPriorVariance", "FLOAT", true, 10.0, "Variance of the gaussian prior over parameters", null);
    static CommandOption.Double hyperbolicPriorSlopeOption = new CommandOption.Double(MCMaxEntTrainer.class, "hyperbolicPriorSlope", "FLOAT", true, 0.2, "Slope of the (L1 penalty) hyperbolic prior over parameters", null);
    static CommandOption.Double hyperbolicPriorSharpnessOption = new CommandOption.Double(MCMaxEntTrainer.class, "hyperbolicPriorSharpness", "FLOAT", true, 10.0, "Sharpness of the (L1 penalty) hyperbolic prior over parameters", null);
    static final CommandOption.List commandOptions = new CommandOption.List("MCMaximum Entropy Classifier", new CommandOption[]{usingHyperbolicPriorOption, gaussianPriorVarianceOption, hyperbolicPriorSlopeOption, hyperbolicPriorSharpnessOption, usingMultiConditionalTrainingOption});

    public static CommandOption.List getCommandOptionList() {
        return commandOptions;
    }

    public MCMaxEntTrainer(CommandOption.List col) {
        this.usingHyperbolicPrior = MCMaxEntTrainer.usingHyperbolicPriorOption.value;
        this.gaussianPriorVariance = MCMaxEntTrainer.gaussianPriorVarianceOption.value;
        this.hyperbolicPriorSlope = MCMaxEntTrainer.hyperbolicPriorSlopeOption.value;
        this.hyperbolicPriorSharpness = MCMaxEntTrainer.hyperbolicPriorSharpnessOption.value;
        this.usingMultiConditionalTraining = MCMaxEntTrainer.usingMultiConditionalTrainingOption.value;
    }

    public MCMaxEntTrainer(MCMaxEnt initialClassifier) {
        this.initialClassifier = initialClassifier;
    }

    public MCMaxEntTrainer() {
        this(false);
    }

    public MCMaxEntTrainer(boolean useHyperbolicPrior) {
        this.usingHyperbolicPrior = useHyperbolicPrior;
    }

    public MCMaxEntTrainer(double gaussianPriorVariance) {
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = gaussianPriorVariance;
    }

    public MCMaxEntTrainer(double gaussianPriorVariance, boolean useMultiConditionalTraining) {
        this.usingHyperbolicPrior = false;
        this.usingMultiConditionalTraining = useMultiConditionalTraining;
        this.gaussianPriorVariance = gaussianPriorVariance;
    }

    public MCMaxEntTrainer(double hyperbolicPriorSlope, double hyperbolicPriorSharpness) {
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = hyperbolicPriorSlope;
        this.hyperbolicPriorSharpness = hyperbolicPriorSharpness;
    }

    public Optimizable.ByGradientValue getMaximizableTrainer(InstanceList ilist) {
        if (ilist == null) {
            return new MaximizableTrainer();
        }
        return new MaximizableTrainer(ilist, null);
    }

    public MCMaxEntTrainer setNumIterations(int i) {
        this.numIterations = i;
        return this;
    }

    public MCMaxEntTrainer setUseHyperbolicPrior(boolean useHyperbolicPrior) {
        this.usingHyperbolicPrior = useHyperbolicPrior;
        return this;
    }

    public MCMaxEntTrainer setGaussianPriorVariance(double gaussianPriorVariance) {
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = gaussianPriorVariance;
        return this;
    }

    public MCMaxEntTrainer setHyperbolicPriorSlope(double hyperbolicPriorSlope) {
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = hyperbolicPriorSlope;
        return this;
    }

    public MCMaxEntTrainer setHyperbolicPriorSharpness(double hyperbolicPriorSharpness) {
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSharpness = hyperbolicPriorSharpness;
        return this;
    }

    @Override
    public MCMaxEnt getClassifier() {
        return this.mt.getClassifier();
    }

    @Override
    public MCMaxEnt train(InstanceList trainingSet) {
        logger.fine("trainingSet.size() = " + trainingSet.size());
        this.mt = new MaximizableTrainer(trainingSet, this.initialClassifier);
        LimitedMemoryBFGS maximizer = new LimitedMemoryBFGS(this.mt);
        maximizer.setTolerance(1.0E-5);
        maximizer.optimize();
        logger.info("MCMaxEnt ngetValueCalls:" + this.getValueCalls() + "\nMCMaxEnt ngetValueGradientCalls:" + this.getValueGradientCalls());
        progressLogger.info("\n");
        return this.mt.getClassifier();
    }

    public int getValueGradientCalls() {
        return this.numGetValueGradientCalls;
    }

    public int getValueCalls() {
        return this.numGetValueCalls;
    }

    public String toString() {
        return "MCMaxEntTrainer,numIterations=" + this.numIterations + (this.usingHyperbolicPrior ? ",hyperbolicPriorSlope=" + this.hyperbolicPriorSlope + ",hyperbolicPriorSharpness=" + this.hyperbolicPriorSharpness : ",gaussianPriorVariance=" + this.gaussianPriorVariance);
    }

    private class MaximizableTrainer
    implements Optimizable.ByGradientValue {
        double[] parameters;
        double[] constraints;
        double[] cachedGradient;
        MCMaxEnt theClassifier;
        InstanceList trainingList;
        double cachedValue;
        boolean cachedValueStale;
        boolean cachedGradientStale;
        int numLabels;
        int numFeatures;
        int defaultFeatureIndex;
        FeatureSelection featureSelection;
        FeatureSelection[] perLabelFeatureSelection;

        public MaximizableTrainer() {
        }

        public MaximizableTrainer(InstanceList ilist, MCMaxEnt initialClassifier) {
            this.trainingList = ilist;
            Alphabet fd = ilist.getDataAlphabet();
            LabelAlphabet ld = (LabelAlphabet)ilist.getTargetAlphabet();
            ld.stopGrowth();
            this.numLabels = ld.size();
            this.numFeatures = fd.size() + 1;
            this.defaultFeatureIndex = this.numFeatures - 1;
            this.parameters = new double[this.numLabels * this.numFeatures];
            this.constraints = new double[this.numLabels * this.numFeatures];
            this.cachedGradient = new double[this.numLabels * this.numFeatures];
            Arrays.fill(this.parameters, 0.0);
            Arrays.fill(this.constraints, 0.0);
            Arrays.fill(this.cachedGradient, 0.0);
            this.featureSelection = ilist.getFeatureSelection();
            this.perLabelFeatureSelection = ilist.getPerLabelFeatureSelection();
            if (this.featureSelection != null) {
                this.featureSelection.add(this.defaultFeatureIndex);
            }
            if (this.perLabelFeatureSelection != null) {
                int i = 0;
                while (i < this.perLabelFeatureSelection.length) {
                    this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
                    ++i;
                }
            }
            assert (this.featureSelection == null || this.perLabelFeatureSelection == null);
            if (initialClassifier != null) {
                this.theClassifier = initialClassifier;
                this.parameters = this.theClassifier.parameters;
                this.featureSelection = this.theClassifier.featureSelection;
                this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
                this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
                assert (initialClassifier.getInstancePipe() == ilist.getPipe());
            } else if (this.theClassifier == null) {
                this.theClassifier = new MCMaxEnt(ilist.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            logger.fine("Number of instances in training list = " + this.trainingList.size());
            for (Instance inst : this.trainingList) {
                double instanceWeight = this.trainingList.getInstanceWeight(inst);
                Labeling labeling = inst.getLabeling();
                FeatureVector fv = (FeatureVector)inst.getData();
                Alphabet fdict = fv.getAlphabet();
                assert (fv.getAlphabet() == fd);
                int li = labeling.getBestIndex();
                MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, li, fv, 2.0 * instanceWeight);
                assert (!Double.isNaN(instanceWeight)) : "instanceWeight is NaN";
                assert (!Double.isNaN(li)) : "bestIndex is NaN";
                boolean hasNaN = false;
                int i = 0;
                while (i < fv.numLocations()) {
                    if (Double.isNaN(fv.valueAtLocation(i))) {
                        logger.info("NaN for feature " + fdict.lookupObject(fv.indexAtLocation(i)).toString());
                        hasNaN = true;
                    }
                    ++i;
                }
                if (hasNaN) {
                    logger.info("NaN in instance: " + inst.getName());
                }
                int n = li * this.numFeatures + this.defaultFeatureIndex;
                this.constraints[n] = this.constraints[n] + instanceWeight;
            }
        }

        public MCMaxEnt getClassifier() {
            return this.theClassifier;
        }

        @Override
        public double getParameter(int index) {
            return this.parameters[index];
        }

        @Override
        public void setParameter(int index, double v) {
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            this.parameters[index] = v;
        }

        @Override
        public int getNumParameters() {
            return this.parameters.length;
        }

        @Override
        public void getParameters(double[] buff) {
            if (buff == null || buff.length != this.parameters.length) {
                buff = new double[this.parameters.length];
            }
            System.arraycopy(this.parameters, 0, buff, 0, this.parameters.length);
        }

        @Override
        public void setParameters(double[] buff) {
            assert (buff != null);
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            if (buff.length != this.parameters.length) {
                this.parameters = new double[buff.length];
            }
            System.arraycopy(buff, 0, this.parameters, 0, buff.length);
        }

        @Override
        public double getValue() {
            if (this.cachedValueStale) {
                ++MCMaxEntTrainer.this.numGetValueCalls;
                this.cachedValue = 0.0;
                this.cachedGradientStale = true;
                Arrays.fill(this.cachedGradient, 0.0);
                double[] scores = new double[this.trainingList.getTargetAlphabet().size()];
                double value = 0.0;
                Iterator iter = this.trainingList.iterator();
                double[][] probs = new double[scores.length][this.numFeatures];
                double[][] lprobs = new double[scores.length][this.numFeatures];
                int si = 0;
                while (si < scores.length) {
                    double sum = 0.0;
                    double max = MatrixOps.max(this.parameters);
                    int fi = 0;
                    while (fi < this.numFeatures) {
                        probs[si][fi] = Math.exp(this.parameters[si * this.numFeatures + fi] - max);
                        sum += probs[si][fi];
                        ++fi;
                    }
                    assert (sum > 0.0);
                    fi = 0;
                    while (fi < this.numFeatures) {
                        double[] dArray = probs[si];
                        int n = fi;
                        dArray[n] = dArray[n] / sum;
                        lprobs[si][fi] = Math.log(probs[si][fi]);
                        ++fi;
                    }
                    ++si;
                }
                while (iter.hasNext()) {
                    Instance instance = (Instance)iter.next();
                    double instanceWeight = this.trainingList.getInstanceWeight(instance);
                    Labeling labeling = instance.getLabeling();
                    this.theClassifier.getClassificationScores(instance, scores);
                    FeatureVector fv = (FeatureVector)instance.getData();
                    int li = labeling.getBestIndex();
                    value = -(instanceWeight * Math.log(scores[li]));
                    if (Double.isNaN(value)) {
                        logger.fine("MCMaxEntTrainer: Instance " + instance.getName() + "has NaN value. log(scores)= " + Math.log(scores[li]) + " scores = " + scores[li] + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(value)) {
                        logger.warning("Instance " + instance.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= value;
                        this.cachedValueStale = false;
                        return -value;
                    }
                    this.cachedValue += value;
                    int si2 = 0;
                    while (si2 < scores.length) {
                        if (scores[si2] != 0.0) {
                            assert (!Double.isInfinite(scores[si2]));
                            MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, si2, fv, -instanceWeight * scores[si2]);
                            int n = this.numFeatures * si2 + this.defaultFeatureIndex;
                            this.cachedGradient[n] = this.cachedGradient[n] + -instanceWeight * scores[si2];
                        }
                        ++si2;
                    }
                    if (!MCMaxEntTrainer.this.usingMultiConditionalTraining) continue;
                    double Ncounts = MatrixOps.sum(fv);
                    this.cachedValue -= instanceWeight * fv.dotProduct(lprobs[li]);
                    int fi = 0;
                    while (fi < this.numFeatures) {
                        int n = this.numFeatures * li + fi;
                        this.cachedGradient[n] = this.cachedGradient[n] + -instanceWeight * Ncounts * probs[li][fi];
                        ++fi;
                    }
                }
                if (MCMaxEntTrainer.this.usingHyperbolicPrior) {
                    int li = 0;
                    while (li < this.numLabels) {
                        int fi = 0;
                        while (fi < this.numFeatures) {
                            this.cachedValue += MCMaxEntTrainer.this.hyperbolicPriorSlope / MCMaxEntTrainer.this.hyperbolicPriorSharpness * Math.log(Maths.cosh(MCMaxEntTrainer.this.hyperbolicPriorSharpness * this.parameters[li * this.numFeatures + fi]));
                            ++fi;
                        }
                        ++li;
                    }
                } else {
                    int li = 0;
                    while (li < this.numLabels) {
                        int fi = 0;
                        while (fi < this.numFeatures) {
                            double param = this.parameters[li * this.numFeatures + fi];
                            this.cachedValue += param * param / (2.0 * MCMaxEntTrainer.this.gaussianPriorVariance);
                            ++fi;
                        }
                        ++li;
                    }
                }
                this.cachedValue *= -1.0;
                this.cachedValueStale = false;
                progressLogger.info("Value (loglikelihood) = " + this.cachedValue);
            }
            return this.cachedValue;
        }

        @Override
        public void getValueGradient(double[] buffer) {
            if (this.cachedGradientStale) {
                ++MCMaxEntTrainer.this.numGetValueGradientCalls;
                if (this.cachedValueStale) {
                    this.getValue();
                }
                MatrixOps.plusEquals(this.cachedGradient, this.constraints);
                if (MCMaxEntTrainer.this.usingHyperbolicPrior) {
                    throw new UnsupportedOperationException("Hyperbolic prior not yet implemented.");
                }
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, -1.0 / MCMaxEntTrainer.this.gaussianPriorVariance);
                MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, 0.0);
                if (this.perLabelFeatureSelection == null) {
                    int labelIndex = 0;
                    while (labelIndex < this.numLabels) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.featureSelection, false);
                        ++labelIndex;
                    }
                } else {
                    int labelIndex = 0;
                    while (labelIndex < this.numLabels) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, labelIndex, 0.0, this.perLabelFeatureSelection[labelIndex], false);
                        ++labelIndex;
                    }
                }
                this.cachedGradientStale = false;
            }
            assert (buffer != null && buffer.length == this.parameters.length);
            System.arraycopy(this.cachedGradient, 0, buffer, 0, this.cachedGradient.length);
        }

        public double sumNegLogProb(double a, double b) {
            if (a == Double.POSITIVE_INFINITY && b == Double.POSITIVE_INFINITY) {
                return Double.POSITIVE_INFINITY;
            }
            if (a > b) {
                return b - Math.log(1.0 + Math.exp(b - a));
            }
            return a - Math.log(1.0 + Math.exp(a - b));
        }
    }
}

