package org.apache.mahout.classifier.cbayes;

import java.util.Iterator;
import java.util.Map;
import org.apache.mahout.common.Model;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.1.jar:org/apache/mahout/classifier/cbayes/CBayesModel.class */
public class CBayesModel extends Model {
    private static final Logger log = LoggerFactory.getLogger(CBayesModel.class);

    @Override // org.apache.mahout.common.Model
    protected double getWeight(Integer num, Integer num2) {
        double d = 0.0d;
        Map<Integer, Double> map = this.featureLabelWeights.get(num2.intValue());
        if (map.containsKey(num)) {
            d = map.get(num).doubleValue();
        }
        return (-Math.log(((getSumFeatureWeight(num2) - d) + 1.0d) / ((this.sigma_jSigma_k - getSumLabelWeight(num)) + this.featureList.size()))) / getThetaNormalizer(num);
    }

    @Override // org.apache.mahout.common.Model
    protected double getWeightUnprocessed(Integer num, Integer num2) {
        Map<Integer, Double> map = this.featureLabelWeights.get(num2.intValue());
        return map.containsKey(num) ? map.get(num).doubleValue() : 0.0d;
    }

    @Override // org.apache.mahout.common.Model
    public void initializeNormalizer() {
        log.info("{}", this.thetaNormalizer);
        double d = Double.MAX_VALUE;
        Iterator<Map.Entry<Integer, Double>> it = this.thetaNormalizer.entrySet().iterator();
        while (it.hasNext()) {
            double doubleValue = it.next().getValue().doubleValue();
            if (d > Math.abs(doubleValue)) {
                d = Math.abs(doubleValue);
            }
        }
        for (Map.Entry<Integer, Double> entry : this.thetaNormalizer.entrySet()) {
            this.thetaNormalizer.put(entry.getKey(), Double.valueOf(entry.getValue().doubleValue() / d));
        }
        log.info("{}", this.thetaNormalizer);
    }

    @Override // org.apache.mahout.common.Model
    public void generateModel() {
        double size = this.featureList.size();
        double[] dArr = new double[this.labelList.size()];
        int size2 = this.featureList.size();
        for (int i = 0; i < size2; i++) {
            Integer valueOf = Integer.valueOf(i);
            int size3 = this.labelList.size();
            for (int i2 = 0; i2 < size3; i2++) {
                Integer valueOf2 = Integer.valueOf(i2);
                double weightUnprocessed = getWeightUnprocessed(valueOf2, valueOf);
                double log2 = Math.log(((getSumFeatureWeight(valueOf) - weightUnprocessed) + 1.0d) / ((this.sigma_jSigma_k - getSumLabelWeight(valueOf2)) + size));
                if (weightUnprocessed != 0.0d) {
                    setWeight(valueOf2, valueOf, Double.valueOf(log2));
                }
                int i3 = i2;
                dArr[i3] = dArr[i3] + log2;
            }
        }
        log.info("Normalizing Weights");
        double d = Double.MAX_VALUE;
        int size4 = this.labelList.size();
        for (int i4 = 0; i4 < size4; i4++) {
            double d2 = dArr[i4];
            if (d > Math.abs(d2)) {
                d = Math.abs(d2);
            }
        }
        int size5 = this.labelList.size();
        for (int i5 = 0; i5 < size5; i5++) {
            dArr[i5] = dArr[i5] / d;
        }
        int size6 = this.featureList.size();
        for (int i6 = 0; i6 < size6; i6++) {
            Integer valueOf3 = Integer.valueOf(i6);
            int size7 = this.labelList.size();
            for (int i7 = 0; i7 < size7; i7++) {
                Integer valueOf4 = Integer.valueOf(i7);
                double weightUnprocessed2 = getWeightUnprocessed(valueOf4, valueOf3);
                if (weightUnprocessed2 != 0.0d) {
                    setWeight(valueOf4, valueOf3, Double.valueOf((-1.0d) * (weightUnprocessed2 / dArr[i7])));
                }
            }
        }
    }

    @Override // org.apache.mahout.common.Model
    public double featureWeight(Integer num, Integer num2) {
        return getWeight(num, num2);
    }
}
