package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.schemas.TreeV3;
import hex.tree.SharedTreeModel;
import hex.tree.TreeHandler;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayDeque;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/TreeHandlerTest.class */
public class TreeHandlerTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testSharedTreeSubgraphConversion() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._response_column = "Dest";
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = gBMModel._output;
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._ntrees);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeys.length);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeysAux.length);
            SharedTreeSubgraph sharedTreeSubgraph = sharedTreeOutput._treeKeys[0][0].get().toSharedTreeSubgraph(sharedTreeOutput._treeKeysAux[0][0].get(), sharedTreeOutput._names, sharedTreeOutput._domains);
            Assert.assertNotNull(sharedTreeSubgraph);
            TreeHandler.TreeProperties convertSharedTreeSubgraph = TreeHandler.convertSharedTreeSubgraph(sharedTreeSubgraph);
            Assert.assertNotNull(convertSharedTreeSubgraph);
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), convertSharedTreeSubgraph._descriptions.length);
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), convertSharedTreeSubgraph._thresholds.length);
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), convertSharedTreeSubgraph._features.length);
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), convertSharedTreeSubgraph._nas.length);
            int[] iArr = convertSharedTreeSubgraph._leftChildren;
            int[] iArr2 = convertSharedTreeSubgraph._rightChildren;
            Assert.assertEquals(iArr.length, iArr2.length);
            SharedTreeNode sharedTreeNode = sharedTreeSubgraph.rootNode;
            ArrayDeque arrayDeque = new ArrayDeque();
            arrayDeque.push(sharedTreeNode);
            int i = 0;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                SharedTreeNode sharedTreeNode2 = (SharedTreeNode) arrayDeque.pollLast();
                SharedTreeNode leftChild = sharedTreeNode2.getLeftChild();
                SharedTreeNode rightChild = sharedTreeNode2.getRightChild();
                if (iArr[i2] != -1) {
                    Assert.assertEquals(iArr[i2], leftChild.getNodeNumber());
                    arrayDeque.push(sharedTreeNode2.getLeftChild());
                    i++;
                }
                if (iArr2[i2] != -1) {
                    Assert.assertEquals(iArr2[i2], rightChild.getNodeNumber());
                    arrayDeque.push(sharedTreeNode2.getRightChild());
                    i++;
                }
            }
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), i + 1);
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testSharedTreeSubgraphConversion_inclusiveLevelsIris() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/iris/iris2.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._response_column = "response";
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = gBMModel._output;
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._ntrees);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeys.length);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeysAux.length);
            SharedTreeSubgraph sharedTreeSubgraph = sharedTreeOutput._treeKeys[0][0].get().toSharedTreeSubgraph(sharedTreeOutput._treeKeysAux[0][0].get(), sharedTreeOutput._names, sharedTreeOutput._domains);
            Assert.assertNotNull(sharedTreeSubgraph);
            TreeHandler.TreeProperties convertSharedTreeSubgraph = TreeHandler.convertSharedTreeSubgraph(sharedTreeSubgraph);
            Assert.assertNotNull(convertSharedTreeSubgraph);
            String[] strArr = convertSharedTreeSubgraph._descriptions;
            Assert.assertEquals(sharedTreeSubgraph.nodesArray.size(), strArr.length);
            for (String str : strArr) {
                Assert.assertFalse(str.isEmpty());
            }
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationMultinomial() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Keyed keyed2 = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/iris/iris2.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._response_column = "response";
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 treeV3 = new TreeV3();
            treeV3.model = new KeyV3.ModelKeyV3(Key.make());
            boolean z = false;
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e) {
                Assert.assertTrue(e.getMessage().contains("Given model does not exist"));
                z = true;
            }
            Assert.assertTrue(z);
            boolean z2 = false;
            keyed2 = new GLMModel(Key.make(), new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial), (GLM) null, (double[]) null, 1.0d, 1.0d, 1L);
            DKV.put(keyed2);
            treeV3.model = new KeyV3.ModelKeyV3(((GLMModel) keyed2)._key);
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e2) {
                Assert.assertTrue(e2.getMessage().contains("Given model is not tree-based."));
                z2 = true;
            }
            Assert.assertTrue(z2);
            boolean z3 = false;
            treeV3.tree_number = 1;
            treeV3.tree_class = "Iris-setosa";
            treeV3.model = new KeyV3.ModelKeyV3(gBMModel._key);
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e3) {
                Assert.assertTrue(e3.getMessage().contains("There is no such tree number."));
                z3 = true;
            }
            Assert.assertTrue(z3);
            boolean z4 = false;
            treeV3.tree_number = 0;
            treeV3.tree_class = "NonExistingCategoricalLevel";
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e4) {
                Assert.assertTrue(e4.getMessage().contains("There is no such tree class. Given categorical level does not exist in response column: NonExistingCategoricalLevel"));
                z4 = true;
            }
            Assert.assertTrue(z4);
            boolean z5 = false;
            treeV3.tree_number = -1;
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e5) {
                Assert.assertTrue(e5.getMessage().contains("Tree number must be greater than 0."));
                z5 = true;
            }
            Assert.assertTrue(z5);
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (keyed2 != null) {
                keyed2.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (keyed2 != null) {
                keyed2.remove();
            }
            throw th;
        }
    }

    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationRegression() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/iris/iris2.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMParameters._response_column = "Sepal.Length";
            TreeV3 treeV3 = new TreeV3();
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            treeV3.model = new KeyV3.ModelKeyV3(gBMModel._key);
            treeV3.tree_class = "NonExistingClass";
            boolean z = false;
            try {
                new TreeHandler().getTree(3, treeV3);
            } catch (IllegalArgumentException e) {
                Assert.assertTrue(e.getMessage().contains("There are no tree classes for regression."));
                z = true;
            }
            Assert.assertTrue(z);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationBinomial() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/testng/airlines_train.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMParameters._response_column = "IsDepDelayed";
            TreeV3 treeV3 = new TreeV3();
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            treeV3.model = new KeyV3.ModelKeyV3(gBMModel._key);
            treeV3.tree_class = "YES";
            TreeHandler treeHandler = new TreeHandler();
            boolean z = false;
            try {
                treeHandler.getTree(3, treeV3);
            } catch (IllegalArgumentException e) {
                Assert.assertTrue(e.getMessage().contains("For binomial, only one tree class has been built per each iteration: NO"));
                z = true;
            }
            Assert.assertTrue(z);
            treeV3.tree_class = "NO";
            Assert.assertNotNull(treeHandler.getTree(3, treeV3));
            treeV3.tree_class = "";
            Assert.assertNotNull(treeHandler.getTree(3, treeV3));
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNaHandling_airlines_train() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/testng/airlines_train.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMParameters._response_column = "IsDepDelayed";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            SharedTreeSubgraph sharedTreeSubgraph = gBMModel.getSharedTreeSubgraph(0, 0);
            TreeHandler.TreeProperties convertSharedTreeSubgraph = TreeHandler.convertSharedTreeSubgraph(sharedTreeSubgraph);
            Assert.assertNotNull(convertSharedTreeSubgraph);
            int checkNaPath = checkNaPath(sharedTreeSubgraph.rootNode, new String[0]);
            int i = 0;
            for (String str : convertSharedTreeSubgraph._nas) {
                if (str != null) {
                    i++;
                }
            }
            Assert.assertEquals(checkNaPath, i);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testNaHandling_cars() {
        Keyed keyed = null;
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            keyed = parse_test_file("./smalldata/junit/cars_nice_header.csv");
            DKV.put(keyed);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) keyed)._key;
            gBMParameters._ignored_columns = new String[]{"name", "economy", "displacement", "weight", "acceleration", "year"};
            gBMParameters._response_column = "power";
            gBMParameters._ntrees = 1;
            gBMParameters._seed = 0L;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = gBMModel._output;
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._ntrees);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeys.length);
            Assert.assertEquals(gBMParameters._ntrees, sharedTreeOutput._treeKeysAux.length);
            SharedTreeSubgraph sharedTreeSubgraph = gBMModel.getSharedTreeSubgraph(0, 0);
            Assert.assertNotNull(sharedTreeSubgraph);
            TreeHandler.TreeProperties convertSharedTreeSubgraph = TreeHandler.convertSharedTreeSubgraph(sharedTreeSubgraph);
            Assert.assertNotNull(convertSharedTreeSubgraph);
            int checkNaPath = checkNaPath(sharedTreeSubgraph.rootNode, new String[0]);
            int i = 0;
            for (String str : convertSharedTreeSubgraph._nas) {
                if (str != null) {
                    i++;
                }
            }
            Assert.assertEquals(checkNaPath, i);
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
        } catch (Throwable th) {
            if (keyed != null) {
                keyed.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            throw th;
        }
    }

    private int checkNaPath(SharedTreeNode sharedTreeNode, String[] strArr) {
        SharedTreeNode leftChild = sharedTreeNode.getLeftChild();
        SharedTreeNode rightChild = sharedTreeNode.getRightChild();
        if (leftChild == null && rightChild == null) {
            return 0;
        }
        boolean isInclusiveNa = rightChild != null ? rightChild.isInclusiveNa() : false;
        boolean isInclusiveNa2 = leftChild != null ? leftChild.isInclusiveNa() : false;
        if (!(isInclusiveNa2 ^ isInclusiveNa) && isInclusiveNa2) {
            Assert.fail("Parent node " + sharedTreeNode.getNodeNumber() + " includes NAs for both children");
        }
        boolean contains = ArrayUtils.contains(strArr, sharedTreeNode.getColName());
        if (contains && (isInclusiveNa || isInclusiveNa2)) {
            Assert.fail("Parent node " + sharedTreeNode.getNodeNumber() + " includes NAs for column " + sharedTreeNode.getColName());
        } else if (!contains && !(isInclusiveNa2 ^ isInclusiveNa)) {
            Assert.fail("Parent node " + sharedTreeNode.getNodeNumber() + " should set NA direction to one of its children");
        }
        String[] strArr2 = isInclusiveNa ? (String[]) ArrayUtils.add(strArr, sharedTreeNode.getColName()) : strArr;
        String[] strArr3 = isInclusiveNa2 ? (String[]) ArrayUtils.add(strArr, sharedTreeNode.getColName()) : strArr;
        int i = 0;
        if (isInclusiveNa ^ isInclusiveNa2) {
            i = 0 + 1;
        }
        if (leftChild != null) {
            i += checkNaPath(leftChild, strArr2);
        }
        if (rightChild != null) {
            i += checkNaPath(rightChild, strArr3);
        }
        return i;
    }
}
