/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.tree.CompressedTree;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;

public class CompressedTreeTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        CompressedTreeTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testToSharedTreeSubgraph() throws IOException {
        int ntrees = 5;
        try {
            Scope.enter();
            GBMModel model = this.trainGbm(ntrees);
            GbmMojoModel mojo = (GbmMojoModel)model.toMojo();
            SharedTreeGraph expectedGraph = mojo.computeGraph(-1);
            Assert.assertEquals((long)5L, (long)expectedGraph.subgraphArray.size());
            for (int i = 0; i < ntrees; ++i) {
                CompressedTree tree = (CompressedTree)((GBMModel.GBMOutput)model._output)._treeKeys[i][0].get();
                Assert.assertNotNull((Object)tree);
                CompressedTree auxTreeInfo = (CompressedTree)((GBMModel.GBMOutput)model._output)._treeKeysAux[i][0].get();
                SharedTreeSubgraph sg = tree.toSharedTreeSubgraph(auxTreeInfo, ((GBMModel.GBMOutput)model._output)._names, ((GBMModel.GBMOutput)model._output)._domains);
                Assert.assertEquals(expectedGraph.subgraphArray.get(i), (Object)sg);
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testNodeIdAssignment() throws IOException {
        int ntrees = 5;
        try {
            Scope.enter();
            GBMModel model = this.trainGbm(5);
            GbmMojoModel mojo = (GbmMojoModel)model.toMojo();
            SharedTreeGraph expectedGraph = mojo.computeGraph(-1);
            Assert.assertEquals((long)5L, (long)expectedGraph.subgraphArray.size());
            double[][] data = CompressedTreeTest.frameToMatrix(CompressedTreeTest.getAdaptedTrainFrame(model));
            for (int i = 0; i < 5; ++i) {
                CompressedTree tree = (CompressedTree)((GBMModel.GBMOutput)model._output)._treeKeys[i][0].get();
                CompressedTree auxTreeInfo = (CompressedTree)((GBMModel.GBMOutput)model._output)._treeKeysAux[i][0].get();
                Assert.assertNotNull((Object)tree);
                Assert.assertNotNull((Object)auxTreeInfo);
                SharedTreeSubgraph sg = tree.toSharedTreeSubgraph(auxTreeInfo, ((GBMModel.GBMOutput)model._output)._names, ((GBMModel.GBMOutput)model._output)._domains);
                for (double[] row : data) {
                    double leafAssignment = SharedTreeMojoModel.scoreTree((byte[])tree._bits, (double[])row, (boolean)true, (String[][])((GBMModel.GBMOutput)model._output)._domains);
                    String nodePath = SharedTreeMojoModel.getDecisionPath((double)leafAssignment);
                    int nodeId = SharedTreeMojoModel.getLeafNodeId((double)leafAssignment, (byte[])auxTreeInfo._bits);
                    SharedTreeNode n = sg.rootNode;
                    for (int j = 0; j < nodePath.length(); ++j) {
                        n = nodePath.charAt(j) == 'L' ? n.getLeftChild() : n.getRightChild();
                    }
                    Assert.assertNull((Object)n.getLeftChild());
                    Assert.assertNull((Object)n.getRightChild());
                    Assert.assertEquals((String)("Path " + nodePath + " in tree #" + i), (long)n.getNodeNumber(), (long)nodeId);
                }
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    @Test
    public void testMakeTreeKey() {
        try {
            Scope.enter();
            CompressedTree ct = new CompressedTree(new byte[0], 123L, 42, 17);
            Scope.track_generic((Keyed)ct);
            DKV.put((Keyed)ct);
            CompressedTree.TreeCoords tc = ct.getTreeCoords();
            Assert.assertEquals((long)42L, (long)tc._treeId);
            Assert.assertEquals((long)17L, (long)tc._clazz);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private static double[][] frameToMatrix(Frame f) {
        double[][] rows = new double[(int)f.numRows()][];
        for (int r = 0; r < rows.length; ++r) {
            rows[r] = new double[f.numCols()];
        }
        for (int c = 0; c < f.numCols(); ++c) {
            Vec.Reader vecReader = new Vec.Reader(f.vec(c));
            for (int r = 0; r < rows.length; ++r) {
                rows[r][c] = vecReader.at((long)r);
            }
        }
        return rows;
    }

    private static Frame getAdaptedTrainFrame(GBMModel m) {
        Frame f = (Frame)((GBMModel.GBMParameters)m._parms)._train.get();
        String[] warns = m.adaptTestForTrain(f, false, false);
        assert (warns == null || warns.length == 0);
        return f;
    }

    private GBMModel trainGbm(int ntrees) {
        Frame f = Scope.track((Frame[])new Frame[]{CompressedTreeTest.parse_test_file((String)"smalldata/logreg/prostate.csv")});
        String response = "CAPSULE";
        f.replace(f.find("CAPSULE"), f.vec("CAPSULE").toCategoricalVec()).remove();
        DKV.put((Key)f._key, (Iced)f);
        GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
        gbmParams._seed = 123L;
        gbmParams._train = f._key;
        gbmParams._ignored_columns = new String[]{"ID"};
        gbmParams._response_column = "CAPSULE";
        gbmParams._ntrees = ntrees;
        gbmParams._score_each_iteration = true;
        return (GBMModel)Scope.track_generic((Keyed)new GBM(gbmParams).trainModel().get());
    }
}

