/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;

public class SimpleLDA
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(SimpleLDA.class.getName());
    protected ArrayList<TopicAssignment> data = new ArrayList();
    protected Alphabet alphabet;
    protected LabelAlphabet topicAlphabet;
    protected int numTopics;
    protected int numTypes;
    protected double alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected int[] oneDocTopicCounts;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    public int showTopicsInterval = 50;
    public int wordsPerTopic = 10;
    protected Randoms random;
    protected NumberFormat formatter;
    protected boolean printLogLikelihood = false;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public SimpleLDA(int numberOfTopics) {
        this(numberOfTopics, numberOfTopics, 0.01);
    }

    public SimpleLDA(int numberOfTopics, double alphaSum, double beta) {
        this(numberOfTopics, alphaSum, beta, new Randoms());
    }

    private static LabelAlphabet newLabelAlphabet(int numTopics) {
        LabelAlphabet ret = new LabelAlphabet();
        int i = 0;
        while (i < numTopics) {
            ret.lookupIndex("topic" + i);
            ++i;
        }
        return ret;
    }

    public SimpleLDA(int numberOfTopics, double alphaSum, double beta, Randoms random) {
        this(SimpleLDA.newLabelAlphabet(numberOfTopics), alphaSum, beta, random);
    }

    public SimpleLDA(LabelAlphabet topicAlphabet, double alphaSum, double beta, Randoms random) {
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();
        this.alphaSum = alphaSum;
        this.alpha = alphaSum / (double)this.numTopics;
        this.beta = beta;
        this.random = random;
        this.oneDocTopicCounts = new int[this.numTopics];
        this.tokensPerTopic = new int[this.numTopics];
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        logger.info("Simple LDA: " + this.numTopics + " topics");
    }

    public Alphabet getAlphabet() {
        return this.alphabet;
    }

    public LabelAlphabet getTopicAlphabet() {
        return this.topicAlphabet;
    }

    public int getNumTopics() {
        return this.numTopics;
    }

    public ArrayList<TopicAssignment> getData() {
        return this.data;
    }

    public void setTopicDisplay(int interval, int n) {
        this.showTopicsInterval = interval;
        this.wordsPerTopic = n;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getTopicTotals() {
        return this.tokensPerTopic;
    }

    public void addInstances(InstanceList training) {
        this.alphabet = training.getDataAlphabet();
        this.numTypes = this.alphabet.size();
        this.betaSum = this.beta * (double)this.numTypes;
        this.typeTopicCounts = new int[this.numTypes][this.numTopics];
        int doc = 0;
        for (Instance instance : training) {
            ++doc;
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            LabelSequence topicSequence = new LabelSequence(this.topicAlphabet, new int[tokens.size()]);
            int[] topics = topicSequence.getFeatures();
            int position = 0;
            while (position < tokens.size()) {
                int topic;
                topics[position] = topic = this.random.nextInt(this.numTopics);
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                int type = tokens.getIndexAtPosition(position);
                int[] nArray = this.typeTopicCounts[type];
                int n2 = topic;
                nArray[n2] = nArray[n2] + 1;
                ++position;
            }
            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            this.data.add(t);
        }
    }

    public void sample(int iterations) throws IOException {
        int iteration = 1;
        while (iteration <= iterations) {
            long iterationStart = System.currentTimeMillis();
            int doc = 0;
            while (doc < this.data.size()) {
                FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
                LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence);
                ++doc;
            }
            long elapsedMillis = System.currentTimeMillis() - iterationStart;
            logger.fine(String.valueOf(iteration) + "\t" + elapsedMillis + "ms\t");
            if (this.showTopicsInterval != 0 && iteration % this.showTopicsInterval == 0) {
                logger.info("<" + iteration + "> Log Likelihood: " + this.modelLogLikelihood() + "\n" + this.topWords(this.wordsPerTopic));
            }
            ++iteration;
        }
    }

    protected void sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence) {
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        int[] localTopicCounts = new int[this.numTopics];
        int position = 0;
        while (position < docLength) {
            int n = oneDocTopics[position];
            localTopicCounts[n] = localTopicCounts[n] + 1;
            ++position;
        }
        double[] topicTermScores = new double[this.numTopics];
        int position2 = 0;
        while (position2 < docLength) {
            int type = tokenSequence.getIndexAtPosition(position2);
            int oldTopic = oneDocTopics[position2];
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            int n = oldTopic;
            localTopicCounts[n] = localTopicCounts[n] - 1;
            int n2 = oldTopic;
            this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
            assert (this.tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
            int n3 = oldTopic;
            currentTypeTopicCounts[n3] = currentTypeTopicCounts[n3] - 1;
            double sum = 0.0;
            int topic = 0;
            while (topic < this.numTopics) {
                double score = (this.alpha + (double)localTopicCounts[topic]) * ((this.beta + (double)currentTypeTopicCounts[topic]) / (this.betaSum + (double)this.tokensPerTopic[topic]));
                sum += score;
                topicTermScores[topic] = score;
                ++topic;
            }
            double sample = this.random.nextUniform() * sum;
            int newTopic = -1;
            while (sample > 0.0) {
                sample -= topicTermScores[++newTopic];
            }
            if (newTopic == -1) {
                throw new IllegalStateException("SimpleLDA: New topic not sampled.");
            }
            oneDocTopics[position2] = newTopic;
            int n4 = newTopic;
            localTopicCounts[n4] = localTopicCounts[n4] + 1;
            int n5 = newTopic;
            this.tokensPerTopic[n5] = this.tokensPerTopic[n5] + 1;
            int n6 = newTopic;
            currentTypeTopicCounts[n6] = currentTypeTopicCounts[n6] + 1;
            ++position2;
        }
    }

    public double modelLogLikelihood() {
        int topic;
        double logLikelihood = 0.0;
        int[] topicCounts = new int[this.numTopics];
        double[] topicLogGammas = new double[this.numTopics];
        int topic2 = 0;
        while (topic2 < this.numTopics) {
            topicLogGammas[topic2] = Dirichlet.logGamma(this.alpha);
            ++topic2;
        }
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] docTopics = topicSequence.getFeatures();
            int token = 0;
            while (token < docTopics.length) {
                int n = docTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            topic = 0;
            while (topic < this.numTopics) {
                if (topicCounts[topic] > 0) {
                    logLikelihood += Dirichlet.logGamma(this.alpha + (double)topicCounts[topic]) - topicLogGammas[topic];
                }
                ++topic;
            }
            logLikelihood -= Dirichlet.logGamma(this.alphaSum + (double)docTopics.length);
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
        logLikelihood += (double)this.data.size() * Dirichlet.logGamma(this.alphaSum);
        double logGammaBeta = Dirichlet.logGamma(this.beta);
        int type = 0;
        while (type < this.numTypes) {
            topicCounts = this.typeTopicCounts[type];
            int topic3 = 0;
            while (topic3 < this.numTopics) {
                if (topicCounts[topic3] != 0 && Double.isNaN(logLikelihood += Dirichlet.logGamma(this.beta + (double)topicCounts[topic3]) - logGammaBeta)) {
                    System.out.println(topicCounts[topic3]);
                    System.exit(1);
                }
                ++topic3;
            }
            ++type;
        }
        topic = 0;
        while (topic < this.numTopics) {
            if (Double.isNaN(logLikelihood -= Dirichlet.logGamma(this.beta * (double)this.numTypes + (double)this.tokensPerTopic[topic]))) {
                System.out.println("after topic " + topic + " " + this.tokensPerTopic[topic]);
                System.exit(1);
            }
            ++topic;
        }
        if (Double.isNaN(logLikelihood += (double)this.numTopics * Dirichlet.logGamma(this.beta * (double)this.numTypes))) {
            System.out.println("at the end");
            System.exit(1);
        }
        return logLikelihood;
    }

    public String topWords(int numWords) {
        StringBuilder output = new StringBuilder();
        Object[] sortedWords = new IDSorter[this.numTypes];
        int topic = 0;
        while (topic < this.numTopics) {
            int type = 0;
            while (type < this.numTypes) {
                sortedWords[type] = new IDSorter(type, this.typeTopicCounts[type][topic]);
                ++type;
            }
            Arrays.sort(sortedWords);
            output.append(String.valueOf(topic) + "\t" + this.tokensPerTopic[topic] + "\t");
            int i = 0;
            while (i < numWords) {
                output.append(this.alphabet.lookupObject(((IDSorter)sortedWords[i]).getID()) + " ");
                ++i;
            }
            output.append("\n");
            ++topic;
        }
        return output.toString();
    }

    public void printDocumentTopics(File file, double threshold, int max) throws IOException {
        PrintWriter out = new PrintWriter(file);
        out.print("#doc source topic proportion ...\n");
        int[] topicCounts = new int[this.numTopics];
        Object[] sortedTopics = new IDSorter[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            sortedTopics[topic] = new IDSorter(topic, topic);
            ++topic;
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        int doc = 0;
        while (doc < this.data.size()) {
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            int[] currentDocTopics = topicSequence.getFeatures();
            out.print(doc);
            out.print(' ');
            if (this.data.get((int)doc).instance.getSource() != null) {
                out.print(this.data.get((int)doc).instance.getSource());
            } else {
                out.print("null-source");
            }
            out.print(' ');
            int docLen = currentDocTopics.length;
            int token = 0;
            while (token < docLen) {
                int n = currentDocTopics[token];
                topicCounts[n] = topicCounts[n] + 1;
                ++token;
            }
            int topic2 = 0;
            while (topic2 < this.numTopics) {
                ((IDSorter)sortedTopics[topic2]).set(topic2, (float)topicCounts[topic2] / (float)docLen);
                ++topic2;
            }
            Arrays.sort(sortedTopics);
            int i = 0;
            while (i < max) {
                if (((IDSorter)sortedTopics[i]).getWeight() < threshold) break;
                out.print(String.valueOf(((IDSorter)sortedTopics[i]).getID()) + " " + ((IDSorter)sortedTopics[i]).getWeight() + " ");
                ++i;
            }
            out.print(" \n");
            Arrays.fill(topicCounts, 0);
            ++doc;
        }
    }

    public void printState(File f) throws IOException {
        PrintStream out = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(f))));
        this.printState(out);
        out.close();
    }

    public void printState(PrintStream out) {
        out.println("#doc source pos typeindex type topic");
        int doc = 0;
        while (doc < this.data.size()) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            String source = "NA";
            if (this.data.get((int)doc).instance.getSource() != null) {
                source = this.data.get((int)doc).instance.getSource().toString();
            }
            int position = 0;
            while (position < topicSequence.getLength()) {
                int type = tokenSequence.getIndexAtPosition(position);
                int topic = topicSequence.getIndexAtPosition(position);
                out.print(doc);
                out.print(' ');
                out.print(source);
                out.print(' ');
                out.print(position);
                out.print(' ');
                out.print(type);
                out.print(' ');
                out.print(this.alphabet.lookupObject(type));
                out.print(' ');
                out.print(topic);
                out.println();
                ++position;
            }
            ++doc;
        }
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
            oos.writeObject(this);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Exception writing file " + f + ": " + e);
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.data);
        out.writeObject(this.alphabet);
        out.writeObject(this.topicAlphabet);
        out.writeInt(this.numTopics);
        out.writeObject(this.alpha);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeInt(this.showTopicsInterval);
        out.writeInt(this.wordsPerTopic);
        out.writeObject(this.random);
        out.writeObject(this.formatter);
        out.writeBoolean(this.printLogLikelihood);
        out.writeObject(this.typeTopicCounts);
        int ti = 0;
        while (ti < this.numTopics) {
            out.writeInt(this.tokensPerTopic[ti]);
            ++ti;
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version2 = in.readInt();
        this.data = (ArrayList)in.readObject();
        this.alphabet = (Alphabet)in.readObject();
        this.topicAlphabet = (LabelAlphabet)in.readObject();
        this.numTopics = in.readInt();
        this.alpha = in.readDouble();
        this.alphaSum = this.alpha * (double)this.numTopics;
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.showTopicsInterval = in.readInt();
        this.wordsPerTopic = in.readInt();
        this.random = (Randoms)in.readObject();
        this.formatter = (NumberFormat)in.readObject();
        this.printLogLikelihood = in.readBoolean();
        int numDocs = this.data.size();
        this.numTypes = this.alphabet.size();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = new int[this.numTopics];
        int ti = 0;
        while (ti < this.numTopics) {
            this.tokensPerTopic[ti] = in.readInt();
            ++ti;
        }
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        SimpleLDA lda = new SimpleLDA(numTopics, 50.0, 0.01);
        lda.addInstances(training);
        lda.sample(1000);
    }
}

