/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;

public class PAM4L {
    int numSuperTopics;
    int numSubTopics;
    double[] alpha;
    double alphaSum;
    double[][] subAlphas;
    double[] subAlphaSums;
    double beta;
    double vBeta;
    InstanceList ilist;
    int numTypes;
    int numTokens;
    int[][] superTopics;
    int[][] subTopics;
    int[][] superSubCounts;
    int[] superCounts;
    double[] superWeights;
    double[] subWeights;
    double[][] superSubWeights;
    double[] cumulativeSuperWeights;
    int[][] typeSubTopicCounts;
    int[] tokensPerSubTopic;
    int[] tokensPerSuperTopic;
    int[][] tokensPerSuperSubTopic;
    int[][] superTopicHistograms;
    int[][][] subTopicHistograms;
    Runtime runtime;
    NumberFormat formatter = NumberFormat.getInstance();

    public PAM4L(int superTopics, int subTopics) {
        this(superTopics, subTopics, 50.0, 0.001);
    }

    public PAM4L(int superTopics, int subTopics, double alphaSum, double beta) {
        this.formatter.setMaximumFractionDigits(5);
        this.numSuperTopics = superTopics;
        this.numSubTopics = subTopics;
        this.alphaSum = alphaSum;
        this.alpha = new double[superTopics];
        Arrays.fill(this.alpha, alphaSum / (double)this.numSuperTopics);
        this.subAlphas = new double[superTopics][subTopics];
        this.subAlphaSums = new double[superTopics];
        int superTopic = 0;
        while (superTopic < superTopics) {
            Arrays.fill(this.subAlphas[superTopic], 1.0);
            ++superTopic;
        }
        Arrays.fill(this.subAlphaSums, (double)subTopics);
        this.beta = beta;
        this.runtime = Runtime.getRuntime();
    }

    public void estimate(InstanceList documents, int numIterations, int optimizeInterval, int showTopicsInterval, int outputModelInterval, String outputModelFilename, Randoms r) {
        int subTopic;
        int superTopic;
        this.ilist = documents;
        this.numTypes = this.ilist.getDataAlphabet().size();
        int numDocs = this.ilist.size();
        this.superTopics = new int[numDocs][];
        this.subTopics = new int[numDocs][];
        this.superSubCounts = new int[this.numSuperTopics][this.numSubTopics];
        this.superCounts = new int[this.numSuperTopics];
        this.superWeights = new double[this.numSuperTopics];
        this.subWeights = new double[this.numSubTopics];
        this.superSubWeights = new double[this.numSuperTopics][this.numSubTopics];
        this.cumulativeSuperWeights = new double[this.numSuperTopics];
        this.typeSubTopicCounts = new int[this.numTypes][this.numSubTopics];
        this.tokensPerSubTopic = new int[this.numSubTopics];
        this.tokensPerSuperTopic = new int[this.numSuperTopics];
        this.tokensPerSuperSubTopic = new int[this.numSuperTopics][this.numSubTopics];
        this.vBeta = this.beta * (double)this.numTypes;
        long startTime = System.currentTimeMillis();
        int maxTokens = 0;
        int di = 0;
        while (di < numDocs) {
            FeatureSequence fs = (FeatureSequence)((Instance)this.ilist.get(di)).getData();
            int seqLen = fs.getLength();
            if (seqLen > maxTokens) {
                maxTokens = seqLen;
            }
            this.numTokens += seqLen;
            this.superTopics[di] = new int[seqLen];
            this.subTopics[di] = new int[seqLen];
            int si = 0;
            while (si < seqLen) {
                this.superTopics[di][si] = superTopic = r.nextInt(this.numSuperTopics);
                int n = superTopic;
                this.tokensPerSuperTopic[n] = this.tokensPerSuperTopic[n] + 1;
                this.subTopics[di][si] = subTopic = r.nextInt(this.numSubTopics);
                int[] nArray = this.typeSubTopicCounts[fs.getIndexAtPosition(si)];
                int n2 = subTopic;
                nArray[n2] = nArray[n2] + 1;
                int n3 = subTopic;
                this.tokensPerSubTopic[n3] = this.tokensPerSubTopic[n3] + 1;
                int[] nArray2 = this.tokensPerSuperSubTopic[superTopic];
                int n4 = subTopic;
                nArray2[n4] = nArray2[n4] + 1;
                ++si;
            }
            ++di;
        }
        System.out.println("max tokens: " + maxTokens);
        this.superTopicHistograms = new int[this.numSuperTopics][maxTokens + 1];
        this.subTopicHistograms = new int[this.numSuperTopics][this.numSubTopics][maxTokens + 1];
        int iterations = 0;
        while (iterations < numIterations) {
            long iterationStart = System.currentTimeMillis();
            this.clearHistograms();
            this.sampleTopicsForAllDocs(r);
            if (iterations > 0) {
                if (showTopicsInterval != 0 && iterations % showTopicsInterval == 0) {
                    System.out.println();
                    this.printTopWords(5, false);
                }
                if (outputModelInterval != 0) {
                    int cfr_ignored_0 = iterations % outputModelInterval;
                }
                if (optimizeInterval != 0 && iterations % optimizeInterval == 0) {
                    long optimizeTime = System.currentTimeMillis();
                    superTopic = 0;
                    while (superTopic < this.numSuperTopics) {
                        this.learnParameters(this.subAlphas[superTopic], this.subTopicHistograms[superTopic], this.superTopicHistograms[superTopic]);
                        this.subAlphaSums[superTopic] = 0.0;
                        subTopic = 0;
                        while (subTopic < this.numSubTopics) {
                            int n = superTopic;
                            this.subAlphaSums[n] = this.subAlphaSums[n] + this.subAlphas[superTopic][subTopic];
                            ++subTopic;
                        }
                        ++superTopic;
                    }
                    System.out.print("[o:" + (System.currentTimeMillis() - optimizeTime) + "]");
                }
            }
            if (iterations > 1107) {
                this.printWordCounts();
            }
            if (iterations % 10 == 0) {
                System.out.println("<" + iterations + "> ");
            }
            System.out.print(String.valueOf(System.currentTimeMillis() - iterationStart) + " ");
            System.out.flush();
            ++iterations;
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    private void clearHistograms() {
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            Arrays.fill(this.superTopicHistograms[superTopic], 0);
            int subTopic = 0;
            while (subTopic < this.numSubTopics) {
                Arrays.fill(this.subTopicHistograms[superTopic][subTopic], 0);
                ++subTopic;
            }
            ++superTopic;
        }
    }

    public void learnParameters(double[] parameters, int[][] observations, int[] observationLengths) {
        int[] histogram;
        double parametersSum = 0.0;
        int k = 0;
        while (k < parameters.length) {
            parametersSum += parameters[k];
            ++k;
        }
        int[] nonZeroLimits = new int[observations.length];
        Arrays.fill(nonZeroLimits, -1);
        int i = 0;
        while (i < observations.length) {
            histogram = observations[i];
            k = 0;
            while (k < histogram.length) {
                if (histogram[k] > 0) {
                    nonZeroLimits[i] = k;
                }
                ++k;
            }
            ++i;
        }
        int iteration = 0;
        while (iteration < 200) {
            double denominator = 0.0;
            double currentDigamma = 0.0;
            i = 1;
            while (i < observationLengths.length) {
                denominator += (double)observationLengths[i] * (currentDigamma += 1.0 / (parametersSum + (double)i - 1.0));
                ++i;
            }
            parametersSum = 0.0;
            k = 0;
            while (k < parameters.length) {
                int nonZeroLimit = nonZeroLimits[k];
                if (nonZeroLimit == -1) {
                    parameters[k] = 1.0E-6;
                    parametersSum += 1.0E-6;
                } else {
                    double oldParametersK = parameters[k];
                    parameters[k] = 0.0;
                    currentDigamma = 0.0;
                    histogram = observations[k];
                    i = 1;
                    while (i <= nonZeroLimit) {
                        int n = k;
                        parameters[n] = parameters[n] + (double)histogram[i] * (currentDigamma += 1.0 / (oldParametersK + (double)i - 1.0));
                        ++i;
                    }
                    int n = k;
                    parameters[n] = parameters[n] * (oldParametersK / denominator);
                    if (Double.isNaN(parameters[k])) {
                        System.out.println("parametersK *= " + oldParametersK + " / " + denominator);
                        i = 1;
                        while (i < histogram.length) {
                            System.out.print(String.valueOf(histogram[i]) + " ");
                            ++i;
                        }
                        System.out.println();
                    }
                    parametersSum += parameters[k];
                }
                ++k;
            }
            ++iteration;
        }
    }

    private void sampleTopicsForAllDocs(Randoms r) {
        int di = 0;
        while (di < this.superTopics.length) {
            this.sampleTopicsForOneDoc((FeatureSequence)((Instance)this.ilist.get(di)).getData(), this.superTopics[di], this.subTopics[di], r);
            ++di;
        }
    }

    private void sampleTopicsForOneDoc(FeatureSequence oneDocTokens, int[] superTopics, int[] subTopics, Randoms r) {
        int[] currentSuperSubCounts;
        int subTopic;
        int superTopic;
        int docLen = oneDocTokens.getLength();
        int t = 0;
        while (t < this.numSuperTopics) {
            Arrays.fill(this.superSubCounts[t], 0);
            ++t;
        }
        Arrays.fill(this.superCounts, 0);
        int si = 0;
        while (si < docLen) {
            int[] nArray = this.superSubCounts[superTopics[si]];
            int n = subTopics[si];
            nArray[n] = nArray[n] + 1;
            int n2 = superTopics[si];
            this.superCounts[n2] = this.superCounts[n2] + 1;
            ++si;
        }
        si = 0;
        while (si < docLen) {
            double[] currentSuperSubWeights;
            int type = oneDocTokens.getIndexAtPosition(si);
            superTopic = superTopics[si];
            subTopic = subTopics[si];
            int[] nArray = this.superSubCounts[superTopic];
            int n = subTopic;
            nArray[n] = nArray[n] - 1;
            int n3 = superTopic;
            this.superCounts[n3] = this.superCounts[n3] - 1;
            int[] nArray2 = this.typeSubTopicCounts[type];
            int n4 = subTopic;
            nArray2[n4] = nArray2[n4] - 1;
            int n5 = superTopic;
            this.tokensPerSuperTopic[n5] = this.tokensPerSuperTopic[n5] - 1;
            int n6 = subTopic;
            this.tokensPerSubTopic[n6] = this.tokensPerSubTopic[n6] - 1;
            int[] nArray3 = this.tokensPerSuperSubTopic[superTopic];
            int n7 = subTopic;
            nArray3[n7] = nArray3[n7] - 1;
            int t2 = 0;
            while (t2 < this.numSuperTopics) {
                Arrays.fill(this.superSubWeights[t2], 0.0);
                ++t2;
            }
            Arrays.fill(this.superWeights, 0.0);
            Arrays.fill(this.subWeights, 0.0);
            Arrays.fill(this.cumulativeSuperWeights, 0.0);
            int[] currentTypeSubTopicCounts = this.typeSubTopicCounts[type];
            superTopic = 0;
            while (superTopic < this.numSuperTopics) {
                this.superWeights[superTopic] = ((double)this.superCounts[superTopic] + this.alpha[superTopic]) / ((double)this.superCounts[superTopic] + this.subAlphaSums[superTopic]);
                ++superTopic;
            }
            subTopic = 0;
            while (subTopic < this.numSubTopics) {
                this.subWeights[subTopic] = ((double)currentTypeSubTopicCounts[subTopic] + this.beta) / ((double)this.tokensPerSubTopic[subTopic] + this.vBeta);
                ++subTopic;
            }
            double cumulativeWeight = 0.0;
            superTopic = 0;
            while (superTopic < this.numSuperTopics) {
                currentSuperSubWeights = this.superSubWeights[superTopic];
                currentSuperSubCounts = this.superSubCounts[superTopic];
                double[] currentSubAlpha = this.subAlphas[superTopic];
                double currentSuperWeight = this.superWeights[superTopic];
                subTopic = 0;
                while (subTopic < this.numSubTopics) {
                    currentSuperSubWeights[subTopic] = currentSuperWeight * this.subWeights[subTopic] * ((double)currentSuperSubCounts[subTopic] + currentSubAlpha[subTopic]);
                    cumulativeWeight += currentSuperSubWeights[subTopic];
                    ++subTopic;
                }
                this.cumulativeSuperWeights[superTopic] = cumulativeWeight;
                ++superTopic;
            }
            double sample = r.nextUniform() * cumulativeWeight;
            superTopic = 0;
            while (sample > this.cumulativeSuperWeights[superTopic]) {
                ++superTopic;
            }
            currentSuperSubWeights = this.superSubWeights[superTopic];
            cumulativeWeight = this.cumulativeSuperWeights[superTopic] - currentSuperSubWeights[0];
            subTopic = 0;
            while (sample < cumulativeWeight) {
                cumulativeWeight -= currentSuperSubWeights[++subTopic];
            }
            superTopics[si] = superTopic;
            subTopics[si] = subTopic;
            int[] nArray4 = this.superSubCounts[superTopic];
            int n8 = subTopic;
            nArray4[n8] = nArray4[n8] + 1;
            int n9 = superTopic;
            this.superCounts[n9] = this.superCounts[n9] + 1;
            int[] nArray5 = this.typeSubTopicCounts[type];
            int n10 = subTopic;
            nArray5[n10] = nArray5[n10] + 1;
            int n11 = superTopic;
            this.tokensPerSuperTopic[n11] = this.tokensPerSuperTopic[n11] + 1;
            int n12 = subTopic;
            this.tokensPerSubTopic[n12] = this.tokensPerSubTopic[n12] + 1;
            int[] nArray6 = this.tokensPerSuperSubTopic[superTopic];
            int n13 = subTopic;
            nArray6[n13] = nArray6[n13] + 1;
            ++si;
        }
        superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            int[] nArray = this.superTopicHistograms[superTopic];
            int n = this.superCounts[superTopic];
            nArray[n] = nArray[n] + 1;
            currentSuperSubCounts = this.superSubCounts[superTopic];
            subTopic = 0;
            while (subTopic < this.numSubTopics) {
                int[] nArray7 = this.subTopicHistograms[superTopic][subTopic];
                int n14 = currentSuperSubCounts[subTopic];
                nArray7[n14] = nArray7[n14] + 1;
                ++subTopic;
            }
            ++superTopic;
        }
    }

    public void printWordCounts() {
        StringBuffer output = new StringBuffer();
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            int subTopic = 0;
            while (subTopic < this.numSubTopics) {
                output.append(String.valueOf(this.tokensPerSuperSubTopic[superTopic][subTopic]) + " (" + this.formatter.format(this.subAlphas[superTopic][subTopic]) + ")\t");
                ++subTopic;
            }
            output.append("\n");
            ++superTopic;
        }
        System.out.println(output);
    }

    public void printTopWords(int numWords, boolean useNewLines) {
        int i;
        Object[] wp = new IDSorter[this.numTypes];
        Object[] sortedSubTopics = new IDSorter[this.numSubTopics];
        String[] subTopicTerms = new String[this.numSubTopics];
        int subTopic = 0;
        while (subTopic < this.numSubTopics) {
            int wi = 0;
            while (wi < this.numTypes) {
                wp[wi] = new IDSorter(wi, (double)this.typeSubTopicCounts[wi][subTopic] / (double)this.tokensPerSubTopic[subTopic]);
                ++wi;
            }
            Arrays.sort(wp);
            StringBuffer topicTerms = new StringBuffer();
            i = 0;
            while (i < numWords) {
                topicTerms.append(this.ilist.getDataAlphabet().lookupObject(((IDSorter)wp[i]).wi));
                topicTerms.append(" ");
                ++i;
            }
            subTopicTerms[subTopic] = topicTerms.toString();
            if (useNewLines) {
                System.out.println("\nTopic " + subTopic);
                i = 0;
                while (i < numWords) {
                    System.out.println(String.valueOf(this.ilist.getDataAlphabet().lookupObject(((IDSorter)wp[i]).wi).toString()) + "\t" + this.formatter.format(((IDSorter)wp[i]).p));
                    ++i;
                }
            } else {
                System.out.println("Topic " + subTopic + ":\t[" + this.tokensPerSubTopic[subTopic] + "]\t" + subTopicTerms[subTopic]);
            }
            ++subTopic;
        }
        int maxSubTopics = 10;
        if (this.numSubTopics < 10) {
            maxSubTopics = this.numSubTopics;
        }
        int superTopic = 0;
        while (superTopic < this.numSuperTopics) {
            subTopic = 0;
            while (subTopic < this.numSubTopics) {
                sortedSubTopics[subTopic] = new IDSorter(subTopic, this.subAlphas[superTopic][subTopic]);
                ++subTopic;
            }
            Arrays.sort(sortedSubTopics);
            System.out.println("\nSuper-topic " + superTopic + "[" + this.tokensPerSuperTopic[superTopic] + "]\t");
            i = 0;
            while (i < maxSubTopics) {
                subTopic = ((IDSorter)sortedSubTopics[i]).wi;
                System.out.println(String.valueOf(subTopic) + ":\t" + this.formatter.format(this.subAlphas[superTopic][subTopic]) + "\t" + subTopicTerms[subTopic]);
                ++i;
            }
            ++superTopic;
        }
    }

    public void printDocumentTopics(File f) throws IOException {
        this.printDocumentTopics(new PrintWriter(new BufferedWriter(new FileWriter(f))), 0.0, -1);
    }

    public void printDocumentTopics(PrintWriter pw, double threshold, int max) {
        pw.println("#doc source subtopic-proportions , supertopic-proportions");
        double[] superTopicDist = new double[this.numSuperTopics];
        double[] subTopicDist = new double[this.numSubTopics];
        int di = 0;
        while (di < this.superTopics.length) {
            int ti;
            int maxindex;
            double maxvalue;
            pw.print(di);
            pw.print(' ');
            int docLen = this.superTopics[di].length;
            if (((Instance)this.ilist.get(di)).getSource() != null) {
                pw.print(((Instance)this.ilist.get(di)).getSource().toString());
            } else {
                pw.print("null-source");
            }
            pw.print(' ');
            docLen = this.subTopics[di].length;
            int si = 0;
            while (si < docLen) {
                int n = this.superTopics[di][si];
                superTopicDist[n] = superTopicDist[n] + 1.0;
                int n2 = this.subTopics[di][si];
                subTopicDist[n2] = subTopicDist[n2] + 1.0;
                ++si;
            }
            int ti2 = 0;
            while (ti2 < this.numSuperTopics) {
                int n = ti2++;
                superTopicDist[n] = superTopicDist[n] / (double)docLen;
            }
            ti2 = 0;
            while (ti2 < this.numSubTopics) {
                int n = ti2++;
                subTopicDist[n] = subTopicDist[n] / (double)docLen;
            }
            if (max < 0) {
                max = this.numSubTopics;
            }
            int tp = 0;
            while (tp < max) {
                maxvalue = 0.0;
                maxindex = -1;
                ti = 0;
                while (ti < this.numSubTopics) {
                    if (subTopicDist[ti] > maxvalue) {
                        maxvalue = subTopicDist[ti];
                        maxindex = ti;
                    }
                    ++ti;
                }
                if (maxindex == -1 || subTopicDist[maxindex] < threshold) break;
                pw.print(String.valueOf(maxindex) + " " + subTopicDist[maxindex] + " ");
                subTopicDist[maxindex] = 0.0;
                ++tp;
            }
            pw.print(" , ");
            if (max < 0) {
                max = this.numSuperTopics;
            }
            tp = 0;
            while (tp < max) {
                maxvalue = 0.0;
                maxindex = -1;
                ti = 0;
                while (ti < this.numSuperTopics) {
                    if (superTopicDist[ti] > maxvalue) {
                        maxvalue = superTopicDist[ti];
                        maxindex = ti;
                    }
                    ++ti;
                }
                if (maxindex == -1 || superTopicDist[maxindex] < threshold) break;
                pw.print(String.valueOf(maxindex) + " " + superTopicDist[maxindex] + " ");
                superTopicDist[maxindex] = 0.0;
                ++tp;
            }
            pw.println();
            ++di;
        }
    }

    public void printState(File f) throws IOException {
        this.printState(new PrintWriter(new BufferedWriter(new FileWriter(f))));
    }

    public void printState(PrintWriter pw) {
        Alphabet a = this.ilist.getDataAlphabet();
        pw.println("#doc pos typeindex type super-topic sub-topic");
        int di = 0;
        while (di < this.superTopics.length) {
            FeatureSequence fs = (FeatureSequence)((Instance)this.ilist.get(di)).getData();
            int si = 0;
            while (si < this.superTopics[di].length) {
                int type = fs.getIndexAtPosition(si);
                pw.print(di);
                pw.print(' ');
                pw.print(si);
                pw.print(' ');
                pw.print(type);
                pw.print(' ');
                pw.print(a.lookupObject(type));
                pw.print(' ');
                pw.print(this.superTopics[di][si]);
                pw.print(' ');
                pw.print(this.subTopics[di][si]);
                pw.println();
                ++si;
            }
            ++di;
        }
        pw.close();
    }

    public static void main(String[] args) throws IOException {
        InstanceList ilist = InstanceList.load(new File(args[0]));
        int numIterations = args.length > 1 ? Integer.parseInt(args[1]) : 1000;
        int numTopWords = args.length > 2 ? Integer.parseInt(args[2]) : 20;
        int numSuperTopics = args.length > 3 ? Integer.parseInt(args[3]) : 10;
        int numSubTopics = args.length > 4 ? Integer.parseInt(args[4]) : 10;
        System.out.println("Data loaded.");
        PAM4L pam = new PAM4L(numSuperTopics, numSubTopics);
        pam.estimate(ilist, numIterations, 50, 0, 50, null, new Randoms());
        pam.printTopWords(numTopWords, true);
    }

    class IDSorter
    implements Comparable {
        int wi;
        double p;

        public IDSorter(int wi, double p) {
            this.wi = wi;
            this.p = p;
        }

        public final int compareTo(Object o2) {
            if (this.p > ((IDSorter)o2).p) {
                return -1;
            }
            if (this.p == ((IDSorter)o2).p) {
                return 0;
            }
            return 1;
        }
    }
}

