package org.apache.mahout.classifier.cbayes;

import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import org.apache.hadoop.util.PriorityQueue;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.common.Classifier;
import org.apache.mahout.common.Model;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.1.jar:org/apache/mahout/classifier/cbayes/CBayesClassifier.class */
public class CBayesClassifier implements Classifier {

    /* loaded from: input_file:WEB-INF/lib/mahout-core-0.1.jar:org/apache/mahout/classifier/cbayes/CBayesClassifier$ClassifierResultPriorityQueue.class */
    private static class ClassifierResultPriorityQueue extends PriorityQueue<ClassifierResult> {
        private ClassifierResultPriorityQueue(int i) {
            initialize(i);
        }

        protected boolean lessThan(Object obj, Object obj2) {
            ClassifierResult classifierResult = (ClassifierResult) obj;
            ClassifierResult classifierResult2 = (ClassifierResult) obj2;
            double score = classifierResult.getScore();
            double score2 = classifierResult2.getScore();
            return score == score2 ? classifierResult.getLabel().compareTo(classifierResult2.getLabel()) < 0 : score < score2;
        }
    }

    @Override // org.apache.mahout.common.Classifier
    public Collection<ClassifierResult> classify(Model model, String[] strArr, String str, int i) {
        Collection<String> labels = model.getLabels();
        ClassifierResultPriorityQueue classifierResultPriorityQueue = new ClassifierResultPriorityQueue(i);
        for (String str2 : labels) {
            double documentWeight = documentWeight(model, str2, strArr);
            if (documentWeight < 0.0d) {
                classifierResultPriorityQueue.insert(new ClassifierResult(str2, documentWeight));
            }
        }
        LinkedList linkedList = new LinkedList();
        while (true) {
            ClassifierResult classifierResult = (ClassifierResult) classifierResultPriorityQueue.pop();
            if (classifierResult == null) {
                break;
            }
            linkedList.addLast(classifierResult);
        }
        if (linkedList.isEmpty()) {
            linkedList.add(new ClassifierResult(str, 0.0d));
        }
        return linkedList;
    }

    @Override // org.apache.mahout.common.Classifier
    public ClassifierResult classify(Model model, String[] strArr, String str) {
        ClassifierResult classifierResult = new ClassifierResult(str);
        double d = 0.0d;
        for (String str2 : model.getLabels()) {
            double documentWeight = documentWeight(model, str2, strArr);
            if (documentWeight < d) {
                d = documentWeight;
                classifierResult.setLabel(str2);
            }
        }
        classifierResult.setScore(d);
        return classifierResult;
    }

    @Override // org.apache.mahout.common.Classifier
    public double documentWeight(Model model, String str, String[] strArr) {
        HashMap hashMap = new HashMap(1000);
        for (String str2 : strArr) {
            Integer[] numArr = (Integer[]) hashMap.get(str2);
            if (numArr == null) {
                numArr = new Integer[]{0};
                hashMap.put(str2, numArr);
            }
            Integer[] numArr2 = numArr;
            Integer num = numArr2[0];
            numArr2[0] = Integer.valueOf(numArr2[0].intValue() + 1);
        }
        double d = 0.0d;
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            d += ((Integer[]) r0.getValue())[0].intValue() * model.featureWeight(str, (String) ((Map.Entry) it.next()).getKey());
        }
        return d;
    }
}
