package cc.mallet.topics;

import cc.mallet.topics.LDAHyper;
import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.Randoms;
import gnu.trove.TIntIntHashMap;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.zip.GZIPOutputStream;
import org.apache.lucene.search.suggest.FileDictionary;

/* loaded from: input_file:cc/mallet/topics/LDAStream.class */
public class LDAStream extends LDAHyper {
    protected ArrayList<LDAHyper.Topication> test;
    static final /* synthetic */ boolean $assertionsDisabled;

    public LDAStream(int i) {
        super(i);
    }

    public LDAStream(int i, double d, double d2) {
        super(i, d, d2);
    }

    public LDAStream(int i, double d, double d2, Randoms randoms) {
        super(i, d, d2, randoms);
    }

    public LDAStream(LabelAlphabet labelAlphabet, double d, double d2, Randoms randoms) {
        super(labelAlphabet, d, d2, randoms);
    }

    public ArrayList<LDAHyper.Topication> getTest() {
        return this.test;
    }

    public void inferenceAll(int i) {
        this.test = new ArrayList<>();
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it = this.testing.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[instanceLength(next)]);
            Randoms randoms = new Randoms();
            FeatureSequence featureSequence = (FeatureSequence) next.getData();
            int[] features = labelSequence.getFeatures();
            for (int i2 = 0; i2 < features.length; i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                features[i2] = randoms.nextInt(this.numTopics);
                this.typeTopicCounts[indexAtPosition].adjustOrPutValue(features[i2], 1, 1);
                int[] iArr = this.tokensPerTopic;
                int i3 = features[i2];
                iArr[i3] = iArr[i3] + 1;
            }
            arrayList.add(labelSequence);
        }
        if (!$assertionsDisabled && this.testing.size() != arrayList.size()) {
            throw new AssertionError();
        }
        for (int i4 = 0; i4 < this.testing.size(); i4++) {
            this.test.add(new LDAHyper.Topication(this.testing.get(i4), this, (LabelSequence) arrayList.get(i4)));
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i5 = 0; i5 <= i; i5++) {
            if (i5 % 100 == 0) {
                System.out.print("Iteration: " + i5);
                System.out.println();
            }
            int size = this.test.size();
            for (int i6 = 0; i6 < size; i6++) {
                sampleTopicsForOneTestDocAll((FeatureSequence) this.test.get(i6).instance.getData(), this.test.get(i6).topicSequence);
            }
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal inferencing time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void sampleTopicsForOneTestDocAll(FeatureSequence featureSequence, LabelSequence labelSequence) {
        int[] features = labelSequence.getFeatures();
        double[] dArr = new double[this.numTopics];
        int length = featureSequence.getLength();
        int[] iArr = new int[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            iArr[i] = 0;
        }
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = features[i2];
            iArr[i3] = iArr[i3] + 1;
        }
        for (int i4 = 0; i4 < length; i4++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i4);
            int i5 = features[i4];
            iArr[i5] = iArr[i5] - 1;
            TIntIntHashMap tIntIntHashMap = this.typeTopicCounts[indexAtPosition];
            if (!$assertionsDisabled && tIntIntHashMap.get(i5) < 0) {
                throw new AssertionError();
            }
            if (tIntIntHashMap.get(i5) == 1) {
                tIntIntHashMap.remove(i5);
            } else {
                tIntIntHashMap.adjustValue(i5, -1);
            }
            int[] iArr2 = this.tokensPerTopic;
            iArr2[i5] = iArr2[i5] - 1;
            Arrays.fill(dArr, 0.0d);
            double d = 0.0d;
            for (int i6 = 0; i6 < this.numTopics; i6++) {
                double d2 = ((tIntIntHashMap.get(i6) + this.beta) / (this.tokensPerTopic[i6] + this.betaSum)) * (iArr[i6] + this.alpha[i6]);
                d += d2;
                dArr[i6] = d2;
            }
            int nextDiscrete = this.random.nextDiscrete(dArr, d);
            features[i4] = nextDiscrete;
            tIntIntHashMap.adjustOrPutValue(nextDiscrete, 1, 1);
            iArr[nextDiscrete] = iArr[nextDiscrete] + 1;
            int[] iArr3 = this.tokensPerTopic;
            iArr3[nextDiscrete] = iArr3[nextDiscrete] + 1;
        }
    }

    public void estimateAll(int i) throws IOException {
        this.data.addAll(this.test);
        initializeHistogramsAndCachedValues();
        estimate(i);
    }

    public void inference(int i) {
        this.test = new ArrayList<>();
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it = this.testing.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[instanceLength(next)]);
            Randoms randoms = new Randoms();
            FeatureSequence featureSequence = (FeatureSequence) next.getData();
            int[] features = labelSequence.getFeatures();
            for (int i2 = 0; i2 < features.length; i2++) {
                featureSequence.getIndexAtPosition(i2);
                features[i2] = randoms.nextInt(this.numTopics);
            }
            arrayList.add(labelSequence);
        }
        if (!$assertionsDisabled && this.testing.size() != arrayList.size()) {
            throw new AssertionError();
        }
        for (int i3 = 0; i3 < this.testing.size(); i3++) {
            LDAHyper.Topication topication = new LDAHyper.Topication(this.testing.get(i3), this, (LabelSequence) arrayList.get(i3));
            this.test.add(topication);
            FeatureSequence featureSequence2 = (FeatureSequence) topication.instance.getData();
            LabelSequence labelSequence2 = topication.topicSequence;
            for (int i4 = 0; i4 < labelSequence2.getLength(); i4++) {
                int indexAtPosition = labelSequence2.getIndexAtPosition(i4);
                int indexAtPosition2 = featureSequence2.getIndexAtPosition(i4);
                if (indexAtPosition != -1) {
                    this.typeTopicCounts[indexAtPosition2].adjustOrPutValue(indexAtPosition, 1, 1);
                    int[] iArr = this.tokensPerTopic;
                    iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
                }
            }
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i5 = 0; i5 <= i; i5++) {
            if (i5 % 100 == 0) {
                System.out.print("Iteration: " + i5);
                System.out.println();
            }
            int size = this.test.size();
            for (int i6 = 0; i6 < size; i6++) {
                sampleTopicsForOneTestDoc((FeatureSequence) this.test.get(i6).instance.getData(), this.test.get(i6).topicSequence);
            }
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal inferencing time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void sampleTopicsForOneTestDoc(FeatureSequence featureSequence, LabelSequence labelSequence) {
        int[] features = labelSequence.getFeatures();
        double[] dArr = new double[this.numTopics];
        int length = featureSequence.getLength();
        int[] iArr = new int[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            iArr[i] = 0;
        }
        for (int i2 = 0; i2 < length; i2++) {
            if (features[i2] != -1) {
                int i3 = features[i2];
                iArr[i3] = iArr[i3] + 1;
            }
        }
        for (int i4 = 0; i4 < length; i4++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i4);
            int i5 = features[i4];
            if (i5 != -1) {
                iArr[i5] = iArr[i5] - 1;
                TIntIntHashMap tIntIntHashMap = this.typeTopicCounts[indexAtPosition];
                if (!$assertionsDisabled && tIntIntHashMap.get(i5) < 0) {
                    throw new AssertionError();
                }
                if (tIntIntHashMap.get(i5) == 1) {
                    tIntIntHashMap.remove(i5);
                } else {
                    tIntIntHashMap.adjustValue(i5, -1);
                }
                int[] iArr2 = this.tokensPerTopic;
                iArr2[i5] = iArr2[i5] - 1;
                Arrays.fill(dArr, 0.0d);
                double d = 0.0d;
                for (int i6 = 0; i6 < this.numTopics; i6++) {
                    double d2 = ((tIntIntHashMap.get(i6) + this.beta) / (this.tokensPerTopic[i6] + this.betaSum)) * (iArr[i6] + this.alpha[i6]);
                    d += d2;
                    dArr[i6] = d2;
                }
                int nextDiscrete = this.random.nextDiscrete(dArr, d);
                features[i4] = nextDiscrete;
                tIntIntHashMap.adjustOrPutValue(nextDiscrete, 1, 1);
                iArr[nextDiscrete] = iArr[nextDiscrete] + 1;
                int[] iArr3 = this.tokensPerTopic;
                iArr3[nextDiscrete] = iArr3[nextDiscrete] + 1;
            }
        }
    }

    public void inferenceOneByOne(int i) {
        this.test = new ArrayList<>();
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it = this.testing.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[instanceLength(next)]);
            Randoms randoms = new Randoms();
            FeatureSequence featureSequence = (FeatureSequence) next.getData();
            int[] features = labelSequence.getFeatures();
            for (int i2 = 0; i2 < features.length; i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                features[i2] = randoms.nextInt(this.numTopics);
                this.typeTopicCounts[indexAtPosition].adjustOrPutValue(features[i2], 1, 1);
                int[] iArr = this.tokensPerTopic;
                int i3 = features[i2];
                iArr[i3] = iArr[i3] + 1;
            }
            arrayList.add(labelSequence);
        }
        if (!$assertionsDisabled && this.testing.size() != arrayList.size()) {
            throw new AssertionError();
        }
        for (int i4 = 0; i4 < this.testing.size(); i4++) {
            this.test.add(new LDAHyper.Topication(this.testing.get(i4), this, (LabelSequence) arrayList.get(i4)));
        }
        long currentTimeMillis = System.currentTimeMillis();
        int size = this.test.size();
        for (int i5 = 0; i5 < size; i5++) {
            FeatureSequence featureSequence2 = (FeatureSequence) this.test.get(i5).instance.getData();
            LabelSequence labelSequence2 = this.test.get(i5).topicSequence;
            for (int i6 = 0; i6 <= i; i6++) {
                sampleTopicsForOneTestDoc(featureSequence2, labelSequence2);
            }
            if (i5 % 100 == 0) {
                System.out.print("Docnum: " + i5);
                System.out.println();
            }
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal inferencing time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    public void inferenceWithTheta(int i, InstanceList instanceList) {
        this.test = new ArrayList<>();
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it = this.testing.iterator();
        while (it.hasNext()) {
            Instance next = it.next();
            LabelSequence labelSequence = new LabelSequence(this.topicAlphabet, new int[instanceLength(next)]);
            Randoms randoms = new Randoms();
            FeatureSequence featureSequence = (FeatureSequence) next.getData();
            int[] features = labelSequence.getFeatures();
            for (int i2 = 0; i2 < features.length; i2++) {
                featureSequence.getIndexAtPosition(i2);
                features[i2] = randoms.nextInt(this.numTopics);
            }
            arrayList.add(labelSequence);
        }
        if (!$assertionsDisabled && this.testing.size() != arrayList.size()) {
            throw new AssertionError();
        }
        for (int i3 = 0; i3 < this.testing.size(); i3++) {
            LDAHyper.Topication topication = new LDAHyper.Topication(this.testing.get(i3), this, (LabelSequence) arrayList.get(i3));
            this.test.add(topication);
            FeatureSequence featureSequence2 = (FeatureSequence) topication.instance.getData();
            LabelSequence labelSequence2 = topication.topicSequence;
            for (int i4 = 0; i4 < labelSequence2.getLength(); i4++) {
                int indexAtPosition = labelSequence2.getIndexAtPosition(i4);
                int indexAtPosition2 = featureSequence2.getIndexAtPosition(i4);
                if (indexAtPosition != -1) {
                    this.typeTopicCounts[indexAtPosition2].adjustOrPutValue(indexAtPosition, 1, 1);
                    int[] iArr = this.tokensPerTopic;
                    iArr[indexAtPosition] = iArr[indexAtPosition] + 1;
                }
            }
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i5 = 0; i5 <= i; i5++) {
            if (i5 % 100 == 0) {
                System.out.print("Iteration: " + i5);
                System.out.println();
            }
            int size = this.test.size();
            for (int i6 = 0; i6 < size; i6++) {
                sampleTopicsForOneDocWithTheta((FeatureSequence) this.test.get(i6).instance.getData(), this.test.get(i6).topicSequence, ((FeatureVector) instanceList.get(i6).getData()).getValues());
            }
        }
        long round = Math.round((System.currentTimeMillis() - currentTimeMillis) / 1000.0d);
        long j = round / 60;
        long j2 = round % 60;
        long j3 = j / 60;
        long j4 = j % 60;
        long j5 = j3 / 24;
        long j6 = j3 % 24;
        System.out.print("\nTotal inferencing time: ");
        if (j5 != 0) {
            System.out.print(j5);
            System.out.print(" days ");
        }
        if (j6 != 0) {
            System.out.print(j6);
            System.out.print(" hours ");
        }
        if (j4 != 0) {
            System.out.print(j4);
            System.out.print(" minutes ");
        }
        System.out.print(j2);
        System.out.println(" seconds");
    }

    private void sampleTopicsForOneDocWithTheta(FeatureSequence featureSequence, LabelSequence labelSequence, double[] dArr) {
        int[] features = labelSequence.getFeatures();
        double[] dArr2 = new double[this.numTopics];
        int length = featureSequence.getLength();
        for (int i = 0; i < length; i++) {
            int indexAtPosition = featureSequence.getIndexAtPosition(i);
            int i2 = features[i];
            if (i2 != -1) {
                TIntIntHashMap tIntIntHashMap = this.typeTopicCounts[indexAtPosition];
                if (!$assertionsDisabled && tIntIntHashMap.get(i2) < 0) {
                    throw new AssertionError();
                }
                if (tIntIntHashMap.get(i2) == 1) {
                    tIntIntHashMap.remove(i2);
                } else {
                    tIntIntHashMap.adjustValue(i2, -1);
                }
                int[] iArr = this.tokensPerTopic;
                iArr[i2] = iArr[i2] - 1;
                Arrays.fill(dArr2, 0.0d);
                double d = 0.0d;
                for (int i3 = 0; i3 < this.numTopics; i3++) {
                    double d2 = ((tIntIntHashMap.get(i3) + this.beta) / (this.tokensPerTopic[i3] + this.betaSum)) * dArr[i3];
                    d += d2;
                    dArr2[i3] = d2;
                }
                int nextDiscrete = this.random.nextDiscrete(dArr2, d);
                features[i] = nextDiscrete;
                tIntIntHashMap.adjustOrPutValue(nextDiscrete, 1, 1);
                int[] iArr2 = this.tokensPerTopic;
                iArr2[nextDiscrete] = iArr2[nextDiscrete] + 1;
            }
        }
    }

    public void printTheta(ArrayList<LDAHyper.Topication> arrayList, File file, double d, int i) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file));
        int[] iArr = new int[this.numTopics];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int[] features = arrayList.get(i2).topicSequence.getFeatures();
            int length = features.length;
            for (int i3 : features) {
                iArr[i3] = iArr[i3] + 1;
            }
            printWriter.println(arrayList.get(i2).instance.getName());
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                printWriter.println("topic" + i4 + FileDictionary.DEFAULT_FIELD_DELIMITER + ((iArr[i4] + this.alpha[i4]) / (length + this.alphaSum)));
            }
            printWriter.println();
            Arrays.fill(iArr, 0);
        }
        printWriter.close();
    }

    public void printPhi(File file, double d) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(file));
        FeatureCounter[] featureCounterArr = new FeatureCounter[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            featureCounterArr[i] = new FeatureCounter(this.alphabet);
        }
        for (int i2 = 0; i2 < this.numTypes; i2++) {
            int[] keys = this.typeTopicCounts[i2].keys();
            for (int i3 = 0; i3 < keys.length; i3++) {
                featureCounterArr[keys[i3]].increment(i2, this.typeTopicCounts[i2].get(keys[i3]));
            }
        }
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            printWriter.println("Topic\t" + i4);
            FeatureVector featureVector = featureCounterArr[i4].toFeatureVector();
            for (int i5 = 0; i5 < featureVector.numLocations(); i5++) {
                printWriter.println(((String) this.alphabet.lookupObject(featureVector.indexAtLocation(i5))) + FileDictionary.DEFAULT_FIELD_DELIMITER + ((((int) featureVector.valueAtLocation(i5)) + this.beta) / (this.tokensPerTopic[i4] + this.betaSum)));
            }
            printWriter.println();
        }
        printWriter.close();
    }

    public void printDocumentTopics(ArrayList<LDAHyper.Topication> arrayList, File file) throws IOException {
        printDocumentTopics(arrayList, new PrintWriter(new FileWriter(file)));
    }

    public void printDocumentTopics(ArrayList<LDAHyper.Topication> arrayList, PrintWriter printWriter) {
        printDocumentTopics(arrayList, printWriter, 0.0d, -1);
    }

    public void printDocumentTopics(ArrayList<LDAHyper.Topication> arrayList, PrintWriter printWriter, double d, int i) {
        printWriter.print("#doc source topic proportion ...\n");
        int[] iArr = new int[this.numTopics];
        IDSorter[] iDSorterArr = new IDSorter[this.numTopics];
        for (int i2 = 0; i2 < this.numTopics; i2++) {
            iDSorterArr[i2] = new IDSorter(i2, i2);
        }
        if (i < 0 || i > this.numTopics) {
            i = this.numTopics;
        }
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            int[] features = arrayList.get(i3).topicSequence.getFeatures();
            printWriter.print(i3);
            printWriter.print(' ');
            if (arrayList.get(i3).instance.getSource() != null) {
                printWriter.print(arrayList.get(i3).instance.getSource());
            } else {
                printWriter.print("null-source");
            }
            printWriter.print(' ');
            int length = features.length;
            int i4 = 0;
            for (int i5 = 0; i5 < length; i5++) {
                if (features[i5] != -1) {
                    int i6 = features[i5];
                    iArr[i6] = iArr[i6] + 1;
                    i4++;
                }
            }
            if (!$assertionsDisabled && i4 != length) {
                throw new AssertionError();
            }
            this.alphaSum = 0.0d;
            for (int i7 = 0; i7 < this.numTopics; i7++) {
                this.alphaSum += this.alpha[i7];
            }
            for (int i8 = 0; i8 < this.numTopics; i8++) {
                iDSorterArr[i8].set(i8, (iArr[i8] + this.alpha[i8]) / (length + this.alphaSum));
            }
            Arrays.sort(iDSorterArr);
            for (int i9 = 0; i9 < i && iDSorterArr[i9].getWeight() >= d; i9++) {
                printWriter.print(iDSorterArr[i9].getID() + " " + iDSorterArr[i9].getWeight() + " ");
            }
            printWriter.print(" \n");
            Arrays.fill(iArr, 0);
        }
        printWriter.close();
    }

    public void printState(ArrayList<LDAHyper.Topication> arrayList, File file) throws IOException {
        PrintStream printStream = new PrintStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file))));
        printState(arrayList, printStream);
        printStream.close();
    }

    public void printState(ArrayList<LDAHyper.Topication> arrayList, PrintStream printStream) {
        printStream.println("#doc source pos typeindex type topic");
        for (int i = 0; i < arrayList.size(); i++) {
            FeatureSequence featureSequence = (FeatureSequence) arrayList.get(i).instance.getData();
            LabelSequence labelSequence = arrayList.get(i).topicSequence;
            String obj = arrayList.get(i).instance.getSource() != null ? arrayList.get(i).instance.getSource().toString() : "NA";
            for (int i2 = 0; i2 < labelSequence.getLength(); i2++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i2);
                int indexAtPosition2 = labelSequence.getIndexAtPosition(i2);
                printStream.print(i);
                printStream.print(' ');
                printStream.print(obj);
                printStream.print(' ');
                printStream.print(i2);
                printStream.print(' ');
                printStream.print(indexAtPosition);
                printStream.print(' ');
                printStream.print(this.alphabet.lookupObject(indexAtPosition));
                printStream.print(' ');
                printStream.print(indexAtPosition2);
                printStream.println();
            }
        }
    }

    static {
        $assertionsDisabled = !LDAStream.class.desiredAssertionStatus();
    }
}
