/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TObjectDoubleHashMap;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;

public class HierarchicalLDA {
    InstanceList instances;
    InstanceList testing;
    NCRPNode rootNode;
    NCRPNode node;
    int numLevels;
    int numDocuments;
    int numTypes;
    double alpha = 10.0;
    double gamma = 1.0;
    double eta = 0.1;
    double etaSum;
    int[][] levels;
    NCRPNode[] documentLeaves;
    int totalNodes = 0;
    String stateFile = "hlda.state";
    Randoms random;
    boolean showProgress = true;
    int displayTopicsInterval = 50;
    int numWordsToDisplay = 10;

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    public void setEta(double eta) {
        this.eta = eta;
    }

    public void setStateFile(String stateFile) {
        this.stateFile = stateFile;
    }

    public void setTopicDisplay(int interval, int words) {
        this.displayTopicsInterval = interval;
        this.numWordsToDisplay = words;
    }

    public void setProgressDisplay(boolean showProgress) {
        this.showProgress = showProgress;
    }

    public void initialize(InstanceList instances, InstanceList testing, int numLevels, Randoms random) {
        this.instances = instances;
        this.testing = testing;
        this.numLevels = numLevels;
        this.random = random;
        if (!(((Instance)instances.get(0)).getData() instanceof FeatureSequence)) {
            throw new IllegalArgumentException("Input must be a FeatureSequence, using the --feature-sequence option when impoting data, for example");
        }
        this.numDocuments = instances.size();
        this.numTypes = instances.getDataAlphabet().size();
        this.etaSum = this.eta * (double)this.numTypes;
        NCRPNode[] path = new NCRPNode[numLevels];
        this.rootNode = new NCRPNode(this.numTypes);
        this.levels = new int[this.numDocuments][];
        this.documentLeaves = new NCRPNode[this.numDocuments];
        int doc = 0;
        while (doc < this.numDocuments) {
            FeatureSequence fs = (FeatureSequence)((Instance)instances.get(doc)).getData();
            int seqLen = fs.getLength();
            path[0] = this.rootNode;
            ++this.rootNode.customers;
            int level = 1;
            while (level < numLevels) {
                path[level] = path[level - 1].select();
                ++path[level].customers;
                ++level;
            }
            this.node = path[numLevels - 1];
            this.levels[doc] = new int[seqLen];
            this.documentLeaves[doc] = this.node;
            int token = 0;
            while (token < seqLen) {
                int type = fs.getIndexAtPosition(token);
                this.levels[doc][token] = random.nextInt(numLevels);
                this.node = path[this.levels[doc][token]];
                ++this.node.totalTokens;
                int n = type;
                this.node.typeCounts[n] = this.node.typeCounts[n] + 1;
                ++token;
            }
            ++doc;
        }
    }

    public void estimate(int numIterations) {
        int iteration = 1;
        while (iteration <= numIterations) {
            int doc = 0;
            while (doc < this.numDocuments) {
                this.samplePath(doc, iteration);
                ++doc;
            }
            doc = 0;
            while (doc < this.numDocuments) {
                this.sampleTopics(doc);
                ++doc;
            }
            if (this.showProgress) {
                System.out.print(".");
                if (iteration % 50 == 0) {
                    System.out.println(" " + iteration);
                }
            }
            if (iteration % this.displayTopicsInterval == 0) {
                this.printNodes();
            }
            ++iteration;
        }
    }

    public void samplePath(int doc, int iteration) {
        int i;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        int level = this.numLevels - 1;
        while (level >= 0) {
            path[level] = node;
            node = node.parent;
            --level;
        }
        this.documentLeaves[doc].dropPath();
        TObjectDoubleHashMap nodeWeights = new TObjectDoubleHashMap();
        this.calculateNCRP((TObjectDoubleHashMap<NCRPNode>)nodeWeights, this.rootNode, 0.0);
        TIntIntHashMap[] typeCounts = new TIntIntHashMap[this.numLevels];
        level = 0;
        while (level < this.numLevels) {
            typeCounts[level] = new TIntIntHashMap();
            ++level;
        }
        int[] docLevels = this.levels[doc];
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        int token = 0;
        while (token < docLevels.length) {
            level = docLevels[token];
            int type = fs.getIndexAtPosition(token);
            if (!typeCounts[level].containsKey(type)) {
                typeCounts[level].put(type, 1);
            } else {
                typeCounts[level].increment(type);
            }
            int n = type;
            path[level].typeCounts[n] = path[level].typeCounts[n] - 1;
            assert (path[level].typeCounts[type] >= 0);
            --path[level].totalTokens;
            assert (path[level].totalTokens >= 0);
            ++token;
        }
        double[] newTopicWeights = new double[this.numLevels];
        level = 1;
        while (level < this.numLevels) {
            int[] types = typeCounts[level].keys();
            int totalTokens = 0;
            int[] nArray = types;
            int n = types.length;
            int n2 = 0;
            while (n2 < n) {
                int t = nArray[n2];
                i = 0;
                while (i < typeCounts[level].get(t)) {
                    int n3 = level;
                    newTopicWeights[n3] = newTopicWeights[n3] + Math.log((this.eta + (double)i) / (this.etaSum + (double)totalTokens));
                    ++totalTokens;
                    ++i;
                }
                ++n2;
            }
            ++level;
        }
        this.calculateWordLikelihood((TObjectDoubleHashMap<NCRPNode>)nodeWeights, this.rootNode, 0.0, typeCounts, newTopicWeights, 0, iteration);
        NCRPNode[] nodes = (NCRPNode[])nodeWeights.keys((Object[])new NCRPNode[0]);
        double[] weights = new double[nodes.length];
        double sum = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        i = 0;
        while (i < nodes.length) {
            if (nodeWeights.get((Object)nodes[i]) > max) {
                max = nodeWeights.get((Object)nodes[i]);
            }
            ++i;
        }
        i = 0;
        while (i < nodes.length) {
            weights[i] = Math.exp(nodeWeights.get((Object)nodes[i]) - max);
            sum += weights[i];
            ++i;
        }
        node = nodes[this.random.nextDiscrete(weights, sum)];
        if (!node.isLeaf()) {
            node = node.getNewLeaf();
        }
        node.addPath();
        this.documentLeaves[doc] = node;
        level = this.numLevels - 1;
        while (level >= 0) {
            int[] types;
            int[] nArray = types = typeCounts[level].keys();
            int n = types.length;
            int n4 = 0;
            while (n4 < n) {
                int t;
                int n5 = t = nArray[n4];
                node.typeCounts[n5] = node.typeCounts[n5] + typeCounts[level].get(t);
                node.totalTokens += typeCounts[level].get(t);
                ++n4;
            }
            node = node.parent;
            --level;
        }
    }

    public void calculateNCRP(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        for (NCRPNode child : node.children) {
            this.calculateNCRP(nodeWeights, child, weight + Math.log((double)child.customers / ((double)node.customers + this.gamma)));
        }
        nodeWeights.put((Object)node, weight + Math.log(this.gamma / ((double)node.customers + this.gamma)));
    }

    public void calculateWordLikelihood(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight, TIntIntHashMap[] typeCounts, double[] newTopicWeights, int level, int iteration) {
        double nodeWeight = 0.0;
        int[] types = typeCounts[level].keys();
        int totalTokens = 0;
        int[] nArray = types;
        int n = types.length;
        int n2 = 0;
        while (n2 < n) {
            int type = nArray[n2];
            int i = 0;
            while (i < typeCounts[level].get(type)) {
                nodeWeight += Math.log((this.eta + (double)node.typeCounts[type] + (double)i) / (this.etaSum + (double)node.totalTokens + (double)totalTokens));
                ++totalTokens;
                ++i;
            }
            ++n2;
        }
        for (NCRPNode child : node.children) {
            this.calculateWordLikelihood(nodeWeights, child, weight + nodeWeight, typeCounts, newTopicWeights, level + 1, iteration);
        }
        ++level;
        while (level < this.numLevels) {
            nodeWeight += newTopicWeights[level];
            ++level;
        }
        nodeWeights.adjustValue((Object)node, nodeWeight);
    }

    public void propagateTopicWeight(TObjectDoubleHashMap<NCRPNode> nodeWeights, NCRPNode node, double weight) {
        if (!nodeWeights.containsKey((Object)node)) {
            return;
        }
        for (NCRPNode child : node.children) {
            this.propagateTopicWeight(nodeWeights, child, weight);
        }
        nodeWeights.adjustValue((Object)node, weight);
    }

    public void sampleTopics(int doc) {
        FeatureSequence fs = (FeatureSequence)((Instance)this.instances.get(doc)).getData();
        int seqLen = fs.getLength();
        int[] docLevels = this.levels[doc];
        NCRPNode[] path = new NCRPNode[this.numLevels];
        int[] levelCounts = new int[this.numLevels];
        NCRPNode node = this.documentLeaves[doc];
        int level = this.numLevels - 1;
        while (level >= 0) {
            path[level] = node;
            node = node.parent;
            --level;
        }
        double[] levelWeights = new double[this.numLevels];
        int token = 0;
        while (token < seqLen) {
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] + 1;
            ++token;
        }
        token = 0;
        while (token < seqLen) {
            int type = fs.getIndexAtPosition(token);
            int n = docLevels[token];
            levelCounts[n] = levelCounts[n] - 1;
            node = path[docLevels[token]];
            int n2 = type;
            node.typeCounts[n2] = node.typeCounts[n2] - 1;
            --node.totalTokens;
            double sum = 0.0;
            level = 0;
            while (level < this.numLevels) {
                levelWeights[level] = (this.alpha + (double)levelCounts[level]) * (this.eta + (double)path[level].typeCounts[type]) / (this.etaSum + (double)path[level].totalTokens);
                sum += levelWeights[level];
                ++level;
            }
            docLevels[token] = level = this.random.nextDiscrete(levelWeights, sum);
            int n3 = docLevels[token];
            levelCounts[n3] = levelCounts[n3] + 1;
            node = path[level];
            int n4 = type;
            node.typeCounts[n4] = node.typeCounts[n4] + 1;
            ++node.totalTokens;
            ++token;
        }
    }

    public void printState() throws IOException, FileNotFoundException {
        this.printState(new PrintWriter(new BufferedWriter(new FileWriter(this.stateFile))));
    }

    public void printState(PrintWriter out) throws IOException {
        int doc = 0;
        Alphabet alphabet = this.instances.getDataAlphabet();
        for (Instance instance : this.instances) {
            FeatureSequence fs = (FeatureSequence)instance.getData();
            int seqLen = fs.getLength();
            int[] docLevels = this.levels[doc];
            StringBuffer path = new StringBuffer();
            NCRPNode node = this.documentLeaves[doc];
            int level = this.numLevels - 1;
            while (level >= 0) {
                path.append(String.valueOf(node.nodeID) + " ");
                node = node.parent;
                --level;
            }
            int token = 0;
            while (token < seqLen) {
                int type = fs.getIndexAtPosition(token);
                level = docLevels[token];
                out.println(path + type + " " + alphabet.lookupObject(type) + " " + level + " ");
                ++token;
            }
            ++doc;
        }
    }

    public void printNodes() {
        this.printNode(this.rootNode, 0, false);
    }

    public void printNodes(boolean withWeight) {
        this.printNode(this.rootNode, 0, withWeight);
    }

    public void printNode(NCRPNode node, int indent, boolean withWeight) {
        StringBuffer out = new StringBuffer();
        int i = 0;
        while (i < indent) {
            out.append("  ");
            ++i;
        }
        out.append(String.valueOf(node.totalTokens) + "/" + node.customers + " ");
        out.append(node.getTopWords(this.numWordsToDisplay, withWeight));
        System.out.println(out);
        for (NCRPNode child : node.children) {
            this.printNode(child, indent + 1, withWeight);
        }
    }

    public double empiricalLikelihood(int numSamples, InstanceList testing) {
        int doc;
        NCRPNode[] path = new NCRPNode[this.numLevels];
        path[0] = this.rootNode;
        Dirichlet dirichlet = new Dirichlet(this.numLevels, this.alpha);
        double[] multinomial = new double[this.numTypes];
        double[][] likelihoods = new double[testing.size()][numSamples];
        int sample = 0;
        while (sample < numSamples) {
            Arrays.fill(multinomial, 0.0);
            int level = 1;
            while (level < this.numLevels) {
                path[level] = path[level - 1].selectExisting();
                ++level;
            }
            double[] levelWeights = dirichlet.nextDistribution();
            int type = 0;
            while (type < this.numTypes) {
                level = 0;
                while (level < this.numLevels) {
                    NCRPNode node = path[level];
                    int n = type;
                    multinomial[n] = multinomial[n] + levelWeights[level] * (this.eta + (double)node.typeCounts[type]) / (this.etaSum + (double)node.totalTokens);
                    ++level;
                }
                ++type;
            }
            type = 0;
            while (type < this.numTypes) {
                multinomial[type] = Math.log(multinomial[type]);
                ++type;
            }
            doc = 0;
            while (doc < testing.size()) {
                FeatureSequence fs = (FeatureSequence)((Instance)testing.get(doc)).getData();
                int seqLen = fs.getLength();
                int token = 0;
                while (token < seqLen) {
                    type = fs.getIndexAtPosition(token);
                    double[] dArray = likelihoods[doc];
                    int n = sample;
                    dArray[n] = dArray[n] + multinomial[type];
                    ++token;
                }
                ++doc;
            }
            ++sample;
        }
        double averageLogLikelihood = 0.0;
        double logNumSamples = Math.log(numSamples);
        doc = 0;
        while (doc < testing.size()) {
            double max = Double.NEGATIVE_INFINITY;
            sample = 0;
            while (sample < numSamples) {
                if (likelihoods[doc][sample] > max) {
                    max = likelihoods[doc][sample];
                }
                ++sample;
            }
            double sum = 0.0;
            sample = 0;
            while (sample < numSamples) {
                sum += Math.exp(likelihoods[doc][sample] - max);
                ++sample;
            }
            averageLogLikelihood += Math.log(sum) + max - logNumSamples;
            ++doc;
        }
        return averageLogLikelihood;
    }

    public static void main(String[] args) {
        try {
            InstanceList instances = InstanceList.load(new File(args[0]));
            InstanceList testing = InstanceList.load(new File(args[1]));
            HierarchicalLDA sampler = new HierarchicalLDA();
            sampler.initialize(instances, testing, 5, new Randoms());
            sampler.estimate(250);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    class NCRPNode {
        int customers = 0;
        ArrayList<NCRPNode> children;
        NCRPNode parent;
        int level;
        int totalTokens;
        int[] typeCounts;
        public int nodeID;

        public NCRPNode(NCRPNode parent, int dimensions, int level) {
            this.parent = parent;
            this.children = new ArrayList();
            this.level = level;
            this.totalTokens = 0;
            this.typeCounts = new int[dimensions];
            this.nodeID = HierarchicalLDA.this.totalNodes++;
        }

        public NCRPNode(int dimensions) {
            this(null, dimensions, 0);
        }

        public NCRPNode addChild() {
            NCRPNode node = new NCRPNode(this, this.typeCounts.length, this.level + 1);
            this.children.add(node);
            return node;
        }

        public boolean isLeaf() {
            return this.level == HierarchicalLDA.this.numLevels - 1;
        }

        public NCRPNode getNewLeaf() {
            NCRPNode node = this;
            int l = this.level;
            while (l < HierarchicalLDA.this.numLevels - 1) {
                node = node.addChild();
                ++l;
            }
            return node;
        }

        public void dropPath() {
            NCRPNode node = this;
            --node.customers;
            if (node.customers == 0) {
                node.parent.remove(node);
            }
            int l = 1;
            while (l < HierarchicalLDA.this.numLevels) {
                node = node.parent;
                --node.customers;
                if (node.customers == 0) {
                    node.parent.remove(node);
                }
                ++l;
            }
        }

        public void remove(NCRPNode node) {
            this.children.remove(node);
        }

        public void addPath() {
            NCRPNode node = this;
            ++node.customers;
            int l = 1;
            while (l < HierarchicalLDA.this.numLevels) {
                node = node.parent;
                ++node.customers;
                ++l;
            }
        }

        public NCRPNode selectExisting() {
            double[] weights = new double[this.children.size()];
            int i = 0;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            return this.children.get(choice);
        }

        public NCRPNode select() {
            double[] weights = new double[this.children.size() + 1];
            weights[0] = HierarchicalLDA.this.gamma / (HierarchicalLDA.this.gamma + (double)this.customers);
            int i = 1;
            for (NCRPNode child : this.children) {
                weights[i] = (double)child.customers / (HierarchicalLDA.this.gamma + (double)this.customers);
                ++i;
            }
            int choice = HierarchicalLDA.this.random.nextDiscrete(weights);
            if (choice == 0) {
                return this.addChild();
            }
            return this.children.get(choice - 1);
        }

        public String getTopWords(int numWords, boolean withWeight) {
            Object[] sortedTypes = new IDSorter[HierarchicalLDA.this.numTypes];
            int type = 0;
            while (type < HierarchicalLDA.this.numTypes) {
                sortedTypes[type] = new IDSorter(type, this.typeCounts[type]);
                ++type;
            }
            Arrays.sort(sortedTypes);
            Alphabet alphabet = HierarchicalLDA.this.instances.getDataAlphabet();
            StringBuffer out = new StringBuffer();
            int i = 0;
            while (i < numWords) {
                if (withWeight) {
                    out.append(alphabet.lookupObject(((IDSorter)sortedTypes[i]).getID()) + ":" + ((IDSorter)sortedTypes[i]).getWeight() + " ");
                } else {
                    out.append(alphabet.lookupObject(((IDSorter)sortedTypes[i]).getID()) + " ");
                }
                ++i;
            }
            return out.toString();
        }
    }
}

