/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst.semi_supervised;

import cc.mallet.fst.CRF;
import cc.mallet.fst.Transducer;
import cc.mallet.fst.semi_supervised.constraints.GEConstraint;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.LogNumber;
import java.util.ArrayList;

public class GELattice {
    protected int latticeLength;
    protected Transducer transducer;
    protected int numStates;
    protected LatticeNode[][] lattice;
    protected LogNumber[][][] dotCache;

    public GELattice(FeatureVectorSequence fvs, double[][] gammas, double[][][] xis, Transducer transducer, int[][] reverseTrans, int[][] reverseTransIndices, CRF.Factors gradient, ArrayList<GEConstraint> constraints, boolean check) {
        assert (gradient != null);
        this.latticeLength = fvs.size() + 1;
        this.transducer = transducer;
        this.numStates = transducer.numStates();
        this.lattice = new LatticeNode[this.latticeLength][this.numStates];
        int ip = 0;
        while (ip < this.latticeLength) {
            int a = 0;
            while (a < this.numStates) {
                this.lattice[ip][a] = new LatticeNode();
                ++a;
            }
            ++ip;
        }
        this.dotCache = new LogNumber[this.latticeLength][this.numStates][this.numStates];
        ArrayList<GEConstraint> constraints1 = new ArrayList<GEConstraint>();
        ArrayList<GEConstraint> constraints2 = new ArrayList<GEConstraint>();
        for (GEConstraint constraint : constraints) {
            if (constraint.isOneStateConstraint()) {
                constraints1.add(constraint);
                continue;
            }
            constraints2.add(constraint);
        }
        CRF crf = (CRF)transducer;
        double dotEx = this.runForward(crf, constraints1, constraints2, gammas, xis, reverseTrans, fvs);
        this.runBackward(crf, gammas, xis, reverseTrans, reverseTransIndices, fvs, dotEx, gradient);
    }

    private double runForward(CRF crf, ArrayList<GEConstraint> constraints1, ArrayList<GEConstraint> constraints2, double[][] gammas, double[][][] xis, int[][] reverseTrans, FeatureVectorSequence fvs) {
        double dotEx = 0.0;
        LogNumber[] oneStateValueCache = new LogNumber[this.numStates];
        LogNumber nuAlpha = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber temp = new LogNumber(Double.NEGATIVE_INFINITY, true);
        int ip = 0;
        while (ip < this.latticeLength - 1) {
            FeatureVector fv = fvs.get(ip);
            for (GEConstraint constraint : constraints1) {
                constraint.preProcess(fv);
            }
            for (GEConstraint constraint : constraints2) {
                constraint.preProcess(fv);
            }
            boolean[] oneStateValComputed = new boolean[this.numStates];
            int prev = 0;
            while (prev < this.numStates) {
                nuAlpha.set(Double.NEGATIVE_INFINITY, true);
                if (ip != 0) {
                    int[] prevPrevs = reverseTrans[prev];
                    int ppi = 0;
                    while (ppi < prevPrevs.length) {
                        nuAlpha.plusEquals(this.lattice[ip - 1][prevPrevs[ppi]].alpha[prev]);
                        ++ppi;
                    }
                }
                assert (!Double.isNaN(nuAlpha.logVal));
                CRF.State prevState = (CRF.State)crf.getState(prev);
                LatticeNode node = this.lattice[ip][prev];
                double[] xi = xis[ip][prev];
                double gamma = gammas[ip][prev];
                int ci = 0;
                while (ci < prevState.numDestinations()) {
                    int curr = prevState.getDestinationState(ci).getIndex();
                    double dot = 0.0;
                    for (GEConstraint constraint : constraints2) {
                        dot += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
                    }
                    if (!oneStateValComputed[curr]) {
                        double osVal = 0.0;
                        for (GEConstraint constraint : constraints1) {
                            osVal += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
                        }
                        if (osVal < 0.0) {
                            dotEx += Math.exp(gammas[ip + 1][curr]) * osVal;
                            oneStateValueCache[curr] = new LogNumber(Math.log(-osVal), false);
                        } else if (osVal > 0.0) {
                            dotEx += Math.exp(gammas[ip + 1][curr]) * osVal;
                            oneStateValueCache[curr] = new LogNumber(Math.log(osVal), true);
                        } else {
                            oneStateValueCache[curr] = null;
                        }
                        oneStateValComputed[curr] = true;
                    }
                    if (dot == 0.0 && oneStateValueCache[curr] == null) {
                        this.dotCache[ip][prev][curr] = null;
                    } else if (dot == 0.0 && oneStateValueCache[curr] != null) {
                        this.dotCache[ip][prev][curr] = oneStateValueCache[curr];
                    } else {
                        dotEx += Math.exp(xi[curr]) * dot;
                        this.dotCache[ip][prev][curr] = dot < 0.0 ? new LogNumber(Math.log(-dot), false) : new LogNumber(Math.log(dot), true);
                        if (oneStateValueCache[curr] != null) {
                            this.dotCache[ip][prev][curr].plusEquals(oneStateValueCache[curr]);
                        }
                    }
                    if (this.dotCache[ip][prev][curr] != null) {
                        temp.set(xi[curr], true);
                        temp.timesEquals(this.dotCache[ip][prev][curr]);
                        node.alpha[curr].plusEquals(temp);
                    }
                    if (gamma == Double.NEGATIVE_INFINITY) {
                        node.alpha[curr] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                    } else {
                        temp.set(xi[curr] - gamma, true);
                        temp.timesEquals(nuAlpha);
                        node.alpha[curr].plusEquals(temp);
                    }
                    assert (!Double.isNaN(node.alpha[curr].logVal)) : "xi: " + xi[curr] + ", gamma: " + gamma + ", constraint feature: " + this.dotCache[ip][prev][curr] + ", nuApha: " + nuAlpha + " dot: " + dot;
                    ++ci;
                }
                ++prev;
            }
            ++ip;
        }
        return dotEx;
    }

    private void runBackward(CRF crf, double[][] gammas, double[][][] xis, int[][] reverseTrans, int[][] reverseTransIndices, FeatureVectorSequence fvs, double dotEx, CRF.Factors gradient) {
        LogNumber nuBeta = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber dot = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber temp = new LogNumber(Double.NEGATIVE_INFINITY, true);
        LogNumber temp2 = new LogNumber(Double.NEGATIVE_INFINITY, true);
        int ip = this.latticeLength - 2;
        while (ip >= 0) {
            int curr = 0;
            while (curr < this.numStates) {
                nuBeta.set(Double.NEGATIVE_INFINITY, true);
                dot.set(Double.NEGATIVE_INFINITY, true);
                CRF.State currState = (CRF.State)crf.getState(curr);
                int ni = 0;
                while (ni < currState.numDestinations()) {
                    int next = currState.getDestinationState(ni).getIndex();
                    nuBeta.plusEquals(this.lattice[ip + 1][curr].beta[next]);
                    assert (!Double.isNaN(nuBeta.logVal));
                    LogNumber nextDot = this.dotCache[ip + 1][curr][next];
                    if (nextDot != null) {
                        double xi = xis[ip + 1][curr][next];
                        temp.set(xi, true);
                        temp.timesEquals(nextDot);
                        dot.plusEquals(temp);
                    }
                    ++ni;
                }
                double gamma = gammas[ip + 1][curr];
                int[] prevStates = reverseTrans[curr];
                int pi = 0;
                while (pi < prevStates.length) {
                    int prev = prevStates[pi];
                    CRF.State crfState = (CRF.State)crf.getState(prev);
                    LatticeNode node = this.lattice[ip][prev];
                    double xi = xis[ip][prev][curr];
                    if (gamma == Double.NEGATIVE_INFINITY) {
                        node.beta[curr] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                    } else {
                        temp.set(dot.logVal, dot.sign);
                        temp.plusEquals(nuBeta);
                        temp2.set(xi - gamma, true);
                        temp.timesEquals(temp2);
                        node.beta[curr].plusEquals(temp);
                    }
                    assert (!Double.isNaN(node.beta[curr].logVal)) : "xi: " + xi + ", gamma: " + gamma + ", xi: " + xi + ", log(indicatorFeat): " + this.dotCache[ip][curr];
                    double transProb = Math.exp(xi);
                    double covFirstTerm = node.alpha[curr].exp() + node.beta[curr].exp();
                    double contribution = covFirstTerm - transProb * dotEx;
                    int nwi = crfState.getWeightNames(reverseTransIndices[curr][pi]).length;
                    int wi = 0;
                    while (wi < nwi) {
                        int weightsIndex = ((CRF)this.transducer).getWeightsIndex(crfState.getWeightNames(reverseTransIndices[curr][pi])[wi]);
                        gradient.weights[weightsIndex].plusEqualsSparse(fvs.get(ip), contribution);
                        int n = weightsIndex;
                        gradient.defaultWeights[n] = gradient.defaultWeights[n] + contribution;
                        ++wi;
                    }
                    ++pi;
                }
                ++curr;
            }
            --ip;
        }
    }

    public void check(ArrayList<GEConstraint> constraints, double[][] gammas, double[][][] xis, FeatureVectorSequence fvs) {
        double ex1 = 0.0;
        int ip = 0;
        while (ip < this.latticeLength - 1) {
            int si1 = 0;
            while (si1 < this.numStates) {
                int si2 = 0;
                while (si2 < this.numStates) {
                    double dot = 0.0;
                    for (GEConstraint constraint : constraints) {
                        dot += constraint.getCompositeConstraintFeatureValue(fvs.get(ip), ip, si1, si2);
                    }
                    double prob = Math.exp(xis[ip][si1][si2]);
                    ex1 += prob * dot;
                    ++si2;
                }
                ++si1;
            }
            ++ip;
        }
        double ex2 = 0.0;
        int ip2 = 0;
        while (ip2 < this.latticeLength - 1) {
            double ex3 = 0.0;
            int s1 = 0;
            while (s1 < this.numStates) {
                LatticeNode node = this.lattice[ip2][s1];
                int s2 = 0;
                while (s2 < this.numStates) {
                    ex3 += node.alpha[s2].exp() + node.beta[s2].exp();
                    ++s2;
                }
                ++s1;
            }
            assert (ex1 - ex3 < 1.0E-6) : String.valueOf(ex1) + " " + ex3;
            ex2 += ex3;
            ++ip2;
        }
        assert (ex1 - (ex2 /= (double)(this.latticeLength - 1)) < 1.0E-6) : String.valueOf(ex1) + " " + ex2;
    }

    public LogNumber getAlpha(int ip, int s1, int s2) {
        return this.lattice[ip][s1].alpha[s2];
    }

    public LogNumber getBeta(int ip, int s1, int s2) {
        return this.lattice[ip][s1].beta[s2];
    }

    protected class LatticeNode {
        protected LogNumber[] alpha;
        protected LogNumber[] beta;

        public LatticeNode() {
            this.alpha = new LogNumber[GELattice.this.numStates];
            this.beta = new LogNumber[GELattice.this.numStates];
            int si = 0;
            while (si < GELattice.this.numStates) {
                this.alpha[si] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                this.beta[si] = new LogNumber(Double.NEGATIVE_INFINITY, true);
                ++si;
            }
        }
    }
}

