package org.apache.mahout.classifier.df.split;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.Instance;

@Deprecated
/* loaded from: input_file:org/apache/mahout/classifier/df/split/RegressionSplit.class */
public class RegressionSplit extends IgSplit {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/mahout/classifier/df/split/RegressionSplit$InstanceComparator.class */
    public static class InstanceComparator implements Comparator<Instance>, Serializable {
        private final int attr;

        InstanceComparator(int i) {
            this.attr = i;
        }

        @Override // java.util.Comparator
        public int compare(Instance instance, Instance instance2) {
            return Double.compare(instance.get(this.attr), instance2.get(this.attr));
        }
    }

    @Override // org.apache.mahout.classifier.df.split.IgSplit
    public Split computeSplit(Data data, int i) {
        return data.getDataset().isNumerical(i) ? numericalSplit(data, i) : categoricalSplit(data, i);
    }

    private static Split categoricalSplit(Data data, int i) {
        double average;
        FullRunningAverage[] fullRunningAverageArr = new FullRunningAverage[data.getDataset().nbValues(i)];
        double[] dArr = new double[data.getDataset().nbValues(i)];
        for (int i2 = 0; i2 < fullRunningAverageArr.length; i2++) {
            fullRunningAverageArr[i2] = new FullRunningAverage();
        }
        FullRunningAverage fullRunningAverage = new FullRunningAverage();
        double d = 0.0d;
        for (int i3 = 0; i3 < data.size(); i3++) {
            Instance instance = data.get(i3);
            int i4 = (int) instance.get(i);
            double label = data.getDataset().getLabel(instance);
            if (fullRunningAverageArr[i4].getCount() == 0) {
                fullRunningAverageArr[i4].addDatum(label);
                dArr[i4] = 0.0d;
            } else {
                double average2 = fullRunningAverageArr[i4].getAverage();
                fullRunningAverageArr[i4].addDatum(label);
                dArr[i4] = dArr[i4] + ((label - average2) * (label - fullRunningAverageArr[i4].getAverage()));
            }
            if (i3 == 0) {
                fullRunningAverage.addDatum(label);
                average = 0.0d;
            } else {
                double average3 = fullRunningAverage.getAverage();
                fullRunningAverage.addDatum(label);
                average = d + ((label - average3) * (label - fullRunningAverage.getAverage()));
            }
            d = average;
        }
        double d2 = d;
        for (double d3 : dArr) {
            d2 -= d3;
        }
        return new Split(i, d2);
    }

    private static Split numericalSplit(Data data, int i) {
        FullRunningAverage[] fullRunningAverageArr = new FullRunningAverage[2];
        for (int i2 = 0; i2 < fullRunningAverageArr.length; i2++) {
            fullRunningAverageArr[i2] = new FullRunningAverage();
        }
        Instance[] instanceArr = new Instance[data.size()];
        for (int i3 = 0; i3 < data.size(); i3++) {
            instanceArr[i3] = data.get(i3);
        }
        Arrays.sort(instanceArr, new InstanceComparator(i));
        double[] dArr = new double[2];
        for (Instance instance : instanceArr) {
            double label = data.getDataset().getLabel(instance);
            if (fullRunningAverageArr[1].getCount() == 0) {
                fullRunningAverageArr[1].addDatum(label);
                dArr[1] = 0.0d;
            } else {
                double average = fullRunningAverageArr[1].getAverage();
                fullRunningAverageArr[1].addDatum(label);
                dArr[1] = dArr[1] + ((label - average) * (label - fullRunningAverageArr[1].getAverage()));
            }
        }
        double d = dArr[1];
        double d2 = Double.NaN;
        double d3 = Double.NaN;
        double d4 = Double.MAX_VALUE;
        double d5 = 0.0d;
        for (Instance instance2 : instanceArr) {
            double label2 = data.getDataset().getLabel(instance2);
            if (instance2.get(i) > d3) {
                double count = (dArr[0] / fullRunningAverageArr[0].getCount()) + (dArr[1] / fullRunningAverageArr[1].getCount());
                if (count < d4) {
                    d4 = count;
                    d5 = dArr[0] + dArr[1];
                    d2 = (instance2.get(i) + d3) / 2.0d;
                }
            }
            if (fullRunningAverageArr[0].getCount() == 0) {
                fullRunningAverageArr[0].addDatum(label2);
                dArr[0] = 0.0d;
            } else {
                double average2 = fullRunningAverageArr[0].getAverage();
                fullRunningAverageArr[0].addDatum(label2);
                dArr[0] = dArr[0] + ((label2 - average2) * (label2 - fullRunningAverageArr[0].getAverage()));
            }
            double average3 = fullRunningAverageArr[1].getAverage();
            fullRunningAverageArr[1].removeDatum(label2);
            dArr[1] = dArr[1] - ((label2 - average3) * (label2 - fullRunningAverageArr[1].getAverage()));
            d3 = instance2.get(i);
        }
        return new Split(i, d - d5, d2);
    }
}
