/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetFactory;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;

public class NPTopicModel
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(NPTopicModel.class.getName());
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet = AlphabetFactory.labelAlphabetOfSize(1);
    protected int maxTopic;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double gamma;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected TIntIntHashMap[] typeTopicCounts;
    protected TIntIntHashMap tokensPerTopic;
    protected TIntIntHashMap docsPerTopic;
    protected int totalDocTopics = 0;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;

    public NPTopicModel(double alpha, double gamma, double beta) {
        this.alpha = alpha;
        this.gamma = gamma;
        this.beta = beta;
        this.random = new Randoms();
        this.tokensPerTopic = new TIntIntHashMap();
        this.docsPerTopic = new TIntIntHashMap();
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Non-Parametric LDA");
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void addInstances(InstanceList training, int initialTopics) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new TIntIntHashMap[this.numTypes];
        int type = 0;
        while (type < this.numTypes) {
            this.typeTopicCounts[type] = new TIntIntHashMap();
            ++type;
        }
        this.numTopics = initialTopics;
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            TIntIntHashMap topicCounts = new TIntIntHashMap();
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            int position = 0;
            while (position < tokens.size()) {
                int topic = this.random.nextInt(this.numTopics);
                this.tokensPerTopic.adjustOrPutValue(topic, 1, 1);
                topics[position] = topic;
                if (!topicCounts.containsKey(topic)) {
                    this.docsPerTopic.adjustOrPutValue(topic, 1, 1);
                    ++this.totalDocTopics;
                    topicCounts.put(topic, 1);
                } else {
                    topicCounts.adjustValue(topic, 1);
                }
                int type2 = tokens.getIndexAtPosition(position);
                this.typeTopicCounts[type2].adjustOrPutValue(topic, 1, 1);
                ++position;
            }
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
        this.maxTopic = this.numTopics - 1;
    }

    public void sample(int iterations) throws IOException {
        int iteration = 1;
        while (iteration <= iterations) {
            long iterationStart = System.currentTimeMillis();
            int doc = 0;
            while (doc < this.data.size()) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence);
                ++doc;
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            logger.info(String.valueOf(iteration) + "\t" + elapsedMillis + "ms\t" + this.numTopics);
            if (this.showTopicsInterval != 0 && iteration % this.showTopicsInterval == 0) {
                logger.info("<" + iteration + "> #Topics: " + this.numTopics + "\n" + this.topWords(this.wordsPerTopic));
            }
            ++iteration;
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence) {
        int[] topics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        TIntIntHashMap localTopicCounts = new TIntIntHashMap();
        int position = 0;
        while (position < docLength) {
            localTopicCounts.adjustOrPutValue(topics[position], 1, 1);
            ++position;
        }
        double[] topicTermScores = new double[this.numTopics + 1];
        int[] allTopics = this.docsPerTopic.keys();
        int position2 = 0;
        while (position2 < docLength) {
            int type = tokenSequence.getIndexAtPosition(position2);
            int oldTopic = topics[position2];
            TIntIntHashMap currentTypeTopicCounts = this.typeTopicCounts[type];
            int currentCount = localTopicCounts.get(oldTopic);
            if (currentCount == 1) {
                localTopicCounts.remove(oldTopic);
                int docCount = this.docsPerTopic.get(oldTopic);
                if (docCount == 1) {
                    assert (this.tokensPerTopic.get(oldTopic) == 1);
                    this.docsPerTopic.remove(oldTopic);
                    --this.totalDocTopics;
                    this.tokensPerTopic.remove(oldTopic);
                    --this.numTopics;
                    allTopics = this.docsPerTopic.keys();
                    topicTermScores = new double[this.numTopics + 1];
                } else {
                    this.docsPerTopic.adjustValue(oldTopic, -1);
                    --this.totalDocTopics;
                    this.tokensPerTopic.adjustValue(oldTopic, -1);
                }
            } else {
                localTopicCounts.adjustValue(oldTopic, -1);
                this.tokensPerTopic.adjustValue(oldTopic, -1);
            }
            if (currentTypeTopicCounts.get(oldTopic) == 1) {
                currentTypeTopicCounts.remove(oldTopic);
            } else {
                currentTypeTopicCounts.adjustValue(oldTopic, -1);
            }
            double sum = 0.0;
            int i = 0;
            while (i < this.numTopics) {
                int topic = allTopics[i];
                topicTermScores[i] = ((double)localTopicCounts.get(topic) + this.alpha * ((double)this.docsPerTopic.get(topic) / ((double)this.totalDocTopics + this.gamma))) * ((double)currentTypeTopicCounts.get(topic) + this.beta) / ((double)this.tokensPerTopic.get(topic) + this.betaSum);
                sum += topicTermScores[i];
                ++i;
            }
            topicTermScores[this.numTopics] = this.alpha * this.gamma / ((double)this.numTypes * ((double)this.totalDocTopics + this.gamma));
            double sample = this.random.nextUniform() * (sum += topicTermScores[this.numTopics]);
            int newTopic = -1;
            int i2 = -1;
            while (sample > 0.0) {
                sample -= topicTermScores[++i2];
            }
            if (i2 < this.numTopics) {
                topics[position2] = newTopic = allTopics[i2];
                currentTypeTopicCounts.adjustOrPutValue(newTopic, 1, 1);
                this.tokensPerTopic.adjustValue(newTopic, 1);
                if (localTopicCounts.containsKey(newTopic)) {
                    localTopicCounts.adjustValue(newTopic, 1);
                } else {
                    localTopicCounts.put(newTopic, 1);
                    this.docsPerTopic.adjustValue(newTopic, 1);
                    ++this.totalDocTopics;
                }
            } else {
                this.maxTopic = newTopic = this.maxTopic + 1;
                ++this.numTopics;
                topics[position2] = newTopic;
                localTopicCounts.put(newTopic, 1);
                this.docsPerTopic.put(newTopic, 1);
                ++this.totalDocTopics;
                currentTypeTopicCounts.put(newTopic, 1);
                this.tokensPerTopic.put(newTopic, 1);
                allTopics = this.docsPerTopic.keys();
                topicTermScores = new double[this.numTopics + 1];
            }
            ++position2;
        }
    }

    public String topWords(int numWords) {
        StringBuilder output = new StringBuilder();
        Object[] sortedWords = new IDSorter[this.numTypes];
        int[] nArray = this.docsPerTopic.keys();
        int n = nArray.length;
        int n2 = 0;
        while (n2 < n) {
            int topic = nArray[n2];
            int type = 0;
            while (type < this.numTypes) {
                sortedWords[type] = new IDSorter(type, this.typeTopicCounts[type].get(topic));
                ++type;
            }
            Arrays.sort(sortedWords);
            output.append(String.valueOf(topic) + "\t" + this.tokensPerTopic.get(topic) + "\t");
            int i = 0;
            while (i < numWords) {
                if (((IDSorter)sortedWords[i]).getWeight() < 1.0) break;
                output.append(this.alphabet.lookupObject(((IDSorter)sortedWords[i]).getID()) + " ");
                ++i;
            }
            output.append("\n");
            ++n2;
        }
        return output.toString();
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            if (this.data.get((int)doc).instance.getSource() != null) {
                source = this.data.get((int)doc).instance.getSource().toString();
            }
            int position = 0;
            while (position < topicSequence.getLength()) {
                int type = tokenSequence.getIndexAtPosition(position);
                int topic = topicSequence.getIndexAtPosition(position);
                out.print(doc);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(position);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
                ++position;
            }
            ++doc;
        }
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        NPTopicModel lda = new NPTopicModel(5.0, 10.0, 0.1);
        lda.addInstances(training, numTopics);
        lda.sample(1000);
    }
}

