package org.apache.mahout.classifier.df.split;

import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.DescriptorException;
import org.apache.mahout.classifier.df.data.conditions.Condition;
import org.apache.mahout.common.MahoutTestCase;
import org.junit.Test;

@Deprecated
/* loaded from: input_file:org/apache/mahout/classifier/df/split/RegressionSplitTest.class */
public final class RegressionSplitTest extends MahoutTestCase {
    private static Data[] generateTrainingData() throws DescriptorException {
        String[] strArr = new String[20];
        for (int i = 0; i < strArr.length; i++) {
            if (i % 3 == 0) {
                strArr[i] = "A," + (40 - i) + ',' + (i + 20);
            } else if (i % 3 == 1) {
                strArr[i] = "B," + (i + 20) + ',' + (40 - i);
            } else {
                strArr[i] = "C," + (i + 20) + ',' + (i + 20);
            }
        }
        Dataset generateDataset = DataLoader.generateDataset("C N L", true, strArr);
        Data[] dataArr = new Data[3];
        dataArr[0] = DataLoader.loadData(generateDataset, strArr);
        String[] strArr2 = new String[20];
        for (int i2 = 0; i2 < strArr2.length; i2++) {
            if (i2 % 2 == 0) {
                strArr2[i2] = "A," + (50 - i2) + ',' + (i2 + 10);
            } else {
                strArr2[i2] = "B," + (i2 + 10) + ',' + (50 - i2);
            }
        }
        dataArr[1] = DataLoader.loadData(generateDataset, strArr2);
        String[] strArr3 = new String[10];
        for (int i3 = 0; i3 < strArr3.length; i3++) {
            strArr3[i3] = "A," + (40 - i3) + ',' + (i3 + 20);
        }
        dataArr[2] = DataLoader.loadData(generateDataset, strArr3);
        return dataArr;
    }

    @Test
    public void testComputeSplit() throws DescriptorException {
        Data[] generateTrainingData = generateTrainingData();
        RegressionSplit regressionSplit = new RegressionSplit();
        Split computeSplit = regressionSplit.computeSplit(generateTrainingData[0], 1);
        assertEquals(180.0d, computeSplit.getIg(), 1.0E-6d);
        assertEquals(38.0d, computeSplit.getSplit(), 1.0E-6d);
        Split computeSplit2 = regressionSplit.computeSplit(generateTrainingData[0].subset(Condition.lesser(1, 38.0d)), 1);
        assertEquals(76.5d, computeSplit2.getIg(), 1.0E-6d);
        assertEquals(21.5d, computeSplit2.getSplit(), 1.0E-6d);
        Split computeSplit3 = regressionSplit.computeSplit(generateTrainingData[1], 0);
        assertEquals(2205.0d, computeSplit3.getIg(), 1.0E-6d);
        assertEquals(Double.NaN, computeSplit3.getSplit(), 1.0E-6d);
        Split computeSplit4 = regressionSplit.computeSplit(generateTrainingData[1].subset(Condition.equals(0, 0.0d)), 1);
        assertEquals(250.0d, computeSplit4.getIg(), 1.0E-6d);
        assertEquals(41.0d, computeSplit4.getSplit(), 1.0E-6d);
    }
}
