package org.apache.mahout.classifier.bayes.algorithm;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.PriorityQueue;
import org.apache.commons.lang.mutable.MutableDouble;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.bayes.common.ByScoreLabelResultComparator;
import org.apache.mahout.classifier.bayes.exceptions.InvalidDatastoreException;
import org.apache.mahout.classifier.bayes.interfaces.Algorithm;
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.math.function.ObjectIntProcedure;
import org.apache.mahout.math.map.OpenObjectIntHashMap;

/* loaded from: input_file:org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.class */
public class CBayesAlgorithm implements Algorithm {
    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public ClassifierResult classifyDocument(String[] strArr, Datastore datastore, String str) throws InvalidDatastoreException {
        ClassifierResult classifierResult = new ClassifierResult(str);
        double d = Double.MIN_VALUE;
        for (String str2 : datastore.getKeys("labelWeight")) {
            double documentWeight = documentWeight(datastore, str2, strArr);
            if (d < documentWeight) {
                d = documentWeight;
                classifierResult.setLabel(str2);
            }
        }
        classifierResult.setScore(d);
        return classifierResult;
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public ClassifierResult[] classifyDocument(String[] strArr, Datastore datastore, String str, int i) throws InvalidDatastoreException {
        Collection<String> keys = datastore.getKeys("labelWeight");
        PriorityQueue priorityQueue = new PriorityQueue(i, new ByScoreLabelResultComparator());
        for (String str2 : keys) {
            double documentWeight = documentWeight(datastore, str2, strArr);
            if (documentWeight > 0.0d) {
                priorityQueue.add(new ClassifierResult(str2, documentWeight));
                if (priorityQueue.size() > i) {
                    priorityQueue.remove();
                }
            }
        }
        if (priorityQueue.isEmpty()) {
            return new ClassifierResult[]{new ClassifierResult(str, 0.0d)};
        }
        ArrayList arrayList = new ArrayList(priorityQueue.size());
        while (!priorityQueue.isEmpty()) {
            arrayList.add(priorityQueue.remove());
        }
        Collections.reverse(arrayList);
        return (ClassifierResult[]) arrayList.toArray(new ClassifierResult[priorityQueue.size()]);
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public double featureWeight(Datastore datastore, String str, String str2) throws InvalidDatastoreException {
        double weight = datastore.getWeight("weight", str2, str);
        double weight2 = datastore.getWeight("sumWeight", "vocabCount");
        double weight3 = datastore.getWeight("weight", str2, "sigma_j");
        double weight4 = datastore.getWeight("sumWeight", "sigma_jSigma_k");
        double weight5 = datastore.getWeight("labelWeight", str);
        return Math.log(((weight3 - weight) + datastore.getWeight("params", "alpha_i")) / ((weight4 - weight5) + weight2)) / datastore.getWeight("thetaNormalizer", str);
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public void initialize(Datastore datastore) throws InvalidDatastoreException {
        datastore.getKeys("labelWeight");
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public double documentWeight(final Datastore datastore, final String str, String[] strArr) {
        OpenObjectIntHashMap openObjectIntHashMap = new OpenObjectIntHashMap(strArr.length / 2);
        for (String str2 : strArr) {
            if (openObjectIntHashMap.containsKey(str2)) {
                openObjectIntHashMap.put(str2, openObjectIntHashMap.get(str2) + 1);
            } else {
                openObjectIntHashMap.put(str2, 1);
            }
        }
        final MutableDouble mutableDouble = new MutableDouble(0.0d);
        openObjectIntHashMap.forEachPair(new ObjectIntProcedure<String>() { // from class: org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm.1
            @Override // org.apache.mahout.math.function.ObjectIntProcedure
            public boolean apply(String str3, int i) {
                try {
                    mutableDouble.add(i * CBayesAlgorithm.this.featureWeight(datastore, str, str3));
                    return true;
                } catch (InvalidDatastoreException e) {
                    throw new IllegalStateException(e);
                }
            }
        });
        return mutableDouble.doubleValue();
    }

    @Override // org.apache.mahout.classifier.bayes.interfaces.Algorithm
    public Collection<String> getLabels(Datastore datastore) throws InvalidDatastoreException {
        return datastore.getKeys("labelWeight");
    }
}
