/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning;

import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.CachingOptimizable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.SparseVector;
import cc.mallet.util.MalletLogger;
import gnu.trove.THashMap;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;

public class PseudolikelihoodACRFTrainer
extends DefaultAcrfTrainer {
    private static final Logger logger = MalletLogger.getLogger(PseudolikelihoodACRFTrainer.class.getName());
    private static final boolean printGradient = false;
    public static final int BY_VARIABLE = 0;
    public static final int BY_EDGE = 1;
    private int structureType = 0;

    public int getStructureType() {
        return this.structureType;
    }

    public void setStructureType(int structureType) {
        this.structureType = structureType;
    }

    @Override
    public Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList training) {
        return new Maxable(acrf, training);
    }

    private CliquesIterator makeCliquesIterator(ACRF.UnrolledGraph acrf, Assignment observed) {
        if (this.structureType == 0) {
            return new VariablesIterator(acrf, observed);
        }
        if (this.structureType == 1) {
            return new EdgesIterator(acrf, observed);
        }
        throw new IllegalArgumentException("Unknown structured pseudolikelihood type " + this.structureType);
    }

    private static interface CliquesIterator {
        public boolean hasNext();

        public void advance();

        public Factor localConditional();

        public ACRF.UnrolledVarSet[] cliques();
    }

    private static class EdgesIterator
    implements CliquesIterator {
        private ACRF.UnrolledGraph graph;
        private Assignment observed;
        private Iterator cursor;
        private List currentCliqueList;
        private Factor ptl;
        private THashMap cliquesByEdge;

        public EdgesIterator(ACRF.UnrolledGraph acrf, Assignment observed) {
            this.graph = acrf;
            this.observed = observed;
            this.cliquesByEdge = new THashMap();
            Iterator it = acrf.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                int v1idx = 0;
                while (v1idx < clique.size()) {
                    Variable v1 = clique.get(v1idx);
                    List adjlist = this.graph.allFactorsContaining(v1);
                    for (Factor factor : adjlist) {
                        List l;
                        if (!this.cliquesByEdge.containsKey((Object)factor)) {
                            this.cliquesByEdge.put((Object)factor, new ArrayList());
                        }
                        if ((l = (List)this.cliquesByEdge.get((Object)factor)).contains(clique)) continue;
                        l.add(clique);
                    }
                    ++v1idx;
                }
            }
            this.cursor = this.cliquesByEdge.keySet().iterator();
        }

        @Override
        public boolean hasNext() {
            return this.cursor.hasNext();
        }

        @Override
        public void advance() {
            Factor pairFactor = (Factor)this.cursor.next();
            VarSet pairVarSet = pairFactor.varSet();
            assert (pairVarSet.size() == 2);
            Variable v1 = pairVarSet.get(0);
            Variable v2 = pairVarSet.get(1);
            Variable[] vars = new Variable[]{v1, v2};
            this.ptl = new TableFactor(vars);
            HashVarSet vs = new HashVarSet(this.observed.varSet());
            vs.remove(v1);
            vs.remove(v2);
            Assignment localObs = (Assignment)this.observed.marginalize(vs);
            this.currentCliqueList = (List)this.cliquesByEdge.get((Object)pairFactor);
            for (ACRF.UnrolledVarSet clique : this.currentCliqueList) {
                Factor slice;
                Factor cliquePtl = this.graph.factorOf(clique);
                if (cliquePtl == null) {
                    throw new IllegalStateException("Could not find potential for clique " + clique);
                }
                boolean hasV1 = clique.contains(v1);
                boolean hasV2 = clique.contains(v2);
                if (hasV1 && hasV2) {
                    slice = cliquePtl.varSet().size() == 2 ? cliquePtl : cliquePtl.slice(localObs);
                } else if (hasV1) {
                    slice = cliquePtl.slice(localObs);
                } else if (hasV2) {
                    slice = cliquePtl.slice(localObs);
                } else {
                    throw new RuntimeException("Illegal state: cliqu ehas neither edge variable");
                }
                this.ptl.multiplyBy(slice);
            }
        }

        @Override
        public Factor localConditional() {
            return this.ptl;
        }

        @Override
        public ACRF.UnrolledVarSet[] cliques() {
            List cliques = this.currentCliqueList;
            return cliques.toArray(new ACRF.UnrolledVarSet[cliques.size()]);
        }
    }

    public class Maxable
    extends CachingOptimizable.ByGradient
    implements Serializable {
        private ACRF acrf;
        InstanceList trainData;
        private ACRF.Template[] templates;
        private ACRF.Template[] fixedTmpls;
        protected BitSet infiniteValues = null;
        private int numParameters;
        private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;
        private double gaussianPriorVariance = 10.0;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;

        public double getGaussianPriorVariance() {
            return this.gaussianPriorVariance;
        }

        public void setGaussianPriorVariance(double gaussianPriorVariance) {
            this.gaussianPriorVariance = gaussianPriorVariance;
        }

        private void initWeights(InstanceList training) {
            int tidx = 0;
            while (tidx < this.templates.length) {
                this.numParameters += this.templates[tidx].initWeights(training);
                ++tidx;
            }
        }

        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[this.templates.length];
            this.defaultExpectations = new SparseVector[this.templates.length];
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector defaults = this.templates[tidx].getDefaultWeights();
                this.defaultConstraints[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                this.defaultExpectations[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                ++tidx;
            }
            this.constraints = new SparseVector[this.templates.length][];
            this.expectations = new SparseVector[this.templates.length][];
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                this.constraints[tidx] = new SparseVector[weights.length];
                this.expectations[tidx] = new SparseVector[weights.length];
                int i = 0;
                while (i < weights.length) {
                    this.constraints[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    this.expectations[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    ++i;
                }
                ++tidx;
            }
        }

        void resetExpectations() {
            int tidx = 0;
            while (tidx < this.expectations.length) {
                this.defaultExpectations[tidx].setAll(0.0);
                int i = 0;
                while (i < this.expectations[tidx].length) {
                    this.expectations[tidx][i].setAll(0.0);
                    ++i;
                }
                ++tidx;
            }
        }

        protected Maxable(ACRF acrf, InstanceList ilist) {
            logger.finest("Initializing OptimizableACRF.");
            this.acrf = acrf;
            this.templates = acrf.getTemplates();
            this.fixedTmpls = acrf.getFixedTemplates();
            this.trainData = ilist;
            this.initWeights(this.trainData);
            this.initConstraintsExpectations();
            int numInstances = this.trainData.size();
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            logger.info("Number of training instances = " + numInstances);
            logger.info("Number of parameters = " + this.numParameters);
            this.describePrior();
            logger.fine("Computing constraints");
            this.collectConstraints(this.trainData);
        }

        private void describePrior() {
            logger.info("Using gaussian prior with variance " + this.gaussianPriorVariance);
        }

        @Override
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override
        public void getParameters(double[] buf) {
            ACRF.Template tmpl;
            if (buf.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(values, 0, buf, idx, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(values, 0, buf, idx, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        @Override
        protected void setParametersInternal(double[] params) {
            ACRF.Template tmpl;
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(params, idx, values, 0, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(params, idx, values, 0, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        public SparseVector[] getExpectations(int cnum) {
            return this.expectations[cnum];
        }

        public SparseVector[] getConstraints(int cnum) {
            return this.constraints[cnum];
        }

        public void printParameters() {
            double[] buf = new double[this.numParameters];
            this.getParameters(buf);
            int len = buf.length;
            int w = 0;
            while (w < len) {
                System.out.print(String.valueOf(buf[w]) + "\t");
                ++w;
            }
            System.out.println();
        }

        @Override
        protected double computeValue() {
            double retval = 0.0;
            int numInstances = this.trainData.size();
            long start = System.currentTimeMillis();
            long unrollTime = 0L;
            boolean initializingInfiniteValues = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                initializingInfiniteValues = true;
            }
            this.resetExpectations();
            int i = 0;
            while (i < numInstances) {
                Instance instance = (Instance)this.trainData.get(i);
                long unrollStart = System.currentTimeMillis();
                ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph(instance, this.templates, this.fixedTmpls);
                long unrollEnd = System.currentTimeMillis();
                unrollTime += unrollEnd - unrollStart;
                if (unrolled.numVariables() != 0) {
                    Assignment observations = unrolled.getAssignment();
                    double value = this.collectExpectationsAndValue(unrolled, observations);
                    if (Double.isInfinite(value)) {
                        if (initializingInfiniteValues) {
                            logger.warning("Instance " + instance.getName() + " has infinite value; skipping.");
                            this.infiniteValues.set(i);
                        } else if (!this.infiniteValues.get(i)) {
                            logger.warning("Infinite value on instance " + instance.getName() + "returning -infinity");
                            return Double.NEGATIVE_INFINITY;
                        }
                    } else {
                        if (Double.isNaN(value)) {
                            System.out.println("NaN on instance " + i + " : " + instance.getName());
                            this.printDebugInfo(unrolled);
                            logger.warning("Value is NaN in ACRF.getValue() Instance " + i + " : " + "returning -infinity... ");
                            return Double.NEGATIVE_INFINITY;
                        }
                        retval += value;
                    }
                }
                ++i;
            }
            double priorDenom = 2.0 * this.gaussianPriorVariance;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector[] weights = this.templates[tidx].getWeights();
                int j = 0;
                while (j < weights.length) {
                    int fnum = 0;
                    while (fnum < weights[j].numLocations()) {
                        double w = weights[j].valueAtLocation(fnum);
                        if (this.weightValid(w, tidx, j)) {
                            retval += -w * w / priorDenom;
                        }
                        ++fnum;
                    }
                    ++j;
                }
                ++tidx;
            }
            long end = System.currentTimeMillis();
            logger.info("ACRF Inference time (ms) = " + (end - start));
            logger.info("ACRF unroll time (ms) = " + unrollTime);
            logger.info("getValue (loglikelihood) = " + retval);
            return retval;
        }

        @Override
        protected void computeValueGradient(double[] grad) {
            int gidx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector theseWeights = this.templates[tidx].getDefaultWeights();
                SparseVector theseConstraints = this.defaultConstraints[tidx];
                SparseVector theseExpectations = this.defaultExpectations[tidx];
                int j = 0;
                while (j < theseWeights.numLocations()) {
                    double weight = theseWeights.valueAtLocation(j);
                    double constraint = theseConstraints.valueAtLocation(j);
                    double expectation = theseExpectations.valueAtLocation(j);
                    grad[gidx++] = constraint - expectation - weight / this.gaussianPriorVariance;
                    ++j;
                }
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int i = 0;
                while (i < weights.length) {
                    SparseVector thisWeightVec = weights[i];
                    SparseVector thisConstraintVec = this.constraints[tidx][i];
                    SparseVector thisExpectationVec = this.expectations[tidx][i];
                    int j = 0;
                    while (j < thisWeightVec.numLocations()) {
                        double gradient;
                        double w = thisWeightVec.valueAtLocation(j);
                        double constraint = thisConstraintVec.valueAtLocation(j);
                        double expectation = thisExpectationVec.valueAtLocation(j);
                        if (Double.isInfinite(w)) {
                            logger.warning("Infinite weight for node index " + i + " feature " + this.acrf.getInputAlphabet().lookupObject(j));
                            gradient = 0.0;
                        } else {
                            gradient = constraint - w / this.gaussianPriorVariance - expectation;
                        }
                        grad[gidx++] = gradient;
                        ++j;
                    }
                    ++i;
                }
                ++tidx;
            }
        }

        private double collectExpectationsAndValue(ACRF.UnrolledGraph unrolled, Assignment observations) {
            double value = 0.0;
            CliquesIterator it = PseudolikelihoodACRFTrainer.this.makeCliquesIterator(unrolled, observations);
            while (it.hasNext()) {
                it.advance();
                TableFactor ptl = (TableFactor)it.localConditional();
                double logZ = ptl.logsum();
                ACRF.UnrolledVarSet[] cliques = it.cliques();
                Assignment assn = (Assignment)observations.duplicate();
                AssignmentIterator assnIt = ptl.assignmentIterator();
                while (assnIt.hasNext()) {
                    double marginal = Math.exp(ptl.logValue(assnIt) - logZ);
                    Assignment currentAssn = assnIt.assignment();
                    int vi = 0;
                    while (vi < currentAssn.numVariables()) {
                        Variable var = currentAssn.getVariable(vi);
                        assn.setValue(0, var, currentAssn.get(var));
                        ++vi;
                    }
                    int cidx = 0;
                    while (cidx < cliques.length) {
                        ACRF.UnrolledVarSet clique = cliques[cidx];
                        int tidx = clique.getTemplate().index;
                        if (tidx != -1) {
                            int assnIdx = clique.lookupNumberOfAssignment(assn);
                            this.expectations[tidx][assnIdx].plusEqualsSparse(clique.getFv(), marginal);
                            if (this.defaultExpectations[tidx].location(assnIdx) != -1) {
                                this.defaultExpectations[tidx].incrementValue(assnIdx, marginal);
                            }
                        }
                        ++cidx;
                    }
                    assnIt.advance();
                }
                value += ptl.logValue(observations) - logZ;
            }
            return value;
        }

        private void collectConstraintsForGraph(ACRF.UnrolledGraph unrolled, Assignment observations) {
            CliquesIterator it = PseudolikelihoodACRFTrainer.this.makeCliquesIterator(unrolled, observations);
            while (it.hasNext()) {
                it.advance();
                ACRF.UnrolledVarSet[] cliques = it.cliques();
                int cidx = 0;
                while (cidx < cliques.length) {
                    ACRF.UnrolledVarSet clique = cliques[cidx];
                    int tidx = clique.getTemplate().index;
                    if (tidx >= 0) {
                        int assnIdx = clique.lookupNumberOfAssignment(observations);
                        this.constraints[tidx][assnIdx].plusEqualsSparse(clique.getFv(), 1.0);
                        if (this.defaultConstraints[tidx].location(assnIdx) != -1) {
                            this.defaultConstraints[tidx].incrementValue(assnIdx, 1.0);
                        }
                    }
                    ++cidx;
                }
            }
        }

        public void collectConstraints(InstanceList ilist) {
            int inum = 0;
            while (inum < ilist.size()) {
                logger.finest("*** Collecting constraints for instance " + inum);
                Instance inst = (Instance)ilist.get(inum);
                ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph(inst, this.templates, null, true);
                Assignment assn = unrolled.getAssignment();
                this.collectConstraintsForGraph(unrolled, assn);
                ++inum;
            }
        }

        void dumpGradientToFile(String fileName) {
            try {
                double[] grad = new double[this.getNumParameters()];
                this.getValueGradient(grad);
                PrintStream w = new PrintStream(new FileOutputStream(fileName));
                int i = 0;
                while (i < this.numParameters) {
                    w.println(grad[i]);
                    ++i;
                }
                w.close();
            }
            catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            int i = 0;
            while (i < this.defaultConstraints.length) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
                ++i;
            }
            System.out.println("Default expectations");
            i = 0;
            while (i < this.defaultExpectations.length) {
                System.out.println("Template " + i);
                this.defaultExpectations[i].print();
                ++i;
            }
        }

        void printDebugInfo(ACRF.UnrolledGraph unrolled) {
            this.acrf.print(System.err);
            Assignment assn = unrolled.getAssignment();
            Iterator it = unrolled.varSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                System.out.println("Clique " + clique);
                this.dumpAssnForClique(assn, clique);
                Factor ptl = unrolled.factorOf(clique);
                System.out.println("Value = " + ptl.value(assn));
                System.out.println(ptl);
            }
        }

        void dumpAssnForClique(Assignment assn, ACRF.UnrolledVarSet clique) {
            for (Variable var : clique) {
                System.out.println(var + " ==> " + assn.getObject(var) + "  (" + assn.get(var) + ")");
            }
        }

        private boolean weightValid(double w, int cnum, int j) {
            if (Double.isInfinite(w)) {
                logger.warning("Weight is infinite for clique " + cnum + "assignment " + j);
                return false;
            }
            if (Double.isNaN(w)) {
                logger.warning("Weight is Nan for clique " + cnum + "assignment " + j);
                return false;
            }
            return true;
        }
    }

    private static class VariablesIterator
    implements CliquesIterator {
        private ACRF.UnrolledGraph graph;
        private Assignment observed;
        private int vidx = -1;
        private Factor ptl;
        private List[] cliquesByVar;

        public VariablesIterator(ACRF.UnrolledGraph acrf, Assignment observed) {
            this.graph = acrf;
            this.observed = observed;
            this.cliquesByVar = new List[this.graph.numVariables()];
            int i = 0;
            while (i < this.cliquesByVar.length) {
                this.cliquesByVar[i] = new ArrayList();
                ++i;
            }
            Iterator it = acrf.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                int vidx = 0;
                while (vidx < clique.size()) {
                    Variable var = clique.get(vidx);
                    this.cliquesByVar[this.graph.getIndex(var)].add(clique);
                    ++vidx;
                }
            }
        }

        @Override
        public boolean hasNext() {
            return this.vidx < this.graph.numVariables() - 1;
        }

        @Override
        public void advance() {
            ++this.vidx;
            Variable var = this.graph.get(this.vidx);
            this.ptl = new TableFactor(var);
            for (ACRF.UnrolledVarSet clique : this.cliquesByVar[this.vidx]) {
                Factor cliquePtl = this.graph.factorOf(clique);
                if (cliquePtl == null) {
                    throw new IllegalStateException("Could not find potential for clique " + clique);
                }
                HashVarSet vs = new HashVarSet(cliquePtl.varSet());
                vs.remove(var);
                Assignment nbrAssn = (Assignment)this.observed.marginalize(vs);
                Factor slice = cliquePtl.slice(nbrAssn);
                this.ptl.multiplyBy(slice);
            }
        }

        @Override
        public Factor localConditional() {
            return this.ptl;
        }

        @Override
        public ACRF.UnrolledVarSet[] cliques() {
            List cliques = this.cliquesByVar[this.vidx];
            return cliques.toArray(new ACRF.UnrolledVarSet[cliques.size()]);
        }
    }
}

