/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.IDSorter;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;

public class MultinomialHMM {
    int numTopics;
    int numStates;
    int numDocs;
    int numSequences;
    double[] alpha;
    double alphaSum;
    double beta;
    double betaSum;
    double gamma;
    double gammaSum;
    double pi;
    double sumPi;
    TIntObjectHashMap<TIntIntHashMap> documentTopics;
    int[] documentSequenceIDs;
    int[] documentStates;
    int[][] stateTopicCounts;
    int[] stateTopicTotals;
    int[][] stateStateTransitions;
    int[] stateTransitionTotals;
    int[] initialStateCounts;
    int[] maxTokensPerTopic;
    int maxDocLength;
    double[][][] topicLogGammaCache;
    double[][] docLogGammaCache;
    int numIterations = 1000;
    int burninPeriod = 200;
    int saveSampleInterval = 10;
    int optimizeInterval = 0;
    int showTopicsInterval = 50;
    String[] topicKeys;
    Randoms random;
    NumberFormat formatter = NumberFormat.getInstance();

    public MultinomialHMM(int numberOfTopics, String topicsFilename, int numStates) throws IOException {
        this.formatter.setMaximumFractionDigits(5);
        System.out.println("LDA HMM: " + numberOfTopics);
        this.documentTopics = new TIntObjectHashMap();
        this.numTopics = numberOfTopics;
        this.alphaSum = numberOfTopics;
        this.alpha = new double[numberOfTopics];
        Arrays.fill(this.alpha, this.alphaSum / (double)this.numTopics);
        this.topicKeys = new String[this.numTopics];
        this.loadTopicsFromFile(topicsFilename);
        this.documentStates = new int[this.numDocs];
        this.documentSequenceIDs = new int[this.numDocs];
        this.maxTokensPerTopic = new int[this.numTopics];
        this.maxDocLength = 0;
        int doc = 0;
        while (doc < this.numDocs) {
            if (this.documentTopics.containsKey(doc)) {
                TIntIntHashMap topicCounts = this.documentTopics.get(doc);
                int count = 0;
                int[] nArray = topicCounts.keys();
                int n = nArray.length;
                int n2 = 0;
                while (n2 < n) {
                    int topic = nArray[n2];
                    int topicCount = topicCounts.get(topic);
                    if (topicCount > this.maxTokensPerTopic[topic]) {
                        this.maxTokensPerTopic[topic] = topicCount;
                    }
                    count += topicCount;
                    ++n2;
                }
                if (count > this.maxDocLength) {
                    this.maxDocLength = count;
                }
            }
            ++doc;
        }
        this.numStates = numStates;
        this.initialStateCounts = new int[numStates];
        this.topicLogGammaCache = new double[numStates][this.numTopics][];
        int state = 0;
        while (state < numStates) {
            int topic = 0;
            while (topic < this.numTopics) {
                this.topicLogGammaCache[state][topic] = new double[this.maxTokensPerTopic[topic] + 1];
                ++topic;
            }
            ++state;
        }
        System.out.println(this.maxDocLength);
        this.docLogGammaCache = new double[numStates][this.maxDocLength + 1];
    }

    public void setGamma(double g) {
        this.gamma = g;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplayInterval(int interval) {
        this.showTopicsInterval = interval;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void initialize() {
        if (this.random == null) {
            this.random = new Randoms();
        }
        this.gammaSum = this.gamma * (double)this.numStates;
        this.stateTopicCounts = new int[this.numStates][this.numTopics];
        this.stateTopicTotals = new int[this.numStates];
        this.stateStateTransitions = new int[this.numStates][this.numStates];
        this.stateTransitionTotals = new int[this.numStates];
        this.pi = 1000.0;
        this.sumPi = (double)this.numStates * this.pi;
        boolean maxTokens = false;
        boolean totalTokens = false;
        this.numSequences = 0;
        int currentSequenceID = -1;
        TIntIntHashMap allTopicsDummy = new TIntIntHashMap();
        int topic = 0;
        while (topic < this.numTopics) {
            allTopicsDummy.put(topic, 1);
            ++topic;
        }
        int state = 0;
        while (state < this.numStates) {
            this.recacheStateTopicDistribution(state, allTopicsDummy);
            ++state;
        }
        int doc = 0;
        while (doc < this.numDocs) {
            this.sampleState(doc, this.random, true);
            ++doc;
        }
    }

    private void recacheStateTopicDistribution(int state, TIntIntHashMap topicCounts) {
        int[] currentStateTopicCounts = this.stateTopicCounts[state];
        double[][] currentStateCache = this.topicLogGammaCache[state];
        int[] nArray = topicCounts.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int topic = nArray[n2];
            double[] cache = currentStateCache[topic];
            cache[0] = 0.0;
            int i = 1;
            while (i < cache.length) {
                cache[i] = cache[i - 1] + Math.log(this.alpha[topic] + (double)i - 1.0 + (double)currentStateTopicCounts[topic]);
                ++i;
            }
            ++n2;
        }
        this.docLogGammaCache[state][0] = 0.0;
        int i = 1;
        while (i < this.docLogGammaCache[state].length) {
            this.docLogGammaCache[state][i] = this.docLogGammaCache[state][i - 1] + Math.log(this.alphaSum + (double)i - 1.0 + (double)this.stateTopicTotals[state]);
            ++i;
        }
    }

    public void sample() throws IOException {
        long startTime = System.currentTimeMillis();
        int iterations = 1;
        while (iterations <= this.numIterations) {
            long iterationStart = System.currentTimeMillis();
            int doc = 0;
            while (doc < this.numDocs) {
                this.sampleState(doc, this.random, false);
                ++doc;
            }
            System.out.print(String.valueOf(System.currentTimeMillis() - iterationStart) + " ");
            if (iterations % 10 == 0) {
                System.out.println("<" + iterations + "> ");
                PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + iterations)));
                out.print(this.stateTransitionMatrix());
                out.close();
                out = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + iterations)));
                out.print(this.stateTopics());
                out.close();
                if (iterations % 10 == 0) {
                    out = new PrintWriter(new BufferedWriter(new FileWriter("states." + iterations)));
                    int doc2 = 0;
                    while (doc2 < this.documentStates.length) {
                        out.println(this.documentStates[doc2]);
                        ++doc2;
                    }
                    out.close();
                }
            }
            System.out.flush();
            ++iterations;
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void loadTopicsFromFile(String stateFilename) throws IOException {
        BufferedReader in = stateFilename.endsWith(".gz") ? new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFilename)))) : new BufferedReader(new FileReader(new File(stateFilename)));
        this.numDocs = 0;
        String line = null;
        while ((line = in.readLine()) != null) {
            if (line.startsWith("#")) continue;
            String[] fields = line.split(" ");
            int doc = Integer.parseInt(fields[0]);
            int token = Integer.parseInt(fields[1]);
            int type = Integer.parseInt(fields[2]);
            int topic = Integer.parseInt(fields[4]);
            if (!this.documentTopics.containsKey(doc)) {
                this.documentTopics.put(doc, new TIntIntHashMap());
            }
            if (this.documentTopics.get(doc).containsKey(topic)) {
                this.documentTopics.get(doc).increment(topic);
            } else {
                this.documentTopics.get(doc).put(topic, 1);
            }
            if (doc < this.numDocs) continue;
            this.numDocs = doc + 1;
        }
        in.close();
        System.out.println("loaded topics, " + this.numDocs + " documents");
    }

    public void loadAlphaFromFile(String alphaFilename) throws IOException {
        this.alphaSum = 0.0;
        BufferedReader in = new BufferedReader(new FileReader(new File(alphaFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {
            if (line.equals("")) continue;
            String[] fields = line.split("\\s+");
            int topic = Integer.parseInt(fields[0]);
            this.alpha[topic] = 1.0;
            this.alphaSum += this.alpha[topic];
            StringBuffer topicKey = new StringBuffer();
            int i = 2;
            while (i < fields.length) {
                topicKey.append(String.valueOf(fields[i]) + " ");
                ++i;
            }
            this.topicKeys[topic] = topicKey.toString();
        }
        in.close();
        System.out.println("loaded alpha");
    }

    public void loadSequenceIDsFromFile(String sequenceFilename) throws IOException {
        int doc = 0;
        int currentSequenceID = -1;
        BufferedReader in = new BufferedReader(new FileReader(new File(sequenceFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {
            int sequenceID;
            String[] fields = line.split("\\t");
            this.documentSequenceIDs[doc] = sequenceID = Integer.parseInt(fields[0]);
            if (sequenceID != currentSequenceID) {
                ++this.numSequences;
            }
            currentSequenceID = sequenceID;
            ++doc;
        }
        in.close();
        if (doc != this.numDocs) {
            System.out.println("Warning: number of documents with topics (" + this.numDocs + ") is not equal to number of docs with sequence IDs (" + doc + ")");
        }
        System.out.println("loaded sequence");
    }

    private void sampleState(int doc, Randoms r, boolean initializing) {
        int newState;
        int topic;
        int nextState;
        int previousState;
        int state;
        long startTime = System.currentTimeMillis();
        if (!this.documentTopics.containsKey(doc)) {
            return;
        }
        TIntIntHashMap topicCounts = this.documentTopics.get(doc);
        int oldState = this.documentStates[doc];
        int[] currentStateTopicCounts = this.stateTopicCounts[oldState];
        int docLength = 0;
        int[] nArray = topicCounts.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int topic2 = nArray[n2];
            int topicCount = topicCounts.get(topic2);
            if (!initializing) {
                int n3 = topic2;
                currentStateTopicCounts[n3] = currentStateTopicCounts[n3] - topicCount;
            }
            docLength += topicCount;
            ++n2;
        }
        if (!initializing) {
            int n4 = oldState;
            this.stateTopicTotals[n4] = this.stateTopicTotals[n4] - docLength;
            this.recacheStateTopicDistribution(oldState, topicCounts);
        }
        int previousSequenceID = -1;
        if (doc > 0) {
            previousSequenceID = this.documentSequenceIDs[doc - 1];
        }
        int sequenceID = this.documentSequenceIDs[doc];
        int nextSequenceID = -1;
        if (!initializing && doc < this.numDocs - 1) {
            nextSequenceID = this.documentSequenceIDs[doc + 1];
        }
        double[] stateLogLikelihoods = new double[this.numStates];
        double[] samplingDistribution = new double[this.numStates];
        if (initializing) {
            if (previousSequenceID != sequenceID) {
                state = 0;
                while (state < this.numStates) {
                    stateLogLikelihoods[state] = Math.log(((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
                    ++state;
                }
            } else {
                previousState = this.documentStates[doc - 1];
                state = 0;
                while (state < this.numStates) {
                    stateLogLikelihoods[state] = Math.log((double)this.stateStateTransitions[previousState][state] + this.gamma);
                    if (Double.isInfinite(stateLogLikelihoods[state])) {
                        System.out.println("infinite end");
                    }
                    ++state;
                }
            }
        } else if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
            int n5 = oldState;
            this.initialStateCounts[n5] = this.initialStateCounts[n5] - 1;
            state = 0;
            while (state < this.numStates) {
                stateLogLikelihoods[state] = Math.log(((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
                ++state;
            }
        } else if (previousSequenceID != sequenceID) {
            int n6 = oldState;
            this.initialStateCounts[n6] = this.initialStateCounts[n6] - 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray2 = this.stateStateTransitions[oldState];
            int n7 = nextState;
            nArray2[n7] = nArray2[n7] - 1;
            assert (this.stateStateTransitions[oldState][nextState] >= 0);
            int n8 = oldState;
            this.stateTransitionTotals[n8] = this.stateTransitionTotals[n8] - 1;
            state = 0;
            while (state < this.numStates) {
                stateLogLikelihoods[state] = Math.log(((double)this.stateStateTransitions[state][nextState] + this.gamma) * ((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
                if (Double.isInfinite(stateLogLikelihoods[state])) {
                    System.out.println("infinite beginning");
                }
                ++state;
            }
        } else if (sequenceID != nextSequenceID) {
            previousState = this.documentStates[doc - 1];
            int[] nArray3 = this.stateStateTransitions[previousState];
            int n9 = oldState;
            nArray3[n9] = nArray3[n9] - 1;
            assert (this.stateStateTransitions[previousState][oldState] >= 0);
            state = 0;
            while (state < this.numStates) {
                stateLogLikelihoods[state] = Math.log((double)this.stateStateTransitions[previousState][state] + this.gamma);
                if (Double.isInfinite(stateLogLikelihoods[state])) {
                    System.out.println("infinite end");
                }
                ++state;
            }
        } else {
            nextState = this.documentStates[doc + 1];
            int[] nArray4 = this.stateStateTransitions[oldState];
            int n10 = nextState;
            nArray4[n10] = nArray4[n10] - 1;
            if (this.stateStateTransitions[oldState][nextState] < 0) {
                System.out.println(this.printStateTransitions());
                System.out.println(String.valueOf(oldState) + " -> " + nextState);
                System.out.println(sequenceID);
            }
            assert (this.stateStateTransitions[oldState][nextState] >= 0);
            int n11 = oldState;
            this.stateTransitionTotals[n11] = this.stateTransitionTotals[n11] - 1;
            previousState = this.documentStates[doc - 1];
            int[] nArray5 = this.stateStateTransitions[previousState];
            int n12 = oldState;
            nArray5[n12] = nArray5[n12] - 1;
            assert (this.stateStateTransitions[previousState][oldState] >= 0);
            state = 0;
            while (state < this.numStates) {
                stateLogLikelihoods[state] = previousState == state && state == nextState ? Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)(this.stateStateTransitions[state][nextState] + 1) + this.gamma) / ((double)(this.stateTransitionTotals[state] + 1) + this.gammaSum)) : (previousState == state ? Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)this.stateStateTransitions[state][nextState] + this.gamma) / ((double)(this.stateTransitionTotals[state] + 1) + this.gammaSum)) : Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)this.stateStateTransitions[state][nextState] + this.gamma) / ((double)this.stateTransitionTotals[state] + this.gammaSum)));
                if (Double.isInfinite(stateLogLikelihoods[state])) {
                    System.out.println("infinite middle: " + doc);
                    System.out.println(String.valueOf(previousState) + " -> " + state + " -> " + nextState);
                    System.out.println(String.valueOf(this.stateStateTransitions[previousState][state]) + " -> " + this.stateStateTransitions[state][nextState] + " / " + this.stateTransitionTotals[state]);
                }
                ++state;
            }
        }
        double max = Double.NEGATIVE_INFINITY;
        int state2 = 0;
        while (state2 < this.numStates) {
            int n13 = state2;
            stateLogLikelihoods[n13] = stateLogLikelihoods[n13] - (double)(this.stateTransitionTotals[state2] / 10);
            currentStateTopicCounts = this.stateTopicCounts[state2];
            double[][] currentStateLogGammaCache = this.topicLogGammaCache[state2];
            boolean totalTokens = false;
            int[] nArray6 = topicCounts.keys();
            int n14 = nArray6.length;
            int n15 = 0;
            while (n15 < n14) {
                topic = nArray6[n15];
                int count = topicCounts.get(topic);
                int n16 = state2;
                stateLogLikelihoods[n16] = stateLogLikelihoods[n16] + currentStateLogGammaCache[topic][count];
                ++n15;
            }
            int n17 = state2;
            stateLogLikelihoods[n17] = stateLogLikelihoods[n17] - this.docLogGammaCache[state2][docLength];
            if (stateLogLikelihoods[state2] > max) {
                max = stateLogLikelihoods[state2];
            }
            ++state2;
        }
        double sum = 0.0;
        int state3 = 0;
        while (state3 < this.numStates) {
            if (Double.isNaN(samplingDistribution[state3])) {
                System.out.println(stateLogLikelihoods[state3]);
            }
            assert (!Double.isNaN(samplingDistribution[state3]));
            samplingDistribution[state3] = Math.exp(stateLogLikelihoods[state3] - max);
            sum += samplingDistribution[state3];
            if (Double.isNaN(samplingDistribution[state3])) {
                System.out.println(stateLogLikelihoods[state3]);
            }
            assert (!Double.isNaN(samplingDistribution[state3]));
            int cfr_ignored_0 = doc % 100;
            ++state3;
        }
        this.documentStates[doc] = newState = r.nextDiscrete(samplingDistribution, sum);
        topic = 0;
        while (topic < this.numTopics) {
            int[] nArray7 = this.stateTopicCounts[newState];
            int n18 = topic;
            nArray7[n18] = nArray7[n18] + topicCounts.get(topic);
            ++topic;
        }
        int n19 = newState;
        this.stateTopicTotals[n19] = this.stateTopicTotals[n19] + docLength;
        this.recacheStateTopicDistribution(newState, topicCounts);
        if (initializing) {
            if (previousSequenceID != sequenceID) {
                int n20 = newState;
                this.initialStateCounts[n20] = this.initialStateCounts[n20] + 1;
            } else {
                previousState = this.documentStates[doc - 1];
                int[] nArray8 = this.stateStateTransitions[previousState];
                int n21 = newState;
                nArray8[n21] = nArray8[n21] + 1;
                int n22 = newState;
                this.stateTransitionTotals[n22] = this.stateTransitionTotals[n22] + 1;
            }
        } else if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
            int n23 = newState;
            this.initialStateCounts[n23] = this.initialStateCounts[n23] + 1;
        } else if (previousSequenceID != sequenceID) {
            int n24 = newState;
            this.initialStateCounts[n24] = this.initialStateCounts[n24] + 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray9 = this.stateStateTransitions[newState];
            int n25 = nextState;
            nArray9[n25] = nArray9[n25] + 1;
            int n26 = newState;
            this.stateTransitionTotals[n26] = this.stateTransitionTotals[n26] + 1;
        } else if (sequenceID != nextSequenceID) {
            previousState = this.documentStates[doc - 1];
            int[] nArray10 = this.stateStateTransitions[previousState];
            int n27 = newState;
            nArray10[n27] = nArray10[n27] + 1;
        } else {
            previousState = this.documentStates[doc - 1];
            int[] nArray11 = this.stateStateTransitions[previousState];
            int n28 = newState;
            nArray11[n28] = nArray11[n28] + 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray12 = this.stateStateTransitions[newState];
            int n29 = nextState;
            nArray12[n29] = nArray12[n29] + 1;
            int n30 = newState;
            this.stateTransitionTotals[n30] = this.stateTransitionTotals[n30] + 1;
        }
    }

    public String printStateTransitions() {
        StringBuffer out = new StringBuffer();
        Object[] sortedTopics = new IDSorter[this.numTopics];
        int s = 0;
        while (s < this.numStates) {
            int topic = 0;
            while (topic < this.numTopics) {
                sortedTopics[topic] = new IDSorter(topic, (double)this.stateTopicCounts[s][topic] / (double)this.stateTopicTotals[s]);
                ++topic;
            }
            Arrays.sort(sortedTopics);
            out.append("\n" + s + "\n");
            int i = 0;
            while (i < 4) {
                int topic2 = ((IDSorter)sortedTopics[i]).getID();
                out.append(String.valueOf(this.stateTopicCounts[s][topic2]) + "\t" + this.topicKeys[topic2] + "\n");
                ++i;
            }
            out.append("\n");
            out.append("[" + this.initialStateCounts[s] + "/" + this.numSequences + "] ");
            out.append("[" + this.stateTransitionTotals[s] + "]");
            int t = 0;
            while (t < this.numStates) {
                out.append("\t");
                if (s == t) {
                    out.append("[" + this.stateStateTransitions[s][t] + "]");
                } else {
                    out.append(this.stateStateTransitions[s][t]);
                }
                ++t;
            }
            out.append("\n");
            ++s;
        }
        return out.toString();
    }

    public String stateTransitionMatrix() {
        StringBuffer out = new StringBuffer();
        int s = 0;
        while (s < this.numStates) {
            int t = 0;
            while (t < this.numStates) {
                out.append(this.stateStateTransitions[s][t]);
                out.append("\t");
                ++t;
            }
            out.append("\n");
            ++s;
        }
        return out.toString();
    }

    public String stateTopics() {
        StringBuffer out = new StringBuffer();
        int s = 0;
        while (s < this.numStates) {
            int topic = 0;
            while (topic < this.numTopics) {
                out.append(String.valueOf(this.stateTopicCounts[s][topic]) + "\t");
                ++topic;
            }
            out.append("\n");
            ++s;
        }
        return out.toString();
    }

    public static void main(String[] args) throws IOException {
        if (args.length != 4) {
            System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
            System.exit(0);
        }
        int numTopics = Integer.parseInt(args[0]);
        MultinomialHMM hmm = new MultinomialHMM(numTopics, args[1], 150);
        hmm.setGamma(1.0);
        hmm.setRandomSeed(1);
        hmm.loadAlphaFromFile(args[2]);
        hmm.loadSequenceIDsFromFile(args[3]);
        hmm.initialize();
        hmm.sample();
    }
}

