/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.bayes;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.datastructure.Iterables;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.tuple.Tuple;
import de.jungblut.writable.MatrixWritable;
import de.jungblut.writable.VectorWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import org.apache.commons.math3.util.FastMath;

public final class MultinomialNaiveBayes
extends AbstractClassifier {
    private static final double LOW_PROBABILITY = FastMath.log((double)1.0E-8);
    private DoubleMatrix probabilityMatrix;
    private DoubleVector classPriorProbability;
    private boolean verbose;

    public MultinomialNaiveBayes() {
    }

    public MultinomialNaiveBayes(boolean verbose) {
        this.verbose = verbose;
    }

    private MultinomialNaiveBayes(DoubleMatrix probabilityMatrix, DoubleVector classProbability) {
        this.probabilityMatrix = probabilityMatrix;
        this.classPriorProbability = classProbability;
    }

    @Override
    public void train(Iterable<DoubleVector> features, Iterable<DoubleVector> outcome) {
        Iterator<DoubleVector> outcomeIterator;
        Iterator<DoubleVector> featureIterator = features.iterator();
        Tuple<DoubleVector, DoubleVector> first = Iterables.consumeNext(featureIterator, outcomeIterator = outcome.iterator());
        int numDistinctClasses = ((DoubleVector)first.getSecond()).getDimension();
        numDistinctClasses = numDistinctClasses == 1 ? 2 : numDistinctClasses;
        this.probabilityMatrix = new SparseDoubleRowMatrix(numDistinctClasses, ((DoubleVector)first.getFirst()).getDimension());
        int[] tokenPerClass = new int[numDistinctClasses];
        int[] numDocumentsPerClass = new int[numDistinctClasses];
        this.observe((DoubleVector)first.getFirst(), (DoubleVector)first.getSecond(), numDistinctClasses, tokenPerClass, numDocumentsPerClass);
        int numDocumentsSeen = 1;
        while ((first = Iterables.consumeNext(featureIterator, outcomeIterator)) != null) {
            this.observe((DoubleVector)first.getFirst(), (DoubleVector)first.getSecond(), numDistinctClasses, tokenPerClass, numDocumentsPerClass);
            ++numDocumentsSeen;
        }
        for (int row = 0; row < numDistinctClasses; ++row) {
            DoubleVector rowVector = this.probabilityMatrix.getRowVector(row);
            Iterator iterateNonZero = rowVector.iterateNonZero();
            double normalizer = FastMath.log((double)(tokenPerClass[row] + this.probabilityMatrix.getColumnCount() - 1));
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                double currentWordCount = next.getValue();
                double logProbability = FastMath.log((double)currentWordCount) - normalizer;
                this.probabilityMatrix.set(row, next.getIndex(), logProbability);
            }
            if (!this.verbose) continue;
            System.out.println("Computed " + row + " / " + numDistinctClasses + "!");
        }
        this.classPriorProbability = new DenseDoubleVector(numDistinctClasses);
        for (int i = 0; i < numDistinctClasses; ++i) {
            double prior = FastMath.log((double)numDocumentsPerClass[i]) - FastMath.log((double)numDocumentsSeen);
            this.classPriorProbability.set(i, prior);
        }
    }

    private void observe(DoubleVector document, DoubleVector outcome, int numDistinctClasses, int[] tokenPerClass, int[] numDocumentsPerClass) {
        int predictedClass = outcome.maxIndex();
        if (numDistinctClasses == 2) {
            predictedClass = (int)outcome.get(0);
        }
        int n = predictedClass;
        tokenPerClass[n] = tokenPerClass[n] + document.getLength();
        int n2 = predictedClass;
        numDocumentsPerClass[n2] = numDocumentsPerClass[n2] + 1;
        Iterator iterateNonZero = document.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
            double currentCount = this.probabilityMatrix.get(predictedClass, next.getIndex());
            this.probabilityMatrix.set(predictedClass, next.getIndex(), currentCount + next.getValue());
        }
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        return this.getProbabilityDistribution(features);
    }

    private double getProbabilityForClass(DoubleVector document, int classIndex) {
        double probabilitySum = 0.0;
        Iterator iterateNonZero = document.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
            double wordCount = next.getValue();
            double probabilityOfToken = this.probabilityMatrix.get(classIndex, next.getIndex());
            if (probabilityOfToken == 0.0) {
                probabilityOfToken = LOW_PROBABILITY;
            }
            probabilitySum += wordCount * probabilityOfToken;
        }
        return probabilitySum;
    }

    private DenseDoubleVector getProbabilityDistribution(DoubleVector document) {
        int numClasses = this.classPriorProbability.getLength();
        DenseDoubleVector distribution = new DenseDoubleVector(numClasses);
        for (int i = 0; i < numClasses; ++i) {
            double probability = this.getProbabilityForClass(document, i);
            distribution.set(i, probability);
        }
        double maxProbability = distribution.max();
        double probabilitySum = 0.0;
        for (int i = 0; i < numClasses; ++i) {
            double probability = distribution.get(i);
            double normalizedProbability = FastMath.exp((double)(probability - maxProbability + this.classPriorProbability.get(i)));
            distribution.set(i, normalizedProbability);
            probabilitySum += normalizedProbability;
        }
        distribution = (DenseDoubleVector)distribution.divide(probabilitySum);
        return distribution;
    }

    DoubleVector getClassProbability() {
        return this.classPriorProbability;
    }

    DoubleMatrix getProbabilityMatrix() {
        return this.probabilityMatrix;
    }

    public static MultinomialNaiveBayes deserialize(DataInput in) throws IOException {
        MatrixWritable matrixWritable = new MatrixWritable();
        matrixWritable.readFields(in);
        DoubleVector classProbability = VectorWritable.readVector(in);
        return new MultinomialNaiveBayes(matrixWritable.getMatrix(), classProbability);
    }

    public static void serialize(MultinomialNaiveBayes model, DataOutput out) throws IOException {
        new MatrixWritable(model.probabilityMatrix).write(out);
        VectorWritable.writeVector(model.classPriorProbability, out);
    }
}

