/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.Multinomial;
import cc.mallet.types.RankedFeatureVector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;

public class NaiveBayes
extends Classifier
implements Serializable {
    Multinomial.Logged prior;
    Multinomial.Logged[] p;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;

    public NaiveBayes(Pipe instancePipe, Multinomial.Logged prior, Multinomial.Logged[] classIndex2FeatureProb) {
        super(instancePipe);
        this.prior = prior;
        this.p = classIndex2FeatureProb;
    }

    private static Multinomial.Logged[] logMultinomials(Multinomial[] m) {
        Multinomial.Logged[] ml = new Multinomial.Logged[m.length];
        int i = 0;
        while (i < m.length) {
            ml[i] = new Multinomial.Logged(m[i]);
            ++i;
        }
        return ml;
    }

    public NaiveBayes(Pipe dataPipe, Multinomial prior, Multinomial[] classIndex2FeatureProb) {
        this(dataPipe, new Multinomial.Logged(prior), NaiveBayes.logMultinomials(classIndex2FeatureProb));
    }

    public Multinomial.Logged[] getMultinomials() {
        return this.p;
    }

    public Multinomial.Logged getPriors() {
        return this.prior;
    }

    public void printWords(int numToPrint) {
        Alphabet alphabet = this.instancePipe.getDataAlphabet();
        int numFeatures = alphabet.size();
        int numLabels = this.instancePipe.getTargetAlphabet().size();
        double[] probs = new double[numFeatures];
        numToPrint = Math.min(numToPrint, numFeatures);
        int li = 0;
        while (li < numLabels) {
            Arrays.fill(probs, 0.0);
            this.p[li].addProbabilities(probs);
            RankedFeatureVector rfv = new RankedFeatureVector(alphabet, probs);
            System.out.println("\nFeature probabilities " + this.instancePipe.getTargetAlphabet().lookupObject(li));
            int i = 0;
            while (i < numToPrint) {
                System.out.println(rfv.getObjectAtRank(i) + " " + rfv.getValueAtRank(i));
                ++i;
            }
            ++li;
        }
    }

    @Override
    public Classification classify(Instance instance) {
        int ci;
        int numClasses = this.getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (this.instancePipe == null || fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        int fvisize = fv.numLocations();
        this.prior.addLogProbabilities(scores);
        int fvi = 0;
        while (fvi < fvisize) {
            int fi = fv.indexAtLocation(fvi);
            ci = 0;
            while (ci < numClasses) {
                if (ci < this.p.length && fi < this.p[ci].size()) {
                    int n = ci;
                    scores[n] = scores[n] + fv.valueAtLocation(fvi) * this.p[ci].logProbability(fi);
                }
                ++ci;
            }
            ++fvi;
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        ci = 0;
        while (ci < numClasses) {
            if (scores[ci] > maxScore) {
                maxScore = scores[ci];
            }
            ++ci;
        }
        ci = 0;
        while (ci < numClasses) {
            int n = ci++;
            scores[n] = scores[n] - maxScore;
        }
        double sum = 0.0;
        int ci2 = 0;
        while (ci2 < numClasses) {
            scores[ci2] = Math.exp(scores[ci2]);
            sum += scores[ci2];
            ++ci2;
        }
        ci2 = 0;
        while (ci2 < numClasses) {
            int n = ci2++;
            scores[n] = scores[n] / sum;
        }
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    private double dataLogProbability(Instance instance, int labelIndex) {
        FeatureVector fv = (FeatureVector)instance.getData();
        int fvisize = fv.numLocations();
        double logProb = 0.0;
        int fvi = 0;
        while (fvi < fvisize) {
            logProb += fv.valueAtLocation(fvi) * this.p[labelIndex].logProbability(fv.indexAtLocation(fvi));
            ++fvi;
        }
        return logProb;
    }

    public double dataLogLikelihood(InstanceList ilist) {
        double logLikelihood = 0.0;
        int ii = 0;
        while (ii < ilist.size()) {
            double instanceWeight = ilist.getInstanceWeight(ii);
            Instance inst = (Instance)ilist.get(ii);
            Labeling labeling = inst.getLabeling();
            if (labeling != null) {
                logLikelihood += instanceWeight * this.dataLogProbability(inst, labeling.getBestIndex());
            } else {
                Labeling predicted = this.classify(inst).getLabeling();
                int lpos = 0;
                while (lpos < predicted.numLocations()) {
                    int li = predicted.indexAtLocation(lpos);
                    double labelWeight = predicted.valueAtLocation(lpos);
                    if (labelWeight != 0.0) {
                        logLikelihood += instanceWeight * labelWeight * this.dataLogProbability(inst, li);
                    }
                    ++lpos;
                }
            }
            ++ii;
        }
        return logLikelihood;
    }

    public double labelLogLikelihood(InstanceList ilist) {
        double logLikelihood = 0.0;
        int ii = 0;
        while (ii < ilist.size()) {
            double instanceWeight = ilist.getInstanceWeight(ii);
            Instance inst = (Instance)ilist.get(ii);
            Labeling labeling = inst.getLabeling();
            if (labeling != null) {
                Labeling predicted = this.classify(inst).getLabeling();
                if (labeling.numLocations() == 1) {
                    logLikelihood += instanceWeight * Math.log(predicted.value(labeling.getBestIndex()));
                } else {
                    int lpos = 0;
                    while (lpos < labeling.numLocations()) {
                        int li = labeling.indexAtLocation(lpos);
                        double labelWeight = labeling.valueAtLocation(lpos);
                        if (labelWeight != 0.0) {
                            logLikelihood += instanceWeight * labelWeight * Math.log(predicted.value(li));
                        }
                        ++lpos;
                    }
                }
            }
            ++ii;
        }
        return logLikelihood;
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        out.writeObject(this.prior);
        out.writeObject(this.p);
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched NaiveBayes versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        this.prior = (Multinomial.Logged)in.readObject();
        this.p = (Multinomial.Logged[])in.readObject();
    }
}

