/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning;

import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.grmm.util.CachingOptimizable;
import cc.mallet.optimize.Optimizable;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Maths;
import cc.mallet.util.Timing;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;

public class PwplACRFTrainer
extends DefaultAcrfTrainer {
    private static final Logger logger = MalletLogger.getLogger(PwplACRFTrainer.class.getName());
    public static boolean printGradient = false;
    public static final int NO_WRONG_WRONG = 0;
    public static final int CONDITION_WW = 1;
    private int wrongWrongType = 0;
    private int wrongWrongIter = 10;
    private double wrongWrongThreshold = 0.1;
    private File outputPrefix = new File(".");

    @Override
    public Optimizable.ByGradientValue createOptimizable(ACRF acrf, InstanceList training) {
        return new Maxable(acrf, training);
    }

    public double getWrongWrongThreshold() {
        return this.wrongWrongThreshold;
    }

    public void setWrongWrongThreshold(double wrongWrongThreshold) {
        this.wrongWrongThreshold = wrongWrongThreshold;
    }

    public void setWrongWrongType(int wrongWrongType) {
        this.wrongWrongType = wrongWrongType;
    }

    public void setWrongWrongIter(int wrongWrongIter) {
        this.wrongWrongIter = wrongWrongIter;
    }

    @Override
    public boolean train(ACRF acrf, InstanceList trainingList, InstanceList validationList, InstanceList testSet, ACRFEvaluator eval, int numIter, Optimizable.ByGradientValue macrf) {
        if (this.wrongWrongType == 0) {
            return super.train(acrf, trainingList, validationList, testSet, eval, numIter, macrf);
        }
        Maxable bipwMaxable = (Maxable)macrf;
        logger.info("BiconditionalPiecewiseACRFTrainer: Initial training");
        super.train(acrf, trainingList, validationList, testSet, eval, this.wrongWrongIter, macrf);
        FileUtils.writeGzippedObject(new File(this.outputPrefix, "initial-acrf.ser.gz"), acrf);
        logger.info("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
        bipwMaxable.addWrongWrong(trainingList);
        logger.info("BiconditionalPiecewiseACRFTrainer: Adding wrong-wrongs");
        boolean converged = super.train(acrf, trainingList, validationList, testSet, eval, numIter, macrf);
        PwplACRFTrainer.reportTrainingLikelihood(acrf, trainingList);
        return converged;
    }

    public static void reportTrainingLikelihood(ACRF acrf, InstanceList trainingList) {
        double total = 0.0;
        Inferencer inf = acrf.getInferencer();
        int i = 0;
        while (i < trainingList.size()) {
            Instance inst = (Instance)trainingList.get(i);
            ACRF.UnrolledGraph unrolled = acrf.unroll(inst);
            inf.computeMarginals(unrolled);
            double lik = inf.lookupLogJoint(unrolled.getAssignment());
            total += lik;
            logger.info("...instance " + i + " likelihood = " + lik);
            ++i;
        }
        logger.info("Unregularized joint likelihood = " + total);
    }

    public class Maxable
    extends CachingOptimizable.ByGradient {
        private ACRF acrf;
        InstanceList trainData;
        private ACRF.Template[] templates;
        protected BitSet infiniteValues = null;
        private int numParameters;
        private static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 10.0;
        private double gaussianPriorVariance = 10.0;
        SparseVector[][] constraints;
        SparseVector[][] expectations;
        SparseVector[] defaultConstraints;
        SparseVector[] defaultExpectations;
        private int numCvgaCalls = 0;
        private long timePerCvgaCall = 0L;
        private List[] allWrongWrongs;

        public double getGaussianPriorVariance() {
            return this.gaussianPriorVariance;
        }

        public void setGaussianPriorVariance(double gaussianPriorVariance) {
            this.gaussianPriorVariance = gaussianPriorVariance;
        }

        private void initWeights(InstanceList training) {
            int tidx = 0;
            while (tidx < this.templates.length) {
                this.numParameters += this.templates[tidx].initWeights(training);
                ++tidx;
            }
        }

        private void initConstraintsExpectations() {
            this.defaultConstraints = new SparseVector[this.templates.length];
            this.defaultExpectations = new SparseVector[this.templates.length];
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector defaults = this.templates[tidx].getDefaultWeights();
                this.defaultConstraints[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                this.defaultExpectations[tidx] = (SparseVector)defaults.cloneMatrixZeroed();
                ++tidx;
            }
            this.constraints = new SparseVector[this.templates.length][];
            this.expectations = new SparseVector[this.templates.length][];
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                this.constraints[tidx] = new SparseVector[weights.length];
                this.expectations[tidx] = new SparseVector[weights.length];
                int i = 0;
                while (i < weights.length) {
                    this.constraints[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    this.expectations[tidx][i] = (SparseVector)weights[i].cloneMatrixZeroed();
                    ++i;
                }
                ++tidx;
            }
        }

        void resetProfilingForCall() {
            this.numCvgaCalls = 0;
            this.timePerCvgaCall = 0L;
        }

        void resetExpectations() {
            int tidx = 0;
            while (tidx < this.expectations.length) {
                this.defaultExpectations[tidx].setAll(0.0);
                int i = 0;
                while (i < this.expectations[tidx].length) {
                    this.expectations[tidx][i].setAll(0.0);
                    ++i;
                }
                ++tidx;
            }
        }

        void resetConstraints() {
            int tidx = 0;
            while (tidx < this.constraints.length) {
                this.defaultConstraints[tidx].setAll(0.0);
                int i = 0;
                while (i < this.constraints[tidx].length) {
                    this.constraints[tidx][i].setAll(0.0);
                    ++i;
                }
                ++tidx;
            }
        }

        protected Maxable(ACRF acrf, InstanceList ilist) {
            logger.finest("Initializing OptimizableACRF.");
            this.acrf = acrf;
            this.templates = acrf.getTemplates();
            this.trainData = ilist;
            this.initWeights(this.trainData);
            this.initConstraintsExpectations();
            int numInstances = this.trainData.size();
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            logger.info("Number of training instances = " + numInstances);
            logger.info("Number of parameters = " + this.numParameters);
            this.describePrior();
            logger.fine("Computing constraints");
            this.collectConstraints(this.trainData);
        }

        private void describePrior() {
            logger.info("Using gaussian prior with variance " + this.gaussianPriorVariance);
        }

        @Override
        public int getNumParameters() {
            return this.numParameters;
        }

        @Override
        public void getParameters(double[] buf) {
            ACRF.Template tmpl;
            if (buf.length != this.numParameters) {
                throw new IllegalArgumentException("Argument is not of the  correct dimensions");
            }
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(values, 0, buf, idx, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(values, 0, buf, idx, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        @Override
        protected void setParametersInternal(double[] params) {
            ACRF.Template tmpl;
            this.cachedGradientStale = true;
            this.cachedValueStale = true;
            int idx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector defaults = tmpl.getDefaultWeights();
                double[] values = defaults.getValues();
                System.arraycopy(params, idx, values, 0, values.length);
                idx += values.length;
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int assn = 0;
                while (assn < weights.length) {
                    double[] values = weights[assn].getValues();
                    System.arraycopy(params, idx, values, 0, values.length);
                    idx += values.length;
                    ++assn;
                }
                ++tidx;
            }
        }

        public SparseVector[] getExpectations(int cnum) {
            return this.expectations[cnum];
        }

        public SparseVector[] getConstraints(int cnum) {
            return this.constraints[cnum];
        }

        public void printParameters() {
            double[] buf = new double[this.numParameters];
            this.getParameters(buf);
            int len = buf.length;
            int w = 0;
            while (w < len) {
                System.out.print(String.valueOf(buf[w]) + "\t");
                ++w;
            }
            System.out.println();
        }

        @Override
        protected double computeValue() {
            double retval = 0.0;
            int numInstances = this.trainData.size();
            long start = System.currentTimeMillis();
            long unrollTime = 0L;
            this.resetProfilingForCall();
            boolean initializingInfiniteValues = false;
            if (this.infiniteValues == null) {
                this.infiniteValues = new BitSet();
                initializingInfiniteValues = true;
            }
            this.resetExpectations();
            int i = 0;
            while (i < numInstances) {
                Instance instance = (Instance)this.trainData.get(i);
                long unrollStart = System.currentTimeMillis();
                ACRF.UnrolledGraph unrolled = this.acrf.unrollStructureOnly(instance);
                long unrollEnd = System.currentTimeMillis();
                unrollTime += unrollEnd - unrollStart;
                Assignment observations = unrolled.getAssignment();
                double value = this.collectExpectationsAndValue(unrolled, observations, i);
                if (Double.isInfinite(value)) {
                    if (initializingInfiniteValues) {
                        logger.warning("Instance " + instance.getName() + " has infinite value; skipping.");
                        this.infiniteValues.set(i);
                    } else if (!this.infiniteValues.get(i)) {
                        logger.warning("Infinite value on instance " + instance.getName() + "returning -infinity");
                        return Double.NEGATIVE_INFINITY;
                    }
                } else {
                    if (Double.isNaN(value)) {
                        System.out.println("NaN on instance " + i + " : " + instance.getName());
                        this.printDebugInfo(unrolled);
                        logger.warning("Value is NaN in ACRF.getValue() Instance " + i + " : " + "returning -infinity... ");
                        return Double.NEGATIVE_INFINITY;
                    }
                    retval += value;
                }
                ++i;
            }
            double priorDenom = 2.0 * this.gaussianPriorVariance;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector[] weights = this.templates[tidx].getWeights();
                int j = 0;
                while (j < weights.length) {
                    int fnum = 0;
                    while (fnum < weights[j].numLocations()) {
                        double w = weights[j].valueAtLocation(fnum);
                        if (this.weightValid(w, tidx, j)) {
                            retval += -w * w / priorDenom;
                        }
                        ++fnum;
                    }
                    ++j;
                }
                ++tidx;
            }
            long end = System.currentTimeMillis();
            logger.info("ACRF Inference time (ms) = " + (end - start));
            logger.info("ACRF unroll time (ms) = " + unrollTime);
            logger.info("getValue (loglikelihood) = " + retval);
            logger.info("Number cVGA calls = " + this.numCvgaCalls);
            logger.info("Total cVGA time (ms) = " + this.timePerCvgaCall);
            return retval;
        }

        @Override
        protected void computeValueGradient(double[] grad) {
            int gidx = 0;
            int tidx = 0;
            while (tidx < this.templates.length) {
                SparseVector theseWeights = this.templates[tidx].getDefaultWeights();
                SparseVector theseConstraints = this.defaultConstraints[tidx];
                SparseVector theseExpectations = this.defaultExpectations[tidx];
                int j = 0;
                while (j < theseWeights.numLocations()) {
                    double weight = theseWeights.valueAtLocation(j);
                    double constraint = theseConstraints.valueAtLocation(j);
                    double expectation = theseExpectations.valueAtLocation(j);
                    if (printGradient) {
                        System.out.println(" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " + weight / this.gaussianPriorVariance + " (reg)  [feature=DEFAULT]");
                    }
                    grad[gidx++] = constraint - expectation - weight / this.gaussianPriorVariance;
                    ++j;
                }
                ++tidx;
            }
            tidx = 0;
            while (tidx < this.templates.length) {
                ACRF.Template tmpl = this.templates[tidx];
                SparseVector[] weights = tmpl.getWeights();
                int i = 0;
                while (i < weights.length) {
                    SparseVector thisWeightVec = weights[i];
                    SparseVector thisConstraintVec = this.constraints[tidx][i];
                    SparseVector thisExpectationVec = this.expectations[tidx][i];
                    int j = 0;
                    while (j < thisWeightVec.numLocations()) {
                        double gradient;
                        double w = thisWeightVec.valueAtLocation(j);
                        double constraint = thisConstraintVec.valueAtLocation(j);
                        double expectation = thisExpectationVec.valueAtLocation(j);
                        if (Double.isInfinite(w)) {
                            logger.warning("Infinite weight for node index " + i + " feature " + this.acrf.getInputAlphabet().lookupObject(j));
                            gradient = 0.0;
                        } else {
                            gradient = constraint - w / this.gaussianPriorVariance - expectation;
                        }
                        if (printGradient) {
                            int idx = thisWeightVec.indexAtLocation(j);
                            Object fname = this.acrf.getInputAlphabet().lookupObject(idx);
                            System.out.println(" gradient [" + gidx + "] = " + constraint + " (ctr) - " + expectation + " (exp) - " + w / this.gaussianPriorVariance + " (reg)  [feature=" + fname + "]");
                        }
                        grad[gidx++] = gradient;
                        ++j;
                    }
                    ++i;
                }
                ++tidx;
            }
        }

        private double collectExpectationsAndValue(ACRF.UnrolledGraph unrolled, Assignment observations, int inum) {
            double value = 0.0;
            Iterator it = unrolled.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                ACRF.Template tmpl = clique.getTemplate();
                int tidx = tmpl.index;
                if (tidx == -1) continue;
                int vi = 0;
                while (vi < clique.size()) {
                    Variable target = clique.get(vi);
                    value += this.computeValueGradientForAssn(observations, clique, target);
                    ++vi;
                }
            }
            switch (PwplACRFTrainer.this.wrongWrongType) {
                case 0: {
                    break;
                }
                case 1: {
                    value += this.addConditionalWW(unrolled, inum);
                    break;
                }
                default: {
                    throw new IllegalStateException();
                }
            }
            return value;
        }

        private double addConditionalWW(ACRF.UnrolledGraph unrolled, int inum) {
            double value = 0.0;
            if (this.allWrongWrongs != null) {
                List wrongs = this.allWrongWrongs[inum];
                for (WrongWrong ww : wrongs) {
                    Variable target = ww.findVariable(unrolled);
                    ACRF.UnrolledVarSet clique = ww.findVarSet(unrolled);
                    Assignment wrong = Assignment.makeFromSingleIndex(clique, ww.assnIdx);
                    value += this.computeValueGradientForAssn(wrong, clique, target);
                }
            }
            return value;
        }

        private double computeValueGradientForAssn(Assignment observations, ACRF.UnrolledVarSet clique, Variable target) {
            ++this.numCvgaCalls;
            Timing timing = new Timing();
            ACRF.Template tmpl = clique.getTemplate();
            int tidx = tmpl.index;
            Assignment cliqueAssn = Assignment.restriction(observations, clique);
            int M = target.getNumOutcomes();
            double[] vals = new double[M];
            int[] singles = new int[M];
            int assnIdx = 0;
            while (assnIdx < M) {
                cliqueAssn.setValue(target, assnIdx);
                vals[assnIdx] = this.computeLogFactorValue(cliqueAssn, tmpl, clique.getFv());
                singles[assnIdx] = cliqueAssn.singleIndex();
                ++assnIdx;
            }
            double logZ = Maths.sumLogProb(vals);
            int assnIdx2 = 0;
            while (assnIdx2 < M) {
                double marginal = Math.exp(vals[assnIdx2] - logZ);
                int expIdx = singles[assnIdx2];
                this.expectations[tidx][expIdx].plusEqualsSparse(clique.getFv(), marginal);
                if (this.defaultExpectations[tidx].location(expIdx) != -1) {
                    this.defaultExpectations[tidx].incrementValue(expIdx, marginal);
                }
                ++assnIdx2;
            }
            int observedVal = observations.get(target);
            this.timePerCvgaCall += timing.elapsedTime();
            return vals[observedVal] - logZ;
        }

        private double computeLogFactorValue(Assignment cliqueAssn, ACRF.Template tmpl, FeatureVector fv) {
            SparseVector[] weights = tmpl.getWeights();
            int idx = cliqueAssn.singleIndex();
            SparseVector w = weights[idx];
            double dp = w.dotProduct(fv);
            return dp += tmpl.getDefaultWeight(idx);
        }

        public void collectConstraints(InstanceList ilist) {
            int inum = 0;
            while (inum < ilist.size()) {
                logger.finest("*** Collecting constraints for instance " + inum);
                Instance inst = (Instance)ilist.get(inum);
                ACRF.UnrolledGraph unrolled = new ACRF.UnrolledGraph(inst, this.templates, null, false);
                Iterator it = unrolled.unrolledVarSetIterator();
                while (it.hasNext()) {
                    ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                    int tidx = clique.getTemplate().index;
                    if (tidx == -1) continue;
                    int assn = clique.lookupAssignmentNumber();
                    this.constraints[tidx][assn].plusEqualsSparse(clique.getFv(), clique.size());
                    if (this.defaultConstraints[tidx].location(assn) == -1) continue;
                    this.defaultConstraints[tidx].incrementValue(assn, clique.size());
                }
                if (this.allWrongWrongs != null) {
                    List wrongs = this.allWrongWrongs[inum];
                    for (WrongWrong ww : wrongs) {
                        ACRF.UnrolledVarSet clique = ww.findVarSet(unrolled);
                        int tidx = clique.getTemplate().index;
                        int wrong2rightId = ww.assnIdx;
                        this.constraints[tidx][wrong2rightId].plusEqualsSparse(clique.getFv(), 1.0);
                        if (this.defaultConstraints[tidx].location(wrong2rightId) == -1) continue;
                        this.defaultConstraints[tidx].incrementValue(wrong2rightId, 1.0);
                    }
                }
                ++inum;
            }
        }

        void dumpGradientToFile(String fileName) {
            try {
                double[] grad = new double[this.getNumParameters()];
                this.getValueGradient(grad);
                PrintStream w = new PrintStream(new FileOutputStream(fileName));
                int i = 0;
                while (i < this.numParameters) {
                    w.println(grad[i]);
                    ++i;
                }
                w.close();
            }
            catch (IOException e) {
                System.err.println("Could not open output file.");
                e.printStackTrace();
            }
        }

        void dumpDefaults() {
            System.out.println("Default constraints");
            int i = 0;
            while (i < this.defaultConstraints.length) {
                System.out.println("Template " + i);
                this.defaultConstraints[i].print();
                ++i;
            }
            System.out.println("Default expectations");
            i = 0;
            while (i < this.defaultExpectations.length) {
                System.out.println("Template " + i);
                this.defaultExpectations[i].print();
                ++i;
            }
        }

        void printDebugInfo(ACRF.UnrolledGraph unrolled) {
            this.acrf.print(System.err);
            Assignment assn = unrolled.getAssignment();
            Iterator it = unrolled.unrolledVarSetIterator();
            while (it.hasNext()) {
                ACRF.UnrolledVarSet clique = (ACRF.UnrolledVarSet)it.next();
                System.out.println("Clique " + clique);
                this.dumpAssnForClique(assn, clique);
                Factor ptl = unrolled.factorOf(clique);
                System.out.println("Value = " + ptl.value(assn));
                System.out.println(ptl);
            }
        }

        void dumpAssnForClique(Assignment assn, ACRF.UnrolledVarSet clique) {
            for (Variable var : clique) {
                System.out.println(var + " ==> " + assn.getObject(var) + "  (" + assn.get(var) + ")");
            }
        }

        private boolean weightValid(double w, int cnum, int j) {
            if (Double.isInfinite(w)) {
                logger.warning("Weight is infinite for clique " + cnum + "assignment " + j);
                return false;
            }
            if (Double.isNaN(w)) {
                logger.warning("Weight is Nan for clique " + cnum + "assignment " + j);
                return false;
            }
            return true;
        }

        private void addWrongWrong(InstanceList training) {
            this.allWrongWrongs = new List[training.size()];
            int totalAdded = 0;
            int i = 0;
            while (i < training.size()) {
                this.allWrongWrongs[i] = new ArrayList();
                int numAdded = 0;
                Instance instance = (Instance)training.get(i);
                ACRF.UnrolledGraph unrolled = this.acrf.unroll(instance);
                if (unrolled.factors().size() == 0) {
                    System.err.println("WARNING: FactorGraph for instance " + instance.getName() + " : no factors.");
                } else {
                    Inferencer inf = this.acrf.getInferencer();
                    inf.computeMarginals(unrolled);
                    Assignment target = unrolled.getAssignment();
                    Iterator it = unrolled.unrolledVarSetIterator();
                    while (it.hasNext()) {
                        ACRF.UnrolledVarSet vs = (ACRF.UnrolledVarSet)it.next();
                        Factor marg = inf.lookupMarginal(vs);
                        AssignmentIterator assnIt = vs.assignmentIterator();
                        while (assnIt.hasNext()) {
                            if (marg.value(assnIt) > PwplACRFTrainer.this.wrongWrongThreshold) {
                                Assignment assn = assnIt.assignment();
                                int vi = 0;
                                while (vi < vs.size()) {
                                    Variable var = vs.get(vi);
                                    if (this.isWrong2RightAssn(target, assn, var)) {
                                        int assnIdx = assn.singleIndex();
                                        this.allWrongWrongs[i].add(new WrongWrong(unrolled, vs, var, assnIdx));
                                        ++numAdded;
                                    }
                                    ++vi;
                                }
                            }
                            assnIt.advance();
                        }
                    }
                    logger.info("WrongWrongs: Instance " + i + " : " + instance.getName() + " Num added = " + numAdded);
                    totalAdded += numAdded;
                }
                ++i;
            }
            this.resetConstraints();
            this.collectConstraints(training);
            this.forceStale();
            logger.info("Total timesteps = " + this.totalTimesteps(training));
            logger.info("Total WrongWrongs = " + totalAdded);
        }

        private int totalTimesteps(InstanceList ilist) {
            int total = 0;
            int i = 0;
            while (i < ilist.size()) {
                Instance inst = (Instance)ilist.get(i);
                Sequence seq = (Sequence)inst.getData();
                total += seq.size();
                ++i;
            }
            return total;
        }

        private boolean isWrong2RightAssn(Assignment target, Assignment assn, Variable toExclude) {
            Variable[] vars = assn.getVars();
            int i = 0;
            while (i < vars.length) {
                Variable variable = vars[i];
                if (variable != toExclude && assn.get(variable) != target.get(variable)) {
                    return assn.get(toExclude) == target.get(toExclude);
                }
                ++i;
            }
            return false;
        }

        private class WrongWrong {
            int varIdx;
            int vsIdx;
            int assnIdx;

            public WrongWrong(ACRF.UnrolledGraph graph, VarSet vs, Variable var, int assnIdx) {
                this.varIdx = graph.getIndex(var);
                this.vsIdx = graph.getIndex(vs);
                this.assnIdx = assnIdx;
            }

            public ACRF.UnrolledVarSet findVarSet(ACRF.UnrolledGraph unrolled) {
                return unrolled.getUnrolledVarSet(this.vsIdx);
            }

            public Variable findVariable(ACRF.UnrolledGraph unrolled) {
                return unrolled.get(this.varIdx);
            }
        }
    }
}

