/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.utils.DistributionFamily;
import hex.tree.Constraints;
import hex.tree.DHistogram;
import hex.tree.DTree;
import hex.tree.SharedTreeModel;
import org.junit.Assert;
import org.junit.Test;

public class DTreeTest {
    @Test
    public void testFindBestSplitPoint_pubdev6495() {
        double[] ws = new double[]{10.0, 10.0, 100.0, 1.0};
        double[] cs = new double[]{Double.NaN, Double.NaN, 0.5, 1.5};
        double[] ys = new double[]{0.0, 1.0, 0.0, 1.0};
        int[] rows = new int[]{0, 1, 2, 3};
        DHistogram hs = new DHistogram("test_hs", 2, 2, 0, 0.0, 2.0, 0.01, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, null, null);
        hs.init();
        hs.updateHisto(ws, null, cs, ys, rows, rows.length, 0);
        DTree.Split s1 = DTree.findBestSplitPoint((DHistogram)hs, (int)0, (double)20.0, (int)0, (double)Double.NaN, (double)Double.NaN, (boolean)false);
        Assert.assertNotNull((Object)s1);
        DTree.Split s2 = DTree.findBestSplitPoint((DHistogram)hs, (int)0, (double)21.0, (int)0, (double)Double.NaN, (double)Double.NaN, (boolean)false);
        Assert.assertNull((Object)s2);
        DHistogram hsN = new DHistogram("test_hs", 2, 2, 0, 0.0, 2.0, -9.0, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, null, null);
        hsN.init();
        hsN.updateHisto(ws, null, cs, ys, rows, rows.length, 0);
        DTree.Split s3 = DTree.findBestSplitPoint((DHistogram)hsN, (int)0, (double)21.0, (int)0, (double)Double.NaN, (double)Double.NaN, (boolean)false);
        Assert.assertNotNull((Object)s3);
        Assert.assertEquals((double)s3._se, (double)DTreeTest.seNonNA(ws, cs, ys), (double)0.0);
        Assert.assertTrue((s3._se < s1._se ? 1 : 0) != 0);
    }

    @Test
    public void testFindBestSplitPoint_bounds_NAvsREST() {
        double min_pred = 0.3;
        double max_pred = 0.4;
        DHistogram hs = DTreeTest.makeHisto(1, 0.3, 0.4);
        ExpectedSplitInfo expectedSplitInfo = DTreeTest.updateHisto(hs, 0.3, 0.4, 1.0);
        DTree.Split split = DTree.findBestSplitPoint((DHistogram)hs, (int)0, (double)0.0, (int)0, (double)0.3, (double)0.4, (boolean)true);
        expectedSplitInfo.checkSplit(split);
    }

    @Test
    public void testFindBestSplitPoint_bounds_NAs() {
        double min_pred = 0.3;
        double max_pred = 0.4;
        DHistogram hs = DTreeTest.makeHisto(100, 0.3, 0.4);
        ExpectedSplitInfo expectedSplitInfo = DTreeTest.updateHisto(hs, 0.3, 0.4, 0.1);
        DTree.Split split = DTree.findBestSplitPoint((DHistogram)hs, (int)0, (double)100.0, (int)0, (double)0.3, (double)0.4, (boolean)true);
        expectedSplitInfo.checkSplit(split);
    }

    private static DHistogram makeHisto(int nbins, double min_pred, double max_pred) {
        Constraints c = new Constraints(new int[]{1}, DistributionFamily.gaussian, true).withNewConstraint(0, 1, min_pred).withNewConstraint(0, 0, max_pred);
        Assert.assertEquals((double)min_pred, (double)c._min, (double)0.0);
        Assert.assertEquals((double)max_pred, (double)c._max, (double)0.0);
        return new DHistogram("test_hs", nbins, 2, 0, 0.0, 10.0, 0.01, SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123L, null, c);
    }

    private static ExpectedSplitInfo updateHisto(DHistogram hs, double min_pred, double max_pred, double na_percent) {
        hs.init();
        int N = 1000;
        int S = 600;
        int NA = (int)(600.0 * na_percent);
        double[] ws = new double[1000];
        double[] cs = new double[1000];
        double[] ys = new double[1000];
        int[] rows = new int[1000];
        double p0 = 0.25;
        double p1 = 0.45;
        double ys_mean = (600.0 * p0 + 400.0 * p1) / 1000.0;
        double se = 0.0;
        double se_min_pred = 0.0;
        double se_max_pred = 0.0;
        for (int i = 0; i < 1000; ++i) {
            ws[i] = 1.0;
            rows[i] = i;
            cs[i] = i < NA ? Double.NaN : (double)i / 100.0;
            ys[i] = i < 600 ? p0 : p1;
            se += (ys_mean - ys[i]) * (ys_mean - ys[i]);
            if (i < 600) {
                se_min_pred += (min_pred - ys[i]) * (min_pred - ys[i]);
                continue;
            }
            se_max_pred += (max_pred - ys[i]) * (max_pred - ys[i]);
        }
        hs.updateHisto(ws, null, cs, ys, rows, 1000, 0);
        if (na_percent == 1.0) {
            return new ExpectedSplitInfo(DHistogram.NASplitDir.NAvsREST, se, se_max_pred, se_min_pred);
        }
        double nna_se = DTreeTest.seNonNA(ws, cs, ys);
        return new ExpectedSplitInfo(DHistogram.NASplitDir.NALeft, nna_se, se_min_pred, se_max_pred);
    }

    private static double seNonNA(double[] ws, double[] cs, double[] ys) {
        double nna_y_sum = 0.0;
        double nna_y_weight = 0.0;
        for (int i = 0; i < ys.length; ++i) {
            if (Double.isNaN(cs[i])) continue;
            nna_y_weight += ws[i];
            nna_y_sum += ws[i] * ys[i];
        }
        double nna_y_mean = nna_y_sum / nna_y_weight;
        double nna_se = 0.0;
        for (int i = 0; i < ys.length; ++i) {
            if (Double.isNaN(cs[i])) continue;
            nna_se += ws[i] * (ys[i] - nna_y_mean) * (ys[i] - nna_y_mean);
        }
        return nna_se;
    }

    private static class ExpectedSplitInfo {
        DHistogram.NASplitDir _nasplit;
        double _se;
        double _se0;
        double _se1;

        ExpectedSplitInfo(DHistogram.NASplitDir nasplit, double se, double se0, double se1) {
            this._nasplit = nasplit;
            this._se = se;
            this._se0 = se0;
            this._se1 = se1;
        }

        void checkSplit(DTree.Split split) {
            Assert.assertNotNull((Object)split);
            Assert.assertEquals((Object)this._nasplit, (Object)split._nasplit);
            Assert.assertEquals((double)this._se, (double)split._se, (double)1.0E-8);
            Assert.assertEquals((double)this._se0, (double)split._se0, (double)1.0E-8);
            Assert.assertEquals((double)this._se1, (double)split._se1, (double)1.0E-8);
        }
    }
}

