package hex.tree;

import hex.DistributionFactory;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.SharedTreeModel;
import org.junit.Assert;
import org.junit.Test;
import water.Key;

/* loaded from: input_file:hex/tree/DTreeTest.class */
public class DTreeTest {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/DTreeTest$ExpectedSplitInfo.class */
    public static class ExpectedSplitInfo {
        DHistogram.NASplitDir _nasplit;
        double _se;
        double _se0;
        double _se1;

        ExpectedSplitInfo(DHistogram.NASplitDir nASplitDir, double d, double d2, double d3) {
            this._nasplit = nASplitDir;
            this._se = d;
            this._se0 = d2;
            this._se1 = d3;
        }

        void checkSplit(DTree.Split split) {
            Assert.assertNotNull(split);
            Assert.assertEquals(this._nasplit, split._nasplit);
            Assert.assertEquals(this._se, split._se, 1.0E-8d);
            Assert.assertEquals(this._se0, split._se0, 1.0E-8d);
            Assert.assertEquals(this._se1, split._se1, 1.0E-8d);
        }
    }

    @Test
    public void testFindBestSplitPoint_pubdev6495() {
        double[] dArr = {10.0d, 10.0d, 100.0d, 1.0d};
        double[] dArr2 = {Double.NaN, Double.NaN, 0.5d, 1.5d};
        double[] dArr3 = {0.0d, 1.0d, 0.0d, 1.0d};
        int[] iArr = {0, 1, 2, 3};
        DHistogram dHistogram = new DHistogram("test_hs", 2, 2, (byte) 0, 0.0d, 2.0d, true, 0.01d, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, (Key) null, (Constraints) null);
        dHistogram.init();
        dHistogram.updateHisto(dArr, (double[]) null, dArr2, dArr3, (double[]) null, iArr, iArr.length, 0);
        DTree.Split findBestSplitPoint = DTree.findBestSplitPoint(dHistogram, 0, 20.0d, 0, Double.NaN, Double.NaN, false);
        Assert.assertNotNull(findBestSplitPoint);
        Assert.assertNull(DTree.findBestSplitPoint(dHistogram, 0, 21.0d, 0, Double.NaN, Double.NaN, false));
        DHistogram dHistogram2 = new DHistogram("test_hs", 2, 2, (byte) 0, 0.0d, 2.0d, true, -9.0d, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, (Key) null, (Constraints) null);
        dHistogram2.init();
        dHistogram2.updateHisto(dArr, (double[]) null, dArr2, dArr3, (double[]) null, iArr, iArr.length, 0);
        DTree.Split findBestSplitPoint2 = DTree.findBestSplitPoint(dHistogram2, 0, 21.0d, 0, Double.NaN, Double.NaN, false);
        Assert.assertNotNull(findBestSplitPoint2);
        Assert.assertEquals(findBestSplitPoint2._se, seNonNA(dArr, dArr2, dArr3), 0.0d);
        Assert.assertTrue(findBestSplitPoint2._se < findBestSplitPoint._se);
    }

    @Test
    public void testFindBestSplitPoint_bounds_NAvsREST() {
        DHistogram makeHisto = makeHisto(1, 0.3d, 0.4d);
        updateHisto(makeHisto, 0.3d, 0.4d, 1.0d).checkSplit(DTree.findBestSplitPoint(makeHisto, 0, 0.0d, 0, 0.3d, 0.4d, true));
    }

    @Test
    public void testFindBestSplitPoint_bounds_NAs() {
        DHistogram makeHisto = makeHisto(100, 0.3d, 0.4d);
        updateHisto(makeHisto, 0.3d, 0.4d, 0.1d).checkSplit(DTree.findBestSplitPoint(makeHisto, 0, 100.0d, 0, 0.3d, 0.4d, true));
    }

    private static DHistogram makeHisto(int i, double d, double d2) {
        Constraints withNewConstraint = new Constraints(new int[]{1}, DistributionFactory.getDistribution(DistributionFamily.gaussian), true).withNewConstraint(0, 1, d).withNewConstraint(0, 0, d2);
        Assert.assertEquals(d, withNewConstraint._min, 0.0d);
        Assert.assertEquals(d2, withNewConstraint._max, 0.0d);
        return new DHistogram("test_hs", i, 2, (byte) 0, 0.0d, 10.0d, false, 0.01d, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, (Key) null, withNewConstraint);
    }

    private static ExpectedSplitInfo updateHisto(DHistogram dHistogram, double d, double d2, double d3) {
        dHistogram.init();
        int i = (int) (600.0d * d3);
        double[] dArr = new double[1000];
        double[] dArr2 = new double[1000];
        double[] dArr3 = new double[1000];
        int[] iArr = new int[1000];
        double d4 = ((600.0d * 0.25d) + (400.0d * 0.45d)) / 1000.0d;
        double d5 = 0.0d;
        double d6 = 0.0d;
        double d7 = 0.0d;
        int i2 = 0;
        while (i2 < 1000) {
            dArr[i2] = 1.0d;
            iArr[i2] = i2;
            dArr2[i2] = i2 < i ? Double.NaN : i2 / 100.0d;
            dArr3[i2] = i2 < 600 ? 0.25d : 0.45d;
            d5 += (d4 - dArr3[i2]) * (d4 - dArr3[i2]);
            if (i2 < 600) {
                d6 += (d - dArr3[i2]) * (d - dArr3[i2]);
            } else {
                d7 += (d2 - dArr3[i2]) * (d2 - dArr3[i2]);
            }
            i2++;
        }
        dHistogram.updateHisto(dArr, new double[1000], dArr2, dArr3, (double[]) null, iArr, 1000, 0);
        if (d3 == 1.0d) {
            return new ExpectedSplitInfo(DHistogram.NASplitDir.NAvsREST, d5, d7, d6);
        }
        return new ExpectedSplitInfo(DHistogram.NASplitDir.NALeft, seNonNA(dArr, dArr2, dArr3), d6, d7);
    }

    private static double seNonNA(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dArr3.length; i++) {
            if (!Double.isNaN(dArr2[i])) {
                d2 += dArr[i];
                d += dArr[i] * dArr3[i];
            }
        }
        double d3 = d / d2;
        double d4 = 0.0d;
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            if (!Double.isNaN(dArr2[i2])) {
                d4 += dArr[i2] * (dArr3[i2] - d3) * (dArr3[i2] - d3);
            }
        }
        return d4;
    }
}
