package org.apache.mahout.df.split;

import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataUtils;
import org.apache.mahout.df.data.Instance;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/df/split/OptIgSplit.class */
public class OptIgSplit extends IgSplit {
    private int[][] counts;
    private int[] countAll;
    private int[] countLess;

    @Override // org.apache.mahout.df.split.IgSplit
    public Split computeSplit(Data data, int i) {
        return data.getDataset().isNumerical(i) ? numericalSplit(data, i) : categoricalSplit(data, i);
    }

    private static Split categoricalSplit(Data data, int i) {
        double[] values = data.values(i);
        int[][] iArr = new int[values.length][data.getDataset().nblabels()];
        int[] iArr2 = new int[data.getDataset().nblabels()];
        for (int i2 = 0; i2 < data.size(); i2++) {
            Instance instance = data.get(i2);
            int[] iArr3 = iArr[ArrayUtils.indexOf(values, instance.get(i))];
            int i3 = instance.label;
            iArr3[i3] = iArr3[i3] + 1;
            int i4 = instance.label;
            iArr2[i4] = iArr2[i4] + 1;
        }
        int size = data.size();
        double entropy = entropy(iArr2, size);
        double d = 0.0d;
        double d2 = 1.0d / size;
        for (int i5 = 0; i5 < values.length; i5++) {
            int sum = DataUtils.sum(iArr[i5]);
            d += sum * d2 * entropy(iArr[i5], sum);
        }
        return new Split(i, entropy - d);
    }

    private static double[] sortedValues(Data data, int i) {
        double[] values = data.values(i);
        Arrays.sort(values);
        return values;
    }

    protected void initCounts(Data data, double[] dArr) {
        this.counts = new int[dArr.length][data.getDataset().nblabels()];
        this.countAll = new int[data.getDataset().nblabels()];
        this.countLess = new int[data.getDataset().nblabels()];
    }

    protected void computeFrequencies(Data data, int i, double[] dArr) {
        for (int i2 = 0; i2 < data.size(); i2++) {
            Instance instance = data.get(i2);
            int[] iArr = this.counts[ArrayUtils.indexOf(dArr, instance.get(i))];
            int i3 = instance.label;
            iArr[i3] = iArr[i3] + 1;
            int[] iArr2 = this.countAll;
            int i4 = instance.label;
            iArr2[i4] = iArr2[i4] + 1;
        }
    }

    protected Split numericalSplit(Data data, int i) {
        double[] sortedValues = sortedValues(data, i);
        initCounts(data, sortedValues);
        computeFrequencies(data, i, sortedValues);
        int size = data.size();
        double entropy = entropy(this.countAll, size);
        double d = 1.0d / size;
        int i2 = -1;
        double d2 = -1.0d;
        for (int i3 = 0; i3 < sortedValues.length; i3++) {
            int sum = DataUtils.sum(this.countLess);
            double entropy2 = entropy - ((sum * d) * entropy(this.countLess, sum));
            int sum2 = DataUtils.sum(this.countAll);
            double entropy3 = entropy2 - ((sum2 * d) * entropy(this.countAll, sum2));
            if (entropy3 > d2) {
                d2 = entropy3;
                i2 = i3;
            }
            DataUtils.add(this.countLess, this.counts[i3]);
            DataUtils.dec(this.countAll, this.counts[i3]);
        }
        if (i2 == -1) {
            throw new IllegalStateException("no best split found !");
        }
        return new Split(i, d2, sortedValues[i2]);
    }

    private static double entropy(int[] iArr, int i) {
        if (i == 0) {
            return 0.0d;
        }
        double d = 0.0d;
        double d2 = 1.0d / i;
        for (int i2 : iArr) {
            if (i2 != 0) {
                double d3 = i2 * d2;
                d += ((-d3) * Math.log(d3)) / LOG2;
            }
        }
        return d;
    }
}
