/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.optimize.LimitedMemoryBFGS;
import cc.mallet.optimize.OptimizationException;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.topics.DMROptimizable;
import cc.mallet.topics.LDAHyper;
import cc.mallet.types.FeatureCounter;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;

public class DMRTopicModel
extends LDAHyper {
    MaxEnt dmrParameters = null;
    int numFeatures;
    int defaultFeatureIndex;
    Pipe parameterPipe = null;
    double[][] alphaCache;
    double[] alphaSumCache;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public DMRTopicModel(int numberOfTopics) {
        super(numberOfTopics);
    }

    @Override
    public void estimate(int iterationsThisRound) throws IOException {
        this.numFeatures = ((LDAHyper.Topication)this.data.get((int)0)).instance.getTargetAlphabet().size() + 1;
        this.defaultFeatureIndex = this.numFeatures - 1;
        int numDocs = this.data.size();
        this.alphaCache = new double[numDocs][this.numTopics];
        this.alphaSumCache = new double[numDocs];
        long startTime = System.currentTimeMillis();
        int maxIteration = this.iterationsSoFar + iterationsThisRound;
        while (this.iterationsSoFar <= maxIteration) {
            long iterationStart = System.currentTimeMillis();
            if (this.showTopicsInterval != 0 && this.iterationsSoFar != 0 && this.iterationsSoFar % this.showTopicsInterval == 0) {
                System.out.println();
                this.printTopWords(System.out, this.wordsPerTopic, false);
            }
            if (this.saveStateInterval != 0 && this.iterationsSoFar % this.saveStateInterval == 0) {
                this.printState(new File(String.valueOf(this.stateFilename) + '.' + this.iterationsSoFar + ".gz"));
            }
            if (this.iterationsSoFar > this.burninPeriod && this.optimizeInterval != 0 && this.iterationsSoFar % this.optimizeInterval == 0) {
                this.learnParameters();
            }
            int doc = 0;
            while (doc < numDocs) {
                FeatureSequence tokenSequence = (FeatureSequence)((LDAHyper.Topication)this.data.get((int)doc)).instance.getData();
                LabelSequence topicSequence = ((LDAHyper.Topication)this.data.get((int)doc)).topicSequence;
                if (this.dmrParameters != null) {
                    this.setAlphas(((LDAHyper.Topication)this.data.get((int)doc)).instance);
                }
                this.sampleTopicsForOneDoc(tokenSequence, topicSequence, false, false);
                ++doc;
            }
            long ms = System.currentTimeMillis() - iterationStart;
            if (ms > 1000L) {
                System.out.print(String.valueOf(Math.round(ms / 1000L)) + "s ");
            } else {
                System.out.print(String.valueOf(ms) + "ms ");
            }
            if (this.iterationsSoFar % 10 == 0) {
                System.out.println("<" + this.iterationsSoFar + "> ");
                if (this.printLogLikelihood) {
                    System.out.println(this.modelLogLikelihood());
                }
            }
            System.out.flush();
            ++this.iterationsSoFar;
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void setAlphas() {
        double[] parameters = this.dmrParameters.getParameters();
        this.alphaSum = 0.0;
        this.smoothingOnlyMass = 0.0;
        int topic = 0;
        while (topic < this.numTopics) {
            this.alpha[topic] = Math.exp(parameters[topic * this.numFeatures + this.defaultFeatureIndex]);
            this.alphaSum += this.alpha[topic];
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++topic;
        }
    }

    public void setAlphas(int featureIndex) {
        double[] parameters = this.dmrParameters.getParameters();
        this.alphaSum = 0.0;
        this.smoothingOnlyMass = 0.0;
        int topic = 0;
        while (topic < this.numTopics) {
            this.alpha[topic] = Math.exp(parameters[topic * this.numFeatures + featureIndex] + parameters[topic * this.numFeatures + this.defaultFeatureIndex]);
            this.alphaSum += this.alpha[topic];
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++topic;
        }
    }

    public void setAlphas(Instance instance) {
        FeatureVector features = (FeatureVector)instance.getTarget();
        if (features == null) {
            this.setAlphas();
            return;
        }
        double[] parameters = this.dmrParameters.getParameters();
        this.alphaSum = 0.0;
        this.smoothingOnlyMass = 0.0;
        int topic = 0;
        while (topic < this.numTopics) {
            this.alpha[topic] = parameters[topic * this.numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(parameters, this.numFeatures, topic, features, this.defaultFeatureIndex, null);
            this.alpha[topic] = Math.exp(this.alpha[topic]);
            this.alphaSum += this.alpha[topic];
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
            ++topic;
        }
    }

    public void learnParameters() {
        if (this.parameterPipe == null) {
            this.parameterPipe = new Noop();
            this.parameterPipe.setDataAlphabet(((LDAHyper.Topication)this.data.get((int)0)).instance.getTargetAlphabet());
            this.parameterPipe.setTargetAlphabet(this.topicAlphabet);
        }
        InstanceList parameterInstances = new InstanceList(this.parameterPipe);
        if (this.dmrParameters == null) {
            this.dmrParameters = new MaxEnt(this.parameterPipe, new double[this.numFeatures * this.numTopics]);
        }
        int doc = 0;
        while (doc < this.data.size()) {
            if (((LDAHyper.Topication)this.data.get((int)doc)).instance.getTarget() != null) {
                FeatureCounter counter = new FeatureCounter(this.topicAlphabet);
                int[] nArray = ((LDAHyper.Topication)this.data.get((int)doc)).topicSequence.getFeatures();
                int n = nArray.length;
                int n2 = 0;
                while (n2 < n) {
                    int topic = nArray[n2];
                    counter.increment(topic);
                    ++n2;
                }
                parameterInstances.add(new Instance(((LDAHyper.Topication)this.data.get((int)doc)).instance.getTarget(), counter.toFeatureVector(), null, null));
            }
            ++doc;
        }
        DMROptimizable optimizable = new DMROptimizable(parameterInstances, this.dmrParameters);
        optimizable.setRegularGaussianPriorVariance(0.5);
        optimizable.setInterceptGaussianPriorVariance(100.0);
        LimitedMemoryBFGS optimizer = new LimitedMemoryBFGS(optimizable);
        try {
            optimizer.optimize();
        }
        catch (OptimizationException topic) {
            // empty catch block
        }
        try {
            optimizer.optimize();
        }
        catch (OptimizationException topic) {
            // empty catch block
        }
        this.dmrParameters = optimizable.getClassifier();
        int doc2 = 0;
        while (doc2 < this.data.size()) {
            Instance instance = ((LDAHyper.Topication)this.data.get((int)doc2)).instance;
            FeatureSequence tokens = (FeatureSequence)instance.getData();
            if (instance.getTarget() != null) {
                int numTokens = tokens.getLength();
                this.setAlphas(instance);
                int topic = 0;
                while (topic < this.numTopics) {
                    this.alphaCache[doc2][topic] = this.alpha[topic];
                    ++topic;
                }
                this.alphaSumCache[doc2] = this.alphaSum;
            }
            ++doc2;
        }
    }

    @Override
    public void printTopWords(PrintStream out, int numWords, boolean usingNewLines) {
        if (this.dmrParameters != null) {
            this.setAlphas();
        }
        super.printTopWords(out, numWords, usingNewLines);
    }

    public void writeParameters(File parameterFile) throws IOException {
        if (this.dmrParameters != null) {
            PrintStream out = new PrintStream(parameterFile);
            this.dmrParameters.print(out);
            out.close();
        }
    }

    public static void main(String[] args) throws IOException {
        InstanceList training = InstanceList.load(new File(args[0]));
        int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200;
        InstanceList testing = args.length > 2 ? InstanceList.load(new File(args[2])) : null;
        DMRTopicModel lda = new DMRTopicModel(numTopics);
        lda.setOptimizeInterval(100);
        lda.setTopicDisplay(100, 10);
        lda.addInstances(training);
        lda.estimate();
        lda.writeParameters(new File("dmr.parameters"));
        lda.printState(new File("dmr.state.gz"));
    }
}

