/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Labeling;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectIntHashMap;
import java.io.BufferedOutputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.zip.GZIPOutputStream;

public class LDAHyper
implements Serializable {
    protected ArrayList<Topication> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    int topicTermCount = 0;
    int betaTopicCount = 0;
    int smoothingOnlyCount = 0;
    protected InstanceList testing = null;
    protected int[] oneDocTopicCounts;
    protected TIntIntHashMap[] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    public int iterationsSoFar = 0;
    public int numIterations = 1000;
    public int burninPeriod = 20;
    public int saveSampleInterval = 5;
    public int optimizeInterval = 20;
    public int showTopicsInterval = 10;
    public int wordsPerTopic = 7;
    protected int outputModelInterval = 0;
    protected String outputModelFilename;
    protected int saveStateInterval = 0;
    protected String stateFilename = null;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public LDAHyper(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics, 0.01);
    }

    public LDAHyper(int numberOfTopics, double alphaSum, double beta) {
        this(numberOfTopics, alphaSum, beta, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        int i = 0;
        while (i < numTopics) {
            ret.lookupIndex("topic" + i);
            ++i;
        }
        return ret;
    }

    public LDAHyper(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this(LDAHyper.newLabelAlphabet(numberOfTopics), alphaSum, beta, random);
    }

    public LDAHyper(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        this.alphaSum = alphaSum;
        this.alpha = new double[this.numTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numTopics);
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        System.err.println("LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<Topication> getData() {
        return this.data;
    }

    public int getCountFeatureTopic(int featureIndex, int topicIndex) {
        return this.typeTopicCounts[featureIndex].get(topicIndex);
    }

    public int getCountTokensPerTopic(int topicIndex) {
        return this.tokensPerTopic[topicIndex];
    }

    public void setTestingInstances(InstanceList testing) {
        this.testing = testing;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void setModelOutput(int interval, String filename) {
        this.outputModelInterval = interval;
        this.outputModelFilename = filename;
    }

    public void setSaveState(int interval, String filename) {
        this.saveStateInterval = interval;
        this.stateFilename = filename;
    }

    protected int instanceLength(Instance instance) {
        return ((FeatureSequence)instance.getData()).size();
    }

    private void initializeForTypes(Alphabet alphabet) {
        if (this.alphabet == null) {
            this.alphabet = alphabet;
            this.numTypes = alphabet.size();
            this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
            int fi = 0;
            while (fi < this.numTypes) {
                this.typeTopicCounts[fi] = new TIntIntHashMap();
                ++fi;
            }
            this.betaSum = this.beta * (double)this.numTypes;
        } else {
            if (alphabet != this.alphabet) {
                throw new IllegalArgumentException("Cannot change Alphabet.");
            }
            if (alphabet.size() != this.numTypes) {
                this.numTypes = alphabet.size();
                TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[this.numTypes];
                int i = 0;
                while (i < this.typeTopicCounts.length) {
                    newTypeTopicCounts[i] = this.typeTopicCounts[i];
                    ++i;
                }
                i = this.typeTopicCounts.length;
                while (i < this.numTypes) {
                    newTypeTopicCounts[i] = new TIntIntHashMap();
                    ++i;
                }
                this.betaSum = this.beta * (double)this.numTypes;
            }
        }
    }

    private void initializeTypeTopicCounts() {
        TIntIntHashMap[] newTypeTopicCounts = new TIntIntHashMap[this.numTypes];
        int i = 0;
        while (i < this.typeTopicCounts.length) {
            newTypeTopicCounts[i] = this.typeTopicCounts[i];
            ++i;
        }
        i = this.typeTopicCounts.length;
        while (i < this.numTypes) {
            newTypeTopicCounts[i] = new TIntIntHashMap();
            ++i;
        }
        this.typeTopicCounts = newTypeTopicCounts;
    }

    public void addInstances(InstanceList training) {
        this.initializeForTypes(training.getDataAlphabet());
        ArrayList<LabelSequence> topicSequences = new ArrayList<LabelSequence>();
        for (Instance instance : training) {
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[this.instanceLength(instance)]);
            Randoms r = new Randoms();
            int[] topics = topicSequence.getFeatures();
            int i = 0;
            while (i < topics.length) {
                topics[i] = r.nextInt(this.numTopics);
                ++i;
            }
            topicSequences.add(topicSequence);
        }
        this.addInstances(training, topicSequences);
    }

    public void addInstances(InstanceList training, List<LabelSequence> topics) {
        this.initializeForTypes(training.getDataAlphabet());
        assert (training.size() == topics.size());
        int i = 0;
        while (i < training.size()) {
            Topication t = new Topication((Instance)training.get(i), this, topics.get(i));
            this.data.add(t);
            FeatureSequence tokenSequence = (FeatureSequence)t.instance.getData();
            LabelSequence topicSequence = t.topicSequence;
            int pi = 0;
            while (pi < topicSequence.getLength()) {
                int topic = topicSequence.getIndexAtPosition(pi);
                this.typeTopicCounts[tokenSequence.getIndexAtPosition(pi)].adjustOrPutValue(topic, 1, 1);
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                ++pi;
            }
            ++i;
        }
        this.initializeHistogramsAndCachedValues();
    }

    protected void initializeHistogramsAndCachedValues() {
        int maxTokens = 0;
        int totalTokens = 0;
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence fs = (FeatureSequence)this.data.get((int)doc).instance.getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            totalTokens += seqLen;
            ++doc;
        }
        this.smoothingOnlyMass = 0.0;
        int topic = 0;
        while (topic < this.numTopics) {
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++topic;
        }
        this.cachedCoefficients = new double[this.numTopics];
        topic = 0;
        while (topic < this.numTopics) {
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++topic;
        }
        System.err.println("max tokens: " + maxTokens);
        System.err.println("total tokens: " + totalTokens);
        this.docLengthCounts = new int[maxTokens + 1];
        this.topicDocCounts = new int[this.numTopics][maxTokens + 1];
    }

    public void estimate() throws IOException {
        this.estimate(this.numIterations);
    }

    public void estimate(int iterationsThisRound) throws IOException {
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
                if (this.testing != null) {
                    double el = this.empiricalLikelihood(1000, this.testing);
                    double ll = this.modelLogLikelihood();
                    double mi = this.topicLabelMutualInformation();
                    System.out.println(String.valueOf(ll) + "\t" + el + "\t" + mi);
                }
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(String.valueOf(this.stateFilename) + '.' + this.iterationsSoFar));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.alphaSum = Dirichlet.learnParameters(this.alpha, this.topicDocCounts, this.docLengthCounts);
                this.smoothingOnlyMass = 0.0;
                int topic = 0;
                while (topic < this.numTopics) {
                    this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
                    this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
                    ++topic;
                }
                this.clearHistograms();
            }
            this.smoothingOnlyCount = 0;
            this.betaTopicCount = 0;
            this.topicTermCount = 0;
            int numDocs = this.data.size();
            int di = 0;
            while (di < numDocs) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
                LabelSequence topicSequence = this.data.get((int)di).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, this.iterationsSoFar >= this.burninPeriod && this.iterationsSoFar % this.saveSampleInterval == 0, true);
                ++di;
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            if (elapsedMillis < 1000L) {
                System.out.print(String.valueOf(elapsedMillis) + "ms ");
            } else {
                System.out.print(String.valueOf(elapsedMillis / 1000L) + "s ");
            }
            if (this.iterationsSoFar % 10 == 0) {
                System.out.println("<" + this.iterationsSoFar + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(this.modelLogLikelihood());
                }
            }
            System.out.flush();
            ++this.iterationsSoFar;
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    private void clearHistograms() {
        Arrays.fill(this.docLengthCounts, 0);
        int topic = 0;
        while (topic < this.topicDocCounts.length) {
            Arrays.fill(this.topicDocCounts[topic], 0);
            ++topic;
        }
    }

    private void oldSampleTopicsForOneDoc(FeatureSequence featureSequence, FeatureSequence topicSequence, boolean saveStateForAlphaEstimation, boolean readjustTopicsAndStats) {
        int token;
        long startTime = System.currentTimeMillis();
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLen = featureSequence.getLength();
        Arrays.fill(this.oneDocTopicCounts, 0);
        if (readjustTopicsAndStats) {
            token = 0;
            while (token < docLen) {
                int n = oneDocTopics[token];
                this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] + 1;
                ++token;
            }
        }
        token = 0;
        while (token < docLen) {
            int type = featureSequence.getIndexAtPosition(token);
            int oldTopic = oneDocTopics[token];
            TIntIntHashMap currentTypeTopicCounts = this.typeTopicCounts[type];
            assert (currentTypeTopicCounts.size() != 0);
            if (readjustTopicsAndStats) {
                int n = oldTopic;
                this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] - 1;
                int adjustedValue = currentTypeTopicCounts.adjustOrPutValue(oldTopic, -1, -1);
                if (adjustedValue == 0) {
                    currentTypeTopicCounts.remove(oldTopic);
                } else if (adjustedValue == -1) {
                    throw new IllegalStateException("Token count in topic went negative.");
                }
                int n2 = oldTopic;
                this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
            }
            int[] topicIndices = currentTypeTopicCounts.keys();
            int[] topicCounts = currentTypeTopicCounts.getValues();
            double[] topicDistribution = new double[topicIndices.length];
            double topicDistributionSum = 0.0;
            int i = 0;
            while (i < topicCounts.length) {
                int topic = topicIndices[i];
                double weight = ((double)topicCounts[i] + this.beta) / ((double)this.tokensPerTopic[topic] + this.betaSum) * ((double)this.oneDocTopicCounts[topic] + this.alpha[topic]);
                topicDistributionSum += weight;
                topicDistribution[topic] = weight;
                ++i;
            }
            int newTopic = topicIndices[this.random.nextDiscrete(topicDistribution, topicDistributionSum)];
            if (readjustTopicsAndStats) {
                oneDocTopics[token] = newTopic;
                int n = newTopic;
                this.oneDocTopicCounts[n] = this.oneDocTopicCounts[n] + 1;
                this.typeTopicCounts[type].adjustOrPutValue(newTopic, 1, 1);
                int n3 = newTopic;
                this.tokensPerTopic[n3] = this.tokensPerTopic[n3] + 1;
            }
            ++token;
        }
        if (saveStateForAlphaEstimation) {
            int n = docLen;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
            int topic = 0;
            while (topic < this.numTopics) {
                int[] nArray = this.topicDocCounts[topic];
                int n4 = this.oneDocTopicCounts[topic];
                nArray[n4] = nArray[n4] + 1;
                ++topic;
            }
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean shouldSaveState, boolean readjustTopicsAndStats) {
        int topic;
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        TIntIntHashMap localTopicCounts = new TIntIntHashMap();
        int position = 0;
        while (position < docLength) {
            localTopicCounts.adjustOrPutValue(oneDocTopics[position], 1, 1);
            ++position;
        }
        double topicBetaMass = 0.0;
        int[] nArray = localTopicCounts.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int topic2 = nArray[n2];
            int n3 = localTopicCounts.get(topic2);
            topicBetaMass += this.beta * (double)n3 / ((double)this.tokensPerTopic[topic2] + this.betaSum);
            this.cachedCoefficients[topic2] = (this.alpha[topic2] + (double)n3) / ((double)this.tokensPerTopic[topic2] + this.betaSum);
            ++n2;
        }
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        int position2 = 0;
        while (position2 < docLength) {
            double sample;
            int type = tokenSequence.getIndexAtPosition(position2);
            int oldTopic = oneDocTopics[position2];
            TIntIntHashMap currentTypeTopicCounts = this.typeTopicCounts[type];
            assert (currentTypeTopicCounts.get(oldTopic) >= 0);
            if (currentTypeTopicCounts.get(oldTopic) == 1) {
                currentTypeTopicCounts.remove(oldTopic);
            } else {
                currentTypeTopicCounts.adjustValue(oldTopic, -1);
            }
            this.smoothingOnlyMass -= this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts.get(oldTopic) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            if (localTopicCounts.get(oldTopic) == 1) {
                localTopicCounts.remove(oldTopic);
            } else {
                localTopicCounts.adjustValue(oldTopic, -1);
            }
            int n4 = oldTopic;
            this.tokensPerTopic[n4] = this.tokensPerTopic[n4] - 1;
            this.smoothingOnlyMass += this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts.get(oldTopic) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts.get(oldTopic)) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            topicTermMass = 0.0;
            int[] topicTermIndices = currentTypeTopicCounts.keys();
            int[] topicTermValues = currentTypeTopicCounts.getValues();
            int i = 0;
            while (i < topicTermIndices.length) {
                int topic3 = topicTermIndices[i];
                double score = this.cachedCoefficients[topic3] * (double)topicTermValues[i];
                topicTermMass += score;
                topicTermScores[i] = score;
                ++i;
            }
            double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int newTopic = -1;
            if (sample < topicTermMass) {
                i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = topicTermIndices[i];
            } else if ((sample -= topicTermMass) < topicBetaMass) {
                sample /= this.beta;
                topicTermIndices = localTopicCounts.keys();
                topicTermValues = localTopicCounts.getValues();
                i = 0;
                while (i < topicTermIndices.length) {
                    newTopic = topicTermIndices[i];
                    if (!((sample -= (double)topicTermValues[i] / ((double)this.tokensPerTopic[newTopic] + this.betaSum)) <= 0.0)) {
                        ++i;
                        continue;
                    }
                    break;
                }
            } else {
                sample -= topicBetaMass;
                sample /= this.beta;
                int topic4 = 0;
                while (topic4 < this.numTopics) {
                    if ((sample -= this.alpha[topic4] / ((double)this.tokensPerTopic[topic4] + this.betaSum)) <= 0.0) {
                        newTopic = topic4;
                        break;
                    }
                    ++topic4;
                }
            }
            if (newTopic == -1) {
                System.err.println("LDAHyper sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[position2] = newTopic;
            currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
            this.smoothingOnlyMass -= this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts.get(newTopic) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            localTopicCounts.adjustOrPutValue(newTopic, 1, 1);
            int n5 = newTopic;
            this.tokensPerTopic[n5] = this.tokensPerTopic[n5] + 1;
            this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts.get(newTopic)) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            this.smoothingOnlyMass += this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts.get(newTopic) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            assert (currentTypeTopicCounts.get(newTopic) >= 0);
            ++position2;
        }
        int[] nArray2 = localTopicCounts.keys();
        int n6 = nArray2.length;
        int n7 = 0;
        while (n7 < n6) {
            topic = nArray2[n7];
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++n7;
        }
        if (shouldSaveState) {
            int n8 = docLength;
            this.docLengthCounts[n8] = this.docLengthCounts[n8] + 1;
            nArray2 = localTopicCounts.keys();
            n6 = nArray2.length;
            n7 = 0;
            while (n7 < n6) {
                topic = nArray2[n7];
                int[] nArray3 = this.topicDocCounts[topic];
                int n9 = localTopicCounts.get(topic);
                nArray3[n9] = nArray3[n9] + 1;
                ++n7;
            }
        }
    }

    public IDSorter[] getSortedTopicWords(int topic) {
        Object[] sortedTypes = new IDSorter[this.numTypes];
        int type = 0;
        while (type < this.numTypes) {
            sortedTypes[type] = new IDSorter(type, this.typeTopicCounts[type].get(topic));
            ++type;
        }
        Arrays.sort(sortedTypes);
        return sortedTypes;
    }

    public void printTopWords(File file, int numWords, boolean useNewLines) throws IOException {
        PrintStream out = new PrintStream(file);
        this.printTopWords(out, numWords, useNewLines);
        out.close();
    }

    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        int topic = 0;
        while (topic < this.numTopics) {
            IDSorter info;
            Iterator iterator;
            int word;
            TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
            int type = 0;
            while (type < this.numTypes) {
                if (this.typeTopicCounts[type].containsKey(topic)) {
                    sortedWords.add(new IDSorter(type, this.typeTopicCounts[type].get(topic)));
                }
                ++type;
            }
            if (usingNewLines) {
                out.println("Topic " + topic);
                word = 1;
                iterator = sortedWords.iterator();
                while (iterator.hasNext() && word < numWords) {
                    info = (IDSorter)iterator.next();
                    out.println(this.alphabet.lookupObject(info.getID()) + "\t" + (int)info.getWeight());
                    ++word;
                }
            } else {
                out.print(String.valueOf(topic) + "\t" + this.formatter.format(this.alpha[topic]) + "\t" + this.tokensPerTopic[topic] + "\t");
                word = 1;
                iterator = sortedWords.iterator();
                while (iterator.hasNext() && word < numWords) {
                    info = (IDSorter)iterator.next();
                    out.print(this.alphabet.lookupObject(info.getID()) + " ");
                    ++word;
                }
                out.println();
            }
            ++topic;
        }
    }

    public void topicXMLReport(PrintWriter out, int numWords) {
        out.println("<?xml version='1.0' ?>");
        out.println("<topicModel>");
        int topic = 0;
        while (topic < this.numTopics) {
            out.println("  <topic id='" + topic + "' alpha='" + this.alpha[topic] + "' totalTokens='" + this.tokensPerTopic[topic] + "'>");
            TreeSet<IDSorter> sortedWords = new TreeSet<IDSorter>();
            int type = 0;
            while (type < this.numTypes) {
                if (this.typeTopicCounts[type].containsKey(topic)) {
                    sortedWords.add(new IDSorter(type, this.typeTopicCounts[type].get(topic)));
                }
                ++type;
            }
            int word = 1;
            Iterator iterator = sortedWords.iterator();
            while (iterator.hasNext() && word < numWords) {
                IDSorter info = (IDSorter)iterator.next();
                out.println("    <word rank='" + word + "'>" + this.alphabet.lookupObject(info.getID()) + "</word>");
                ++word;
            }
            out.println("  </topic>");
            ++topic;
        }
        out.println("</topicModel>");
    }

    public void topicXMLReportPhrases(PrintStream out, int numWords) {
        int numTopics = this.getNumTopics();
        TObjectIntHashMap[] phrases = new TObjectIntHashMap[numTopics];
        Alphabet alphabet = this.getAlphabet();
        int ti = 0;
        while (ti < numTopics) {
            phrases[ti] = new TObjectIntHashMap();
            ++ti;
        }
        int di = 0;
        while (di < this.getData().size()) {
            Topication t = this.getData().get(di);
            Instance instance = t.instance;
            FeatureSequence fvs = (FeatureSequence)instance.getData();
            boolean withBigrams = false;
            if (fvs instanceof FeatureSequenceWithBigrams) {
                withBigrams = true;
            }
            int prevtopic = -1;
            int prevfeature = -1;
            int topic = -1;
            StringBuffer sb = null;
            int feature = -1;
            int doclen = fvs.size();
            int pi = 0;
            while (pi < doclen) {
                feature = fvs.getIndexAtPosition(pi);
                topic = this.getData().get((int)di).topicSequence.getIndexAtPosition(pi);
                if (!(topic != prevtopic || withBigrams && ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) == -1)) {
                    if (sb == null) {
                        sb = new StringBuffer(String.valueOf(alphabet.lookupObject(prevfeature).toString()) + " " + alphabet.lookupObject(feature));
                    } else {
                        sb.append(" ");
                        sb.append(alphabet.lookupObject(feature));
                    }
                } else if (sb != null) {
                    String sbs = sb.toString();
                    if (phrases[prevtopic].get((Object)sbs) == 0) {
                        phrases[prevtopic].put((Object)sbs, 0);
                    }
                    phrases[prevtopic].increment((Object)sbs);
                    prevfeature = -1;
                    prevtopic = -1;
                    sb = null;
                } else {
                    prevtopic = topic;
                    prevfeature = feature;
                }
                ++pi;
            }
            ++di;
        }
        out.println("<?xml version='1.0' ?>");
        out.println("<topics>");
        double[] probs = new double[alphabet.size()];
        int ti2 = 0;
        while (ti2 < numTopics) {
            out.print("  <topic id=\"" + ti2 + "\" alpha=\"" + this.alpha[ti2] + "\" totalTokens=\"" + this.tokensPerTopic[ti2] + "\" ");
            ByteArrayOutputStream bout = new ByteArrayOutputStream();
            PrintStream pout = new PrintStream(bout);
            AugmentableFeatureVector titles = new AugmentableFeatureVector(new Alphabet());
            int type = 0;
            while (type < alphabet.size()) {
                probs[type] = (double)this.getCountFeatureTopic(type, ti2) / (double)this.getCountTokensPerTopic(ti2);
                ++type;
            }
            RankedFeatureVector rfv = new RankedFeatureVector(alphabet, probs);
            int ri = 0;
            while (ri < numWords) {
                int fi = rfv.getIndexAtRank(ri);
                pout.println("      <term weight=\"" + probs[fi] + "\" count=\"" + this.getCountFeatureTopic(fi, ti2) + "\">" + alphabet.lookupObject(fi) + "</term>");
                if (ri < 20) {
                    titles.add(alphabet.lookupObject(fi), (double)this.getCountFeatureTopic(fi, ti2));
                }
                ++ri;
            }
            Object[] keys = phrases[ti2].keys();
            int[] values = phrases[ti2].getValues();
            double[] counts = new double[keys.length];
            int i = 0;
            while (i < counts.length) {
                counts[i] = values[i];
                ++i;
            }
            double countssum = MatrixOps.sum(counts);
            Alphabet alph = new Alphabet(keys);
            rfv = new RankedFeatureVector(alph, counts);
            int max = rfv.numLocations() < numWords ? rfv.numLocations() : numWords;
            int ri2 = 0;
            while (ri2 < max) {
                int fi = rfv.getIndexAtRank(ri2);
                pout.println("      <phrase weight=\"" + counts[fi] / countssum + "\" count=\"" + values[fi] + "\">" + alph.lookupObject(fi) + "</phrase>");
                if (ri2 < 20 && values[fi] > 20) {
                    titles.add(alph.lookupObject(fi), (double)(100 * values[fi]));
                }
                ++ri2;
            }
            StringBuffer titlesStringBuffer = new StringBuffer();
            rfv = new RankedFeatureVector(titles.getAlphabet(), titles);
            int numTitles = 10;
            int ri3 = 0;
            while (ri3 < numTitles && ri3 < rfv.numLocations()) {
                if (titlesStringBuffer.indexOf(rfv.getObjectAtRank(ri3).toString()) == -1) {
                    titlesStringBuffer.append(rfv.getObjectAtRank(ri3));
                    if (ri3 < numTitles - 1) {
                        titlesStringBuffer.append(", ");
                    }
                } else {
                    ++numTitles;
                }
                ++ri3;
            }
            out.println("titles=\"" + titlesStringBuffer.toString() + "\">");
            out.print(pout.toString());
            out.println("  </topic>");
            ++ti2;
        }
        out.println("</topics>");
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new FileWriter(f)));
    }

    public void printDocumentTopics(PrintWriter pw) {
        this.printDocumentTopics(pw, 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            sortedTopics[topic] = new IDSorter(topic, topic);
            ++topic;
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        int di = 0;
        while (di < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            pw.print(di);
            pw.print(' ');
            if (this.data.get((int)di).instance.getSource() != null) {
                pw.print(this.data.get((int)di).instance.getSource());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            int docLen = currentDocTopics.length;
            int token = 0;
            while (token < docLen) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                ((IDSorter)sortedTopics[topic2]).set(topic2, (float)topicCounts[topic2] / (float)docLen);
                ++topic2;
            }
            Arrays.sort(sortedTopics);
            int i = 0;
            while (i < max) {
                if (((IDSorter)sortedTopics[i]).getWeight() < threshold) break;
                pw.print(String.valueOf(((IDSorter)sortedTopics[i]).getID()) + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
                ++i;
            }
            pw.print(" \n");
            Arrays.fill(topicCounts, 0);
            ++di;
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        int di = 0;
        while (di < this.data.size()) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)di).instance.getData();
            LabelSequence topicSequence = this.data.get((int)di).topicSequence;
            String source = "NA";
            if (this.data.get((int)di).instance.getSource() != null) {
                source = this.data.get((int)di).instance.getSource().toString();
            }
            int pi = 0;
            while (pi < topicSequence.getLength()) {
                int type = tokenSequence.getIndexAtPosition(pi);
                int topic = topicSequence.getIndexAtPosition(pi);
                out.print(di);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(pi);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
                ++pi;
            }
            ++di;
        }
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("LDAHyper.write: Exception writing LDAHyper to file " + f + ": " + e);
        }
    }

    public static LDAHyper read(File f) {
        LDAHyper lda = null;
        try {
            ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
            lda = (LDAHyper)ois.readObject();
            lda.initializeTypeTopicCounts();
            ois.close();
        }
        catch (IOException e) {
            System.err.println("Exception reading file " + f + ": " + e);
        }
        catch (ClassNotFoundException e) {
            System.err.println("Exception reading file " + f + ": " + e);
        }
        return lda;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.data);
        out.writeObject(this.alphabet);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeObject(this.alpha);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeDouble(this.smoothingOnlyMass);
        out.writeObject(this.cachedCoefficients);
        out.writeInt(this.iterationsSoFar);
        out.writeInt(this.numIterations);
        out.writeInt(this.burninPeriod);
        out.writeInt(this.saveSampleInterval);
        out.writeInt(this.optimizeInterval);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeInt(this.outputModelInterval);
        out.writeObject(this.outputModelFilename);
        out.writeInt(this.saveStateInterval);
        out.writeObject(this.stateFilename);
        out.writeObject(this.random);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
        out.writeObject(this.docLengthCounts);
        out.writeObject(this.topicDocCounts);
        int fi = 0;
        while (fi < this.numTypes) {
            out.writeObject(this.typeTopicCounts[fi]);
            ++fi;
        }
        int ti = 0;
        while (ti < this.numTopics) {
            out.writeInt(this.tokensPerTopic[ti]);
            ++ti;
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.alpha = (double[])in.readObject();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
        this.iterationsSoFar = in.readInt();
        this.numIterations = in.readInt();
        this.burninPeriod = in.readInt();
        this.saveSampleInterval = in.readInt();
        this.optimizeInterval = in.readInt();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.outputModelInterval = in.readInt();
        this.outputModelFilename = (String)in.readObject();
        this.saveStateInterval = in.readInt();
        this.stateFilename = (String)in.readObject();
        this.random = (Randoms)in.readObject();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        this.docLengthCounts = (int[])in.readObject();
        this.topicDocCounts = (int[][])in.readObject();
        int numDocs = this.data.size();
        this.numTypes = this.alphabet.size();
        this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
        int fi = 0;
        while (fi < this.numTypes) {
            this.typeTopicCounts[fi] = (TIntIntHashMap)in.readObject();
            ++fi;
        }
        this.tokensPerTopic = new int[this.numTopics];
        int ti = 0;
        while (ti < this.numTopics) {
            this.tokensPerTopic[ti] = in.readInt();
            ++ti;
        }
    }

    public double topicLabelMutualInformation() {
        double p;
        int topic;
        int label;
        if (this.data.get((int)0).instance.getTargetAlphabet() == null) {
            return 0.0;
        }
        int targetAlphabetSize = this.data.get((int)0).instance.getTargetAlphabet().size();
        int[][] topicLabelCounts = new int[this.numTopics][targetAlphabetSize];
        int[] topicCounts = new int[this.numTopics];
        int[] labelCounts = new int[targetAlphabetSize];
        int total = 0;
        int doc = 0;
        while (doc < this.data.size()) {
            label = this.data.get((int)doc).instance.getLabeling().getBestIndex();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            int token = 0;
            while (token < docTopics.length) {
                topic = docTopics[token];
                int[] nArray = topicLabelCounts[topic];
                int n = label;
                nArray[n] = nArray[n] + 1;
                int n2 = topic;
                topicCounts[n2] = topicCounts[n2] + 1;
                int n3 = label;
                labelCounts[n3] = labelCounts[n3] + 1;
                ++total;
                ++token;
            }
            ++doc;
        }
        double topicEntropy = 0.0;
        double labelEntropy = 0.0;
        double jointEntropy = 0.0;
        double log2 = Math.log(2.0);
        topic = 0;
        while (topic < topicCounts.length) {
            if (topicCounts[topic] != 0) {
                p = (double)topicCounts[topic] / (double)total;
                topicEntropy -= p * Math.log(p) / log2;
            }
            ++topic;
        }
        label = 0;
        while (label < labelCounts.length) {
            if (labelCounts[label] != 0) {
                p = (double)labelCounts[label] / (double)total;
                labelEntropy -= p * Math.log(p) / log2;
            }
            ++label;
        }
        topic = 0;
        while (topic < topicCounts.length) {
            label = 0;
            while (label < labelCounts.length) {
                if (topicLabelCounts[topic][label] != 0) {
                    p = (double)topicLabelCounts[topic][label] / (double)total;
                    jointEntropy -= p * Math.log(p) / log2;
                }
                ++label;
            }
            ++topic;
        }
        return topicEntropy + labelEntropy - jointEntropy;
    }

    public double empiricalLikelihood(int numSamples, InstanceList testing) {
        int doc;
        double[][] likelihoods = new double[testing.size()][numSamples];
        double[] multinomial = new double[this.numTypes];
        Dirichlet topicPrior = new Dirichlet(this.alpha);
        int sample = 0;
        while (sample < numSamples) {
            int type;
            double[] topicDistribution = topicPrior.nextDistribution();
            Arrays.fill(multinomial, 0.0);
            int topic = 0;
            while (topic < this.numTopics) {
                type = 0;
                while (type < this.numTypes) {
                    int n = type;
                    multinomial[n] = multinomial[n] + topicDistribution[topic] * (this.beta + (double)this.typeTopicCounts[type].get(topic)) / (this.betaSum + (double)this.tokensPerTopic[topic]);
                    ++type;
                }
                ++topic;
            }
            type = 0;
            while (type < this.numTypes) {
                assert (multinomial[type] > 0.0);
                multinomial[type] = Math.log(multinomial[type]);
                ++type;
            }
            doc = 0;
            while (doc < testing.size()) {
                FeatureSequence fs = (FeatureSequence)((Instance)testing.get(doc)).getData();
                int seqLen = fs.getLength();
                int token = 0;
                while (token < seqLen) {
                    type = fs.getIndexAtPosition(token);
                    if (type < this.numTypes) {
                        double[] dArray = likelihoods[doc];
                        int n = sample;
                        dArray[n] = dArray[n] + multinomial[type];
                    }
                    ++token;
                }
                ++doc;
            }
            ++sample;
        }
        double averageLogLikelihood = 0.0;
        double logNumSamples = Math.log(numSamples);
        doc = 0;
        while (doc < testing.size()) {
            double max = Double.NEGATIVE_INFINITY;
            sample = 0;
            while (sample < numSamples) {
                if (likelihoods[doc][sample] > max) {
                    max = likelihoods[doc][sample];
                }
                ++sample;
            }
            double sum = 0.0;
            sample = 0;
            while (sample < numSamples) {
                sum += Math.exp(likelihoods[doc][sample] - max);
                ++sample;
            }
            averageLogLikelihood += Math.log(sum) + max - logNumSamples;
            ++doc;
        }
        return averageLogLikelihood;
    }

    public double modelLogLikelihood() {
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            topicLogGammas[topic] = Dirichlet.logGammaStirling(this.alpha[topic]);
            ++topic;
        }
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            int token = 0;
            while (token < docTopics.length) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                if (topicCounts[topic2] > 0) {
                    logLikelihood += Dirichlet.logGammaStirling(this.alpha[topic2] + (double)topicCounts[topic2]) - topicLogGammas[topic2];
                }
                ++topic2;
            }
            logLikelihood -= Dirichlet.logGammaStirling(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGammaStirling(this.alphaSum);
        int nonZeroTypeTopics = 0;
        int type = 0;
        while (type < this.numTypes) {
            int[] usedTopics;
            int[] nArray = usedTopics = this.typeTopicCounts[type].keys();
            int n = usedTopics.length;
            int n2 = 0;
            while (n2 < n) {
                int topic3 = nArray[n2];
                int count = this.typeTopicCounts[type].get(topic3);
                if (count > 0) {
                    ++nonZeroTypeTopics;
                    logLikelihood += Dirichlet.logGammaStirling(this.beta + (double)count);
                }
                ++n2;
            }
            ++type;
        }
        int topic4 = 0;
        while (topic4 < this.numTopics) {
            logLikelihood -= Dirichlet.logGammaStirling(this.beta * (double)this.numTopics + (double)this.tokensPerTopic[topic4]);
            ++topic4;
        }
        return logLikelihood += Dirichlet.logGammaStirling(this.beta * (double)this.numTopics) - Dirichlet.logGammaStirling(this.beta) * (double)nonZeroTypeTopics;
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        InstanceList testing = args.length > 2 ? InstanceList.load(new File(args[2])) : null;
        LDAHyper lda = new LDAHyper(numTopics, 50.0, 0.01);
        lda.printLogLikelihood = true;
        lda.setTopicDisplay(50, 7);
        lda.addInstances(training);
        lda.estimate();
    }

    public class Topication
    implements Serializable {
        public Instance instance;
        public LDAHyper model;
        public LabelSequence topicSequence;
        public Labeling topicDistribution;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;

        public Topication(Instance instance, LDAHyper model, LabelSequence topicSequence) {
            this.instance = instance;
            this.model = model;
            this.topicSequence = topicSequence;
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(0);
            out.writeObject(this.instance);
            out.writeObject(this.model);
            out.writeObject(this.topicSequence);
            out.writeObject(this.topicDistribution);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            this.instance = (Instance)in.readObject();
            this.model = (LDAHyper)in.readObject();
            this.topicSequence = (LabelSequence)in.readObject();
            this.topicDistribution = (Labeling)in.readObject();
        }
    }
}

