package cc.mallet.topics;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.InstanceList;
import cc.mallet.util.Randoms;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import org.apache.jena.sparql.sse.Tags;
import org.apache.lucene.search.suggest.FileDictionary;

/* loaded from: input_file:cc/mallet/topics/PAM4L.class */
public class PAM4L {
    int numSuperTopics;
    int numSubTopics;
    double[] alpha;
    double alphaSum;
    double[][] subAlphas;
    double[] subAlphaSums;
    double beta;
    double vBeta;
    InstanceList ilist;
    int numTypes;
    int numTokens;
    int[][] superTopics;
    int[][] subTopics;
    int[][] superSubCounts;
    int[] superCounts;
    double[] superWeights;
    double[] subWeights;
    double[][] superSubWeights;
    double[] cumulativeSuperWeights;
    int[][] typeSubTopicCounts;
    int[] tokensPerSubTopic;
    int[] tokensPerSuperTopic;
    int[][] tokensPerSuperSubTopic;
    int[][] superTopicHistograms;
    int[][][] subTopicHistograms;
    Runtime runtime;
    NumberFormat formatter;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:cc/mallet/topics/PAM4L$IDSorter.class */
    public class IDSorter implements Comparable {
        int wi;
        double p;

        public IDSorter(int i, double d) {
            this.wi = i;
            this.p = d;
        }

        @Override // java.lang.Comparable
        public final int compareTo(Object obj) {
            if (this.p > ((IDSorter) obj).p) {
                return -1;
            }
            return this.p == ((IDSorter) obj).p ? 0 : 1;
        }
    }

    public PAM4L(int i, int i2) {
        this(i, i2, 50.0d, 0.001d);
    }

    public PAM4L(int i, int i2, double d, double d2) {
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(5);
        this.numSuperTopics = i;
        this.numSubTopics = i2;
        this.alphaSum = d;
        this.alpha = new double[i];
        Arrays.fill(this.alpha, d / this.numSuperTopics);
        this.subAlphas = new double[i][i2];
        this.subAlphaSums = new double[i];
        for (int i3 = 0; i3 < i; i3++) {
            Arrays.fill(this.subAlphas[i3], 1.0d);
        }
        Arrays.fill(this.subAlphaSums, i2);
        this.beta = d2;
        this.runtime = Runtime.getRuntime();
    }

    /* JADX WARN: Type inference failed for: r1v6, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [int[], int[][]] */
    public void estimate(InstanceList instanceList, int i, int i2, int i3, int i4, String str, Randoms randoms) {
        this.ilist = instanceList;
        this.numTypes = this.ilist.getDataAlphabet().size();
        int size = this.ilist.size();
        this.superTopics = new int[size];
        this.subTopics = new int[size];
        this.superSubCounts = new int[this.numSuperTopics][this.numSubTopics];
        this.superCounts = new int[this.numSuperTopics];
        this.superWeights = new double[this.numSuperTopics];
        this.subWeights = new double[this.numSubTopics];
        this.superSubWeights = new double[this.numSuperTopics][this.numSubTopics];
        this.cumulativeSuperWeights = new double[this.numSuperTopics];
        this.typeSubTopicCounts = new int[this.numTypes][this.numSubTopics];
        this.tokensPerSubTopic = new int[this.numSubTopics];
        this.tokensPerSuperTopic = new int[this.numSuperTopics];
        this.tokensPerSuperSubTopic = new int[this.numSuperTopics][this.numSubTopics];
        this.vBeta = this.beta * this.numTypes;
        long currentTimeMillis = System.currentTimeMillis();
        int i5 = 0;
        for (int i6 = 0; i6 < size; i6++) {
            FeatureSequence featureSequence = (FeatureSequence) this.ilist.get(i6).getData();
            int length = featureSequence.getLength();
            if (length > i5) {
                i5 = length;
            }
            this.numTokens += length;
            this.superTopics[i6] = new int[length];
            this.subTopics[i6] = new int[length];
            for (int i7 = 0; i7 < length; i7++) {
                int nextInt = randoms.nextInt(this.numSuperTopics);
                this.superTopics[i6][i7] = nextInt;
                int[] iArr = this.tokensPerSuperTopic;
                iArr[nextInt] = iArr[nextInt] + 1;
                int nextInt2 = randoms.nextInt(this.numSubTopics);
                this.subTopics[i6][i7] = nextInt2;
                int[] iArr2 = this.typeSubTopicCounts[featureSequence.getIndexAtPosition(i7)];
                iArr2[nextInt2] = iArr2[nextInt2] + 1;
                int[] iArr3 = this.tokensPerSubTopic;
                iArr3[nextInt2] = iArr3[nextInt2] + 1;
                int[] iArr4 = this.tokensPerSuperSubTopic[nextInt];
                iArr4[nextInt2] = iArr4[nextInt2] + 1;
            }
        }
        System.out.println("max tokens: " + i5);
        this.superTopicHistograms = new int[this.numSuperTopics][i5 + 1];
        this.subTopicHistograms = new int[this.numSuperTopics][this.numSubTopics][i5 + 1];
        for (int i8 = 0; i8 < i; i8++) {
            long currentTimeMillis2 = System.currentTimeMillis();
            clearHistograms();
            sampleTopicsForAllDocs(randoms);
            if (i8 > 0) {
                if (i3 != 0 && i8 % i3 == 0) {
                    System.out.println();
                    printTopWords(5, false);
                }
                if (i4 == 0 || i8 % i4 == 0) {
                }
                if (i2 != 0 && i8 % i2 == 0) {
                    long currentTimeMillis3 = System.currentTimeMillis();
                    for (int i9 = 0; i9 < this.numSuperTopics; i9++) {
                        learnParameters(this.subAlphas[i9], this.subTopicHistograms[i9], this.superTopicHistograms[i9]);
                        this.subAlphaSums[i9] = 0.0d;
                        for (int i10 = 0; i10 < this.numSubTopics; i10++) {
                            double[] dArr = this.subAlphaSums;
                            int i11 = i9;
                            dArr[i11] = dArr[i11] + this.subAlphas[i9][i10];
                        }
                    }
                    System.out.print("[o:" + (System.currentTimeMillis() - currentTimeMillis3) + "]");
                }
            }
            if (i8 > 1107) {
                printWordCounts();
            }
            if (i8 % 10 == 0) {
                System.out.println(Tags.symLT + i8 + "> ");
            }
            System.out.print((System.currentTimeMillis() - currentTimeMillis2) + " ");
            System.out.flush();
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void clearHistograms() {
        for (int i = 0; i < this.numSuperTopics; i++) {
            Arrays.fill(this.superTopicHistograms[i], 0);
            for (int i2 = 0; i2 < this.numSubTopics; i2++) {
                Arrays.fill(this.subTopicHistograms[i][i2], 0);
            }
        }
    }

    public void learnParameters(double[] dArr, int[][] iArr, int[] iArr2) {
        double d;
        double d2;
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += d4;
        }
        int[] iArr3 = new int[iArr.length];
        Arrays.fill(iArr3, -1);
        for (int i = 0; i < iArr.length; i++) {
            int[] iArr4 = iArr[i];
            for (int i2 = 0; i2 < iArr4.length; i2++) {
                if (iArr4[i2] > 0) {
                    iArr3[i] = i2;
                }
            }
        }
        for (int i3 = 0; i3 < 200; i3++) {
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i4 = 1; i4 < iArr2.length; i4++) {
                d6 += 1.0d / ((d3 + i4) - 1.0d);
                d5 += iArr2[i4] * d6;
            }
            d3 = 0.0d;
            for (int i5 = 0; i5 < dArr.length; i5++) {
                int i6 = iArr3[i5];
                if (i6 == -1) {
                    dArr[i5] = 1.0E-6d;
                    d = d3;
                    d2 = 1.0E-6d;
                } else {
                    double d7 = dArr[i5];
                    dArr[i5] = 0.0d;
                    double d8 = 0.0d;
                    int[] iArr5 = iArr[i5];
                    for (int i7 = 1; i7 <= i6; i7++) {
                        d8 += 1.0d / ((d7 + i7) - 1.0d);
                        int i8 = i5;
                        dArr[i8] = dArr[i8] + (iArr5[i7] * d8);
                    }
                    int i9 = i5;
                    dArr[i9] = dArr[i9] * (d7 / d5);
                    if (Double.isNaN(dArr[i5])) {
                        System.out.println("parametersK *= " + d7 + " / " + d5);
                        for (int i10 = 1; i10 < iArr5.length; i10++) {
                            System.out.print(iArr5[i10] + " ");
                        }
                        System.out.println();
                    }
                    d = d3;
                    d2 = dArr[i5];
                }
                d3 = d + d2;
            }
        }
    }

    private void sampleTopicsForAllDocs(Randoms randoms) {
        for (int i = 0; i < this.superTopics.length; i++) {
            sampleTopicsForOneDoc((FeatureSequence) this.ilist.get(i).getData(), this.superTopics[i], this.subTopics[i], randoms);
        }
    }

    private void sampleTopicsForOneDoc(FeatureSequence featureSequence, int[] iArr, int[] iArr2, Randoms randoms) {
        int length = featureSequence.getLength();
        for (int i = 0; i < this.numSuperTopics; i++) {
            Arrays.fill(this.superSubCounts[i], 0);
        }
        Arrays.fill(this.superCounts, 0);
        for (int i2 = 0; i2 < length; i2++) {
            int[] iArr3 = this.superSubCounts[iArr[i2]];
            int i3 = iArr2[i2];
            iArr3[i3] = iArr3[i3] + 1;
            int[] iArr4 = this.superCounts;
            int i4 = iArr[i2];
            iArr4[i4] = iArr4[i4] + 1;
        }
        for (int i5 = 0; i5 < length; i5++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i5);
            int i6 = iArr[i5];
            int i7 = iArr2[i5];
            int[] iArr5 = this.superSubCounts[i6];
            iArr5[i7] = iArr5[i7] - 1;
            int[] iArr6 = this.superCounts;
            iArr6[i6] = iArr6[i6] - 1;
            int[] iArr7 = this.typeSubTopicCounts[indexAtPosition];
            iArr7[i7] = iArr7[i7] - 1;
            int[] iArr8 = this.tokensPerSuperTopic;
            iArr8[i6] = iArr8[i6] - 1;
            int[] iArr9 = this.tokensPerSubTopic;
            iArr9[i7] = iArr9[i7] - 1;
            int[] iArr10 = this.tokensPerSuperSubTopic[i6];
            iArr10[i7] = iArr10[i7] - 1;
            for (int i8 = 0; i8 < this.numSuperTopics; i8++) {
                Arrays.fill(this.superSubWeights[i8], 0.0d);
            }
            Arrays.fill(this.superWeights, 0.0d);
            Arrays.fill(this.subWeights, 0.0d);
            Arrays.fill(this.cumulativeSuperWeights, 0.0d);
            int[] iArr11 = this.typeSubTopicCounts[indexAtPosition];
            for (int i9 = 0; i9 < this.numSuperTopics; i9++) {
                this.superWeights[i9] = (this.superCounts[i9] + this.alpha[i9]) / (this.superCounts[i9] + this.subAlphaSums[i9]);
            }
            for (int i10 = 0; i10 < this.numSubTopics; i10++) {
                this.subWeights[i10] = (iArr11[i10] + this.beta) / (this.tokensPerSubTopic[i10] + this.vBeta);
            }
            double d = 0.0d;
            for (int i11 = 0; i11 < this.numSuperTopics; i11++) {
                double[] dArr = this.superSubWeights[i11];
                int[] iArr12 = this.superSubCounts[i11];
                double[] dArr2 = this.subAlphas[i11];
                double d2 = this.superWeights[i11];
                for (int i12 = 0; i12 < this.numSubTopics; i12++) {
                    dArr[i12] = d2 * this.subWeights[i12] * (iArr12[i12] + dArr2[i12]);
                    d += dArr[i12];
                }
                this.cumulativeSuperWeights[i11] = d;
            }
            double nextUniform = randoms.nextUniform() * d;
            int i13 = 0;
            while (nextUniform > this.cumulativeSuperWeights[i13]) {
                i13++;
            }
            double[] dArr3 = this.superSubWeights[i13];
            double d3 = this.cumulativeSuperWeights[i13] - dArr3[0];
            int i14 = 0;
            while (nextUniform < d3) {
                i14++;
                d3 -= dArr3[i14];
            }
            iArr[i5] = i13;
            iArr2[i5] = i14;
            int[] iArr13 = this.superSubCounts[i13];
            int i15 = i14;
            iArr13[i15] = iArr13[i15] + 1;
            int[] iArr14 = this.superCounts;
            int i16 = i13;
            iArr14[i16] = iArr14[i16] + 1;
            int[] iArr15 = this.typeSubTopicCounts[indexAtPosition];
            int i17 = i14;
            iArr15[i17] = iArr15[i17] + 1;
            int[] iArr16 = this.tokensPerSuperTopic;
            int i18 = i13;
            iArr16[i18] = iArr16[i18] + 1;
            int[] iArr17 = this.tokensPerSubTopic;
            int i19 = i14;
            iArr17[i19] = iArr17[i19] + 1;
            int[] iArr18 = this.tokensPerSuperSubTopic[i13];
            int i20 = i14;
            iArr18[i20] = iArr18[i20] + 1;
        }
        for (int i21 = 0; i21 < this.numSuperTopics; i21++) {
            int[] iArr19 = this.superTopicHistograms[i21];
            int i22 = this.superCounts[i21];
            iArr19[i22] = iArr19[i22] + 1;
            int[] iArr20 = this.superSubCounts[i21];
            for (int i23 = 0; i23 < this.numSubTopics; i23++) {
                int[] iArr21 = this.subTopicHistograms[i21][i23];
                int i24 = iArr20[i23];
                iArr21[i24] = iArr21[i24] + 1;
            }
        }
    }

    public void printWordCounts() {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.numSuperTopics; i++) {
            for (int i2 = 0; i2 < this.numSubTopics; i2++) {
                stringBuffer.append(this.tokensPerSuperSubTopic[i][i2] + " (" + this.formatter.format(this.subAlphas[i][i2]) + ")\t");
            }
            stringBuffer.append("\n");
        }
        System.out.println(stringBuffer);
    }

    public void printTopWords(int i, boolean z) {
        IDSorter[] iDSorterArr = new IDSorter[this.numTypes];
        IDSorter[] iDSorterArr2 = new IDSorter[this.numSubTopics];
        String[] strArr = new String[this.numSubTopics];
        for (int i2 = 0; i2 < this.numSubTopics; i2++) {
            for (int i3 = 0; i3 < this.numTypes; i3++) {
                iDSorterArr[i3] = new IDSorter(i3, this.typeSubTopicCounts[i3][i2] / this.tokensPerSubTopic[i2]);
            }
            Arrays.sort(iDSorterArr);
            StringBuffer stringBuffer = new StringBuffer();
            for (int i4 = 0; i4 < i; i4++) {
                stringBuffer.append(this.ilist.getDataAlphabet().lookupObject(iDSorterArr[i4].wi));
                stringBuffer.append(" ");
            }
            strArr[i2] = stringBuffer.toString();
            if (z) {
                System.out.println("\nTopic " + i2);
                for (int i5 = 0; i5 < i; i5++) {
                    System.out.println(this.ilist.getDataAlphabet().lookupObject(iDSorterArr[i5].wi).toString() + FileDictionary.DEFAULT_FIELD_DELIMITER + this.formatter.format(iDSorterArr[i5].p));
                }
            } else {
                System.out.println("Topic " + i2 + ":\t[" + this.tokensPerSubTopic[i2] + "]\t" + strArr[i2]);
            }
        }
        int i6 = this.numSubTopics < 10 ? this.numSubTopics : 10;
        for (int i7 = 0; i7 < this.numSuperTopics; i7++) {
            for (int i8 = 0; i8 < this.numSubTopics; i8++) {
                iDSorterArr2[i8] = new IDSorter(i8, this.subAlphas[i7][i8]);
            }
            Arrays.sort(iDSorterArr2);
            System.out.println("\nSuper-topic " + i7 + "[" + this.tokensPerSuperTopic[i7] + "]\t");
            for (int i9 = 0; i9 < i6; i9++) {
                int i10 = iDSorterArr2[i9].wi;
                System.out.println(i10 + ":\t" + this.formatter.format(this.subAlphas[i7][i10]) + FileDictionary.DEFAULT_FIELD_DELIMITER + strArr[i10]);
            }
        }
    }

    public void printDocumentTopics(File file) throws IOException {
        printDocumentTopics(new PrintWriter(new BufferedWriter(new FileWriter(file))), 0.0d, -1);
    }

    public void printDocumentTopics(PrintWriter printWriter, double d, int i) {
        printWriter.println("#doc source subtopic-proportions , supertopic-proportions");
        double[] dArr = new double[this.numSuperTopics];
        double[] dArr2 = new double[this.numSubTopics];
        for (int i2 = 0; i2 < this.superTopics.length; i2++) {
            printWriter.print(i2);
            printWriter.print(' ');
            int length = this.superTopics[i2].length;
            if (this.ilist.get(i2).getSource() != null) {
                printWriter.print(this.ilist.get(i2).getSource().toString());
            } else {
                printWriter.print("null-source");
            }
            printWriter.print(' ');
            int length2 = this.subTopics[i2].length;
            for (int i3 = 0; i3 < length2; i3++) {
                int i4 = this.superTopics[i2][i3];
                dArr[i4] = dArr[i4] + 1.0d;
                int i5 = this.subTopics[i2][i3];
                dArr2[i5] = dArr2[i5] + 1.0d;
            }
            for (int i6 = 0; i6 < this.numSuperTopics; i6++) {
                int i7 = i6;
                dArr[i7] = dArr[i7] / length2;
            }
            for (int i8 = 0; i8 < this.numSubTopics; i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] / length2;
            }
            if (i < 0) {
                i = this.numSubTopics;
            }
            for (int i10 = 0; i10 < i; i10++) {
                double d2 = 0.0d;
                int i11 = -1;
                for (int i12 = 0; i12 < this.numSubTopics; i12++) {
                    if (dArr2[i12] > d2) {
                        d2 = dArr2[i12];
                        i11 = i12;
                    }
                }
                if (i11 == -1 || dArr2[i11] < d) {
                    break;
                }
                printWriter.print(i11 + " " + dArr2[i11] + " ");
                dArr2[i11] = 0.0d;
            }
            printWriter.print(" , ");
            if (i < 0) {
                i = this.numSuperTopics;
            }
            for (int i13 = 0; i13 < i; i13++) {
                double d3 = 0.0d;
                int i14 = -1;
                for (int i15 = 0; i15 < this.numSuperTopics; i15++) {
                    if (dArr[i15] > d3) {
                        d3 = dArr[i15];
                        i14 = i15;
                    }
                }
                if (i14 != -1 && dArr[i14] >= d) {
                    printWriter.print(i14 + " " + dArr[i14] + " ");
                    dArr[i14] = 0.0d;
                }
                printWriter.println();
            }
            printWriter.println();
        }
    }

    public void printState(File file) throws IOException {
        printState(new PrintWriter(new BufferedWriter(new FileWriter(file))));
    }

    public void printState(PrintWriter printWriter) {
        Alphabet dataAlphabet = this.ilist.getDataAlphabet();
        printWriter.println("#doc pos typeindex type super-topic sub-topic");
        for (int i = 0; i < this.superTopics.length; i++) {
            FeatureSequence featureSequence = (FeatureSequence) this.ilist.get(i).getData();
            for (int i2 = 0; i2 < this.superTopics[i].length; i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                printWriter.print(i);
                printWriter.print(' ');
                printWriter.print(i2);
                printWriter.print(' ');
                printWriter.print(indexAtPosition);
                printWriter.print(' ');
                printWriter.print(dataAlphabet.lookupObject(indexAtPosition));
                printWriter.print(' ');
                printWriter.print(this.superTopics[i][i2]);
                printWriter.print(' ');
                printWriter.print(this.subTopics[i][i2]);
                printWriter.println();
            }
        }
        printWriter.close();
    }

    public static void main(String[] strArr) throws IOException {
        InstanceList load = InstanceList.load(new File(strArr[0]));
        int parseInt = strArr.length > 1 ? Integer.parseInt(strArr[1]) : 1000;
        int parseInt2 = strArr.length > 2 ? Integer.parseInt(strArr[2]) : 20;
        int parseInt3 = strArr.length > 3 ? Integer.parseInt(strArr[3]) : 10;
        int parseInt4 = strArr.length > 4 ? Integer.parseInt(strArr[4]) : 10;
        System.out.println("Data loaded.");
        PAM4L pam4l = new PAM4L(parseInt3, parseInt4);
        pam4l.estimate(load, parseInt, 50, 0, 50, null, new Randoms());
        pam4l.printTopWords(parseInt2, true);
    }
}
