package hex.tree;

import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.tree.CompressedTree;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/CompressedTreeTest.class */
public class CompressedTreeTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testToSharedTreeSubgraph() throws IOException {
        try {
            Scope.enter();
            GBMModel trainGbm = trainGbm(5);
            SharedTreeGraph _computeGraph = trainGbm.toMojo()._computeGraph(-1);
            Assert.assertEquals(5L, _computeGraph.subgraphArray.size());
            for (int i = 0; i < 5; i++) {
                CompressedTree compressedTree = trainGbm._output._treeKeys[i][0].get();
                Assert.assertNotNull(compressedTree);
                Assert.assertEquals(_computeGraph.subgraphArray.get(i), compressedTree.toSharedTreeSubgraph(trainGbm._output._treeKeysAux[i][0].get(), trainGbm._output._names, trainGbm._output._domains));
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testMakeTreeKey() {
        try {
            Scope.enter();
            CompressedTree compressedTree = new CompressedTree(new byte[0], 7, 123L, 42, 17);
            Scope.track_generic(compressedTree);
            DKV.put(compressedTree);
            CompressedTree.TreeCoords treeCoords = compressedTree.getTreeCoords();
            Assert.assertEquals(42L, treeCoords._treeId);
            Assert.assertEquals(17L, treeCoords._clazz);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private GBMModel trainGbm(int i) {
        Frame track = Scope.track(new Frame[]{parse_test_file("smalldata/logreg/prostate.csv")});
        track.replace(track.find("CAPSULE"), track.vec("CAPSULE").toCategoricalVec()).remove();
        DKV.put(track._key, track);
        GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
        gBMParameters._train = track._key;
        gBMParameters._ignored_columns = new String[]{"ID"};
        gBMParameters._response_column = "CAPSULE";
        gBMParameters._ntrees = i;
        gBMParameters._score_each_iteration = true;
        return Scope.track_generic(new GBM(gBMParameters).trainModel().get());
    }
}
