/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.cluster;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.stats.Statistics;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.FeatureExtractor;
import com.aliasi.util.Iterators;
import com.aliasi.util.Math;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.ObjectToDoubleMap;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;

public class LatentDirichletAllocation
implements Serializable {
    static final long serialVersionUID = 6313662446710242382L;
    private final double mDocTopicPrior;
    private final double[][] mTopicWordProbs;

    public LatentDirichletAllocation(double docTopicPrior, double[][] topicWordProbs) {
        if (docTopicPrior <= 0.0 || Double.isNaN(docTopicPrior) || Double.isInfinite(docTopicPrior)) {
            String msg = "Document-topic prior must be finite and positive. Found docTopicPrior=" + docTopicPrior;
            throw new IllegalArgumentException(msg);
        }
        int numTopics = topicWordProbs.length;
        if (numTopics < 1) {
            String msg = "Require non-empty topic-word probabilities.";
            throw new IllegalArgumentException(msg);
        }
        int numWords = topicWordProbs[0].length;
        int topic = 1;
        while (topic < numTopics) {
            if (topicWordProbs[topic].length != numWords) {
                String msg = "All topics must have the same number of words. topicWordProbs[0].length=" + topicWordProbs[0].length + " topicWordProbs[" + topic + "]=" + topicWordProbs[topic].length;
                throw new IllegalArgumentException(msg);
            }
            ++topic;
        }
        topic = 0;
        while (topic < numTopics) {
            int word = 0;
            while (word < numWords) {
                if (topicWordProbs[topic][word] < 0.0 || topicWordProbs[topic][word] > 1.0) {
                    String msg = "All probabilities must be between 0.0 and 1.0 Found topicWordProbs[" + topic + "][" + word + "]=" + topicWordProbs[topic][word];
                    throw new IllegalArgumentException(msg);
                }
                ++word;
            }
            ++topic;
        }
        this.mDocTopicPrior = docTopicPrior;
        this.mTopicWordProbs = topicWordProbs;
    }

    public int numTopics() {
        return this.mTopicWordProbs.length;
    }

    public int numWords() {
        return this.mTopicWordProbs[0].length;
    }

    public double documentTopicPrior() {
        return this.mDocTopicPrior;
    }

    public double wordProbability(int topic, int word) {
        return this.mTopicWordProbs[topic][word];
    }

    public double[] wordProbabilities(int topic) {
        double[] xs = new double[this.mTopicWordProbs[topic].length];
        int i = 0;
        while (i < xs.length) {
            xs[i] = this.mTopicWordProbs[topic][i];
            ++i;
        }
        return xs;
    }

    public short[][] sampleTopics(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) {
        if (burnin < 0) {
            String msg = "Burnin period must be non-negative. Found burnin=" + burnin;
            throw new IllegalArgumentException(msg);
        }
        if (numSamples < 1) {
            String msg = "Number of samples must be at least 1. Found numSamples=" + numSamples;
            throw new IllegalArgumentException(msg);
        }
        if (sampleLag < 1) {
            String msg = "Sample lag must be at least 1. Found sampleLag=" + sampleLag;
            throw new IllegalArgumentException(msg);
        }
        double docTopicPrior = this.documentTopicPrior();
        int numTokens = tokens.length;
        int numTopics = this.numTopics();
        int[] topicCount = new int[numTopics];
        short[][] samples = new short[numSamples][numTokens];
        int sample = 0;
        short[] currentSample = samples[0];
        int token = 0;
        while (token < numTokens) {
            int randomTopic;
            int n = randomTopic = random.nextInt(numTopics);
            topicCount[n] = topicCount[n] + 1;
            currentSample[token] = (short)randomTopic;
            ++token;
        }
        double[] topicDistro = new double[numTopics];
        int numEpochs = burnin + sampleLag * (numSamples - 1);
        int epoch = 0;
        while (epoch < numEpochs) {
            int token2 = 0;
            while (token2 < numTokens) {
                int sampledTopic;
                short currentTopic;
                int word = tokens[token2];
                short s = currentTopic = currentSample[token2];
                topicCount[s] = topicCount[s] - 1;
                if (topicCount[currentTopic] < 0) {
                    throw new IllegalArgumentException("bomb");
                }
                int topic = 0;
                while (topic < numTopics) {
                    topicDistro[topic] = ((double)topicCount[topic] + docTopicPrior) * this.wordProbability(topic, word) + (topic == 0 ? 0.0 : topicDistro[topic - 1]);
                    ++topic;
                }
                int n = sampledTopic = Statistics.sample(topicDistro, random);
                topicCount[n] = topicCount[n] + 1;
                currentSample[token2] = (short)sampledTopic;
                ++token2;
            }
            if (epoch >= burnin && (epoch - burnin) % sampleLag == 0) {
                short[] pastSample = currentSample;
                currentSample = samples[++sample];
                int token3 = 0;
                while (token3 < numTokens) {
                    currentSample[token3] = pastSample[token3];
                    ++token3;
                }
            }
            ++epoch;
        }
        return samples;
    }

    double[] mapTopicEstimate(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) {
        return this.bayesTopicEstimate(tokens, numSamples, burnin, sampleLag, random);
    }

    public double[] bayesTopicEstimate(int[] tokens, int numSamples, int burnin, int sampleLag, Random random) {
        short[][] sampleTopics = this.sampleTopics(tokens, numSamples, burnin, sampleLag, random);
        int numTopics = this.numTopics();
        int[] counts = new int[numTopics];
        short[][] sArray = sampleTopics;
        int n = sampleTopics.length;
        int n2 = 0;
        while (n2 < n) {
            short[] topics = sArray[n2];
            int tok = 0;
            while (tok < topics.length) {
                short s = topics[tok];
                counts[s] = counts[s] + 1;
                ++tok;
            }
            ++n2;
        }
        double totalCount = 0.0;
        int topic = 0;
        while (topic < numTopics) {
            totalCount += (double)counts[topic];
            ++topic;
        }
        double[] result = new double[numTopics];
        int topic2 = 0;
        while (topic2 < numTopics) {
            result[topic2] = (double)counts[topic2] / totalCount;
            ++topic2;
        }
        return result;
    }

    Object writeReplace() {
        return new Serializer(this);
    }

    public static GibbsSample gibbsSampler(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples, Random random, ObjectHandler<GibbsSample> handler) {
        LatentDirichletAllocation.validateInputs(docWords, numTopics, docTopicPrior, topicWordPrior, burninEpochs, sampleLag, numSamples);
        int numDocs = docWords.length;
        int numWords = LatentDirichletAllocation.max(docWords) + 1;
        int numTokens = 0;
        int doc = 0;
        while (doc < numDocs) {
            numTokens += docWords[doc].length;
            ++doc;
        }
        short[][] currentSample = new short[numDocs][];
        int doc2 = 0;
        while (doc2 < numDocs) {
            currentSample[doc2] = new short[docWords[doc2].length];
            ++doc2;
        }
        int[][] docTopicCount = new int[numDocs][numTopics];
        int[][] wordTopicCount = new int[numWords][numTopics];
        int[] topicTotalCount = new int[numTopics];
        int doc3 = 0;
        while (doc3 < numDocs) {
            int tok = 0;
            while (tok < docWords[doc3].length) {
                int word = docWords[doc3][tok];
                int topic = random.nextInt(numTopics);
                currentSample[doc3][tok] = (short)topic;
                int[] nArray = docTopicCount[doc3];
                int n = topic;
                nArray[n] = nArray[n] + 1;
                int[] nArray2 = wordTopicCount[word];
                int n2 = topic;
                nArray2[n2] = nArray2[n2] + 1;
                int n3 = topic;
                topicTotalCount[n3] = topicTotalCount[n3] + 1;
                ++tok;
            }
            ++doc3;
        }
        double numWordsTimesTopicWordPrior = (double)numWords * topicWordPrior;
        double[] topicDistro = new double[numTopics];
        int numEpochs = burninEpochs + sampleLag * (numSamples - 1);
        int epoch = 0;
        while (epoch <= numEpochs) {
            double corpusLog2Prob = 0.0;
            int numChangedTopics = 0;
            int doc4 = 0;
            while (doc4 < numDocs) {
                int[] docWordsDoc = docWords[doc4];
                short[] currentSampleDoc = currentSample[doc4];
                int[] docTopicCountDoc = docTopicCount[doc4];
                int tok = 0;
                while (tok < docWordsDoc.length) {
                    int topic;
                    int word = docWordsDoc[tok];
                    int[] wordTopicCountWord = wordTopicCount[word];
                    int currentTopic = currentSampleDoc[tok];
                    if (currentTopic == 0) {
                        topicDistro[0] = ((double)docTopicCountDoc[0] - 1.0 + docTopicPrior) * ((double)wordTopicCountWord[0] - 1.0 + topicWordPrior) / ((double)topicTotalCount[0] - 1.0 + numWordsTimesTopicWordPrior);
                    } else {
                        topicDistro[0] = ((double)docTopicCountDoc[0] + docTopicPrior) * ((double)wordTopicCountWord[0] + topicWordPrior) / ((double)topicTotalCount[0] + numWordsTimesTopicWordPrior);
                        topic = 1;
                        while (topic < currentTopic) {
                            topicDistro[topic] = ((double)docTopicCountDoc[topic] + docTopicPrior) * ((double)wordTopicCountWord[topic] + topicWordPrior) / ((double)topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                            ++topic;
                        }
                        topicDistro[currentTopic] = ((double)docTopicCountDoc[currentTopic] - 1.0 + docTopicPrior) * ((double)wordTopicCountWord[currentTopic] - 1.0 + topicWordPrior) / ((double)topicTotalCount[currentTopic] - 1.0 + numWordsTimesTopicWordPrior) + topicDistro[currentTopic - 1];
                    }
                    topic = currentTopic + 1;
                    while (topic < numTopics) {
                        topicDistro[topic] = ((double)docTopicCountDoc[topic] + docTopicPrior) * ((double)wordTopicCountWord[topic] + topicWordPrior) / ((double)topicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                        ++topic;
                    }
                    int sampledTopic = Statistics.sample(topicDistro, random);
                    if (sampledTopic != currentTopic) {
                        currentSampleDoc[tok] = (short)sampledTopic;
                        int n = currentTopic;
                        docTopicCountDoc[n] = docTopicCountDoc[n] - 1;
                        int n4 = currentTopic;
                        wordTopicCountWord[n4] = wordTopicCountWord[n4] - 1;
                        int n5 = currentTopic;
                        topicTotalCount[n5] = topicTotalCount[n5] - 1;
                        int n6 = sampledTopic;
                        docTopicCountDoc[n6] = docTopicCountDoc[n6] + 1;
                        int n7 = sampledTopic;
                        wordTopicCountWord[n7] = wordTopicCountWord[n7] + 1;
                        int n8 = sampledTopic;
                        topicTotalCount[n8] = topicTotalCount[n8] + 1;
                    }
                    if (sampledTopic != currentTopic) {
                        ++numChangedTopics;
                    }
                    double topicProbGivenDoc = (double)docTopicCountDoc[sampledTopic] / (double)docWordsDoc.length;
                    double wordProbGivenTopic = (double)wordTopicCountWord[sampledTopic] / (double)topicTotalCount[sampledTopic];
                    double tokenLog2Prob = Math.log2(topicProbGivenDoc * wordProbGivenTopic);
                    corpusLog2Prob += tokenLog2Prob;
                    ++tok;
                }
                ++doc4;
            }
            if (epoch >= burninEpochs && (epoch - burninEpochs) % sampleLag == 0) {
                GibbsSample sample = new GibbsSample(epoch, currentSample, docWords, docTopicPrior, topicWordPrior, docTopicCount, wordTopicCount, topicTotalCount, numChangedTopics, numWords, numTokens);
                if (handler != null) {
                    handler.handle(sample);
                }
                if (epoch == numEpochs) {
                    return sample;
                }
            }
            ++epoch;
        }
        throw new IllegalStateException("unreachable in practice because of return if epoch==numEpochs");
    }

    public static Iterator<GibbsSample> gibbsSample(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, Random random) {
        LatentDirichletAllocation.validateInputs(docWords, numTopics, docTopicPrior, topicWordPrior);
        return new SampleIterator(docWords, numTopics, docTopicPrior, topicWordPrior, random);
    }

    public static int[][] tokenizeDocuments(CharSequence[] texts, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, int minCount) {
        ObjectToCounterMap<String> tokenCounter = new ObjectToCounterMap<String>();
        CharSequence[] charSequenceArray = texts;
        int n = texts.length;
        int n2 = 0;
        while (n2 < n) {
            CharSequence text = charSequenceArray[n2];
            char[] cs = Strings.toCharArray(text);
            Tokenizer tokenizer = tokenizerFactory.tokenizer(cs, 0, cs.length);
            for (String token : tokenizer) {
                tokenCounter.increment(token);
            }
            ++n2;
        }
        tokenCounter.prune(minCount);
        Set tokenSet = tokenCounter.keySet();
        for (String token : tokenSet) {
            symbolTable.getOrAddSymbol(token);
        }
        int[][] docTokenId = new int[texts.length][];
        int i = 0;
        while (i < docTokenId.length) {
            docTokenId[i] = LatentDirichletAllocation.tokenizeDocument(texts[i], tokenizerFactory, symbolTable);
            ++i;
        }
        return docTokenId;
    }

    public static int[] tokenizeDocument(CharSequence text, TokenizerFactory tokenizerFactory, SymbolTable symbolTable) {
        char[] cs = Strings.toCharArray(text);
        Tokenizer tokenizer = tokenizerFactory.tokenizer(cs, 0, cs.length);
        ArrayList<Integer> idList = new ArrayList<Integer>();
        for (String token : tokenizer) {
            int id = symbolTable.symbolToID(token);
            if (id < 0) continue;
            idList.add(id);
        }
        int[] tokenIds = new int[idList.size()];
        int i = 0;
        while (i < tokenIds.length) {
            tokenIds[i] = (Integer)idList.get(i);
            ++i;
        }
        return tokenIds;
    }

    static int max(int[][] xs) {
        int max = 0;
        int i = 0;
        while (i < xs.length) {
            int[] xsI = xs[i];
            int j = 0;
            while (j < xsI.length) {
                if (xsI[j] > max) {
                    max = xsI[j];
                }
                ++j;
            }
            ++i;
        }
        return max;
    }

    static double relativeDifference(double x, double y) {
        return java.lang.Math.abs(x - y) / (java.lang.Math.abs(x) + java.lang.Math.abs(y));
    }

    static void validateInputs(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, int burninEpochs, int sampleLag, int numSamples) {
        LatentDirichletAllocation.validateInputs(docWords, numTopics, docTopicPrior, topicWordPrior);
        if (burninEpochs < 0) {
            String msg = "Number of burnin epochs must be non-negative. Found burninEpochs=" + burninEpochs;
            throw new IllegalArgumentException(msg);
        }
        if (sampleLag < 1) {
            String msg = "Sample lag must be positive. Found sampleLag=" + sampleLag;
            throw new IllegalArgumentException(msg);
        }
        if (numSamples < 1) {
            String msg = "Number of samples must be positive. Found numSamples=" + numSamples;
            throw new IllegalArgumentException(msg);
        }
    }

    static void validateInputs(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior) {
        int doc = 0;
        while (doc < docWords.length) {
            int tok = 0;
            while (tok < docWords[doc].length) {
                if (docWords[doc][tok] < 0) {
                    String msg = "All tokens must have IDs greater than 0. Found docWords[" + doc + "][" + tok + "]=" + docWords[doc][tok];
                    throw new IllegalArgumentException(msg);
                }
                ++tok;
            }
            ++doc;
        }
        if (numTopics < 1) {
            String msg = "Num topics must be positive. Found numTopics=" + numTopics;
            throw new IllegalArgumentException(msg);
        }
        if (Double.isInfinite(docTopicPrior) || Double.isNaN(docTopicPrior) || docTopicPrior < 0.0) {
            String msg = "Document-topic prior must be finite and positive. Found docTopicPrior=" + docTopicPrior;
            throw new IllegalArgumentException(msg);
        }
        if (Double.isInfinite(topicWordPrior) || Double.isNaN(topicWordPrior) || topicWordPrior < 0.0) {
            String msg = "Topic-word prior must be finite and positive. Found topicWordPrior=" + topicWordPrior;
            throw new IllegalArgumentException(msg);
        }
    }

    FeatureExtractor<CharSequence> expectedTopicFeatureExtractor(TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String featurePrefix) {
        return new ExpectedTopicFeatureExtractor(this, tokenizerFactory, symbolTable, featurePrefix);
    }

    FeatureExtractor<CharSequence> bayesTopicFeatureExtractor(TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String featurePrefix, int burnIn, int sampleLag, int numSamples) {
        return new BayesTopicFeatureExtractor(this, tokenizerFactory, symbolTable, featurePrefix, burnIn, sampleLag, numSamples);
    }

    static String[] genFeatures(String prefix, int numTopics) {
        String[] features = new String[numTopics];
        int k = 0;
        while (k < numTopics) {
            features[k] = String.valueOf(prefix) + k;
            ++k;
        }
        return features;
    }

    static class BayesTopicFeatureExtractor
    implements FeatureExtractor<CharSequence>,
    Serializable {
        static final long serialVersionUID = 8883227852502200365L;
        private final LatentDirichletAllocation mLda;
        private final TokenizerFactory mTokenizerFactory;
        private final SymbolTable mSymbolTable;
        private final String[] mFeatures;
        private final int mBurnin;
        private final int mSampleLag;
        private final int mNumSamples;

        public BayesTopicFeatureExtractor(LatentDirichletAllocation lda, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String featurePrefix, int burnin, int sampleLag, int numSamples) {
            this(lda, tokenizerFactory, symbolTable, LatentDirichletAllocation.genFeatures(featurePrefix, lda.numTopics()), burnin, sampleLag, numSamples);
        }

        BayesTopicFeatureExtractor(LatentDirichletAllocation lda, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String[] features, int burnin, int sampleLag, int numSamples) {
            this.mLda = lda;
            this.mTokenizerFactory = tokenizerFactory;
            this.mSymbolTable = symbolTable;
            this.mFeatures = features;
            this.mBurnin = burnin;
            this.mSampleLag = sampleLag;
            this.mNumSamples = numSamples;
        }

        @Override
        public Map<String, Double> features(CharSequence cSeq) {
            int numTopics = this.mLda.numTopics();
            char[] cs = Strings.toCharArray(cSeq);
            Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
            ArrayList<Integer> tokenIdList = new ArrayList<Integer>();
            for (String token : tokenizer) {
                int symbol = this.mSymbolTable.symbolToID(token);
                if (symbol < 0 || symbol >= this.mLda.numWords()) continue;
                tokenIdList.add(symbol);
            }
            int[] tokens = new int[tokenIdList.size()];
            int i = 0;
            while (i < tokenIdList.size()) {
                tokens[i] = (Integer)tokenIdList.get(i);
                ++i;
            }
            double[] vals = this.mLda.mapTopicEstimate(tokens, this.mNumSamples, this.mBurnin, this.mSampleLag, new Random());
            ObjectToDoubleMap<String> features = new ObjectToDoubleMap<String>(numTopics * 3 / 2);
            int k = 0;
            while (k < numTopics) {
                if (vals[k] > 0.0) {
                    features.set(this.mFeatures[k], vals[k]);
                }
                ++k;
            }
            return features;
        }

        Object writeReplace() {
            return new Serializer(this);
        }

        static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = 6719636683732661958L;
            final BayesTopicFeatureExtractor mFeatureExtractor;

            public Serializer() {
                this(null);
            }

            Serializer(BayesTopicFeatureExtractor featureExtractor) {
                this.mFeatureExtractor = featureExtractor;
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeObject(this.mFeatureExtractor.mLda);
                out.writeObject(this.mFeatureExtractor.mTokenizerFactory);
                out.writeObject(this.mFeatureExtractor.mSymbolTable);
                Serializer.writeUTFs(this.mFeatureExtractor.mFeatures, out);
                out.writeInt(this.mFeatureExtractor.mBurnin);
                out.writeInt(this.mFeatureExtractor.mSampleLag);
                out.writeInt(this.mFeatureExtractor.mNumSamples);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                LatentDirichletAllocation lda = (LatentDirichletAllocation)in.readObject();
                TokenizerFactory tokenizerFactory = (TokenizerFactory)in.readObject();
                SymbolTable symbolTable = (SymbolTable)in.readObject();
                String[] features = Serializer.readUTFs(in);
                int burnIn = in.readInt();
                int sampleLag = in.readInt();
                int numSamples = in.readInt();
                return new BayesTopicFeatureExtractor(lda, tokenizerFactory, symbolTable, features, burnIn, sampleLag, numSamples);
            }
        }
    }

    static class ExpectedTopicFeatureExtractor
    implements FeatureExtractor<CharSequence>,
    Serializable {
        static final long serialVersionUID = -7996546432550775177L;
        private final double[][] mWordTopicProbs;
        private final double mDocTopicPrior;
        private final TokenizerFactory mTokenizerFactory;
        private final SymbolTable mSymbolTable;
        private final String[] mFeatures;

        public ExpectedTopicFeatureExtractor(LatentDirichletAllocation lda, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String featurePrefix) {
            double[][] wordTopicProbs = new double[lda.numWords()][lda.numTopics()];
            int word = 0;
            while (word < lda.numWords()) {
                int topic = 0;
                while (topic < lda.numTopics()) {
                    wordTopicProbs[word][topic] = lda.wordProbability(topic, word);
                    ++topic;
                }
                ++word;
            }
            double[][] dArray = wordTopicProbs;
            int n = wordTopicProbs.length;
            int n2 = 0;
            while (n2 < n) {
                double[] topicProbs = dArray[n2];
                double sum = Math.sum(topicProbs);
                int k = 0;
                while (k < topicProbs.length) {
                    int n3 = k++;
                    topicProbs[n3] = topicProbs[n3] / sum;
                }
                ++n2;
            }
            this.mWordTopicProbs = wordTopicProbs;
            this.mDocTopicPrior = lda.documentTopicPrior();
            this.mTokenizerFactory = tokenizerFactory;
            this.mSymbolTable = symbolTable;
            this.mFeatures = LatentDirichletAllocation.genFeatures(featurePrefix, lda.numTopics());
        }

        ExpectedTopicFeatureExtractor(double docTopicPrior, double[][] wordTopicProbs, TokenizerFactory tokenizerFactory, SymbolTable symbolTable, String[] features) {
            this.mWordTopicProbs = wordTopicProbs;
            this.mDocTopicPrior = docTopicPrior;
            this.mTokenizerFactory = tokenizerFactory;
            this.mSymbolTable = symbolTable;
            this.mFeatures = features;
        }

        @Override
        public Map<String, Double> features(CharSequence cSeq) {
            int k;
            int numTopics = this.mWordTopicProbs[0].length;
            char[] cs = Strings.toCharArray(cSeq);
            Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, 0, cs.length);
            double[] vals = new double[numTopics];
            Arrays.fill(vals, this.mDocTopicPrior);
            for (String token : tokenizer) {
                int symbol = this.mSymbolTable.symbolToID(token);
                if (symbol < 0 || symbol >= this.mWordTopicProbs.length) continue;
                k = 0;
                while (k < numTopics) {
                    int n = k;
                    vals[n] = vals[n] + this.mWordTopicProbs[symbol][k];
                    ++k;
                }
            }
            ObjectToDoubleMap<String> featMap = new ObjectToDoubleMap<String>(numTopics * 3 / 2);
            double sum = Math.sum(vals);
            k = 0;
            while (k < numTopics) {
                if (vals[k] > 0.0) {
                    featMap.set(this.mFeatures[k], vals[k] / sum);
                }
                ++k;
            }
            return featMap;
        }

        Object writeReplace() {
            return new Serializer(this);
        }

        static class Serializer
        extends AbstractExternalizable {
            static final long serialVersionUID = -1472744781627035426L;
            final ExpectedTopicFeatureExtractor mFeatures;

            public Serializer() {
                this(null);
            }

            public Serializer(ExpectedTopicFeatureExtractor features) {
                this.mFeatures = features;
            }

            @Override
            public void writeExternal(ObjectOutput out) throws IOException {
                out.writeDouble(this.mFeatures.mDocTopicPrior);
                out.writeInt(this.mFeatures.mWordTopicProbs.length);
                int w = 0;
                while (w < this.mFeatures.mWordTopicProbs.length) {
                    Serializer.writeDoubles(this.mFeatures.mWordTopicProbs[w], out);
                    ++w;
                }
                out.writeObject(this.mFeatures.mTokenizerFactory);
                out.writeObject(this.mFeatures.mSymbolTable);
                Serializer.writeUTFs(this.mFeatures.mFeatures, out);
            }

            @Override
            public Object read(ObjectInput in) throws IOException, ClassNotFoundException {
                double docTopicPrior = in.readDouble();
                int numWords = in.readInt();
                double[][] wordTopicProbs = new double[numWords][];
                int w = 0;
                while (w < numWords) {
                    wordTopicProbs[w] = Serializer.readDoubles(in);
                    ++w;
                }
                TokenizerFactory tokenizerFactory = (TokenizerFactory)in.readObject();
                SymbolTable symbolTable = (SymbolTable)in.readObject();
                String[] features = Serializer.readUTFs(in);
                return new ExpectedTopicFeatureExtractor(docTopicPrior, wordTopicProbs, tokenizerFactory, symbolTable, features);
            }
        }
    }

    public static class GibbsSample {
        private final int mEpoch;
        private final short[][] mTopicSample;
        private final int[][] mDocWords;
        private final double mDocTopicPrior;
        private final double mTopicWordPrior;
        private final int[][] mDocTopicCount;
        private final int[][] mWordTopicCount;
        private final int[] mTopicCount;
        private final int mNumChangedTopics;
        private final int mNumWords;
        private final int mNumTokens;

        GibbsSample(int epoch, short[][] topicSample, int[][] docWords, double docTopicPrior, double topicWordPrior, int[][] docTopicCount, int[][] wordTopicCount, int[] topicCount, int numChangedTopics, int numWords, int numTokens) {
            this.mEpoch = epoch;
            this.mTopicSample = topicSample;
            this.mDocWords = docWords;
            this.mDocTopicPrior = docTopicPrior;
            this.mTopicWordPrior = topicWordPrior;
            this.mDocTopicCount = docTopicCount;
            this.mWordTopicCount = wordTopicCount;
            this.mTopicCount = topicCount;
            this.mNumChangedTopics = numChangedTopics;
            this.mNumWords = numWords;
            this.mNumTokens = numTokens;
        }

        public int epoch() {
            return this.mEpoch;
        }

        public int numDocuments() {
            return this.mDocWords.length;
        }

        public int numWords() {
            return this.mNumWords;
        }

        public int numTokens() {
            return this.mNumTokens;
        }

        public int numTopics() {
            return this.mTopicCount.length;
        }

        public short topicSample(int doc, int token) {
            return this.mTopicSample[doc][token];
        }

        public int word(int doc, int token) {
            return this.mDocWords[doc][token];
        }

        public double documentTopicPrior() {
            return this.mDocTopicPrior;
        }

        public double topicWordPrior() {
            return this.mTopicWordPrior;
        }

        public int documentTopicCount(int doc, int topic) {
            return this.mDocTopicCount[doc][topic];
        }

        public int documentLength(int doc) {
            return this.mDocWords[doc].length;
        }

        public int topicWordCount(int topic, int word) {
            return this.mWordTopicCount[word][topic];
        }

        public int topicCount(int topic) {
            return this.mTopicCount[topic];
        }

        public int numChangedTopics() {
            return this.mNumChangedTopics;
        }

        public double topicWordProb(int topic, int word) {
            return ((double)this.topicWordCount(topic, word) + this.topicWordPrior()) / ((double)this.topicCount(topic) + (double)this.numWords() * this.topicWordPrior());
        }

        public int wordCount(int word) {
            int count = 0;
            int topic = 0;
            while (topic < this.numTopics()) {
                count += this.topicWordCount(topic, word);
                ++topic;
            }
            return count;
        }

        public double documentTopicProb(int doc, int topic) {
            return ((double)this.documentTopicCount(doc, topic) + this.documentTopicPrior()) / ((double)this.documentLength(doc) + (double)this.numTopics() * this.documentTopicPrior());
        }

        public double corpusLog2Probability() {
            double corpusLog2Prob = 0.0;
            int numDocs = this.numDocuments();
            int numTopics = this.numTopics();
            int doc = 0;
            while (doc < numDocs) {
                int docLength = this.documentLength(doc);
                int token = 0;
                while (token < docLength) {
                    int word = this.word(doc, token);
                    double wordProb = 0.0;
                    int topic = 0;
                    while (topic < numTopics) {
                        double wordTopicProbGivenDoc = this.topicWordProb(topic, word) * this.documentTopicProb(doc, topic);
                        wordProb += wordTopicProbGivenDoc;
                        ++topic;
                    }
                    corpusLog2Prob += Math.log2(wordProb);
                    ++token;
                }
                ++doc;
            }
            return corpusLog2Prob;
        }

        public LatentDirichletAllocation lda() {
            int numTopics = this.numTopics();
            int numWords = this.numWords();
            double topicWordPrior = this.topicWordPrior();
            double[][] topicWordProbs = new double[numTopics][numWords];
            int topic = 0;
            while (topic < numTopics) {
                double topicCount = this.topicCount(topic);
                double denominator = topicCount + (double)numWords * topicWordPrior;
                int word = 0;
                while (word < numWords) {
                    topicWordProbs[topic][word] = ((double)this.topicWordCount(topic, word) + topicWordPrior) / denominator;
                    ++word;
                }
                ++topic;
            }
            return new LatentDirichletAllocation(this.mDocTopicPrior, topicWordProbs);
        }
    }

    static class SampleIterator
    extends Iterators.Buffered<GibbsSample> {
        private final int[][] mDocWords;
        private final short mNumTopics;
        private final double mDocTopicPrior;
        private final double mTopicWordPrior;
        private final Random mRandom;
        private final int mNumDocs;
        private final int mNumWords;
        private final int mNumTokens;
        private final short[][] mCurrentSample;
        private final int[][] mDocTopicCount;
        private final int[][] mWordTopicCount;
        private final int[] mTopicTotalCount;
        private int mNumChangedTopics;
        private int mEpoch = 0;

        SampleIterator(int[][] docWords, short numTopics, double docTopicPrior, double topicWordPrior, Random random) {
            this.mDocWords = docWords;
            this.mNumTopics = numTopics;
            this.mDocTopicPrior = docTopicPrior;
            this.mTopicWordPrior = topicWordPrior;
            this.mRandom = random;
            this.mNumDocs = this.mDocWords.length;
            this.mNumWords = LatentDirichletAllocation.max(this.mDocWords) + 1;
            int numTokens = 0;
            int doc = 0;
            while (doc < this.mNumDocs) {
                numTokens += this.mDocWords[doc].length;
                ++doc;
            }
            this.mNumTokens = numTokens;
            this.mNumChangedTopics = numTokens;
            this.mCurrentSample = new short[this.mNumDocs][];
            doc = 0;
            while (doc < this.mNumDocs) {
                this.mCurrentSample[doc] = new short[this.mDocWords[doc].length];
                ++doc;
            }
            this.mDocTopicCount = new int[this.mNumDocs][numTopics];
            this.mWordTopicCount = new int[this.mNumWords][numTopics];
            this.mTopicTotalCount = new int[numTopics];
            doc = 0;
            while (doc < this.mNumDocs) {
                int tok = 0;
                while (tok < docWords[doc].length) {
                    int word = docWords[doc][tok];
                    int topic = this.mRandom.nextInt(numTopics);
                    this.mCurrentSample[doc][tok] = (short)topic;
                    int[] nArray = this.mDocTopicCount[doc];
                    int n = topic;
                    nArray[n] = nArray[n] + 1;
                    int[] nArray2 = this.mWordTopicCount[word];
                    int n2 = topic;
                    nArray2[n2] = nArray2[n2] + 1;
                    int n3 = topic;
                    this.mTopicTotalCount[n3] = this.mTopicTotalCount[n3] + 1;
                    ++tok;
                }
                ++doc;
            }
        }

        @Override
        protected GibbsSample bufferNext() {
            GibbsSample sample = new GibbsSample(this.mEpoch, this.mCurrentSample, this.mDocWords, this.mDocTopicPrior, this.mTopicWordPrior, this.mDocTopicCount, this.mWordTopicCount, this.mTopicTotalCount, this.mNumChangedTopics, this.mNumWords, this.mNumTokens);
            ++this.mEpoch;
            double numWordsTimesTopicWordPrior = (double)this.mNumWords * this.mTopicWordPrior;
            double[] topicDistro = new double[this.mNumTopics];
            int numChangedTopics = 0;
            int doc = 0;
            while (doc < this.mNumDocs) {
                int[] docWordsDoc = this.mDocWords[doc];
                short[] currentSampleDoc = this.mCurrentSample[doc];
                int[] docTopicCountDoc = this.mDocTopicCount[doc];
                int tok = 0;
                while (tok < docWordsDoc.length) {
                    int topic;
                    int word = docWordsDoc[tok];
                    int[] wordTopicCountWord = this.mWordTopicCount[word];
                    int currentTopic = currentSampleDoc[tok];
                    if (currentTopic == 0) {
                        topicDistro[0] = ((double)docTopicCountDoc[0] - 1.0 + this.mDocTopicPrior) * ((double)wordTopicCountWord[0] - 1.0 + this.mTopicWordPrior) / ((double)this.mTopicTotalCount[0] - 1.0 + numWordsTimesTopicWordPrior);
                    } else {
                        topicDistro[0] = ((double)docTopicCountDoc[0] + this.mDocTopicPrior) * ((double)wordTopicCountWord[0] + this.mTopicWordPrior) / ((double)this.mTopicTotalCount[0] + numWordsTimesTopicWordPrior);
                        topic = 1;
                        while (topic < currentTopic) {
                            topicDistro[topic] = ((double)docTopicCountDoc[topic] + this.mDocTopicPrior) * ((double)wordTopicCountWord[topic] + this.mTopicWordPrior) / ((double)this.mTopicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                            ++topic;
                        }
                        topicDistro[currentTopic] = ((double)docTopicCountDoc[currentTopic] - 1.0 + this.mDocTopicPrior) * ((double)wordTopicCountWord[currentTopic] - 1.0 + this.mTopicWordPrior) / ((double)this.mTopicTotalCount[currentTopic] - 1.0 + numWordsTimesTopicWordPrior) + topicDistro[currentTopic - 1];
                    }
                    topic = currentTopic + 1;
                    while (topic < this.mNumTopics) {
                        topicDistro[topic] = ((double)docTopicCountDoc[topic] + this.mDocTopicPrior) * ((double)wordTopicCountWord[topic] + this.mTopicWordPrior) / ((double)this.mTopicTotalCount[topic] + numWordsTimesTopicWordPrior) + topicDistro[topic - 1];
                        ++topic;
                    }
                    int sampledTopic = Statistics.sample(topicDistro, this.mRandom);
                    if (sampledTopic != currentTopic) {
                        currentSampleDoc[tok] = (short)sampledTopic;
                        int n = currentTopic;
                        docTopicCountDoc[n] = docTopicCountDoc[n] - 1;
                        int n2 = currentTopic;
                        wordTopicCountWord[n2] = wordTopicCountWord[n2] - 1;
                        int n3 = currentTopic;
                        this.mTopicTotalCount[n3] = this.mTopicTotalCount[n3] - 1;
                        int n4 = sampledTopic;
                        docTopicCountDoc[n4] = docTopicCountDoc[n4] + 1;
                        int n5 = sampledTopic;
                        wordTopicCountWord[n5] = wordTopicCountWord[n5] + 1;
                        int n6 = sampledTopic;
                        this.mTopicTotalCount[n6] = this.mTopicTotalCount[n6] + 1;
                        ++numChangedTopics;
                    }
                    ++tok;
                }
                ++doc;
            }
            this.mNumChangedTopics = numChangedTopics;
            return sample;
        }
    }

    static class Serializer
    extends AbstractExternalizable {
        static final long serialVersionUID = 4725870665020270825L;
        final LatentDirichletAllocation mLda;

        public Serializer() {
            this(null);
        }

        public Serializer(LatentDirichletAllocation lda) {
            this.mLda = lda;
        }

        @Override
        public Object read(ObjectInput in) throws IOException {
            double docTopicPrior = in.readDouble();
            int numTopics = in.readInt();
            double[][] topicWordProbs = new double[numTopics][];
            int i = 0;
            while (i < topicWordProbs.length) {
                topicWordProbs[i] = Serializer.readDoubles(in);
                ++i;
            }
            return new LatentDirichletAllocation(docTopicPrior, topicWordProbs);
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            out.writeDouble(this.mLda.mDocTopicPrior);
            out.writeInt(this.mLda.mTopicWordProbs.length);
            double[][] dArray = this.mLda.mTopicWordProbs;
            int n = dArray.length;
            int n2 = 0;
            while (n2 < n) {
                double[] topicWordProbs = dArray[n2];
                Serializer.writeDoubles(topicWordProbs, out);
                ++n2;
            }
        }
    }
}

