/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.CachingOptimizable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.BitSet;
import java.util.Iterator;
import java.util.logging.Logger;

public class PiecewiseACRFTrainer
extends DefaultAcrfTrainer {
    private static final Logger logger = MalletLogger.getLogger(PiecewiseACRFTrainer.class.getName());
    private static final boolean printGradient = false;

    @Override
    public Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList training) {
        return new Maxable(acrf, training);
    }

    public static class Maxable
    extends CachingOptimizable.ByGradient
    implements Serializable {
        private ACRF acrf;
        InstanceList trainData;
        private ACRF.Template[] templates;
        private ACRF.Template[] fixedTmpls;
        protected BitSet infiniteValues = null;
        private int numParameters;
        private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;
        private double gaussianPriorVariance = 10.0;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;
        int numInBatch = 0;

        public double getGaussianPriorVariance() {
            return this.gaussianPriorVariance;
        }

        public void setGaussianPriorVariance(double gaussianPriorVariance) {
            this.gaussianPriorVariance = gaussianPriorVariance;
        }

        private void initWeights(InstanceList training) {
            int tidx = 0;
            while (tidx < this.templates.length) {
                this.numParameters += this.templates[tidx].initWeights(training);
                ++tidx;
            }
        }

        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[this.templates.length];
            this.defaultExpectations = new SparseVector[this.templates.length];
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector defaults = this.templates[tidx].getDefaultWeights();
                this.defaultConstraints[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                this.defaultExpectations[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                ++tidx;
            }
            this.constraints = new SparseVector[this.templates.length][];
            this.expectations = new SparseVector[this.templates.length][];
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                this.constraints[tidx] = new SparseVector[weights.length];
                this.expectations[tidx] = new SparseVector[weights.length];
                int i = 0;
                while (i < weights.length) {
                    this.constraints[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    this.expectations[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    ++i;
                }
                ++tidx;
            }
        }

        void resetExpectations() {
            int tidx = 0;
            while (tidx < this.expectations.length) {
                this.defaultExpectations[tidx].setAll(0.0);
                int i = 0;
                while (i < this.expectations[tidx].length) {
                    this.expectations[tidx][i].setAll(0.0);
                    ++i;
                }
                ++tidx;
            }
        }

        void resetConstraints() {
            int tidx = 0;
            while (tidx < this.constraints.length) {
                this.defaultConstraints[tidx].setAll(0.0);
                int i = 0;
                while (i < this.constraints[tidx].length) {
                    this.constraints[tidx][i].setAll(0.0);
                    ++i;
                }
                ++tidx;
            }
        }

        protected Maxable(ACRF acrf, InstanceList ilist) {
            logger.finest("Initializing OptimizableACRF.");
            this.acrf = acrf;
            this.templates = acrf.getTemplates();
            this.fixedTmpls = acrf.getFixedTemplates();
            this.trainData = ilist;
            this.initWeights(this.trainData);
            this.initConstraintsExpectations();
            int numInstances = this.trainData.size();
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            logger.info("Number of training instances = " + numInstances);
            logger.info("Number of parameters = " + this.numParameters);
            this.describePrior();
            logger.fine("Computing constraints");
            this.collectConstraints(this.trainData);
        }

        private void describePrior() {
            logger.info("Using gaussian prior with variance " + this.gaussianPriorVariance);
        }

        @Override
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override
        public void getParameters(double[] buf) {
            ACRF.Template tmpl;
            if (buf.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(values, 0, buf, idx, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(values, 0, buf, idx, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        @Override
        protected void setParametersInternal(double[] params) {
            ACRF.Template tmpl;
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(params, idx, values, 0, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(params, idx, values, 0, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        public SparseVector[] getExpectations(int cnum) {
            return this.expectations[cnum];
        }

        public SparseVector[] getConstraints(int cnum) {
            return this.constraints[cnum];
        }

        public void printParameters() {
            double[] buf = new double[this.numParameters];
            this.getParameters(buf);
            int len = buf.length;
            int w = 0;
            while (w < len) {
                System.out.print(String.valueOf(buf[w]) + "\t");
                ++w;
            }
            System.out.println();
        }

        @Override
        protected double computeValue() {
            double retval = 0.0;
            int numInstances = this.trainData.size();
            long start = System.currentTimeMillis();
            long unrollTime = 0L;
            boolean initializingInfiniteValues = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                initializingInfiniteValues = true;
            }
            this.resetExpectations();
            int i = 0;
            while (i < numInstances) {
                retval += this.computeValueForInstance(i);
                ++i;
            }
            long end = System.currentTimeMillis();
            logger.info("ACRF Inference time (ms) = " + (end - start));
            logger.info("ACRF unroll time (ms) = " + unrollTime);
            logger.info("getValue (loglikelihood) = " + (retval += this.computePrior()));
            return retval;
        }

        private double computePrior() {
            double retval = 0.0;
            double priorDenom = 2.0 * this.gaussianPriorVariance;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector[] weights = this.templates[tidx].getWeights();
                int j = 0;
                while (j < weights.length) {
                    int fnum = 0;
                    while (fnum < weights[j].numLocations()) {
                        double w = weights[j].valueAtLocation(fnum);
                        if (this.weightValid(w, tidx, j)) {
                            retval += -w * w / priorDenom;
                        }
                        ++fnum;
                    }
                    ++j;
                }
                ++tidx;
            }
            return retval;
        }

        private double computeValueForInstance(int i) {
            double retval = 0.0;
            Instance instance = (Instance)this.trainData.get(i);
            ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph(instance, this.templates, this.fixedTmpls);
            if (unrolled.numVariables() == 0) {
                return 0.0;
            }
            Assignment observations = unrolled.getAssignment();
            double value = this.collectExpectationsAndValue(unrolled, observations);
            if (Double.isNaN(value)) {
                System.out.println("NaN on instance " + i + " : " + instance.getName());
                this.printDebugInfo(unrolled);
                logger.warning("Value is NaN in ACRF.getValue() Instance " + i + " : " + "returning -infinity... ");
                return Double.NEGATIVE_INFINITY;
            }
            return retval += value;
        }

        @Override
        protected void computeValueGradient(double[] grad) {
            this.computeValueGradient(grad, 1.0);
        }

        private void computeValueGradient(double[] grad, double priorScale) {
            int gidx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector theseWeights = this.templates[tidx].getDefaultWeights();
                SparseVector theseConstraints = this.defaultConstraints[tidx];
                SparseVector theseExpectations = this.defaultExpectations[tidx];
                int j = 0;
                while (j < theseWeights.numLocations()) {
                    double weight = theseWeights.valueAtLocation(j);
                    double constraint = theseConstraints.valueAtLocation(j);
                    double expectation = theseExpectations.valueAtLocation(j);
                    grad[gidx++] = constraint - expectation - priorScale * (weight / this.gaussianPriorVariance);
                    ++j;
                }
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int i = 0;
                while (i < weights.length) {
                    SparseVector thisWeightVec = weights[i];
                    SparseVector thisConstraintVec = this.constraints[tidx][i];
                    SparseVector thisExpectationVec = this.expectations[tidx][i];
                    int j = 0;
                    while (j < thisWeightVec.numLocations()) {
                        double gradient;
                        double w = thisWeightVec.valueAtLocation(j);
                        double constraint = thisConstraintVec.valueAtLocation(j);
                        double expectation = thisExpectationVec.valueAtLocation(j);
                        if (Double.isInfinite(w)) {
                            logger.warning("Infinite weight for node index " + i + " feature " + this.acrf.getInputAlphabet().lookupObject(j));
                            gradient = 0.0;
                        } else {
                            gradient = constraint - priorScale * (w / this.gaussianPriorVariance) - expectation;
                        }
                        grad[gidx++] = gradient;
                        ++j;
                    }
                    ++i;
                }
                ++tidx;
            }
        }

        private double collectExpectationsAndValue(ACRF.UnrolledGraph unrolled, Assignment observations) {
            double value = 0.0;
            Iterator it = unrolled.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                int tidx = clique.getTemplate().index;
                if (tidx == -1) continue;
                Factor ptl = unrolled.factorOf(clique);
                double logZ = Math.log(ptl.sum());
                AssignmentIterator assnIt = clique.assignmentIterator();
                int i = 0;
                while (assnIt.hasNext()) {
                    double marginal = Math.exp(ptl.logValue(assnIt) - logZ);
                    this.expectations[tidx][i].plusEqualsSparse(clique.getFv(), marginal);
                    if (this.defaultExpectations[tidx].location(i) != -1) {
                        this.defaultExpectations[tidx].incrementValue(i, marginal);
                    }
                    assnIt.advance();
                    ++i;
                }
                value += ptl.logValue(observations) - logZ;
            }
            return value;
        }

        public void collectConstraints(InstanceList ilist) {
            int inum = 0;
            while (inum < ilist.size()) {
                logger.finest("*** Collecting constraints for instance " + inum);
                this.collectConstraintsForInstance(ilist, inum);
                ++inum;
            }
        }

        private void collectConstraintsForInstance(InstanceList ilist, int inum) {
            Instance inst = (Instance)ilist.get(inum);
            ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph(inst, this.templates, null, false);
            Iterator it = unrolled.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                int tidx = clique.getTemplate().index;
                if (tidx == -1) continue;
                int assn = clique.lookupAssignmentNumber();
                this.constraints[tidx][assn].plusEqualsSparse(clique.getFv());
                if (this.defaultConstraints[tidx].location(assn) == -1) continue;
                this.defaultConstraints[tidx].incrementValue(assn, 1.0);
            }
        }

        void dumpGradientToFile(String fileName) {
            try {
                double[] grad = new double[this.getNumParameters()];
                this.getValueGradient(grad);
                PrintStream w = new PrintStream(new FileOutputStream(fileName));
                int i = 0;
                while (i < this.numParameters) {
                    w.println(grad[i]);
                    ++i;
                }
                w.close();
            }
            catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            int i = 0;
            while (i < this.defaultConstraints.length) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
                ++i;
            }
            System.out.println("Default expectations");
            i = 0;
            while (i < this.defaultExpectations.length) {
                System.out.println("Template " + i);
                this.defaultExpectations[i].print();
                ++i;
            }
        }

        void printDebugInfo(ACRF.UnrolledGraph unrolled) {
            this.acrf.print(System.err);
            Assignment assn = unrolled.getAssignment();
            Iterator it = unrolled.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                System.out.println("Clique " + clique);
                this.dumpAssnForClique(assn, clique);
                Factor ptl = unrolled.factorOf(clique);
                System.out.println("Value = " + ptl.value(assn));
                System.out.println(ptl);
            }
        }

        void dumpAssnForClique(Assignment assn, ACRF.UnrolledVarSet clique) {
            for (Variable var : clique) {
                System.out.println(var + " ==> " + assn.getObject(var) + "  (" + assn.get(var) + ")");
            }
        }

        private boolean weightValid(double w, int cnum, int j) {
            if (Double.isInfinite(w)) {
                logger.warning("Weight is infinite for clique " + cnum + "assignment " + j);
                return false;
            }
            if (Double.isNaN(w)) {
                logger.warning("Weight is Nan for clique " + cnum + "assignment " + j);
                return false;
            }
            return true;
        }

        public double computeValueAndGradient(int instance) {
            ++this.numInBatch;
            this.collectConstraintsForInstance(this.trainData, instance);
            double value = this.computeValueForInstance(instance);
            return value += this.computePrior() / (double)this.trainData.size();
        }

        public int getNumInstances() {
            return this.trainData.size();
        }

        public void getCachedGradient(double[] grad) {
            this.computeValueGradient(grad, (double)this.numInBatch / (double)this.trainData.size());
        }

        public void resetValueGradient() {
            this.resetExpectations();
            this.resetConstraints();
        }
    }
}

