/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.inference;

import cc.mallet.grmm.inference.AbstractInferencer;
import cc.mallet.grmm.inference.MessageArray;
import cc.mallet.grmm.types.AbstractTableFactor;
import cc.mallet.grmm.types.Assignment;
import cc.mallet.grmm.types.AssignmentIterator;
import cc.mallet.grmm.types.Factor;
import cc.mallet.grmm.types.FactorGraph;
import cc.mallet.grmm.types.Factors;
import cc.mallet.grmm.types.LogTableFactor;
import cc.mallet.grmm.types.TableFactor;
import cc.mallet.grmm.types.VarSet;
import cc.mallet.grmm.types.Variable;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

public abstract class AbstractBeliefPropagation
extends AbstractInferencer {
    protected static Logger logger = MalletLogger.getLogger(AbstractBeliefPropagation.class.getName());
    private static final boolean diagnoseConvergence = false;
    protected boolean normalizeBeliefs = true;
    private static int totalMessagesSent = 0;
    private transient int myMessagesSent = 0;
    private transient int messagesSentAtStart = 0;
    private double threshold = 1.0E-5;
    protected boolean useCaching = false;
    private MessageStrategy messager;
    protected transient int iterUsed;
    private transient MessageArray messages;
    private transient MessageArray oldMessages;
    private transient Factor[] bel;
    protected transient FactorGraph mdlCurrent;
    protected transient int[] assignedVertexPtls;
    private static final long serialVersionUID = 1L;

    protected AbstractBeliefPropagation() {
        this(new SumProductMessageStrategy());
    }

    protected AbstractBeliefPropagation(MessageStrategy messager) {
        this.messager = messager;
    }

    public MessageStrategy getMessager() {
        return this.messager;
    }

    public AbstractBeliefPropagation setMessager(MessageStrategy messager) {
        this.messager = messager;
        return this;
    }

    public static int getTotalMessagesSent() {
        return totalMessagesSent;
    }

    public int getMessagesSent() {
        return this.myMessagesSent;
    }

    public int getMessagesUsedLastTime() {
        return this.myMessagesSent - this.messagesSentAtStart;
    }

    protected void resetMessagesSentAtStart() {
        this.messagesSentAtStart = this.myMessagesSent;
    }

    private void retrieveCachedMessages(FactorGraph m3) {
        this.messages = (MessageArray)m3.getInferenceCache(this.getClass());
    }

    private void cacheMessages(FactorGraph m3) {
        m3.setInferenceCache(this.getClass(), this.messages);
    }

    private void clearOldMessages() {
        this.oldMessages = null;
    }

    protected final void copyOldMessages() {
        this.clearOldMessages();
        this.oldMessages = this.messages.duplicate();
    }

    protected final boolean hasConverged() {
        return this.hasConverged(this.threshold);
    }

    protected final boolean hasConverged(double threshold) {
        double maxDiff = Double.NEGATIVE_INFINITY;
        Factor bestOldMsg = null;
        Factor bestNewMsg = null;
        MessageArray.Iterator msgIt = this.oldMessages.iterator();
        while (msgIt.hasNext()) {
            Factor oldMsg = (Factor)msgIt.next();
            Object from = msgIt.from();
            Object to = msgIt.to();
            Factor newMsg = this.messages.get(from, to);
            if (oldMsg == null) continue;
            assert (newMsg != null) : "Message went from nonnull to null " + from + " --> " + to;
            AssignmentIterator it = oldMsg.assignmentIterator();
            while (it.hasNext()) {
                double val2;
                Assignment assn = (Assignment)it.next();
                double val1 = oldMsg.value(assn);
                double diff = Math.abs(val1 - (val2 = newMsg.value(assn)));
                if (diff > threshold) {
                    return false;
                }
                if (!(diff > maxDiff)) continue;
                maxDiff = diff;
                bestOldMsg = oldMsg;
                bestNewMsg = newMsg;
            }
        }
        return true;
    }

    private void initOldMessages(FactorGraph fg) {
        this.oldMessages = new MessageArray(fg);
        if (this.useCaching && fg.getInferenceCache(this.getClass()) != null) {
            logger.info("AsyncLoopyBP: Reusing previous marginals");
            this.retrieveCachedMessages(fg);
            this.copyOldMessages();
        } else {
            Iterator it = fg.factorsIterator();
            while (it.hasNext()) {
                Factor factor = (Factor)it.next();
                VarSet varset = factor.varSet();
                for (Variable var : varset) {
                    this.oldMessages.put(var, factor, (Factor)new TableFactor(var));
                    this.oldMessages.put(factor, var, (Factor)new TableFactor(var));
                }
            }
        }
    }

    protected void initForGraph(FactorGraph mdl) {
        this.mdlCurrent = mdl;
        int numV = mdl.numVariables();
        this.bel = new Factor[numV];
        Object cache = mdl.getInferenceCache(this.getClass());
        this.messages = this.useCaching && cache != null ? (MessageArray)cache : new MessageArray(mdl);
        this.initOldMessages(mdl);
        this.messager.setMessageArray(this.messages, this.oldMessages);
    }

    protected void sendMessage(FactorGraph mdl, Variable from, Factor to) {
        ++totalMessagesSent;
        ++this.myMessagesSent;
        this.messager.sendMessage(mdl, from, to);
    }

    protected void sendMessage(FactorGraph mdl, Factor from, Variable to) {
        ++totalMessagesSent;
        ++this.myMessagesSent;
        this.messager.sendMessage(mdl, from, to);
    }

    protected void doneWithGraph(FactorGraph mdl) {
        this.clearOldMessages();
        if (this.useCaching) {
            this.cacheMessages(mdl);
        }
    }

    public int iterationsUsed() {
        return this.iterUsed;
    }

    @Override
    public Factor lookupMarginal(Variable var) {
        int idx = this.mdlCurrent.getIndex(var);
        if (idx < 0 || idx > this.bel.length) {
            throw new IllegalArgumentException("Cannot find variable " + var + " in factor graph " + this.mdlCurrent);
        }
        if (this.bel[idx] == null) {
            Factor marg = this.messager.msgProduct(null, idx, Integer.MIN_VALUE);
            if (this.normalizeBeliefs) {
                marg.normalize();
            }
            assert (marg.varSet().size() == 1) : "Invalid marginal for var " + var + ": " + marg;
            assert (marg.varSet().contains(var)) : "Invalid marginal for var " + var + ": " + marg;
            this.bel[idx] = marg;
        }
        return this.bel[idx];
    }

    @Override
    public void dump() {
        this.messages.dump();
    }

    @Override
    public void reportTime() {
        System.err.println("AbstractBeliefPropagation: Total messages sent = " + totalMessagesSent);
    }

    public void dump(PrintWriter writer) {
        this.messages.dump(writer);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.defaultWriteObject();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
    }

    @Override
    public Factor lookupMarginal(VarSet c) {
        if (c.size() == 1) {
            return this.lookupMarginal(c.get(0));
        }
        List factors = this.mdlCurrent.allFactorsOf(c);
        if (factors.isEmpty()) {
            throw new UnsupportedOperationException("Cannot compute marginal of " + c + ": Must be either a single variable or a factor in the graph.");
        }
        return this.lookupMarginal(c, factors);
    }

    private Factor lookupMarginal(VarSet vs, List factors) {
        Factor marginal = Factors.multiplyAll(factors);
        for (Factor factor : factors) {
            for (Variable var : vs) {
                Factor msg = this.messages.get(var, factor);
                if (msg == null) continue;
                marginal.multiplyBy(msg);
            }
        }
        marginal.normalize();
        return marginal;
    }

    @Override
    public double lookupLogJoint(Assignment assn) {
        double accum = 0.0;
        Iterator it = this.mdlCurrent.variablesIterator();
        while (it.hasNext()) {
            Variable var = (Variable)it.next();
            Factor ptl = this.lookupMarginal(var);
            int deg = this.mdlCurrent.getDegree(var);
            if (deg == 1) continue;
            accum -= (double)(deg - 1) * ptl.logValue(assn);
        }
        it = this.mdlCurrent.varSetIterator();
        while (it.hasNext()) {
            VarSet varSet = (VarSet)it.next();
            Factor p12 = this.lookupMarginal(varSet);
            double logphi = p12.logValue(assn);
            accum += logphi;
        }
        return accum;
    }

    public static abstract class AbstractMessageStrategy
    implements MessageStrategy {
        protected MessageArray messages;
        protected MessageArray oldMessages;

        @Override
        public void setMessageArray(MessageArray msgs, MessageArray oldMsgs) {
            this.messages = msgs;
            this.oldMessages = oldMsgs;
        }

        @Override
        public Factor msgProduct(Factor product, int idx, int excludeMsgFrom) {
            if (product == null) {
                product = this.createEmptyFactorForVar(idx);
            }
            MessageArray.ToMsgsIterator it = this.messages.toMessagesIterator(idx);
            while (it.hasNext()) {
                it.next();
                int j = it.currentFromIdx();
                Factor msg = it.currentMessage();
                if (j == excludeMsgFrom) continue;
                product.multiplyBy(msg);
            }
            return product;
        }

        private Factor createEmptyFactorForVar(int idx) {
            AbstractTableFactor product = this.messages.isInLogSpace() ? new LogTableFactor((Variable)this.messages.idx2obj(idx)) : new TableFactor((Variable)this.messages.idx2obj(idx));
            return product;
        }
    }

    public static class MaxProductMessageStrategy
    extends AbstractMessageStrategy
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private static final int CUURENT_SERIAL_VERSION = 1;

        @Override
        public void sendMessage(FactorGraph mdl, Factor from, Variable to) {
            int fromIdx = this.messages.getIndex(from);
            int toIdx = this.messages.getIndex(to);
            Factor product = from.duplicate();
            this.msgProduct(product, fromIdx, toIdx);
            Factor msg = product.extractMax(to);
            msg.normalize();
            assert (msg.varSet().size() == 1);
            assert (msg.varSet().contains(to));
            this.messages.put(fromIdx, toIdx, msg);
        }

        @Override
        public void sendMessage(FactorGraph mdl, Variable from, Factor to) {
            int fromIdx = this.messages.getIndex(from);
            int toIdx = this.messages.getIndex(to);
            Factor msg = this.msgProduct(null, fromIdx, toIdx);
            msg.normalize();
            assert (msg.varSet().size() == 1);
            assert (msg.varSet().contains(from));
            this.messages.put(fromIdx, toIdx, msg);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.defaultWriteObject();
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            in.readInt();
        }
    }

    public static interface MessageStrategy {
        public void setMessageArray(MessageArray var1, MessageArray var2);

        public void sendMessage(FactorGraph var1, Factor var2, Variable var3);

        public void sendMessage(FactorGraph var1, Variable var2, Factor var3);

        public Factor msgProduct(Factor var1, int var2, int var3);
    }

    public static class SumProductMessageStrategy
    extends AbstractMessageStrategy
    implements Serializable {
        private double damping = 1.0;
        private static final long serialVersionUID = 1L;
        private static final int CUURENT_SERIAL_VERSION = 2;

        public SumProductMessageStrategy() {
        }

        public SumProductMessageStrategy(double damping) {
            this.damping = damping;
        }

        @Override
        public void sendMessage(FactorGraph mdl, Factor from, Variable to) {
            int fromIdx = this.messages.getIndex(from);
            int toIdx = this.messages.getIndex(to);
            Factor product = from.duplicate();
            this.msgProduct(product, fromIdx, toIdx);
            Factor msg = product.marginalize(to);
            msg.normalize();
            if (logger.isLoggable(Level.FINEST)) {
                logger.info("MSG " + from + " --> " + to);
                logger.info("FACTOR: " + from.dumpToString());
                logger.info("MSG: " + msg.dumpToString());
                logger.info("END MSG " + from + " --> " + to);
            }
            assert (msg.varSet().size() == 1);
            assert (msg.varSet().contains(to));
            this.makeDampedUpdate(fromIdx, toIdx, msg);
        }

        @Override
        public void sendMessage(FactorGraph mdl, Variable from, Factor to) {
            int fromIdx = this.messages.getIndex(from);
            int toIdx = this.messages.getIndex(to);
            Factor msg = this.msgProduct(null, fromIdx, toIdx);
            msg.normalize();
            assert (msg.varSet().size() == 1);
            assert (msg.varSet().contains(from));
            this.messages.put(fromIdx, toIdx, msg);
        }

        private void makeDampedUpdate(int fromIdx, int toIdx, Factor msg) {
            Factor oldMsg;
            if (this.damping < 1.0 && (oldMsg = this.oldMessages.get(fromIdx, toIdx)) != null) {
                AbstractTableFactor oldTbl = (AbstractTableFactor)oldMsg.duplicate();
                oldTbl.normalize();
                oldTbl.timesEquals(1.0 - this.damping);
                AbstractTableFactor tbl = (AbstractTableFactor)msg;
                tbl.timesEquals(this.damping);
                tbl.plusEquals(oldTbl);
                msg = tbl;
            }
            this.messages.put(fromIdx, toIdx, msg);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.defaultWriteObject();
            out.writeInt(2);
            out.writeDouble(this.damping);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            int version2 = in.readInt();
            if (2 <= version2) {
                this.damping = in.readDouble();
            }
        }
    }
}

