package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.tree.RegTree;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.schemas.TreeV3;
import hex.tree.TreeHandler;
import hex.tree.xgboost.XGBoostModel;
import java.io.ByteArrayInputStream;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import water.DKV;
import water.ExtensionManager;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/xgboost/XGBoostTreeConverterTest.class */
public class XGBoostTreeConverterTest extends TestUtil {

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
        Assume.assumeTrue("XGBoost was not loaded!\nH2O XGBoost needs binary compatible environment;Make sure that you have correct libraries installedand correctly configured LD_LIBRARY_PATH, especiallymake sure that CUDA libraries are available if you are running on GPU!", ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME));
    }

    @Test
    public void convertXGBoostTree_weather() throws Exception {
        Frame frame = null;
        XGBoostModel xGBoostModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/junit/weather.csv");
            Scope.track(frame.replace(frame.find("PressureChange"), frame.vecs()[frame.find("PressureChange")].toCategoricalVec()));
            DKV.put(frame);
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 3;
            xGBoostParameters._train = frame._key;
            xGBoostParameters._response_column = "PressureChange";
            xGBoostParameters._reg_lambda = 0.0f;
            xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
            Assert.assertNotNull(new Predictor(new ByteArrayInputStream(xGBoostModel.model_info()._boosterBytes)).getBooster().getGroupedTrees()[0][0].getNodes());
            SharedTreeGraph convert = xGBoostModel.convert(0, "down");
            Assert.assertNotNull(convert);
            Assert.assertEquals(xGBoostParameters._ntrees, convert.subgraphArray.size());
            SharedTreeSubgraph sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            Assert.assertEquals(xGBoostParameters._max_depth, ((SharedTreeNode) sharedTreeSubgraph.nodesArray.get(sharedTreeSubgraph.nodesArray.size() - 1)).getDepth());
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void convertXGBoostTree_airlines() throws Exception {
        Frame frame = null;
        XGBoostModel xGBoostModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(frame.replace(frame.find("IsDepDelayed"), frame.vecs()[frame.find("IsDepDelayed")].toCategoricalVec()));
            DKV.put(frame);
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier", "Dest"};
            xGBoostParameters._train = frame._key;
            xGBoostParameters._response_column = "IsDepDelayed";
            xGBoostParameters._reg_lambda = 0.0f;
            xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
            Assert.assertNotNull(new Predictor(new ByteArrayInputStream(xGBoostModel.model_info()._boosterBytes)).getBooster().getGroupedTrees()[0][0].getNodes());
            SharedTreeGraph convert = xGBoostModel.convert(0, "NO");
            Assert.assertNotNull(convert);
            Assert.assertEquals(xGBoostParameters._ntrees, convert.subgraphArray.size());
            SharedTreeSubgraph sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            Assert.assertEquals(xGBoostParameters._max_depth, ((SharedTreeNode) sharedTreeSubgraph.nodesArray.get(sharedTreeSubgraph.nodesArray.size() - 1)).getDepth());
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void convertXGBoostTree_airlines_wrong_tree_class() throws Exception {
        this.expectedException.expect(IllegalArgumentException.class);
        this.expectedException.expectMessage("There should be no tree class specified for regression.");
        Keyed keyed = null;
        XGBoostModel xGBoostModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/testng/airlines_train.csv", "NA", 1, new byte[]{4, 2, 2, 2, 2, 4, 4, 4, 3});
            DKV.put(keyed);
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier", "Dest"};
            xGBoostParameters._train = ((Frame) keyed)._key;
            xGBoostParameters._response_column = "Distance";
            xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
            RegTree regTree = new Predictor(new ByteArrayInputStream(xGBoostModel.model_info()._boosterBytes)).getBooster().getGroupedTrees()[0][0];
            SharedTreeGraph convert = xGBoostModel.convert(0, "NO");
            Assert.assertNotNull(convert);
            Assert.assertEquals(xGBoostParameters._ntrees, convert.subgraphArray.size());
            SharedTreeSubgraph sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            Assert.assertEquals(xGBoostParameters._max_depth, ((SharedTreeNode) sharedTreeSubgraph.nodesArray.get(sharedTreeSubgraph.nodesArray.size() - 1)).getDepth());
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testXGBoostBinomialClass_noTreeClassSpecified() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier"};
            xGBoostParameters._train = parse_test_file._key;
            xGBoostParameters._response_column = "IsDepDelayed";
            xGBoostParameters._reg_lambda = 0.0f;
            XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
            Scope.track_generic(xGBoostModel);
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 treeV3 = new TreeV3();
            treeV3.model = new KeyV3.ModelKeyV3(xGBoostModel._key);
            Assert.assertNotNull(treeHandler.getTree(3, treeV3));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testXGBoostRegression_noTreeClassSpecified() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier"};
            xGBoostParameters._train = parse_test_file._key;
            xGBoostParameters._response_column = "Distance";
            xGBoostParameters._reg_lambda = 0.0f;
            XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
            Scope.track_generic(xGBoostModel);
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 treeV3 = new TreeV3();
            treeV3.model = new KeyV3.ModelKeyV3(xGBoostModel._key);
            Assert.assertNotNull(treeHandler.getTree(3, treeV3));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testXGBoostMultinomial_noTreeClassSpecified() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier"};
            xGBoostParameters._train = parse_test_file._key;
            xGBoostParameters._response_column = "Origin";
            xGBoostParameters._reg_lambda = 0.0f;
            XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
            Scope.track_generic(xGBoostModel);
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 treeV3 = new TreeV3();
            treeV3.model = new KeyV3.ModelKeyV3(xGBoostModel._key);
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Model category 'Multinomial' requires tree class to be specified.");
            Assert.assertNotNull(treeHandler.getTree(3, treeV3));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
