/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.AbstractBeliefPropagation;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.DirectedModel;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.HashVarSet;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.Tree;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
import gnu.trove.THashMap;
import gnu.trove.THashSet;
import gnu.trove.TIntObjectHashMap;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Reader;
import java.io.Serializable;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import org._3pq.jgrapht.Edge;
import org._3pq.jgrapht.Graph;
import org._3pq.jgrapht.graph.SimpleGraph;
import org._3pq.jgrapht.traverse.BreadthFirstIterator;
import org.jdom.Document;
import org.jdom.Element;
import org.jdom.JDOMException;
import org.jdom.input.SAXBuilder;

public class TRP
extends AbstractBeliefPropagation {
    private static Logger logger = MalletLogger.getLogger(TRP.class.getName());
    private static final boolean reportSpanningTrees = false;
    private TreeFactory factory;
    private TerminationCondition terminator;
    private Random random = new Random();
    private transient TIntObjectHashMap factorTouched;
    private transient boolean hasConverged;
    private transient File verboseOutputDirectory = null;
    private static final long serialVersionUID = 1L;

    public TRP() {
        this(null, null);
    }

    public TRP(TreeFactory f) {
        this(f, null);
    }

    public TRP(TerminationCondition cond) {
        this(null, cond);
    }

    public TRP(TreeFactory f, TerminationCondition cond) {
        this.factory = f;
        this.terminator = cond;
    }

    public static TRP createForMaxProduct() {
        TRP trp = new TRP();
        trp.setMessager(new AbstractBeliefPropagation.MaxProductMessageStrategy());
        return trp;
    }

    public TRP setTerminator(TerminationCondition cond) {
        this.terminator = cond;
        return this;
    }

    public TRP setFactory(TreeFactory factory) {
        this.factory = factory;
        return this;
    }

    public void setRandomSeed(long seed) {
        this.random = new Random(seed);
    }

    public void setVerboseOutputDirectory(File verboseOutputDirectory) {
        this.verboseOutputDirectory = verboseOutputDirectory;
    }

    public boolean isConverged() {
        return this.hasConverged;
    }

    @Override
    protected void initForGraph(FactorGraph m) {
        super.initForGraph(m);
        int numNodes = m.numVariables();
        this.factorTouched = new TIntObjectHashMap(numNodes);
        this.hasConverged = false;
        if (this.factory == null) {
            this.factory = new AlmostRandomTreeFactory();
        }
        if (this.terminator == null) {
            this.terminator = new DefaultConvergenceTerminator();
        } else {
            this.terminator.reset();
        }
    }

    private static Tree graphToTree(Graph g) throws Exception {
        if (g.vertexSet().size() <= 0) {
            throw new RuntimeException("Empty graph.");
        }
        Tree tree = new Tree();
        Object root = g.vertexSet().iterator().next();
        tree.add(root);
        BreadthFirstIterator it1 = new BreadthFirstIterator(g, root);
        while (it1.hasNext()) {
            Object v1 = it1.next();
            for (Edge edge : g.edgesOf(v1)) {
                Object v2 = edge.oppositeVertex(v1);
                if (tree.getParent(v1) == v2) continue;
                tree.addNode(v1, v2);
                assert (tree.getParent(v2) == v1);
            }
        }
        return tree;
    }

    @Override
    public void computeMarginals(FactorGraph m) {
        this.resetMessagesSentAtStart();
        this.initForGraph(m);
        int iter = 0;
        while (this.terminator.shouldContinue(this)) {
            logger.finer("TRP iteration " + iter++);
            Tree tree = this.factory.nextTree(m);
            this.propagate(tree);
            this.dumpForIter(iter, tree);
        }
        this.iterUsed = iter;
        logger.info("TRP used " + iter + " iterations.");
        this.doneWithGraph(m);
    }

    private void dumpForIter(int iter, Tree tree) {
        if (this.verboseOutputDirectory != null) {
            try {
                FileWriter writer = new FileWriter(new File(this.verboseOutputDirectory, "iter" + iter + ".txt"));
                this.dump(new PrintWriter((Writer)writer, true));
                writer.close();
                FileWriter bfWriter = new FileWriter(new File(this.verboseOutputDirectory, "beliefs" + iter + ".txt"));
                this.dumpBeliefs(new PrintWriter((Writer)bfWriter, true));
                bfWriter.close();
                FileWriter treeWriter = new FileWriter(new File(this.verboseOutputDirectory, "tree" + iter + ".txt"));
                treeWriter.write(tree.toString());
                treeWriter.write("\n");
                treeWriter.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void dumpBeliefs(PrintWriter writer) {
        int vi = 0;
        while (vi < this.mdlCurrent.numVariables()) {
            Variable var = this.mdlCurrent.get(vi);
            Factor mrg = this.lookupMarginal(var);
            writer.println(mrg.dumpToString());
            writer.println();
            ++vi;
        }
    }

    private void propagate(Tree tree) {
        Object root = tree.getRoot();
        this.lambdaPropagation(tree, root);
        this.piPropagation(tree, root);
    }

    private void lambdaPropagation(Tree tree, Object root) {
        LinkedList openList = new LinkedList();
        LinkedList closedList = new LinkedList();
        openList.addAll(tree.getChildren(root));
        while (!openList.isEmpty()) {
            Object var = openList.removeFirst();
            openList.addAll(tree.getChildren(var));
            closedList.addFirst(var);
        }
        for (Object child : closedList) {
            Object parent = tree.getParent(child);
            this.sendMessage(this.mdlCurrent, child, parent);
        }
    }

    private void piPropagation(Tree tree, Object root) {
        LinkedList<Object> openList = new LinkedList<Object>();
        openList.add(root);
        while (!openList.isEmpty()) {
            Object current = openList.removeFirst();
            List children = tree.getChildren(current);
            for (Object child : children) {
                this.sendMessage(this.mdlCurrent, current, child);
                openList.add(child);
            }
        }
    }

    private void sendMessage(FactorGraph fg, Object parent, Object child) {
        if (logger.isLoggable(Level.FINER)) {
            logger.finer("Sending message: " + parent + " --> " + child);
        }
        if (parent instanceof Factor) {
            this.sendMessage(fg, (Factor)parent, (Variable)child);
        } else if (parent instanceof Variable) {
            this.sendMessage(fg, (Variable)parent, (Factor)child);
        }
    }

    private boolean allEdgesTouched() {
        Iterator it = this.mdlCurrent.factorsIterator();
        while (it.hasNext()) {
            Factor factor = (Factor)it.next();
            int idx = this.mdlCurrent.getIndex(factor);
            int numTouches = this.getNumTouches(idx);
            if (numTouches != 0) continue;
            logger.finest("***TRP continuing: factor " + idx + " not touched.");
            return false;
        }
        return true;
    }

    private void touchFactor(Factor factor) {
        int idx = this.mdlCurrent.getIndex(factor);
        this.incrementTouches(idx);
    }

    private boolean isFactorTouched(Factor factor) {
        int idx1 = this.mdlCurrent.getIndex(factor);
        return this.getNumTouches(idx1) > 0;
    }

    private int getNumTouches(int idx1) {
        Integer integer = (Integer)this.factorTouched.get(idx1);
        return integer == null ? 0 : integer;
    }

    private void incrementTouches(int idx1) {
        int nt = this.getNumTouches(idx1);
        this.factorTouched.put(idx1, (Object)new Integer(nt + 1));
    }

    public Factor query(DirectedModel m, Variable var) {
        throw new UnsupportedOperationException("GRMM doesn't yet do directed models.");
    }

    public Assignment bestAssignment() {
        int[] outcomes = new int[this.mdlCurrent.numVariables()];
        int i = 0;
        while (i < outcomes.length) {
            Variable var = this.mdlCurrent.get(i);
            TableFactor ptl = (TableFactor)this.lookupMarginal(var);
            outcomes[i] = ptl.argmax();
            ++i;
        }
        return new Assignment(this.mdlCurrent, outcomes);
    }

    public Object clone() {
        try {
            TRP dup = (TRP)super.clone();
            if (this.terminator != null) {
                dup.terminator = (TerminationCondition)this.terminator.clone();
            }
            return dup;
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
    }

    public class AlmostRandomTreeFactory
    implements TreeFactory {
        private static final long serialVersionUID = -7461763414516915264L;

        @Override
        public Tree nextTree(FactorGraph fullGraph) {
            SimpleUnionFind unionFind = new SimpleUnionFind();
            ArrayList edges = new ArrayList(fullGraph.factors());
            ArrayList<Factor> goodEdges = new ArrayList<Factor>(fullGraph.numVariables());
            Collections.shuffle(edges, TRP.this.random);
            try {
                VarSet varSet;
                Iterator it = edges.iterator();
                while (it.hasNext()) {
                    Factor factor = (Factor)it.next();
                    varSet = factor.varSet();
                    if (TRP.this.isFactorTouched(factor) || !unionFind.noPairConnected(varSet)) continue;
                    goodEdges.add(factor);
                    unionFind.unionAll(factor);
                    it.remove();
                }
                for (Factor factor : edges) {
                    varSet = factor.varSet();
                    if (!unionFind.noPairConnected(varSet)) continue;
                    goodEdges.add(factor);
                    unionFind.unionAll(factor);
                }
                for (Factor factor : goodEdges) {
                    TRP.this.touchFactor(factor);
                }
                SimpleGraph g = new SimpleGraph();
                Iterator it2 = fullGraph.variablesIterator();
                while (it2.hasNext()) {
                    Variable var = (Variable)it2.next();
                    g.addVertex((Object)var);
                }
                for (Factor factor : goodEdges) {
                    g.addVertex((Object)factor);
                    for (Variable var : factor.varSet()) {
                        g.addEdge((Object)factor, (Object)var);
                    }
                }
                Tree tree = TRP.graphToTree((Graph)g);
                return tree;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException(e);
            }
        }
    }

    public static class ConvergenceTerminator
    implements TerminationCondition {
        double delta = 0.01;

        public ConvergenceTerminator() {
        }

        public ConvergenceTerminator(double delta) {
            this.delta = delta;
        }

        @Override
        public void reset() {
        }

        @Override
        public boolean shouldContinue(TRP trp) {
            boolean retval = !trp.hasConverged(this.delta);
            trp.copyOldMessages();
            return retval;
        }

        @Override
        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    public static class DefaultConvergenceTerminator
    implements TerminationCondition {
        ConvergenceTerminator cterminator;
        IterationTerminator iterminator;
        String msg;

        public DefaultConvergenceTerminator() {
            this(0.001, 1000);
        }

        public DefaultConvergenceTerminator(double delta, int maxIter) {
            this.cterminator = new ConvergenceTerminator(delta);
            this.iterminator = new IterationTerminator(maxIter);
            this.msg = "***TRP quitting: over " + maxIter + " iterations";
        }

        @Override
        public void reset() {
            this.iterminator.reset();
            this.cterminator.reset();
        }

        @Override
        public boolean shouldContinue(TRP trp) {
            boolean notAllTouched;
            boolean bl = notAllTouched = !trp.allEdgesTouched();
            if (!this.iterminator.shouldContinue(trp)) {
                logger.warning(this.msg);
                if (notAllTouched) {
                    logger.warning("***TRP warning: Not all edges used!");
                }
                return false;
            }
            if (notAllTouched) {
                return true;
            }
            return this.cterminator.shouldContinue(trp);
        }

        @Override
        public Object clone() throws CloneNotSupportedException {
            DefaultConvergenceTerminator dup = (DefaultConvergenceTerminator)super.clone();
            dup.iterminator = (IterationTerminator)this.iterminator.clone();
            dup.cterminator = (ConvergenceTerminator)this.cterminator.clone();
            return dup;
        }
    }

    public static class IterationTerminator
    implements TerminationCondition {
        int current;
        int max;

        @Override
        public void reset() {
            this.current = 0;
        }

        public IterationTerminator(int m) {
            this.max = m;
            this.reset();
        }

        @Override
        public boolean shouldContinue(TRP trp) {
            ++this.current;
            if (this.current >= this.max) {
                logger.finest("***TRP quitting: Iteration " + this.current + " >= " + this.max);
            }
            return this.current <= this.max;
        }

        @Override
        public Object clone() throws CloneNotSupportedException {
            return super.clone();
        }
    }

    private static class SimpleUnionFind {
        private Map obj2set = new THashMap();

        private SimpleUnionFind() {
        }

        private Set findSet(Object obj) {
            Set container = (Set)this.obj2set.get(obj);
            if (container != null) {
                return container;
            }
            THashSet newSet = new THashSet();
            newSet.add(obj);
            this.obj2set.put(obj, newSet);
            return newSet;
        }

        private void union(Object obj1, Object obj2) {
            Set set1 = this.findSet(obj1);
            Set set2 = this.findSet(obj2);
            set1.addAll(set2);
            for (Object obj : set2) {
                this.obj2set.put(obj, set1);
            }
        }

        public boolean noPairConnected(VarSet varSet) {
            int i = 0;
            while (i < varSet.size()) {
                int j = i + 1;
                while (j < varSet.size()) {
                    Variable v1 = varSet.get(i);
                    Variable v2 = varSet.get(j);
                    if (this.findSet(v1) == this.findSet(v2)) {
                        return false;
                    }
                    ++j;
                }
                ++i;
            }
            return true;
        }

        public void unionAll(Factor factor) {
            VarSet varSet = factor.varSet();
            int i = 0;
            while (i < varSet.size()) {
                Variable var = varSet.get(i);
                this.union(var, factor);
                ++i;
            }
        }
    }

    public static interface TerminationCondition
    extends Cloneable,
    Serializable {
        public boolean shouldContinue(TRP var1);

        public void reset();

        public Object clone() throws CloneNotSupportedException;
    }

    public static interface TreeFactory
    extends Serializable {
        public Tree nextTree(FactorGraph var1);
    }

    public static class TreeListFactory
    implements TreeFactory {
        private List lst;
        private Iterator it;

        public TreeListFactory(List l) {
            this.lst = l;
            this.it = this.lst.iterator();
        }

        public TreeListFactory(Tree[] arr) {
            this.lst = new ArrayList<Tree>(Arrays.asList(arr));
            this.it = this.lst.iterator();
        }

        public static TreeListFactory makeFromReaders(FactorGraph fg, List readerList) {
            ArrayList<Tree> treeList = new ArrayList<Tree>();
            Iterator it = readerList.iterator();
            while (it.hasNext()) {
                try {
                    Reader reader = (Reader)it.next();
                    Document doc = new SAXBuilder().build(reader);
                    Element treeElt = doc.getRootElement();
                    Element rootElt = (Element)treeElt.getChildren().get(0);
                    Tree tree = TreeListFactory.readTreeRec(fg, rootElt);
                    System.out.println(tree.dumpToString());
                    treeList.add(tree);
                }
                catch (JDOMException e) {
                    throw new RuntimeException(e);
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            return new TreeListFactory(treeList);
        }

        public static TreeListFactory readFromFiles(FactorGraph fg, List fileList) {
            ArrayList<Tree> treeList = new ArrayList<Tree>();
            Iterator it = fileList.iterator();
            while (it.hasNext()) {
                try {
                    File treeFile = (File)it.next();
                    Document doc = new SAXBuilder().build(treeFile);
                    Element treeElt = doc.getRootElement();
                    Element rootElt = (Element)treeElt.getChildren().get(0);
                    treeList.add(TreeListFactory.readTreeRec(fg, rootElt));
                }
                catch (JDOMException e) {
                    throw new RuntimeException(e);
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            return new TreeListFactory(treeList);
        }

        private static Tree readTreeRec(FactorGraph fg, Element elt) {
            ArrayList<Tree> subtrees = new ArrayList<Tree>();
            for (Element child : elt.getChildren()) {
                Tree subtree = TreeListFactory.readTreeRec(fg, child);
                subtrees.add(subtree);
            }
            Object parent = TreeListFactory.objFromElt(fg, elt);
            return Tree.makeFromSubtree(parent, subtrees);
        }

        private static Object objFromElt(FactorGraph fg, Element elt) {
            String type = elt.getName();
            if (type.equals("VAR")) {
                String vname = elt.getAttributeValue("NAME");
                return fg.findVariable(vname);
            }
            if (type.equals("FACTOR")) {
                String varSetStr = elt.getAttributeValue("VARS");
                String[] vnames = varSetStr.split("\\s+");
                Variable[] vars = new Variable[vnames.length];
                int i = 0;
                while (i < vnames.length) {
                    vars[i] = fg.findVariable(vnames[i]);
                    ++i;
                }
                return fg.factorOf(new HashVarSet(vars));
            }
            throw new RuntimeException("Can't figure out element " + elt);
        }

        @Override
        public Tree nextTree(FactorGraph mdl) {
            if (!this.it.hasNext()) {
                this.it = this.lst.iterator();
            }
            return (Tree)this.it.next();
        }
    }
}

