/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised.pr;

import cc.mallet.fst.CRF;
import cc.mallet.fst.SumLattice;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.pr.CachedDotTransitionIterator;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.Sequence;
import cc.mallet.util.MalletLogger;
import java.util.logging.Level;
import java.util.logging.Logger;

public class SumLatticeDefaultCachedDot
implements SumLattice {
    private static Logger logger = MalletLogger.getLogger(SumLatticeDefaultCachedDot.class.getName());
    protected static boolean saveXis = false;
    Transducer t;
    double totalWeight;
    Sequence input;
    Sequence output;
    LatticeNode[][] nodes;
    int latticeLength;
    double[][] gammas;
    double[][][] xis;
    LabelVector[] labelings;

    protected SumLatticeDefaultCachedDot() {
    }

    protected LatticeNode getLatticeNode(int ip, int stateIndex) {
        if (this.nodes[ip][stateIndex] == null) {
            this.nodes[ip][stateIndex] = new LatticeNode(ip, this.t.getState(stateIndex));
        }
        return this.nodes[ip][stateIndex];
    }

    public SumLatticeDefaultCachedDot(Transducer trans, Sequence input, Sequence output, double[][][] cachedDots, Transducer.Incrementor incrementor, boolean saveXis, LabelAlphabet outputAlphabet) {
        Transducer.State destination;
        CachedDotTransitionIterator iter;
        Transducer.State s2;
        int ip;
        assert (output == null || input.size() == output.size());
        this.t = trans;
        this.input = input;
        this.output = output;
        this.latticeLength = input.size() + 1;
        int numStates = this.t.numStates();
        this.nodes = new LatticeNode[this.latticeLength][numStates];
        this.gammas = new double[this.latticeLength][numStates];
        if (saveXis) {
            this.xis = new double[this.latticeLength][numStates][numStates];
        }
        double[][] outputCounts = null;
        if (outputAlphabet != null) {
            outputCounts = new double[this.latticeLength][outputAlphabet.size()];
        }
        int i = 0;
        while (i < numStates) {
            ip = 0;
            while (ip < this.latticeLength) {
                this.gammas[ip][i] = Double.NEGATIVE_INFINITY;
                ++ip;
            }
            if (saveXis) {
                int j = 0;
                while (j < numStates) {
                    int ip2 = 0;
                    while (ip2 < this.latticeLength) {
                        this.xis[ip2][i][j] = Double.NEGATIVE_INFINITY;
                        ++ip2;
                    }
                    ++j;
                }
            }
            ++i;
        }
        logger.fine("Starting Foward pass");
        boolean atLeastOneInitialState = false;
        int i2 = 0;
        while (i2 < numStates) {
            double initialWeight = this.t.getState(i2).getInitialWeight();
            if (initialWeight > Double.NEGATIVE_INFINITY) {
                this.getLatticeNode((int)0, (int)i2).alpha = initialWeight;
                atLeastOneInitialState = true;
            }
            ++i2;
        }
        if (!atLeastOneInitialState) {
            logger.warning("There are no starting states!");
        }
        ip = 0;
        while (ip < this.latticeLength - 1) {
            int i3 = 0;
            while (i3 < numStates) {
                if (this.nodes[ip][i3] != null && this.nodes[ip][i3].alpha != Double.NEGATIVE_INFINITY) {
                    s2 = this.t.getState(i3);
                    iter = new CachedDotTransitionIterator((CRF.State)s2, input, ip, null, cachedDots[ip][i3]);
                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine(" Starting Foward transition iteration from state " + s2.getName() + " on input " + input.get(ip).toString() + " and output " + (output == null ? "(null)" : output.get(ip).toString()));
                    }
                    while (iter.hasNext()) {
                        destination = iter.nextState();
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("Forward Lattice[inputPos=" + ip + "][source=" + s2.getName() + "][dest=" + destination.getName() + "]");
                        }
                        LatticeNode destinationNode = this.getLatticeNode(ip + 1, destination.getIndex());
                        destinationNode.output = iter.getOutput();
                        double transitionWeight = iter.getWeight();
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("BEFORE update: destinationNode.alpha=" + destinationNode.alpha);
                        }
                        destinationNode.alpha = Transducer.sumLogProb(destinationNode.alpha, this.nodes[ip][i3].alpha + transitionWeight);
                        if (!logger.isLoggable(Level.FINE)) continue;
                        logger.fine("transitionWeight=" + transitionWeight + " nodes[" + ip + "][" + i3 + "].alpha=" + this.nodes[ip][i3].alpha + " destinationNode.alpha=" + destinationNode.alpha);
                    }
                }
                ++i3;
            }
            ++ip;
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Forward Lattice:");
            ip = 0;
            while (ip < this.latticeLength) {
                StringBuffer sb = new StringBuffer();
                int i4 = 0;
                while (i4 < numStates) {
                    sb.append(" " + (this.nodes[ip][i4] == null ? "<null>" : Double.valueOf(this.nodes[ip][i4].alpha)));
                    ++i4;
                }
                logger.fine(sb.toString());
                ++ip;
            }
        }
        this.totalWeight = Double.NEGATIVE_INFINITY;
        i2 = 0;
        while (i2 < numStates) {
            if (this.nodes[this.latticeLength - 1][i2] != null) {
                this.totalWeight = Transducer.sumLogProb(this.totalWeight, this.nodes[this.latticeLength - 1][i2].alpha + this.t.getState(i2).getFinalWeight());
            }
            ++i2;
        }
        logger.fine("totalWeight=" + this.totalWeight);
        if (this.totalWeight == Double.NEGATIVE_INFINITY) {
            return;
        }
        i2 = 0;
        while (i2 < numStates) {
            if (this.nodes[this.latticeLength - 1][i2] != null) {
                Transducer.State s3 = this.t.getState(i2);
                this.nodes[this.latticeLength - 1][i2].beta = s3.getFinalWeight();
                this.gammas[this.latticeLength - 1][i2] = this.nodes[this.latticeLength - 1][i2].alpha + this.nodes[this.latticeLength - 1][i2].beta - this.totalWeight;
                if (incrementor != null) {
                    double p = Math.exp(this.gammas[this.latticeLength - 1][i2]);
                    assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", gamma=" + this.gammas[this.latticeLength - 1][i2];
                    incrementor.incrementFinalState(s3, p);
                }
            }
            ++i2;
        }
        ip = this.latticeLength - 2;
        while (ip >= 0) {
            int i5 = 0;
            while (i5 < numStates) {
                if (this.nodes[ip][i5] != null && this.nodes[ip][i5].alpha != Double.NEGATIVE_INFINITY) {
                    s2 = this.t.getState(i5);
                    iter = new CachedDotTransitionIterator((CRF.State)s2, input, ip, null, cachedDots[ip][i5]);
                    while (iter.hasNext()) {
                        int j;
                        LatticeNode destinationNode;
                        destination = iter.nextState();
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("Backward Lattice[inputPos=" + ip + "][source=" + s2.getName() + "][dest=" + destination.getName() + "]");
                        }
                        if ((destinationNode = this.nodes[ip + 1][j = destination.getIndex()]) == null) continue;
                        double transitionWeight = iter.getWeight();
                        assert (!Double.isNaN(transitionWeight));
                        double oldBeta = this.nodes[ip][i5].beta;
                        assert (!Double.isNaN(this.nodes[ip][i5].beta));
                        this.nodes[ip][i5].beta = Transducer.sumLogProb(this.nodes[ip][i5].beta, destinationNode.beta + transitionWeight);
                        assert (!Double.isNaN(this.nodes[ip][i5].beta)) : "dest.beta=" + destinationNode.beta + " trans=" + transitionWeight + " sum=" + (destinationNode.beta + transitionWeight) + " oldBeta=" + oldBeta;
                        double xi = this.nodes[ip][i5].alpha + transitionWeight + this.nodes[ip + 1][j].beta - this.totalWeight;
                        if (saveXis) {
                            this.xis[ip][i5][j] = xi;
                        }
                        assert (!Double.isNaN(this.nodes[ip][i5].alpha));
                        assert (!Double.isNaN(transitionWeight));
                        assert (!Double.isNaN(this.nodes[ip + 1][j].beta));
                        assert (!Double.isNaN(this.totalWeight));
                        if (incrementor == null && outputAlphabet == null) continue;
                        double p = Math.exp(xi);
                        assert (p >= 0.0 && p <= 1.000001) : "p=" + p + ", xis[" + ip + "][" + i5 + "][" + j + "]=" + xi;
                        if (incrementor != null) {
                            incrementor.incrementTransition(iter, p);
                        }
                        if (outputAlphabet == null) continue;
                        int outputIndex = outputAlphabet.lookupIndex(iter.getOutput(), false);
                        assert (outputIndex >= 0);
                        double[] dArray = outputCounts[ip];
                        int n = outputIndex;
                        dArray[n] = dArray[n] + p;
                    }
                    this.gammas[ip][i5] = this.nodes[ip][i5].alpha + this.nodes[ip][i5].beta - this.totalWeight;
                }
                ++i5;
            }
            --ip;
        }
        if (incrementor != null) {
            i2 = 0;
            while (i2 < numStates) {
                double p = Math.exp(this.gammas[0][i2]);
                assert (p >= 0.0 && p <= 1.000001) : "p=" + p;
                incrementor.incrementInitialState(this.t.getState(i2), p);
                ++i2;
            }
        }
        if (outputAlphabet != null) {
            this.labelings = new LabelVector[this.latticeLength];
            ip = this.latticeLength - 2;
            while (ip >= 0) {
                assert (Math.abs(1.0 - MatrixOps.sum(outputCounts[ip])) < 1.0E-6);
                this.labelings[ip] = new LabelVector(outputAlphabet, outputCounts[ip]);
                --ip;
            }
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Lattice:");
            ip = 0;
            while (ip < this.latticeLength) {
                StringBuffer sb = new StringBuffer();
                int i6 = 0;
                while (i6 < numStates) {
                    sb.append(" " + this.gammas[ip][i6]);
                    ++i6;
                }
                logger.fine(sb.toString());
                ++ip;
            }
        }
    }

    @Override
    public double[][][] getXis() {
        return this.xis;
    }

    @Override
    public double[][] getGammas() {
        return this.gammas;
    }

    @Override
    public double getTotalWeight() {
        assert (!Double.isNaN(this.totalWeight));
        return this.totalWeight;
    }

    @Override
    public double getGammaWeight(int inputPosition, Transducer.State s2) {
        return this.gammas[inputPosition][s2.getIndex()];
    }

    public double getGammaWeight(int inputPosition, int stateIndex) {
        return this.gammas[inputPosition][stateIndex];
    }

    @Override
    public double getGammaProbability(int inputPosition, Transducer.State s2) {
        return Math.exp(this.gammas[inputPosition][s2.getIndex()]);
    }

    public double getGammaProbability(int inputPosition, int stateIndex) {
        return Math.exp(this.gammas[inputPosition][stateIndex]);
    }

    @Override
    public double getXiProbability(int ip, Transducer.State s1, Transducer.State s2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int i = s1.getIndex();
        int j = s2.getIndex();
        return Math.exp(this.xis[ip][i][j]);
    }

    @Override
    public double getXiWeight(int ip, Transducer.State s1, Transducer.State s2) {
        if (this.xis == null) {
            throw new IllegalStateException("xis were not saved.");
        }
        int i = s1.getIndex();
        int j = s2.getIndex();
        return this.xis[ip][i][j];
    }

    @Override
    public int length() {
        return this.latticeLength;
    }

    @Override
    public Sequence getInput() {
        return this.input;
    }

    @Override
    public double getAlpha(int ip, Transducer.State s2) {
        LatticeNode node = this.getLatticeNode(ip, s2.getIndex());
        return node.alpha;
    }

    @Override
    public double getBeta(int ip, Transducer.State s2) {
        LatticeNode node = this.getLatticeNode(ip, s2.getIndex());
        return node.beta;
    }

    @Override
    public LabelVector getLabelingAtPosition(int outputPosition) {
        if (this.labelings != null) {
            return this.labelings[outputPosition];
        }
        return null;
    }

    @Override
    public Transducer getTransducer() {
        return this.t;
    }

    protected class LatticeNode {
        int inputPosition;
        Transducer.State state;
        Object output;
        double alpha = Double.NEGATIVE_INFINITY;
        double beta = Double.NEGATIVE_INFINITY;

        LatticeNode(int inputPosition, Transducer.State state) {
            this.inputPosition = inputPosition;
            this.state = state;
            assert (this.alpha == Double.NEGATIVE_INFINITY);
        }
    }
}

