package hex.tree.xgboost;

import hex.KeyValue;
import hex.SplitFrame;
import hex.tree.xgboost.XGBoostModel;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.H2O;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.runner.CloudSize;
import water.runner.H2ORunner;

@CloudSize(1)
@RunWith(H2ORunner.class)
/* loaded from: input_file:hex/tree/xgboost/XGBoostMultiNodeTest.class */
public class XGBoostMultiNodeTest extends TestUtil {
    private Frame[] loadData() {
        SplitFrame splitFrame = new SplitFrame(Scope.track(new Frame[]{parse_test_file("smalldata/demos/bank-additional-full.csv")}), new double[]{0.7d, 0.3d}, (Key[]) null);
        splitFrame.exec().get();
        return new Frame[]{Scope.track(new Frame[]{(Frame) splitFrame._destination_frames[0].get()}), Scope.track(new Frame[]{(Frame) splitFrame._destination_frames[1].get()})};
    }

    private XGBoostModel.XGBoostParameters makeParms() {
        Frame[] loadData = loadData();
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        xGBoostParameters._train = loadData[0]._key;
        xGBoostParameters._valid = loadData[1]._key;
        xGBoostParameters._response_column = "y";
        xGBoostParameters._ntrees = 100;
        xGBoostParameters._max_depth = 3;
        xGBoostParameters._seed = -889275714L;
        return xGBoostParameters;
    }

    @Test
    public void shouldBuildExactOnSingleNode() {
        Assume.assumeTrue(H2O.getCloudSize() == 1);
        Scope.enter();
        try {
            XGBoostModel.XGBoostParameters makeParms = makeParms();
            makeParms._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.exact;
            makeParms._build_tree_one_node = true;
            XGBoostModel xGBoostModel = new XGBoost(makeParms).trainModel().get();
            Scope.track_generic(xGBoostModel);
            int i = 0;
            while (!xGBoostModel._output._native_parameters.get(i, 0).equals("tree_method")) {
                i++;
            }
            Assert.assertEquals("exact", xGBoostModel._output._native_parameters.get(i, 1));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void shouldFailWithExact() {
        Assume.assumeTrue(H2O.getCloudSize() > 1);
        Scope.enter();
        try {
            XGBoostModel.XGBoostParameters makeParms = makeParms();
            makeParms._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.exact;
            Exception exc = null;
            try {
                Scope.track_generic(new XGBoost(makeParms).trainModel().get());
            } catch (Exception e) {
                exc = e;
            }
            Assert.assertNotNull("Expected exception, but none was thrown", exc);
            TestCase.assertTrue("Unexpected exception" + exc.getMessage(), exc.getMessage().contains("exact is not supported in distributed environment"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void shouldFailWithMonotoneApprox() {
        Assume.assumeTrue(H2O.getCloudSize() > 1);
        Scope.enter();
        try {
            XGBoostModel.XGBoostParameters makeParms = makeParms();
            makeParms._monotone_constraints = new KeyValue[]{new KeyValue("duration", -1.0d), new KeyValue("age", -1.0d)};
            makeParms._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.approx;
            Exception exc = null;
            try {
                Scope.track_generic(new XGBoost(makeParms).trainModel().get());
            } catch (Exception e) {
                exc = e;
            }
            Assert.assertNotNull("Expected exception, but none was thrown", exc);
            TestCase.assertTrue("Unexpected exception" + exc.getMessage(), exc.getMessage().contains("approx is not supported with _monotone_constraints"));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void shouldUseHistWithMonotoneAuto() {
        Assume.assumeTrue(H2O.getCloudSize() > 1);
        Scope.enter();
        try {
            XGBoostModel.XGBoostParameters makeParms = makeParms();
            makeParms._monotone_constraints = new KeyValue[]{new KeyValue("duration", -1.0d), new KeyValue("age", -1.0d)};
            makeParms._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.auto;
            XGBoostModel xGBoostModel = new XGBoost(makeParms).trainModel().get();
            Scope.track_generic(xGBoostModel);
            int i = 0;
            while (!xGBoostModel._output._native_parameters.get(i, 0).equals("tree_method")) {
                i++;
            }
            Assert.assertEquals("hist", xGBoostModel._output._native_parameters.get(i, 1));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
