/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.util.Arrays;

public class TopicInferencer
implements Serializable {
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected int numTypes;
    protected double[] alpha;
    protected double beta;
    protected double betaSum;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    Alphabet alphabet;
    protected Randoms random = null;
    double smoothingOnlyMass = 0.0;
    double[] cachedCoefficients;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public TopicInferencer(int[][] typeTopicCounts, int[] tokensPerTopic, Alphabet alphabet, double[] alpha, double beta, double betaSum) {
        this.tokensPerTopic = tokensPerTopic;
        this.typeTopicCounts = typeTopicCounts;
        this.alphabet = alphabet;
        this.numTopics = tokensPerTopic.length;
        this.numTypes = typeTopicCounts.length;
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(this.numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.alpha = alpha;
        this.beta = beta;
        this.betaSum = betaSum;
        this.cachedCoefficients = new double[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            this.smoothingOnlyMass += alpha[topic] * beta / ((double)tokensPerTopic[topic] + betaSum);
            this.cachedCoefficients[topic] = alpha[topic] / ((double)tokensPerTopic[topic] + betaSum);
            ++topic;
        }
        this.random = new Randoms();
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public double[] getSampledDistribution(Instance instance, int numIterations, int thinning, int burnIn) {
        int topic;
        int[] currentTypeTopicCounts;
        int type;
        FeatureSequence tokens = (FeatureSequence)instance.getData();
        int docLength = tokens.size();
        int[] topics = new int[docLength];
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        int position = 0;
        while (position < docLength) {
            type = tokens.getIndexAtPosition(position);
            if (type < this.numTypes && this.typeTopicCounts[type].length != 0) {
                currentTypeTopicCounts = this.typeTopicCounts[type];
                topics[position] = currentTypeTopicCounts[0] & this.topicMask;
                int n = topics[position];
                localTopicCounts[n] = localTopicCounts[n] + 1;
            }
            ++position;
        }
        int denseIndex = 0;
        int topic2 = 0;
        while (topic2 < this.numTopics) {
            if (localTopicCounts[topic2] != 0) {
                localTopicIndex[denseIndex] = topic2;
                ++denseIndex;
            }
            ++topic2;
        }
        int nonZeroTopics = denseIndex;
        double topicBetaMass = 0.0;
        denseIndex = 0;
        while (denseIndex < nonZeroTopics) {
            int topic3 = localTopicIndex[denseIndex];
            int n = localTopicCounts[topic3];
            topicBetaMass += this.beta * (double)n / ((double)this.tokensPerTopic[topic3] + this.betaSum);
            this.cachedCoefficients[topic3] = (this.alpha[topic3] + (double)n) / ((double)this.tokensPerTopic[topic3] + this.betaSum);
            ++denseIndex;
        }
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        double[] result = new double[this.numTopics];
        double sum = 0.0;
        int iteration = 1;
        while (iteration <= numIterations) {
            int position2 = 0;
            while (position2 < docLength) {
                type = tokens.getIndexAtPosition(position2);
                if (type < this.numTypes && this.typeTopicCounts[type].length != 0) {
                    double sample;
                    int currentValue;
                    int currentTopic;
                    int oldTopic = topics[position2];
                    currentTypeTopicCounts = this.typeTopicCounts[type];
                    topicBetaMass -= this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    int n = oldTopic;
                    localTopicCounts[n] = localTopicCounts[n] - 1;
                    if (localTopicCounts[oldTopic] == 0) {
                        denseIndex = 0;
                        while (localTopicIndex[denseIndex] != oldTopic) {
                            ++denseIndex;
                        }
                        while (denseIndex < nonZeroTopics) {
                            if (denseIndex < localTopicIndex.length - 1) {
                                localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                            }
                            ++denseIndex;
                        }
                        --nonZeroTopics;
                    }
                    topicBetaMass += this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                    if (this.cachedCoefficients[oldTopic] <= 0.0) {
                        System.out.println("zero or less coefficient: " + oldTopic + " = (" + this.alpha[oldTopic] + " + " + localTopicCounts[oldTopic] + ") / ( " + this.tokensPerTopic[oldTopic] + " + " + this.betaSum + " );");
                    }
                    int index = 0;
                    boolean alreadyDecremented = false;
                    topicTermMass = 0.0;
                    while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                        currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                        currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                        double score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                        topicTermMass += score;
                        topicTermScores[index] = score;
                        ++index;
                    }
                    double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
                    int newTopic = -1;
                    if (sample < topicTermMass) {
                        int i = -1;
                        while (sample > 0.0) {
                            sample -= topicTermScores[++i];
                        }
                        newTopic = currentTypeTopicCounts[i] & this.topicMask;
                    } else if ((sample -= topicTermMass) < topicBetaMass) {
                        sample /= this.beta;
                        denseIndex = 0;
                        while (denseIndex < nonZeroTopics) {
                            int topic4 = localTopicIndex[denseIndex];
                            if ((sample -= (double)localTopicCounts[topic4] / ((double)this.tokensPerTopic[topic4] + this.betaSum)) <= 0.0) {
                                newTopic = topic4;
                                break;
                            }
                            ++denseIndex;
                        }
                    } else {
                        sample -= topicBetaMass;
                        sample /= this.beta;
                        newTopic = 0;
                        sample -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                        while (sample > 0.0) {
                            if (++newTopic >= this.numTopics) {
                                index = 0;
                                while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                                    currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                                    currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                                    System.out.println(String.valueOf(currentTopic) + "\t" + currentValue + "\t" + topicTermScores[index] + "\t" + this.cachedCoefficients[currentTopic]);
                                    ++index;
                                }
                            }
                            sample -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                        }
                    }
                    topics[position2] = newTopic;
                    topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    int n2 = newTopic;
                    localTopicCounts[n2] = localTopicCounts[n2] + 1;
                    if (localTopicCounts[newTopic] == 1) {
                        denseIndex = nonZeroTopics;
                        while (denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                            --denseIndex;
                        }
                        localTopicIndex[denseIndex] = newTopic;
                        ++nonZeroTopics;
                    }
                    this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                }
                ++position2;
            }
            if (iteration > burnIn && (iteration - burnIn) % thinning == 0) {
                int topic5 = 0;
                while (topic5 < this.numTopics) {
                    int n = topic5;
                    result[n] = result[n] + (this.alpha[topic5] + (double)localTopicCounts[topic5]);
                    sum += this.alpha[topic5] + (double)localTopicCounts[topic5];
                    ++topic5;
                }
            }
            ++iteration;
        }
        denseIndex = 0;
        while (denseIndex < nonZeroTopics) {
            topic = localTopicIndex[denseIndex];
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++denseIndex;
        }
        if (sum == 0.0) {
            topic = 0;
            while (topic < this.numTopics) {
                result[topic] = this.alpha[topic] + (double)localTopicCounts[topic];
                sum += result[topic];
                ++topic;
            }
        }
        topic = 0;
        while (topic < this.numTopics) {
            int n = topic++;
            result[n] = result[n] / sum;
        }
        return result;
    }

    public void writeInferredDistributions(InstanceList instances, File distributionsFile, int numIterations, int thinning, int burnIn, double threshold, int max) throws IOException {
        PrintWriter out = new PrintWriter(distributionsFile);
        out.print("#doc name topic proportion ...\n");
        Object[] sortedTopics = new IDSorter[this.numTopics];
        int topic = 0;
        while (topic < this.numTopics) {
            sortedTopics[topic] = new IDSorter(topic, topic);
            ++topic;
        }
        if (max < 0 || max > this.numTopics) {
            max = this.numTopics;
        }
        int doc = 0;
        for (Instance instance : instances) {
            int topic2;
            StringBuilder builder = new StringBuilder();
            double[] topicDistribution = this.getSampledDistribution(instance, numIterations, thinning, burnIn);
            builder.append(doc);
            builder.append("\t");
            if (instance.getName() != null) {
                builder.append(instance.getName());
            } else {
                builder.append("no-name");
            }
            if (threshold > 0.0) {
                topic2 = 0;
                while (topic2 < this.numTopics) {
                    ((IDSorter)sortedTopics[topic2]).set(topic2, topicDistribution[topic2]);
                    ++topic2;
                }
                Arrays.sort(sortedTopics);
                int i = 0;
                while (i < max) {
                    if (!(((IDSorter)sortedTopics[i]).getWeight() < threshold)) {
                        builder.append("\t" + ((IDSorter)sortedTopics[i]).getID() + "\t" + ((IDSorter)sortedTopics[i]).getWeight());
                        ++i;
                        continue;
                    }
                    break;
                }
            } else {
                topic2 = 0;
                while (topic2 < this.numTopics) {
                    builder.append("\t" + topicDistribution[topic2]);
                    ++topic2;
                }
            }
            out.println(builder);
            ++doc;
        }
        out.close();
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(0);
        out.writeObject(this.alphabet);
        out.writeInt(this.numTopics);
        out.writeInt(this.topicMask);
        out.writeInt(this.topicBits);
        out.writeInt(this.numTypes);
        out.writeObject(this.alpha);
        out.writeDouble(this.beta);
        out.writeDouble(this.betaSum);
        out.writeObject(this.typeTopicCounts);
        out.writeObject(this.tokensPerTopic);
        out.writeObject(this.random);
        out.writeDouble(this.smoothingOnlyMass);
        out.writeObject(this.cachedCoefficients);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version2 = in.readInt();
        this.alphabet = (Alphabet)in.readObject();
        this.numTopics = in.readInt();
        this.topicMask = in.readInt();
        this.topicBits = in.readInt();
        this.numTypes = in.readInt();
        this.alpha = (double[])in.readObject();
        this.beta = in.readDouble();
        this.betaSum = in.readDouble();
        this.typeTopicCounts = (int[][])in.readObject();
        this.tokensPerTopic = (int[])in.readObject();
        this.random = (Randoms)in.readObject();
        this.smoothingOnlyMass = in.readDouble();
        this.cachedCoefficients = (double[])in.readObject();
    }

    public static TopicInferencer read(File f) throws Exception {
        TopicInferencer inferencer = null;
        ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
        inferencer = (TopicInferencer)ois.readObject();
        ois.close();
        return inferencer;
    }
}

