package org.apache.mahout.classifier.df.split;

import java.util.Arrays;
import java.util.Comparator;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Instance;

/* loaded from: input_file:org/apache/mahout/classifier/df/split/RegressionSplit.class */
public class RegressionSplit extends IgSplit {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/classifier/df/split/RegressionSplit$InstanceComparator.class */
    public static class InstanceComparator implements Comparator<Instance> {
        private final int attr;

        InstanceComparator(int i) {
            this.attr = i;
        }

        @Override // java.util.Comparator
        public int compare(Instance instance, Instance instance2) {
            return Double.compare(instance.get(this.attr), instance2.get(this.attr));
        }
    }

    @Override // org.apache.mahout.classifier.df.split.IgSplit
    public Split computeSplit(Data data, int i) {
        return data.getDataset().isNumerical(i) ? numericalSplit(data, i) : categoricalSplit(data, i);
    }

    private static Split categoricalSplit(Data data, int i) {
        double[] dArr = new double[data.getDataset().nbValues(i)];
        double[] dArr2 = new double[data.getDataset().nbValues(i)];
        double[] dArr3 = new double[data.getDataset().nbValues(i)];
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < data.size(); i2++) {
            Instance instance = data.get(i2);
            int i3 = (int) instance.get(i);
            double label = data.getDataset().getLabel(instance);
            double d3 = label * label;
            dArr[i3] = dArr[i3] + label;
            dArr2[i3] = dArr2[i3] + d3;
            dArr3[i3] = dArr3[i3] + 1.0d;
            d += label;
            d2 += d3;
        }
        return new Split(i, (d2 - ((d * d) / data.size())) - variance(dArr, dArr2, dArr3));
    }

    static Split numericalSplit(Data data, int i) {
        Instance[] instanceArr = new Instance[data.size()];
        for (int i2 = 0; i2 < data.size(); i2++) {
            instanceArr[i2] = data.get(i2);
        }
        Arrays.sort(instanceArr, new InstanceComparator(i));
        double d = 0.0d;
        double d2 = 0.0d;
        for (Instance instance : instanceArr) {
            double label = data.getDataset().getLabel(instance);
            d += label;
            d2 += label * label;
        }
        double[] dArr = new double[2];
        double[] dArr2 = new double[2];
        double d3 = d;
        dArr2[1] = d3;
        dArr[1] = d3;
        double[] dArr3 = new double[2];
        double[] dArr4 = new double[2];
        double d4 = d2;
        dArr4[1] = d4;
        dArr3[1] = d4;
        double[] dArr5 = new double[2];
        double[] dArr6 = new double[2];
        double size = data.size();
        dArr6[1] = size;
        dArr5[1] = size;
        double d5 = instanceArr[0].get(i);
        double d6 = Double.MAX_VALUE;
        double d7 = Double.NaN;
        for (Instance instance2 : instanceArr) {
            if (instance2.get(i) > d5) {
                double variance = variance(dArr2, dArr4, dArr6);
                if (variance < d6) {
                    d6 = variance;
                    d7 = (instance2.get(i) + d5) / 2.0d;
                    for (int i3 = 0; i3 < 2; i3++) {
                        dArr[i3] = dArr2[i3];
                        dArr3[i3] = dArr4[i3];
                        dArr5[i3] = dArr6[i3];
                    }
                }
            }
            d5 = instance2.get(i);
            double label2 = data.getDataset().getLabel(instance2);
            double d8 = label2 * label2;
            dArr2[0] = dArr2[0] + label2;
            dArr4[0] = dArr4[0] + d8;
            dArr6[0] = dArr6[0] + 1.0d;
            dArr2[1] = dArr2[1] - label2;
            dArr4[1] = dArr4[1] - d8;
            dArr6[1] = dArr6[1] - 1.0d;
        }
        return new Split(i, (d2 - ((d * d) / data.size())) - variance(dArr, dArr3, dArr5), d7);
    }

    private static double variance(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr3[i] > 0.0d) {
                d += dArr2[i] - ((dArr[i] * dArr[i]) / dArr3[i]);
            }
        }
        return d;
    }
}
