/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.fst;

import cc.mallet.fst.MaxLattice;
import cc.mallet.fst.MaxLatticeFactory;
import cc.mallet.fst.Transducer;
import cc.mallet.types.ArraySequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.SequencePairAlignment;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.search.AStar;
import cc.mallet.util.search.AStarState;
import cc.mallet.util.search.SearchNode;
import cc.mallet.util.search.SearchState;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;

public class MaxLatticeDefault
implements MaxLattice {
    private static Logger logger = MalletLogger.getLogger(MaxLatticeDefault.class.getName());
    private Transducer t;
    private Sequence<Object> input;
    private Sequence<Object> providedOutput;
    private int latticeLength;
    private ViterbiNode[][] lattice;
    private WeightCache first;
    private WeightCache last;
    private WeightCache[] caches;
    private int numCaches;
    private int maxCaches;
    private List<SequencePairAlignment<Object, ViterbiNode>> viterbiNodeAlignmentCache = null;
    private List<SequencePairAlignment<Object, Transducer.State>> stateAlignmentCache = null;
    private List<SequencePairAlignment<Object, Object>> outputAlignmentCache = null;

    @Override
    public Transducer getTransducer() {
        return this.t;
    }

    public Sequence getInput() {
        return this.input;
    }

    public Sequence getProvidedOutput() {
        return this.providedOutput;
    }

    private WeightCache getCache(int position) {
        WeightCache cache = this.caches[position];
        if (cache == null) {
            if (this.numCaches < this.maxCaches) {
                cache = new WeightCache(position);
                if (this.numCaches++ == 0) {
                    this.first = this.last = cache;
                }
            } else {
                cache = this.last;
                this.caches[((WeightCache)cache).position] = null;
                cache.init(position);
            }
            int i = 0;
            while (i < this.t.numStates()) {
                if (this.lattice[position][i] != null && this.lattice[position][i].delta != Double.NEGATIVE_INFINITY) {
                    Transducer.State s = this.t.getState(i);
                    Transducer.TransitionIterator iter = s.transitionIterator(this.input, position, this.providedOutput, position);
                    while (iter.hasNext()) {
                        Transducer.State d = iter.next();
                        ((WeightCache)cache).weight[i][d.getIndex()] = iter.getWeight();
                    }
                }
                ++i;
            }
            this.caches[position] = cache;
        }
        if (cache != this.first) {
            if (cache == this.last) {
                this.last = cache.prev;
            }
            if (cache.prev != null) {
                cache.prev.next = cache.next;
            }
            cache.next = this.first;
            cache.prev = null;
            this.first.prev = cache;
            this.first = cache;
        }
        return cache;
    }

    protected ViterbiNode getViterbiNode(int ip, int stateIndex) {
        if (this.lattice[ip][stateIndex] == null) {
            this.lattice[ip][stateIndex] = new ViterbiNode(ip, this.t.getState(stateIndex));
        }
        return this.lattice[ip][stateIndex];
    }

    public MaxLatticeDefault(Transducer t, Sequence inputSequence) {
        this(t, inputSequence, null, 100000);
    }

    public MaxLatticeDefault(Transducer t, Sequence inputSequence, Sequence outputSequence) {
        this(t, inputSequence, outputSequence, 100000);
    }

    public MaxLatticeDefault(Transducer t, Sequence inputSequence, Sequence outputSequence, int maxCaches) {
        this.t = t;
        if (maxCaches < 1) {
            maxCaches = 1;
        }
        this.maxCaches = maxCaches;
        assert (inputSequence != null);
        if (logger.isLoggable(Level.FINE)) {
            logger.fine("Starting ViterbiLattice");
            logger.fine("Input: ");
            int ip = 0;
            while (ip < inputSequence.size()) {
                logger.fine(" " + inputSequence.get(ip));
                ++ip;
            }
            logger.fine("\nOutput: ");
            if (outputSequence == null) {
                logger.fine("null");
            } else {
                int op = 0;
                while (op < outputSequence.size()) {
                    logger.fine(" " + outputSequence.get(op));
                    ++op;
                }
            }
            logger.fine("\n");
        }
        this.input = inputSequence;
        this.providedOutput = outputSequence;
        this.latticeLength = this.input.size() + 1;
        int numStates = t.numStates();
        this.lattice = new ViterbiNode[this.latticeLength][numStates];
        this.caches = new WeightCache[this.latticeLength - 1];
        logger.fine("Starting Viterbi");
        boolean anyInitialState = false;
        int i = 0;
        while (i < numStates) {
            double initialWeight = t.getState(i).getInitialWeight();
            if (initialWeight > Double.NEGATIVE_INFINITY) {
                ViterbiNode n = this.getViterbiNode(0, i);
                n.delta = initialWeight;
                anyInitialState = true;
            }
            ++i;
        }
        if (!anyInitialState) {
            logger.warning("Viterbi: No initial states!");
        }
        int ip = 0;
        while (ip < this.latticeLength - 1) {
            int i2 = 0;
            while (i2 < numStates) {
                if (this.lattice[ip][i2] != null && this.lattice[ip][i2].delta != Double.NEGATIVE_INFINITY) {
                    Transducer.State s = t.getState(i2);
                    Transducer.TransitionIterator iter = s.transitionIterator(this.input, ip, this.providedOutput, ip);
                    if (logger.isLoggable(Level.FINE)) {
                        logger.fine(" Starting Viterbi transition iteration from state " + s.getName() + " on input " + this.input.get(ip));
                    }
                    while (iter.hasNext()) {
                        Transducer.State destination = iter.next();
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("Viterbi[inputPos=" + ip + "][source=" + s.getName() + "][dest=" + destination.getName() + "]");
                        }
                        ViterbiNode destinationNode = this.getViterbiNode(ip + 1, destination.getIndex());
                        destinationNode.output = iter.getOutput();
                        double weight = this.lattice[ip][i2].delta + iter.getWeight();
                        if (ip == this.latticeLength - 2) {
                            weight += destination.getFinalWeight();
                        }
                        if (!(weight > destinationNode.delta)) continue;
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine("Viterbi[inputPos=" + ip + "][source][dest=" + destination.getName() + "] weight increased to " + weight + " by source=" + s.getName());
                        }
                        destinationNode.delta = weight;
                        destinationNode.maxWeightPredecessor = this.lattice[ip][i2];
                    }
                }
                ++i2;
            }
            ++ip;
        }
    }

    @Override
    public double getDelta(int ip, int stateIndex) {
        if (this.lattice != null) {
            return this.getViterbiNode((int)ip, (int)stateIndex).delta;
        }
        throw new RuntimeException("Attempt to called getDelta() when lattice not stored.");
    }

    public List<SequencePairAlignment<Object, ViterbiNode>> bestViterbiNodeSequences(int n) {
        if (this.viterbiNodeAlignmentCache != null && this.viterbiNodeAlignmentCache.size() >= n) {
            return this.viterbiNodeAlignmentCache;
        }
        int numFinal = 0;
        int i = 0;
        while (i < this.t.numStates()) {
            if (this.lattice[this.latticeLength - 1][i] != null && this.lattice[this.latticeLength - 1][i].delta > Double.NEGATIVE_INFINITY) {
                ++numFinal;
            }
            ++i;
        }
        AStarState[] finalNodes = new ViterbiNode[numFinal];
        int f = 0;
        int i2 = 0;
        while (i2 < this.t.numStates()) {
            if (this.lattice[this.latticeLength - 1][i2] != null && this.lattice[this.latticeLength - 1][i2].delta > Double.NEGATIVE_INFINITY) {
                finalNodes[f++] = this.lattice[this.latticeLength - 1][i2];
            }
            ++i2;
        }
        AStar search = new AStar(finalNodes, this.latticeLength * this.t.numStates());
        ArrayList<SequencePairAlignment<Object, ViterbiNode>> outputs = new ArrayList<SequencePairAlignment<Object, ViterbiNode>>(n);
        int i3 = 0;
        while (i3 < n && search.hasNext()) {
            SearchNode ans = search.next();
            double weight = -ans.getCost();
            ViterbiNode[] seq = new ViterbiNode[this.latticeLength];
            int j = 0;
            while (j < this.latticeLength) {
                ViterbiNode v = (ViterbiNode)ans.getState();
                assert (v.inputPosition == j);
                seq[j] = v;
                ans = ans.getParent();
                ++j;
            }
            outputs.add(new SequencePairAlignment<Object, ViterbiNode>(this.input, new ArraySequence<ViterbiNode>(seq), weight));
            ++i3;
        }
        this.viterbiNodeAlignmentCache = outputs;
        return outputs;
    }

    public List<SequencePairAlignment<Object, Transducer.State>> bestStateAlignments(int n) {
        if (this.stateAlignmentCache != null && this.stateAlignmentCache.size() >= n) {
            return this.stateAlignmentCache;
        }
        this.bestViterbiNodeSequences(n);
        ArrayList<SequencePairAlignment<Object, Transducer.State>> ret = new ArrayList<SequencePairAlignment<Object, Transducer.State>>(n);
        int i = 0;
        while (i < n) {
            Transducer.State[] ss = new Transducer.State[this.latticeLength];
            Sequence vs = this.viterbiNodeAlignmentCache.get(i).output();
            int j = 0;
            while (j < this.latticeLength) {
                ss[j] = ((ViterbiNode)vs.get((int)j)).state;
                ++j;
            }
            ret.add(new SequencePairAlignment<Object, Transducer.State>(this.input, new ArraySequence<Transducer.State>(ss), this.viterbiNodeAlignmentCache.get(i).getWeight()));
            ++i;
        }
        this.stateAlignmentCache = ret;
        return ret;
    }

    public SequencePairAlignment<Object, Transducer.State> bestStateAlignment() {
        return this.bestStateAlignments(1).get(0);
    }

    @Override
    public List<Sequence<Transducer.State>> bestStateSequences(int n) {
        List<SequencePairAlignment<Object, Transducer.State>> a = this.bestStateAlignments(n);
        ArrayList<Sequence<Transducer.State>> ret = new ArrayList<Sequence<Transducer.State>>(n);
        int i = 0;
        while (i < n) {
            ret.add(a.get(i).output());
            ++i;
        }
        return ret;
    }

    @Override
    public Sequence<Transducer.State> bestStateSequence() {
        return this.bestStateAlignments(1).get(0).output();
    }

    public List<SequencePairAlignment<Object, Object>> bestOutputAlignments(int n) {
        if (this.outputAlignmentCache != null && this.outputAlignmentCache.size() >= n) {
            return this.outputAlignmentCache;
        }
        this.bestViterbiNodeSequences(n);
        ArrayList<SequencePairAlignment<Object, Object>> ret = new ArrayList<SequencePairAlignment<Object, Object>>(n);
        int i = 0;
        while (i < n) {
            Object[] ss = new Object[this.latticeLength - 1];
            Sequence vs = this.viterbiNodeAlignmentCache.get(i).output();
            int j = 0;
            while (j < this.latticeLength - 1) {
                ss[j] = ((ViterbiNode)vs.get((int)(j + 1))).output;
                ++j;
            }
            ret.add(new SequencePairAlignment<Object, Object>(this.input, new ArraySequence<Object>(ss), this.viterbiNodeAlignmentCache.get(i).getWeight()));
            ++i;
        }
        this.outputAlignmentCache = ret;
        return ret;
    }

    public SequencePairAlignment<Object, Object> bestOutputAlignment() {
        return this.bestOutputAlignments(1).get(0);
    }

    @Override
    public List<Sequence<Object>> bestOutputSequences(int n) {
        this.bestOutputAlignments(n);
        ArrayList<Sequence<Object>> ret = new ArrayList<Sequence<Object>>(n);
        int i = 0;
        while (i < n) {
            ret.add(this.outputAlignmentCache.get(i).output());
            ++i;
        }
        return ret;
    }

    @Override
    public Sequence<Object> bestOutputSequence() {
        return this.bestOutputAlignments(1).get(0).output();
    }

    public double bestWeight() {
        return this.bestOutputAlignments(1).get(0).getWeight();
    }

    public void incrementTransducer(Transducer.Incrementor incrementor) {
        SequencePairAlignment<Object, ViterbiNode> viterbiNodeAlignment = this.bestViterbiNodeSequences(1).get(0);
        int sequenceLength = viterbiNodeAlignment.output().size();
        assert (sequenceLength == viterbiNodeAlignment.input().size());
        incrementor.incrementInitialState(((ViterbiNode)viterbiNodeAlignment.output().get((int)0)).state, 1.0);
        incrementor.incrementFinalState(((ViterbiNode)viterbiNodeAlignment.output().get((int)(sequenceLength - 1))).state, 1.0);
        int ip = 0;
        while (ip < viterbiNodeAlignment.input().size() - 1) {
            Transducer.TransitionIterator iter = ((ViterbiNode)viterbiNodeAlignment.output().get((int)ip)).state.transitionIterator(this.input, ip, this.providedOutput, ip);
            int numIncrements = 0;
            while (iter.hasNext()) {
                if (!iter.next().equals(((ViterbiNode)viterbiNodeAlignment.output().get((int)(ip + 1))).state) || !iter.getOutput().equals(((ViterbiNode)viterbiNodeAlignment.output().get((int)ip)).output)) continue;
                incrementor.incrementTransition(iter, 1.0);
                ++numIncrements;
            }
            if (numIncrements > 1) {
                throw new IllegalStateException("More than one satisfying transition found.");
            }
            if (numIncrements == 0) {
                throw new IllegalStateException("No satisfying transition found.");
            }
            ++ip;
        }
    }

    @Override
    public double elementwiseAccuracy(Sequence referenceOutput) {
        int accuracy = 0;
        Sequence<Object> output = this.bestOutputSequence();
        assert (referenceOutput.size() == output.size());
        int i = 0;
        while (i < output.size()) {
            if (referenceOutput.get(i).toString().equals(output.get(i).toString())) {
                ++accuracy;
            }
            ++i;
        }
        logger.info("Number correct: " + accuracy + " out of " + output.size());
        return (double)accuracy / (double)output.size();
    }

    public double tokenAccuracy(Sequence referenceOutput, PrintWriter out) {
        Sequence<Object> output = this.bestOutputSequence();
        int accuracy = 0;
        assert (referenceOutput.size() == output.size());
        int i = 0;
        while (i < output.size()) {
            String testString = output.get(i).toString();
            if (out != null) {
                out.println(testString);
            }
            if (referenceOutput.get(i).toString().equals(testString)) {
                ++accuracy;
            }
            ++i;
        }
        logger.info("Number correct: " + accuracy + " out of " + output.size());
        return (double)accuracy / (double)output.size();
    }

    public static class Factory
    extends MaxLatticeFactory
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        @Override
        public MaxLattice newMaxLattice(Transducer trans, Sequence inputSequence, Sequence outputSequence) {
            return new MaxLatticeDefault(trans, inputSequence, outputSequence);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.readInt();
        }
    }

    private class ViterbiNode
    implements AStarState {
        int inputPosition;
        Transducer.State state;
        Object output;
        double delta = Double.NEGATIVE_INFINITY;
        ViterbiNode maxWeightPredecessor = null;

        ViterbiNode(int inputPosition, Transducer.State state) {
            this.inputPosition = inputPosition;
            this.state = state;
        }

        @Override
        public double completionCost() {
            return -this.delta;
        }

        @Override
        public boolean isFinal() {
            return this.inputPosition == 0 && this.state.getInitialWeight() > Double.NEGATIVE_INFINITY;
        }

        @Override
        public SearchState.NextStateIterator getNextStates() {
            return new PreviousStateIterator();
        }

        private class PreviousStateIterator
        extends SearchState.NextStateIterator {
            private int prev = 0;
            private boolean found;
            private double weight;
            private double[] weights;

            private PreviousStateIterator() {
                if (ViterbiNode.this.inputPosition > 0) {
                    int j = ViterbiNode.this.state.getIndex();
                    this.weights = new double[MaxLatticeDefault.this.t.numStates()];
                    WeightCache c = MaxLatticeDefault.this.getCache(ViterbiNode.this.inputPosition - 1);
                    int s = 0;
                    while (s < MaxLatticeDefault.this.t.numStates()) {
                        this.weights[s] = c.weight[s][j];
                        ++s;
                    }
                }
            }

            private void lookAhead() {
                if (this.weights != null && !this.found) {
                    while (this.prev < MaxLatticeDefault.this.t.numStates()) {
                        if (this.weights[this.prev] > Double.NEGATIVE_INFINITY) {
                            this.found = true;
                            return;
                        }
                        ++this.prev;
                    }
                }
            }

            @Override
            public boolean hasNext() {
                this.lookAhead();
                return this.weights != null && this.prev < MaxLatticeDefault.this.t.numStates();
            }

            @Override
            public SearchState nextState() {
                this.lookAhead();
                this.weight = this.weights[this.prev++];
                this.found = false;
                return MaxLatticeDefault.this.getViterbiNode(ViterbiNode.this.inputPosition - 1, this.prev - 1);
            }

            @Override
            public double cost() {
                return -this.weight;
            }

            public double weight() {
                return this.weight;
            }
        }
    }

    private class WeightCache {
        private WeightCache prev;
        private WeightCache next;
        private double[][] weight;
        private int position;

        private WeightCache(int position) {
            this.weight = new double[MaxLatticeDefault.this.t.numStates()][MaxLatticeDefault.this.t.numStates()];
            this.init(position);
        }

        private void init(int position) {
            this.position = position;
            int i = 0;
            while (i < MaxLatticeDefault.this.t.numStates()) {
                int j = 0;
                while (j < MaxLatticeDefault.this.t.numStates()) {
                    this.weight[i][j] = Double.NEGATIVE_INFINITY;
                    ++j;
                }
                ++i;
            }
        }
    }
}

