/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.MarginalProbEstimator;
import cc.mallet.topics.TopicAssignment;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.topics.WorkerRunnable;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Formatter;
import java.util.Iterator;
import java.util.Locale;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

public class ParallelTopicModel
implements Serializable {
    public static final int UNASSIGNED_TOPIC = -1;
    public static Logger logger = MalletLogger.getLogger(ParallelTopicModel.class.getName());
    public ArrayList<TopicAssignment> data = new ArrayList();
    public Alphabet alphabet;
    public LabelAlphabet topicAlphabet;
    public int numTopics;
    public int topicMask;
    public int topicBits;
    public int numTypes;
    public int totalTokens;
    public double[] alpha;
    public double alphaSum;
    public double beta;
    public double betaSum;
    public boolean usingSymmetricAlpha = false;
    public static final double DEFAULT_BETA = 0.01;
    public int[][] typeTopicCounts;
    public int[] tokensPerTopic;
    public int[] docLengthCounts;
    public int[][] topicDocCounts;
    public int numIterations = 1000;
    public int burninPeriod = 200;
    public int saveSampleInterval = 10;
    public int optimizeInterval = 50;
    public int temperingInterval = 0;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 7;
    public int saveStateInterval = 0;
    public String stateFilename = null;
    public int saveModelInterval = 0;
    public String modelFilename = null;
    public int randomSeed = -1;
    public NumberFormat formatter;
    public boolean printLogLikelihood = true;
    int[] typeTotals;
    int maxTypeCount;
    int numThreads = 1;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public ParallelTopicModel(int numberOfTopics) {
        this(numberOfTopics, (double)numberOfTopics, 0.01);
    }

    public ParallelTopicModel(int numberOfTopics, double alphaSum, double beta) {
        this(ParallelTopicModel.newLabelAlphabet(numberOfTopics), alphaSum, beta);
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        int i = 0;
        while (i < numTopics) {
            ret.lookupIndex("topic" + i);
            ++i;
        }
        return ret;
    }

    public ParallelTopicModel(LabelAlphabet topicAlphabet, double alphaSum, double beta) {
        this.topicAlphabet = topicAlphabet;
        this.alphaSum = alphaSum;
        this.beta = beta;
        this.setNumTopics(topicAlphabet.size());
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Mallet LDA: " + this.numTopics + " topics, " + this.topicBits + " topic bits, " + Integer.toBinaryString(this.topicMask) + " topic mask");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public void setNumTopics(int numTopics) {
        this.numTopics = numTopics;
        if (Integer.bitCount(numTopics) == 1) {
            this.topicMask = numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alpha = new double[numTopics];
        Arrays.fill(this.alpha, this.alphaSum / (double)numTopics);
        this.tokensPerTopic = new int[numTopics];
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.randomSeed = seed;
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
        if (this.saveSampleInterval > this.optimizeInterval) {
            this.saveSampleInterval = this.optimizeInterval;
        }
    }

    public void setSymmetricAlpha(boolean b) {
        this.usingSymmetricAlpha = b;
    }

    public void setTemperingInterval(int interval) {
        this.temperingInterval = interval;
    }

    public void setNumThreads(int threads) {
        this.numThreads = threads;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    public void setSaveSerializedModel(int interval, String filename) {
        this.saveModelInterval = interval;
        this.modelFilename = filename;
    }

    public void addInstances(InstanceList training) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        Randoms random = null;
        random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
        for (Instance instance : training) {
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            int position = 0;
            while (position < topics.length) {
                int topic;
                topics[position] = topic = random.nextInt(this.numTopics);
                ++position;
            }
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
        this.buildInitialTypeTopicCounts();
        this.initializeHistograms();
    }

    public void initializeFromState(File stateFile) throws IOException {
        String[] fields;
        BufferedReader reader = new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFile))));
        String line = reader.readLine();
        while (line.startsWith("#")) {
            if (line.startsWith("#alpha : ")) {
                line = line.replace("#alpha : ", "");
                fields = line.split(" ");
                this.setNumTopics(fields.length);
                this.alphaSum = 0.0;
                int topic = 0;
                while (topic < fields.length) {
                    this.alpha[topic] = Double.parseDouble(fields[topic]);
                    this.alphaSum += this.alpha[topic];
                    ++topic;
                }
            } else if (line.startsWith("#beta : ")) {
                line = line.replace("#beta : ", "");
                this.beta = Double.parseDouble(line);
                this.betaSum = this.beta * (double)this.numTypes;
            }
            line = reader.readLine();
        }
        fields = line.split(" ");
        for (TopicAssignment document : this.data) {
            FeatureSequence tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topicSequence = document.topicSequence;
            int[] topics = topicSequence.getFeatures();
            int position = 0;
            while (position < tokens.size()) {
                int type = tokens.getIndexAtPosition(position);
                if (type != Integer.parseInt(fields[3])) {
                    System.err.println("instance list and state do not match: " + line);
                    throw new IllegalStateException();
                }
                topics[position] = Integer.parseInt(fields[5]);
                line = reader.readLine();
                if (line != null) {
                    fields = line.split(" ");
                }
                ++position;
            }
        }
        this.buildInitialTypeTopicCounts();
        this.initializeHistograms();
    }

    public void buildInitialTypeTopicCounts() {
        FeatureSequence tokens;
        this.typeTopicCounts = new int[this.numTypes][];
        this.tokensPerTopic = new int[this.numTopics];
        this.typeTotals = new int[this.numTypes];
        for (TopicAssignment document : this.data) {
            tokens = (FeatureSequence)document.instance.getData();
            int position = 0;
            while (position < tokens.getLength()) {
                int type;
                int n = type = tokens.getIndexAtPosition(position);
                this.typeTotals[n] = this.typeTotals[n] + 1;
                ++position;
            }
        }
        this.maxTypeCount = 0;
        int type = 0;
        while (type < this.numTypes) {
            if (this.typeTotals[type] > this.maxTypeCount) {
                this.maxTypeCount = this.typeTotals[type];
            }
            this.typeTopicCounts[type] = new int[Math.min(this.numTopics, this.typeTotals[type])];
            ++type;
        }
        for (TopicAssignment document : this.data) {
            tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topicSequence = document.topicSequence;
            int[] topics = topicSequence.getFeatures();
            int position = 0;
            while (position < tokens.size()) {
                int topic = topics[position];
                if (topic != -1) {
                    int n = topic;
                    this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                    int type2 = tokens.getIndexAtPosition(position);
                    int[] currentTypeTopicCounts = this.typeTopicCounts[type2];
                    int index = 0;
                    int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                        if (++index == currentTypeTopicCounts.length) {
                            logger.info("overflow on type " + type2);
                        }
                        currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                    }
                    int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    if (currentValue == 0) {
                        currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                    } else {
                        currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                        while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                            int temp = currentTypeTopicCounts[index];
                            currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                            currentTypeTopicCounts[index - 1] = temp;
                            --index;
                        }
                    }
                }
                ++position;
            }
        }
    }

    public void sumTypeTopicCounts(WorkerRunnable[] runnables) {
        Arrays.fill(this.tokensPerTopic, 0);
        int type = 0;
        while (type < this.numTypes) {
            int[] targetCounts = this.typeTopicCounts[type];
            int position = 0;
            while (position < targetCounts.length && targetCounts[position] > 0) {
                targetCounts[position] = 0;
                ++position;
            }
            ++type;
        }
        int thread = 0;
        while (thread < this.numThreads) {
            int[] sourceTotals = runnables[thread].getTokensPerTopic();
            int topic = 0;
            while (topic < this.numTopics) {
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + sourceTotals[topic];
                ++topic;
            }
            int[][] sourceTypeTopicCounts = runnables[thread].getTypeTopicCounts();
            int type2 = 0;
            while (type2 < this.numTypes) {
                int[] sourceCounts = sourceTypeTopicCounts[type2];
                int[] targetCounts = this.typeTopicCounts[type2];
                int sourceIndex = 0;
                while (sourceIndex < sourceCounts.length && sourceCounts[sourceIndex] > 0) {
                    int topic2 = sourceCounts[sourceIndex] & this.topicMask;
                    int count = sourceCounts[sourceIndex] >> this.topicBits;
                    int targetIndex = 0;
                    int currentTopic = targetCounts[targetIndex] & this.topicMask;
                    while (targetCounts[targetIndex] > 0 && currentTopic != topic2) {
                        if (++targetIndex == targetCounts.length) {
                            logger.info("overflow in merging on type " + type2);
                        }
                        currentTopic = targetCounts[targetIndex] & this.topicMask;
                    }
                    int currentCount = targetCounts[targetIndex] >> this.topicBits;
                    targetCounts[targetIndex] = (currentCount + count << this.topicBits) + topic2;
                    while (targetIndex > 0 && targetCounts[targetIndex] > targetCounts[targetIndex - 1]) {
                        int temp = targetCounts[targetIndex];
                        targetCounts[targetIndex] = targetCounts[targetIndex - 1];
                        targetCounts[targetIndex - 1] = temp;
                        --targetIndex;
                    }
                    ++sourceIndex;
                }
                ++type2;
            }
            ++thread;
        }
    }

    private void initializeHistograms() {
        int maxTokens = 0;
        this.totalTokens = 0;
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            this.totalTokens += seqLen;
            ++doc;
        }
        logger.info("max tokens: " + maxTokens);
        logger.info("total tokens: " + this.totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    /*
     * Unable to fully structure code
     */
    public void optimizeAlpha(WorkerRunnable[] runnables) {
        block17: {
            Arrays.fill(this.docLengthCounts, 0);
            topic = 0;
            while (topic < this.topicDocCounts.length) {
                Arrays.fill(this.topicDocCounts[topic], 0);
                ++topic;
            }
            thread = 0;
            while (thread < this.numThreads) {
                sourceLengthCounts = runnables[thread].getDocLengthCounts();
                sourceTopicCounts = runnables[thread].getTopicDocCounts();
                count = 0;
                while (count < sourceLengthCounts.length) {
                    if (sourceLengthCounts[count] > 0) {
                        v0 = count;
                        this.docLengthCounts[v0] = this.docLengthCounts[v0] + sourceLengthCounts[count];
                        sourceLengthCounts[count] = 0;
                    }
                    ++count;
                }
                topic = 0;
                while (topic < this.numTopics) {
                    if (!this.usingSymmetricAlpha) {
                        count = 0;
                        while (count < sourceTopicCounts[topic].length) {
                            if (sourceTopicCounts[topic][count] > 0) {
                                v1 = this.topicDocCounts[topic];
                                v2 = count;
                                v1[v2] = v1[v2] + sourceTopicCounts[topic][count];
                                sourceTopicCounts[topic][count] = 0;
                            }
                            ++count;
                        }
                    } else {
                        count = 0;
                        while (count < sourceTopicCounts[topic].length) {
                            if (sourceTopicCounts[topic][count] > 0) {
                                v3 = this.topicDocCounts[0];
                                v4 = count;
                                v3[v4] = v3[v4] + sourceTopicCounts[topic][count];
                                sourceTopicCounts[topic][count] = 0;
                            }
                            ++count;
                        }
                    }
                    ++topic;
                }
                ++thread;
            }
            if (this.usingSymmetricAlpha) {
                this.alphaSum = Dirichlet.learnSymmetricConcentration(this.topicDocCounts[0], this.docLengthCounts, this.numTopics, this.alphaSum);
                topic = 0;
                while (topic < this.numTopics) {
                    this.alpha[topic] = this.alphaSum / (double)this.numTopics;
                    ++topic;
                }
            } else {
                try {
                    this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts, 1.001, 1.0, 1);
                    break block17;
                }
                catch (RuntimeException e) {
                    ParallelTopicModel.logger.warning("Dirichlet optimization has become unstable. Resetting to alpha_t = 1.0.");
                    this.alphaSum = this.numTopics;
                    topic = 0;
                    ** while (topic < this.numTopics)
                }
lbl-1000:
                // 1 sources

                {
                    this.alpha[topic] = 1.0;
                    ++topic;
                    continue;
lbl64:
                    // 1 sources

                    break;
                }
            }
        }
    }

    public void temperAlpha(WorkerRunnable[] runnables) {
        Arrays.fill(this.docLengthCounts, 0);
        int topic = 0;
        while (topic < this.topicDocCounts.length) {
            Arrays.fill(this.topicDocCounts[topic], 0);
            ++topic;
        }
        int thread = 0;
        while (thread < this.numThreads) {
            int[] sourceLengthCounts = runnables[thread].getDocLengthCounts();
            int[][] sourceTopicCounts = runnables[thread].getTopicDocCounts();
            int count = 0;
            while (count < sourceLengthCounts.length) {
                if (sourceLengthCounts[count] > 0) {
                    sourceLengthCounts[count] = 0;
                }
                ++count;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                int count2 = 0;
                while (count2 < sourceTopicCounts[topic2].length) {
                    if (sourceTopicCounts[topic2][count2] > 0) {
                        sourceTopicCounts[topic2][count2] = 0;
                    }
                    ++count2;
                }
                ++topic2;
            }
            ++thread;
        }
        topic = 0;
        while (topic < this.numTopics) {
            this.alpha[topic] = 1.0;
            ++topic;
        }
        this.alphaSum = this.numTopics;
    }

    public void optimizeBeta(WorkerRunnable[] runnables) {
        int[] countHistogram = new int[this.maxTypeCount + 1];
        int type = 0;
        while (type < this.numTypes) {
            int[] counts = this.typeTopicCounts[type];
            int index = 0;
            while (index < counts.length && counts[index] > 0) {
                int count;
                int n = count = counts[index] >> this.topicBits;
                countHistogram[n] = countHistogram[n] + 1;
                ++index;
            }
            ++type;
        }
        int maxTopicSize = 0;
        int topic = 0;
        while (topic < this.numTopics) {
            if (this.tokensPerTopic[topic] > maxTopicSize) {
                maxTopicSize = this.tokensPerTopic[topic];
            }
            ++topic;
        }
        int[] topicSizeHistogram = new int[maxTopicSize + 1];
        int topic2 = 0;
        while (topic2 < this.numTopics) {
            int n = this.tokensPerTopic[topic2];
            topicSizeHistogram[n] = topicSizeHistogram[n] + 1;
            ++topic2;
        }
        this.betaSum = Dirichlet.learnSymmetricConcentration(countHistogram, topicSizeHistogram, this.numTypes, this.betaSum);
        this.beta = this.betaSum / (double)this.numTypes;
        logger.info("[beta: " + this.formatter.format(this.beta) + "] ");
        int thread = 0;
        while (thread < this.numThreads) {
            runnables[thread].resetBeta(this.beta, this.betaSum);
            ++thread;
        }
    }

    public void estimate() throws IOException {
        long startTime = System.currentTimeMillis();
        WorkerRunnable[] runnables = new WorkerRunnable[this.numThreads];
        int docsPerThread = this.data.size() / this.numThreads;
        int offset = 0;
        if (this.numThreads > 1) {
            int thread = 0;
            while (thread < this.numThreads) {
                int[] runnableTotals = new int[this.numTopics];
                System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
                int[][] runnableCounts = new int[this.numTypes][];
                int type = 0;
                while (type < this.numTypes) {
                    int[] counts = new int[this.typeTopicCounts[type].length];
                    System.arraycopy(this.typeTopicCounts[type], 0, counts, 0, counts.length);
                    runnableCounts[type] = counts;
                    ++type;
                }
                if (thread == this.numThreads - 1) {
                    docsPerThread = this.data.size() - offset;
                }
                Randoms random = null;
                random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
                runnables[thread] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, random, this.data, runnableCounts, runnableTotals, offset, docsPerThread);
                runnables[thread].initializeAlphaStatistics(this.docLengthCounts.length);
                offset += docsPerThread;
                ++thread;
            }
        } else {
            Randoms random = null;
            random = this.randomSeed == -1 ? new Randoms() : new Randoms(this.randomSeed);
            runnables[0] = new WorkerRunnable(this.numTopics, this.alpha, this.alphaSum, this.beta, random, this.data, this.typeTopicCounts, this.tokensPerTopic, offset, docsPerThread);
            runnables[0].initializeAlphaStatistics(this.docLengthCounts.length);
            runnables[0].makeOnlyThread();
        }
        ExecutorService executor = Executors.newFixedThreadPool(this.numThreads);
        int iteration = 1;
        while (iteration <= this.numIterations) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && iteration != 0 && iteration % this.showTopicsInterval == 0) {
                logger.info("\n" + this.displayTopWords(this.wordsPerTopic, false));
            }
            if (this.saveStateInterval != 0 && iteration % this.saveStateInterval == 0) {
                this.printState(new File(String.valueOf(this.stateFilename) + '.' + iteration));
            }
            if (this.saveModelInterval != 0 && iteration % this.saveModelInterval == 0) {
                this.write(new File(String.valueOf(this.modelFilename) + '.' + iteration));
            }
            if (this.numThreads > 1) {
                int thread;
                int thread22 = 0;
                while (thread22 < this.numThreads) {
                    if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.saveSampleInterval == 0) {
                        runnables[thread22].collectAlphaStatistics();
                    }
                    logger.fine("submitting thread " + thread22);
                    executor.submit(runnables[thread22]);
                    ++thread22;
                }
                try {
                    Thread.sleep(20L);
                }
                catch (InterruptedException thread22) {
                    // empty catch block
                }
                boolean finished = false;
                while (!finished) {
                    try {
                        Thread.sleep(10L);
                    }
                    catch (InterruptedException interruptedException) {
                        // empty catch block
                    }
                    finished = true;
                    thread = 0;
                    while (thread < this.numThreads) {
                        finished = finished && runnables[thread].isFinished;
                        ++thread;
                    }
                }
                this.sumTypeTopicCounts(runnables);
                thread = 0;
                while (thread < this.numThreads) {
                    int[] runnableTotals = runnables[thread].getTokensPerTopic();
                    System.arraycopy(this.tokensPerTopic, 0, runnableTotals, 0, this.numTopics);
                    int[][] runnableCounts = runnables[thread].getTypeTopicCounts();
                    int type = 0;
                    while (type < this.numTypes) {
                        int[] targetCounts = runnableCounts[type];
                        int[] sourceCounts = this.typeTopicCounts[type];
                        int index = 0;
                        while (index < sourceCounts.length) {
                            if (sourceCounts[index] != 0) {
                                targetCounts[index] = sourceCounts[index];
                            } else {
                                if (targetCounts[index] == 0) break;
                                targetCounts[index] = 0;
                            }
                            ++index;
                        }
                        ++type;
                    }
                    ++thread;
                }
            } else {
                if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.saveSampleInterval == 0) {
                    runnables[0].collectAlphaStatistics();
                }
                runnables[0].run();
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            if (elapsedMillis < 1000L) {
                logger.fine(String.valueOf(elapsedMillis) + "ms ");
            } else {
                logger.fine(String.valueOf(elapsedMillis / 1000L) + "s ");
            }
            if (iteration > this.burninPeriod && this.optimizeInterval != 0 && iteration % this.optimizeInterval == 0) {
                this.optimizeAlpha(runnables);
                this.optimizeBeta(runnables);
                logger.fine("[O " + (System.currentTimeMillis() - iterationStart) + "] ");
            }
            if (iteration % 10 == 0) {
                if (this.printLogLikelihood) {
                    logger.info("<" + iteration + "> LL/token: " + this.formatter.format(this.modelLogLikelihood() / (double)this.totalTokens));
                } else {
                    logger.info("<" + iteration + ">");
                }
            }
            ++iteration;
        }
        executor.shutdownNow();
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        StringBuilder timeReport = new StringBuilder();
        timeReport.append("\nTotal time: ");
        if (days != 0L) {
            timeReport.append(days);
            timeReport.append(" days ");
        }
        if (hours != 0L) {
            timeReport.append(hours);
            timeReport.append(" hours ");
        }
        if (minutes != 0L) {
            timeReport.append(minutes);
            timeReport.append(" minutes ");
        }
        timeReport.append(seconds);
        timeReport.append(" seconds");
        logger.info(timeReport.toString());
    }

    public void maximize(int iterations) {
        int iteration = 0;
        int totalChange = Integer.MAX_VALUE;
        double[] topicCoefficients = new double[this.numTopics];
        while (iteration < iterations && totalChange > 0) {
            long iterationStart = System.currentTimeMillis();
            totalChange = 0;
            int doc = 0;
            while (doc < this.data.size()) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
                int[] oneDocTopics = topicSequence.getFeatures();
                int docLength = tokenSequence.getLength();
                int[] localTopicCounts = new int[this.numTopics];
                int position = 0;
                while (position < docLength) {
                    int n = oneDocTopics[position];
                    localTopicCounts[n] = localTopicCounts[n] + 1;
                    ++position;
                }
                int globalMaxTopic = 0;
                double globalMaxScore = 0.0;
                int topic = 0;
                while (topic < this.numTopics) {
                    topicCoefficients[topic] = (this.alpha[topic] + (double)localTopicCounts[topic]) / (this.betaSum + (double)this.tokensPerTopic[topic]);
                    if (this.beta * topicCoefficients[topic] > globalMaxScore) {
                        globalMaxTopic = topic;
                        globalMaxScore = this.beta * topicCoefficients[topic];
                    }
                    ++topic;
                }
                double[] topicTermScores = new double[this.numTopics];
                int position2 = 0;
                while (position2 < docLength) {
                    int temp;
                    int currentValue;
                    int currentTopic;
                    int type = tokenSequence.getIndexAtPosition(position2);
                    int oldTopic = oneDocTopics[position2];
                    int[] currentTypeTopicCounts = this.typeTopicCounts[type];
                    int n = oldTopic;
                    localTopicCounts[n] = localTopicCounts[n] - 1;
                    int n2 = oldTopic;
                    this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
                    topicCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / (this.betaSum + (double)this.tokensPerTopic[oldTopic]);
                    if (oldTopic == globalMaxTopic) {
                        globalMaxScore = this.beta * topicCoefficients[oldTopic];
                        int topic2 = 0;
                        while (topic2 < this.numTopics) {
                            if (this.beta * topicCoefficients[topic2] > globalMaxScore) {
                                globalMaxTopic = topic2;
                                globalMaxScore = this.beta * topicCoefficients[topic2];
                            }
                            ++topic2;
                        }
                    }
                    int newTopic = globalMaxTopic;
                    double maxScore = globalMaxScore;
                    assert (this.tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
                    int index = 0;
                    boolean alreadyDecremented = false;
                    while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                        currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                        currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        if (!alreadyDecremented && currentTopic == oldTopic) {
                            currentTypeTopicCounts[index] = --currentValue == 0 ? 0 : (currentValue << this.topicBits) + oldTopic;
                            int subIndex = index;
                            while (subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]) {
                                temp = currentTypeTopicCounts[subIndex];
                                currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
                                currentTypeTopicCounts[subIndex + 1] = temp;
                                ++subIndex;
                            }
                            alreadyDecremented = true;
                            continue;
                        }
                        double score = topicCoefficients[currentTopic] * (this.beta + (double)currentValue);
                        if (score > maxScore) {
                            newTopic = currentTopic;
                            maxScore = score;
                        }
                        ++index;
                    }
                    oneDocTopics[position2] = newTopic;
                    int n3 = newTopic;
                    localTopicCounts[n3] = localTopicCounts[n3] + 1;
                    int n4 = newTopic;
                    this.tokensPerTopic[n4] = this.tokensPerTopic[n4] + 1;
                    index = 0;
                    boolean foundTopic = false;
                    while (!foundTopic && index < currentTypeTopicCounts.length) {
                        currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                        currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        if (currentTopic == newTopic) {
                            currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + newTopic;
                            while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                                temp = currentTypeTopicCounts[index];
                                currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                                currentTypeTopicCounts[index - 1] = temp;
                            }
                            foundTopic = true;
                        } else if (currentValue == 0) {
                            currentTypeTopicCounts[index] = (1 << this.topicBits) + newTopic;
                            foundTopic = true;
                        }
                        ++index;
                    }
                    topicCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / (this.betaSum + (double)this.tokensPerTopic[newTopic]);
                    if (this.beta * topicCoefficients[newTopic] > globalMaxScore) {
                        globalMaxScore = this.beta * topicCoefficients[newTopic];
                        globalMaxTopic = newTopic;
                    }
                    if (newTopic != oldTopic) {
                        ++totalChange;
                    }
                    ++position2;
                }
                ++doc;
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            logger.info(String.valueOf(iteration) + "\t" + elapsedMillis + "ms\t" + totalChange + "\t" + this.modelLogLikelihood() / (double)this.totalTokens);
            ++iteration;
        }
    }

    public ArrayList<TreeSet<IDSorter>> getSortedWords() {
        ArrayList<TreeSet<IDSorter>> topicSortedWords = new ArrayList<TreeSet<IDSorter>>(this.numTopics);
        int topic = 0;
        while (topic < this.numTopics) {
            topicSortedWords.add(new TreeSet());
            ++topic;
        }
        int type = 0;
        while (type < this.numTypes) {
            int[] topicCounts = this.typeTopicCounts[type];
            int index = 0;
            while (index < topicCounts.length && topicCounts[index] > 0) {
                int topic2 = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                topicSortedWords.get(topic2).add(new IDSorter(type, count));
                ++index;
            }
            ++type;
        }
        return topicSortedWords;
    }

    public Object[][] getTopWords(int numWords) {
        ArrayList<TreeSet<IDSorter>> topicSortedWords = this.getSortedWords();
        Object[][] result = new Object[this.numTopics][];
        int topic = 0;
        while (topic < this.numTopics) {
            TreeSet<IDSorter> sortedWords = topicSortedWords.get(topic);
            int limit = numWords;
            if (sortedWords.size() < numWords) {
                limit = sortedWords.size();
            }
            result[topic] = new Object[limit];
            Iterator<IDSorter> iterator = sortedWords.iterator();
            int i = 0;
            while (i < limit) {
                IDSorter info = iterator.next();
                result[topic][i] = this.alphabet.lookupObject(info.getID());
                ++i;
            }
            ++topic;
        }
        return result;
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        out.print(this.displayTopWords(numWords, usingNewLines));
    }

    public String displayTopWords(int numWords, boolean usingNewLines) {
        StringBuilder out = new StringBuilder();
        ArrayList<TreeSet<IDSorter>> topicSortedWords = this.getSortedWords();
        int topic = 0;
        while (topic < this.numTopics) {
            IDSorter info;
            TreeSet<IDSorter> sortedWords = topicSortedWords.get(topic);
            int word = 0;
            Iterator<IDSorter> iterator = sortedWords.iterator();
            if (usingNewLines) {
                out.append(String.valueOf(topic) + "\t" + this.formatter.format(this.alpha[topic]) + "\n");
                while (iterator.hasNext() && word < numWords) {
                    info = iterator.next();
                    out.append(this.alphabet.lookupObject(info.getID()) + "\t" + this.formatter.format(info.getWeight()) + "\n");
                    ++word;
                }
            } else {
                out.append(String.valueOf(topic) + "\t" + this.formatter.format(this.alpha[topic]) + "\t");
                while (iterator.hasNext() && word < numWords) {
                    info = iterator.next();
                    out.append(this.alphabet.lookupObject(info.getID()) + " ");
                    ++word;
                }
                out.append("\n");
            }
            ++topic;
        }
        return out.toString();
    }

    public void topicXMLReport(PrintWriter out, int numWords) {
        ArrayList<TreeSet<IDSorter>> topicSortedWords = this.getSortedWords();
        out.println("<?xml version='1.0' ?>");
        out.println("<topicModel>");
        int topic = 0;
        while (topic < this.numTopics) {
            out.println("  <topic id='" + topic + "' alpha='" + this.alpha[topic] + "' totalTokens='" + this.tokensPerTopic[topic] + "'>");
            int rank = 1;
            Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator();
            while (iterator.hasNext() && rank <= numWords) {
                IDSorter info = iterator.next();
                out.println("\t<word rank='" + rank + "' count='" + info.getWeight() + "'>" + this.alphabet.lookupObject(info.getID()) + "</word>");
                ++rank;
            }
            out.println("  </topic>");
            ++topic;
        }
        out.println("</topicModel>");
    }

    public void topicPhraseXMLReport(PrintWriter out, int numWords) {
        int numTopics = this.getNumTopics();
        TObjectIntHashMap[] phrases = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = this.getAlphabet();
        int ti = 0;
        while (ti < numTopics) {
            phrases[ti] = new TObjectIntHashMap();
            ++ti;
        }
        int di = 0;
        while (di < this.getData().size()) {
            TopicAssignment t = this.getData().get(di);
            Instance instance = t.instance;
            FeatureSequence fvs = (FeatureSequence)instance.getData();
            boolean withBigrams = false;
            if (fvs instanceof FeatureSequenceWithBigrams) {
                withBigrams = true;
            }
            int prevtopic = -1;
            int prevfeature = -1;
            int topic = -1;
            StringBuffer sb = null;
            int feature = -1;
            int doclen = fvs.size();
            int pi = 0;
            while (pi < doclen) {
                feature = fvs.getIndexAtPosition(pi);
                topic = this.getData().get((int)di).topicSequence.getIndexAtPosition(pi);
                if (!(topic != prevtopic || withBigrams && ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) == -1)) {
                    if (sb == null) {
                        sb = new StringBuffer(String.valueOf(alphabet.lookupObject(prevfeature).toString()) + " " + alphabet.lookupObject(feature));
                    } else {
                        sb.append(" ");
                        sb.append(alphabet.lookupObject(feature));
                    }
                } else if (sb != null) {
                    String sbs = sb.toString();
                    if (phrases[prevtopic].get(sbs) == 0) {
                        phrases[prevtopic].put(sbs, 0);
                    }
                    phrases[prevtopic].increment(sbs);
                    prevfeature = -1;
                    prevtopic = -1;
                    sb = null;
                } else {
                    prevtopic = topic;
                    prevfeature = feature;
                }
                ++pi;
            }
            ++di;
        }
        out.println("<?xml version='1.0' ?>");
        out.println("<topics>");
        ArrayList<TreeSet<IDSorter>> topicSortedWords = this.getSortedWords();
        double[] probs = new double[alphabet.size()];
        int ti2 = 0;
        while (ti2 < numTopics) {
            out.print("  <topic id=\"" + ti2 + "\" alpha=\"" + this.alpha[ti2] + "\" totalTokens=\"" + this.tokensPerTopic[ti2] + "\" ");
            ByteArrayOutputStream bout = new ByteArrayOutputStream();
            PrintStream pout = new PrintStream(bout);
            AugmentableFeatureVector titles = new AugmentableFeatureVector(new Alphabet());
            int word = 0;
            Iterator<IDSorter> iterator = topicSortedWords.get(ti2).iterator();
            while (iterator.hasNext() && word < numWords) {
                IDSorter info = iterator.next();
                pout.println("\t<word weight=\"" + info.getWeight() / (double)this.tokensPerTopic[ti2] + "\" count=\"" + Math.round(info.getWeight()) + "\">" + alphabet.lookupObject(info.getID()) + "</word>");
                if (++word >= 20) continue;
                titles.add(alphabet.lookupObject(info.getID()), info.getWeight());
            }
            Object[] keys = phrases[ti2].keys();
            int[] values = phrases[ti2].getValues();
            double[] counts = new double[keys.length];
            int i = 0;
            while (i < counts.length) {
                counts[i] = values[i];
                ++i;
            }
            double countssum = MatrixOps.sum(counts);
            Alphabet alph = new Alphabet(keys);
            RankedFeatureVector rfv = new RankedFeatureVector(alph, counts);
            int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
            int ri = 0;
            while (ri < max) {
                int fi = rfv.getIndexAtRank(ri);
                pout.println("\t<phrase weight=\"" + counts[fi] / countssum + "\" count=\"" + values[fi] + "\">" + alph.lookupObject(fi) + "</phrase>");
                if (ri < 20 && values[fi] > 20) {
                    titles.add(alph.lookupObject(fi), (double)(100 * values[fi]));
                }
                ++ri;
            }
            StringBuffer titlesStringBuffer = new StringBuffer();
            rfv = new RankedFeatureVector(titles.getAlphabet(), titles);
            int numTitles = 10;
            int ri2 = 0;
            while (ri2 < numTitles && ri2 < rfv.numLocations()) {
                if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri2).toString()) == -1) {
                    titlesStringBuffer.append(rfv.getObjectAtRank(ri2));
                    if (ri2 < numTitles - 1) {
                        titlesStringBuffer.append(", ");
                    }
                } else {
                    ++numTitles;
                }
                ++ri2;
            }
            out.println("titles=\"" + titlesStringBuffer.toString() + "\">");
            out.print(bout.toString());
            out.println("  </topic>");
            ++ti2;
        }
        out.println("</topics>");
    }

    public void printTypeTopicCounts(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        int type = 0;
        while (type < this.numTypes) {
            StringBuilder buffer = new StringBuilder();
            buffer.append(String.valueOf(type) + " " + this.alphabet.lookupObject(type));
            int[] topicCounts = this.typeTopicCounts[type];
            int index = 0;
            while (index < topicCounts.length && topicCounts[index] > 0) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                buffer.append(" " + topic + ":" + count);
                ++index;
            }
            out.println(buffer);
            ++type;
        }
        out.close();
    }

    public void printTopicWordWeights(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        this.printTopicWordWeights(out);
        out.close();
    }

    public void printTopicWordWeights(PrintWriter out) throws IOException {
        int topic = 0;
        while (topic < this.numTopics) {
            int type = 0;
            while (type < this.numTypes) {
                int[] topicCounts = this.typeTopicCounts[type];
                double weight = this.beta;
                int index = 0;
                while (index < topicCounts.length && topicCounts[index] > 0) {
                    int currentTopic = topicCounts[index] & this.topicMask;
                    if (currentTopic == topic) {
                        weight += (double)(topicCounts[index] >> this.topicBits);
                        break;
                    }
                    ++index;
                }
                out.println(String.valueOf(topic) + "\t" + this.alphabet.lookupObject(type) + "\t" + weight);
                ++type;
            }
            ++topic;
        }
    }

    public double[] getTopicProbabilities(int instanceID) {
        LabelSequence topics = this.data.get((int)instanceID).topicSequence;
        return this.getTopicProbabilities(topics);
    }

    public double[] getTopicProbabilities(LabelSequence topics) {
        double[] topicDistribution = new double[this.numTopics];
        int position = 0;
        while (position < topics.getLength()) {
            int n = topics.getIndexAtPosition(position);
            topicDistribution[n] = topicDistribution[n] + 1.0;
            ++position;
        }
        double sum = 0.0;
        int topic = 0;
        while (topic < this.numTopics) {
            int n = topic;
            topicDistribution[n] = topicDistribution[n] + this.alpha[topic];
            sum += topicDistribution[topic];
            ++topic;
        }
        topic = 0;
        while (topic < this.numTopics) {
            int n = topic++;
            topicDistribution[n] = topicDistribution[n] / sum;
        }
        return topicDistribution;
    }

    public void printDocumentTopics(File file) throws IOException {
        PrintWriter out = new PrintWriter(new FileWriter(file));
        this.printDocumentTopics(out);
        out.close();
    }

    public void printDenseDocumentTopics(PrintWriter out) {
        int[] topicCounts = new int[this.numTopics];
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            StringBuilder builder = new StringBuilder();
            builder.append(doc);
            builder.append("\t");
            if (this.data.get((int)doc).instance.getName() != null) {
                builder.append(this.data.get((int)doc).instance.getName());
            } else {
                builder.append("no-name");
            }
            int docLen = currentDocTopics.length;
            int token = 0;
            while (token < docLen) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic = 0;
            while (topic < this.numTopics) {
                builder.append("\t" + (this.alpha[topic] + (double)topicCounts[topic]) / ((double)docLen + this.alphaSum));
                ++topic;
            }
            out.println(builder);
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
    }

    public void printDocumentTopics(PrintWriter out) {
        this.printDocumentTopics(out, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter out, double threshold, int max) {
        out.print("#doc name topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            sortedTopics[topic] = new IDSorter(topic, topic);
            ++topic;
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            StringBuilder builder = new StringBuilder();
            builder.append(doc);
            builder.append("\t");
            if (this.data.get((int)doc).instance.getName() != null) {
                builder.append(this.data.get((int)doc).instance.getName());
            } else {
                builder.append("no-name");
            }
            builder.append("\t");
            int docLen = currentDocTopics.length;
            int token = 0;
            while (token < docLen) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                ((IDSorter)sortedTopics[topic2]).set(topic2, (this.alpha[topic2] + (double)topicCounts[topic2]) / ((double)docLen + this.alphaSum));
                ++topic2;
            }
            Arrays.sort(sortedTopics);
            int i = 0;
            while (i < max) {
                if (((IDSorter)sortedTopics[i]).getWeight() < threshold) break;
                builder.append(String.valueOf(((IDSorter)sortedTopics[i]).getID()) + "\t" + ((IDSorter)sortedTopics[i]).getWeight() + "\t");
                ++i;
            }
            out.println(builder);
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
    }

    public double[][] getSubCorpusTopicWords(boolean[] documentMask, boolean normalized, boolean smoothed) {
        double[][] result = new double[this.numTopics][this.numTypes];
        int[] subCorpusTokensPerTopic = new int[this.numTopics];
        int doc = 0;
        while (doc < this.data.size()) {
            if (documentMask[doc]) {
                int[] words = ((FeatureSequence)this.data.get((int)doc).instance.getData()).getFeatures();
                int[] topics = this.data.get((int)doc).topicSequence.getFeatures();
                int position = 0;
                while (position < topics.length) {
                    double[] dArray = result[topics[position]];
                    int n = words[position];
                    dArray[n] = dArray[n] + 1.0;
                    int n2 = topics[position];
                    subCorpusTokensPerTopic[n2] = subCorpusTokensPerTopic[n2] + 1;
                    ++position;
                }
            }
            ++doc;
        }
        if (smoothed) {
            int topic = 0;
            while (topic < this.numTopics) {
                int type = 0;
                while (type < this.numTypes) {
                    double[] dArray = result[topic];
                    int n = type++;
                    dArray[n] = dArray[n] + this.beta;
                }
                ++topic;
            }
        }
        if (normalized) {
            int topic;
            double[] topicNormalizers = new double[this.numTopics];
            if (smoothed) {
                topic = 0;
                while (topic < this.numTopics) {
                    topicNormalizers[topic] = 1.0 / ((double)subCorpusTokensPerTopic[topic] + (double)this.numTypes * this.beta);
                    ++topic;
                }
            } else {
                topic = 0;
                while (topic < this.numTopics) {
                    topicNormalizers[topic] = 1.0 / (double)subCorpusTokensPerTopic[topic];
                    ++topic;
                }
            }
            topic = 0;
            while (topic < this.numTopics) {
                int type = 0;
                while (type < this.numTypes) {
                    double[] dArray = result[topic];
                    int n = type++;
                    dArray[n] = dArray[n] * topicNormalizers[topic];
                }
                ++topic;
            }
        }
        return result;
    }

    public double[][] getTopicWords(boolean normalized, boolean smoothed) {
        double[][] result = new double[this.numTopics][this.numTypes];
        int type = 0;
        while (type < this.numTypes) {
            int[] topicCounts = this.typeTopicCounts[type];
            int index = 0;
            while (index < topicCounts.length && topicCounts[index] > 0) {
                int topic = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                double[] dArray = result[topic];
                int n = type;
                dArray[n] = dArray[n] + (double)count;
                ++index;
            }
            ++type;
        }
        if (smoothed) {
            int topic = 0;
            while (topic < this.numTopics) {
                int type2 = 0;
                while (type2 < this.numTypes) {
                    double[] dArray = result[topic];
                    int n = type2++;
                    dArray[n] = dArray[n] + this.beta;
                }
                ++topic;
            }
        }
        if (normalized) {
            int topic;
            double[] topicNormalizers = new double[this.numTopics];
            if (smoothed) {
                topic = 0;
                while (topic < this.numTopics) {
                    topicNormalizers[topic] = 1.0 / ((double)this.tokensPerTopic[topic] + (double)this.numTypes * this.beta);
                    ++topic;
                }
            } else {
                topic = 0;
                while (topic < this.numTopics) {
                    topicNormalizers[topic] = 1.0 / (double)this.tokensPerTopic[topic];
                    ++topic;
                }
            }
            topic = 0;
            while (topic < this.numTopics) {
                int type3 = 0;
                while (type3 < this.numTypes) {
                    double[] dArray = result[topic];
                    int n = type3++;
                    dArray[n] = dArray[n] * topicNormalizers[topic];
                }
                ++topic;
            }
        }
        return result;
    }

    public double[][] getDocumentTopics(boolean normalized, boolean smoothed) {
        double[][] result = new double[this.data.size()][this.numTopics];
        int doc = 0;
        while (doc < this.data.size()) {
            int[] topics = this.data.get((int)doc).topicSequence.getFeatures();
            int position = 0;
            while (position < topics.length) {
                double[] dArray = result[doc];
                int n = topics[position];
                dArray[n] = dArray[n] + 1.0;
                ++position;
            }
            if (smoothed) {
                int topic = 0;
                while (topic < this.numTopics) {
                    double[] dArray = result[doc];
                    int n = topic;
                    dArray[n] = dArray[n] + this.alpha[topic];
                    ++topic;
                }
            }
            if (normalized) {
                double sum = 0.0;
                int topic = 0;
                while (topic < this.numTopics) {
                    sum += result[doc][topic];
                    ++topic;
                }
                double normalizer = 1.0 / sum;
                int topic2 = 0;
                while (topic2 < this.numTopics) {
                    double[] dArray = result[doc];
                    int n = topic2++;
                    dArray[n] = dArray[n] * normalizer;
                }
            }
            ++doc;
        }
        return result;
    }

    public ArrayList<TreeSet<IDSorter>> getTopicDocuments(double smoothing) {
        ArrayList<TreeSet<IDSorter>> topicSortedDocuments = new ArrayList<TreeSet<IDSorter>>(this.numTopics);
        int topic = 0;
        while (topic < this.numTopics) {
            topicSortedDocuments.add(new TreeSet());
            ++topic;
        }
        int[] topicCounts = new int[this.numTopics];
        int doc = 0;
        while (doc < this.data.size()) {
            int[] topics = this.data.get((int)doc).topicSequence.getFeatures();
            int position = 0;
            while (position < topics.length) {
                int n = topics[position];
                topicCounts[n] = topicCounts[n] + 1;
                ++position;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                topicSortedDocuments.get(topic2).add(new IDSorter(doc, ((double)topicCounts[topic2] + smoothing) / ((double)topics.length + (double)this.numTopics * smoothing)));
                topicCounts[topic2] = 0;
                ++topic2;
            }
            ++doc;
        }
        return topicSortedDocuments;
    }

    public void printTopicDocuments(PrintWriter out) {
        this.printTopicDocuments(out, 100);
    }

    public void printTopicDocuments(PrintWriter out, int max) {
        out.println("#topic doc name proportion ...");
        ArrayList<TreeSet<IDSorter>> topicSortedDocuments = this.getTopicDocuments(10.0);
        int topic = 0;
        while (topic < this.numTopics) {
            TreeSet<IDSorter> sortedDocuments = topicSortedDocuments.get(topic);
            int i = 0;
            for (IDSorter sorter : sortedDocuments) {
                if (i == max) break;
                int doc = sorter.getID();
                double proportion = sorter.getWeight();
                String name = (String)this.data.get((int)doc).instance.getName();
                if (name == null) {
                    name = "no-name";
                }
                out.format("%d %d %s %f\n", topic, doc, name, proportion);
                ++i;
            }
            ++topic;
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        out.print("#alpha : ");
        int topic = 0;
        while (topic < this.numTopics) {
            out.print(String.valueOf(this.alpha[topic]) + " ");
            ++topic;
        }
        out.println();
        out.println("#beta : " + this.beta);
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            if (this.data.get((int)doc).instance.getSource() != null) {
                source = this.data.get((int)doc).instance.getSource().toString();
            }
            Formatter output = new Formatter(new StringBuilder(), Locale.US);
            int pi = 0;
            while (pi < topicSequence.getLength()) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic2 = topicSequence.getIndexAtPosition(pi);
                output.format("%d %s %d %d %s %d\n", doc, source, pi, type, this.alphabet.lookupObject(type), topic2);
                ++pi;
            }
            out.print(output);
            ++doc;
        }
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
            ++topic;
        }
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            int token = 0;
            while (token < docTopics.length) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                if (topicCounts[topic2] > 0) {
                    logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic2] + (double)topicCounts[topic2]) - topicLogGammas[topic2];
                }
                ++topic2;
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        int type = 0;
        while (type < this.numTypes) {
            topicCounts = this.typeTopicCounts[type];
            int index = 0;
            while (index < topicCounts.length && topicCounts[index] > 0) {
                int topic3 = topicCounts[index] & this.topicMask;
                int count = topicCounts[index] >> this.topicBits;
                ++nonZeroTypeTopics;
                if (Double.isNaN(logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count))) {
                    logger.warning("NaN in log likelihood calculation");
                    return 0.0;
                }
                if (Double.isInfinite(logLikelihood)) {
                    logger.warning("infinite log likelihood");
                    return 0.0;
                }
                ++index;
            }
            ++type;
        }
        int topic4 = 0;
        while (topic4 < this.numTopics) {
            if (Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTypes + (double)this.tokensPerTopic[topic4]))) {
                logger.info("NaN after topic " + topic4 + " " + this.tokensPerTopic[topic4]);
                return 0.0;
            }
            if (Double.isInfinite(logLikelihood)) {
                logger.info("Infinite value after topic " + topic4 + " " + this.tokensPerTopic[topic4]);
                return 0.0;
            }
            ++topic4;
        }
        logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTypes) * (double)this.numTopics;
        if (Double.isNaN(logLikelihood -= Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics)) {
            logger.info("at the end");
        } else if (Double.isInfinite(logLikelihood)) {
            logger.info("Infinite value beta " + this.beta + " * " + this.numTypes);
            return 0.0;
        }
        return logLikelihood;
    }

    public TopicInferencer getInferencer() {
        return new TopicInferencer(this.typeTopicCounts, this.tokensPerTopic, this.data.get((int)0).instance.getDataAlphabet(), this.alpha, this.beta, this.betaSum);
    }

    public MarginalProbEstimator getProbEstimator() {
        return new MarginalProbEstimator(this.numTopics, this.alpha, this.alphaSum, this.beta, this.typeTopicCounts, this.tokensPerTopic);
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.data);
        out.writeObject(this.alphabet);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeInt(this.numTypes);
        out.writeObject(this.alpha);
        out.writeDouble(this.alphaSum);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeObject(this.typeTopicCounts);
        out.writeObject(this.tokensPerTopic);
        out.writeObject(this.docLengthCounts);
        out.writeObject(this.topicDocCounts);
        out.writeInt(this.numIterations);
        out.writeInt(this.burninPeriod);
        out.writeInt(this.saveSampleInterval);
        out.writeInt(this.optimizeInterval);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeInt(this.saveStateInterval);
        out.writeObject(this.stateFilename);
        out.writeInt(this.saveModelInterval);
        out.writeObject(this.modelFilename);
        out.writeInt(this.randomSeed);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
        out.writeInt(this.numThreads);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.numTypes = in.readInt();
        this.alpha = (double[])in.readObject();
        this.alphaSum = in.readDouble();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = (int[])in.readObject();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.saveModelInterval = in.readInt();
        this.modelFilename = (String)in.readObject();
        this.randomSeed = in.readInt();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        this.numThreads = in.readInt();
    }

    public void write(File serializedModelFile) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(serializedModelFile));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Problem serializing ParallelTopicModel to file " + serializedModelFile + ": " + e);
        }
    }

    public static ParallelTopicModel read(File f) throws Exception {
        ParallelTopicModel topicModel = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        topicModel = (ParallelTopicModel)ois.readObject();
        ois.close();
        topicModel.initializeHistograms();
        return topicModel;
    }

    public static void main(String[] args) {
        try {
            InstanceList training = InstanceList.load(new File(args[0]));
            int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
            ParallelTopicModel lda = new ParallelTopicModel(numTopics, 50.0, 0.01);
            lda.printLogLikelihood = true;
            lda.setTopicDisplay(50, 7);
            lda.addInstances(training);
            lda.setNumThreads(Integer.parseInt(args[2]));
            lda.estimate();
            logger.info("printing state");
            lda.printState(new File("state.gz"));
            logger.info("finished printing");
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

