/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.schemas.TreeV3;
import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.TreeHandler;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import hex.tree.isofor.IsolationForest;
import hex.tree.isofor.IsolationForestModel;
import java.util.ArrayDeque;
import java.util.regex.Pattern;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;

public class TreeHandlerTest
extends TestUtil {
    @BeforeClass
    public static void stall() {
        TreeHandlerTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSharedTreeSubgraphConversion() {
        Frame tfr = null;
        GBMModel model = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/airlines/allyears2k_headers.zip");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "Dest";
            parms._ntrees = 1;
            parms._seed = 0L;
            model = (GBMModel)new GBM(parms).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = (SharedTreeModel.SharedTreeOutput)model._output;
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._ntrees);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeys.length);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeysAux.length);
            boolean treeIndex = false;
            boolean treeClass = false;
            CompressedTree auxCompressedTree = (CompressedTree)sharedTreeOutput._treeKeysAux[0][0].get();
            SharedTreeSubgraph sharedTreeSubgraph = ((CompressedTree)sharedTreeOutput._treeKeys[0][0].get()).toSharedTreeSubgraph(auxCompressedTree, sharedTreeOutput._names, sharedTreeOutput._domains);
            Assert.assertNotNull((Object)sharedTreeSubgraph);
            TreeHandler.TreeProperties treeProperties = TreeHandler.convertSharedTreeSubgraph((SharedTreeSubgraph)sharedTreeSubgraph);
            Assert.assertNotNull((Object)treeProperties);
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)treeProperties._descriptions.length);
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)treeProperties._thresholds.length);
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)treeProperties._features.length);
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)treeProperties._nas.length);
            Pattern rootNodeSplitColPattern = Pattern.compile(".* and splits on column '.+'.*");
            Assert.assertTrue((boolean)rootNodeSplitColPattern.matcher(treeProperties._descriptions[0]).matches());
            int[] leftChildren = treeProperties._leftChildren;
            int[] rightChildren = treeProperties._rightChildren;
            Assert.assertEquals((long)leftChildren.length, (long)rightChildren.length);
            SharedTreeNode rootNode = sharedTreeSubgraph.rootNode;
            ArrayDeque<SharedTreeNode> discoverednodes = new ArrayDeque<SharedTreeNode>();
            discoverednodes.push(rootNode);
            int nonRootNodesFound = 0;
            for (int i = 0; i < leftChildren.length; ++i) {
                SharedTreeNode sharedTreeNode = (SharedTreeNode)discoverednodes.pollLast();
                SharedTreeNode leftChild = sharedTreeNode.getLeftChild();
                SharedTreeNode rightChild = sharedTreeNode.getRightChild();
                if (leftChildren[i] != -1) {
                    Assert.assertEquals((long)leftChildren[i], (long)leftChild.getNodeNumber());
                    discoverednodes.push(sharedTreeNode.getLeftChild());
                    ++nonRootNodesFound;
                }
                if (rightChildren[i] == -1) continue;
                Assert.assertEquals((long)rightChildren[i], (long)rightChild.getNodeNumber());
                discoverednodes.push(sharedTreeNode.getRightChild());
                ++nonRootNodesFound;
            }
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)(nonRootNodesFound + 1));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSharedTreeSubgraphConversion_inclusiveLevelsIris() {
        Frame tfr = null;
        GBMModel model = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/iris/iris2.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "response";
            parms._ntrees = 1;
            parms._seed = 0L;
            model = (GBMModel)new GBM(parms).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = (SharedTreeModel.SharedTreeOutput)model._output;
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._ntrees);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeys.length);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeysAux.length);
            boolean treeNumber = false;
            boolean treeClass = false;
            CompressedTree auxCompressedTree = (CompressedTree)sharedTreeOutput._treeKeysAux[0][0].get();
            SharedTreeSubgraph sharedTreeSubgraph = ((CompressedTree)sharedTreeOutput._treeKeys[0][0].get()).toSharedTreeSubgraph(auxCompressedTree, sharedTreeOutput._names, sharedTreeOutput._domains);
            Assert.assertNotNull((Object)sharedTreeSubgraph);
            TreeHandler.TreeProperties treeProperties = TreeHandler.convertSharedTreeSubgraph((SharedTreeSubgraph)sharedTreeSubgraph);
            Assert.assertNotNull((Object)treeProperties);
            String[] nodeDescriptions = treeProperties._descriptions;
            Assert.assertEquals((long)sharedTreeSubgraph.nodesArray.size(), (long)nodeDescriptions.length);
            for (String nodeDescription : nodeDescriptions) {
                Assert.assertFalse((boolean)nodeDescription.isEmpty());
            }
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationMultinomial() {
        Frame tfr = null;
        GBMModel model = null;
        GLMModel nonTreeBasedModel = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/iris/iris2.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "response";
            parms._ntrees = 1;
            parms._seed = 0L;
            model = (GBMModel)new GBM(parms).trainModel().get();
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 args = new TreeV3();
            args.model = new KeyV3.ModelKeyV3(Key.make());
            boolean exceptionThrown = false;
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("Given model does not exist"));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
            exceptionThrown = false;
            nonTreeBasedModel = new GLMModel(Key.make(), new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial), null, null, 1.0, 1.0, 1L);
            DKV.put((Keyed)nonTreeBasedModel);
            args.model = new KeyV3.ModelKeyV3(nonTreeBasedModel._key);
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("Given model is not tree-based."));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
            exceptionThrown = false;
            args.tree_number = 1;
            args.tree_class = tfr.vec(parms._response_column).domain()[0];
            args.model = new KeyV3.ModelKeyV3(model._key);
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertEquals((Object)"Invalid tree index: 1. Tree index must be in range [0, 0].", (Object)e.getMessage());
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
            exceptionThrown = false;
            args.tree_number = 0;
            args.tree_class = "NonExistingCategoricalLevel";
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("There is no such tree class. Given categorical level does not exist in response column: NonExistingCategoricalLevel"));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
            exceptionThrown = false;
            args.tree_number = -1;
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("Invalid tree number: " + args.tree_number + ". Tree number must be >= 0."));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (nonTreeBasedModel != null) {
                nonTreeBasedModel.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationRegression() {
        Frame tfr = null;
        GBMModel regressionModel = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/iris/iris2.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._ntrees = 1;
            parms._seed = 0L;
            parms._response_column = "Sepal.Length";
            TreeV3 args = new TreeV3();
            regressionModel = (GBMModel)new GBM(parms).trainModel().get();
            args.model = new KeyV3.ModelKeyV3(regressionModel._key);
            args.tree_class = "NonExistingClass";
            TreeHandler treeHandler = new TreeHandler();
            boolean exceptionThrown = false;
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("There are no tree classes for Regression."));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (regressionModel != null) {
                regressionModel.remove();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSharedTreeSubgraphConversion_argumentValidationBinomial() {
        Frame tfr = null;
        GBMModel model = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/testng/airlines_train.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._ntrees = 1;
            parms._seed = 0L;
            parms._response_column = "IsDepDelayed";
            TreeV3 args = new TreeV3();
            model = (GBMModel)new GBM(parms).trainModel().get();
            args.model = new KeyV3.ModelKeyV3(model._key);
            args.tree_class = "YES";
            TreeHandler treeHandler = new TreeHandler();
            boolean exceptionThrown = false;
            try {
                treeHandler.getTree(3, args);
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((boolean)e.getMessage().contains("For binomial, only one tree class has been built per each iteration: NO"));
                exceptionThrown = true;
            }
            Assert.assertTrue((boolean)exceptionThrown);
            args.tree_class = "NO";
            TreeV3 correctlySpecifiedClassTree = treeHandler.getTree(3, args);
            Assert.assertNotNull((Object)correctlySpecifiedClassTree);
            args.tree_class = "";
            TreeV3 noClassTree = treeHandler.getTree(3, args);
            Assert.assertNotNull((Object)noClassTree);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNaHandling_airlines_train() {
        Frame tfr = null;
        GBMModel model = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/testng/airlines_train.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._ntrees = 1;
            parms._seed = 0L;
            parms._response_column = "IsDepDelayed";
            model = (GBMModel)new GBM(parms).trainModel().get();
            SharedTreeSubgraph sharedTreeSubgraph = model.getSharedTreeSubgraph(0, 0);
            TreeHandler.TreeProperties treeProperties = TreeHandler.convertSharedTreeSubgraph((SharedTreeSubgraph)sharedTreeSubgraph);
            Assert.assertNotNull((Object)treeProperties);
            SharedTreeNode rootNode = sharedTreeSubgraph.rootNode;
            String[] noExcludedSplits = new String[]{};
            int naSplits = this.checkNaPath(rootNode, noExcludedSplits);
            int nonNullNaSplits = 0;
            for (String naSplitDescription : treeProperties._nas) {
                if (naSplitDescription == null) continue;
                ++nonNullNaSplits;
            }
            Assert.assertEquals((long)naSplits, (long)nonNullNaSplits);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNaHandling_cars() {
        Frame tfr = null;
        GBMModel model = null;
        Scope.enter();
        try {
            tfr = TreeHandlerTest.parse_test_file((String)"./smalldata/junit/cars_nice_header.csv");
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._ignored_columns = new String[]{"name", "economy", "displacement", "weight", "acceleration", "year"};
            parms._response_column = "power";
            parms._ntrees = 1;
            parms._seed = 0L;
            model = (GBMModel)new GBM(parms).trainModel().get();
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = (SharedTreeModel.SharedTreeOutput)model._output;
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._ntrees);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeys.length);
            Assert.assertEquals((long)parms._ntrees, (long)sharedTreeOutput._treeKeysAux.length);
            SharedTreeSubgraph sharedTreeSubgraph = model.getSharedTreeSubgraph(0, 0);
            Assert.assertNotNull((Object)sharedTreeSubgraph);
            TreeHandler.TreeProperties treeProperties = TreeHandler.convertSharedTreeSubgraph((SharedTreeSubgraph)sharedTreeSubgraph);
            Assert.assertNotNull((Object)treeProperties);
            SharedTreeNode rootNode = sharedTreeSubgraph.rootNode;
            String[] noExcludedSplits = new String[]{};
            int naSplits = this.checkNaPath(rootNode, noExcludedSplits);
            int nonNullNaSplits = 0;
            for (String naSplitDescription : treeProperties._nas) {
                if (naSplitDescription == null) continue;
                ++nonNullNaSplits;
            }
            Assert.assertEquals((long)naSplits, (long)nonNullNaSplits);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (model != null) {
                model.remove();
            }
        }
    }

    private int checkNaPath(SharedTreeNode parentNode, String[] previousNaSplits) {
        boolean splitsOnExcluded;
        boolean isLeftInclusive;
        SharedTreeNode leftChild = parentNode.getLeftChild();
        SharedTreeNode rightChild = parentNode.getRightChild();
        if (leftChild == null && rightChild == null) {
            return 0;
        }
        boolean isRightInclusive = rightChild != null ? rightChild.isInclusiveNa() : false;
        boolean bl = isLeftInclusive = leftChild != null ? leftChild.isInclusiveNa() : false;
        if (!(isLeftInclusive ^ isRightInclusive) && isLeftInclusive) {
            Assert.fail((String)("Parent node " + parentNode.getNodeNumber() + " includes NAs for both children"));
        }
        if ((splitsOnExcluded = ArrayUtils.contains((Object[])previousNaSplits, (Object)parentNode.getColName())) && (isRightInclusive || isLeftInclusive)) {
            Assert.fail((String)("Parent node " + parentNode.getNodeNumber() + " includes NAs for column " + parentNode.getColName()));
        } else if (!splitsOnExcluded && !(isLeftInclusive ^ isRightInclusive)) {
            Assert.fail((String)("Parent node " + parentNode.getNodeNumber() + " should set NA direction to one of its children"));
        }
        String[] leftPreviousSplits = isRightInclusive ? (String[])ArrayUtils.add((Object[])previousNaSplits, (Object)parentNode.getColName()) : previousNaSplits;
        String[] rightPreviousSplits = isLeftInclusive ? (String[])ArrayUtils.add((Object[])previousNaSplits, (Object)parentNode.getColName()) : previousNaSplits;
        int naSplits = 0;
        if (isRightInclusive ^ isLeftInclusive) {
            ++naSplits;
        }
        if (leftChild != null) {
            naSplits += this.checkNaPath(leftChild, leftPreviousSplits);
        }
        if (rightChild != null) {
            naSplits += this.checkNaPath(rightChild, rightPreviousSplits);
        }
        return naSplits;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testEmptyInheritedCategoricalLevels() {
        try {
            Scope.enter();
            Frame trainingFrame = TreeHandlerTest.parse_test_file((String)"./smalldata/testng/airlines_train.csv");
            Scope.track_generic((Keyed)trainingFrame);
            IsolationForestModel.IsolationForestParameters parms = new IsolationForestModel.IsolationForestParameters();
            parms._train = trainingFrame._key;
            parms._distribution = DistributionFamily.AUTO;
            parms._response_column = "IsDepDelayed";
            parms._ntrees = 10;
            parms._max_depth = 10;
            parms._seed = 65261L;
            IsolationForest job = new IsolationForest(parms);
            IsolationForestModel model = (IsolationForestModel)job.trainModel().get();
            Scope.track_generic((Keyed)model);
            TreeHandler treeHandler = new TreeHandler();
            TreeV3 arguments = new TreeV3();
            arguments.model = new KeyV3.ModelKeyV3(model._key);
            int i = 0;
            while (i < parms._ntrees) {
                arguments.tree_number = i++;
                TreeV3 tree = treeHandler.getTree(3, arguments);
                Assert.assertNotNull((Object)tree);
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }
}

