package hex.tree.drf;

import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRFModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.ModelSerializationTest;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:hex/tree/drf/DRFCheckpointTest.class */
public class DRFCheckpointTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial() {
        testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial() {
        testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 1, 1);
    }

    @Test(expected = H2OIllegalArgumentException.class)
    public void testCheckpointWrongParams() {
        testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f);
    }

    @Test
    public void testCheckpointReconstruction4Regression() {
        testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 4, 3);
    }

    @Test
    public void testCheckpointReconstruction4Regression2() {
        testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 4, 3);
    }

    private void testCheckPointReconstruction(String str, int i, boolean z, int i2, int i3) {
        testCheckPointReconstruction(str, i, z, i2, i3, 0.632f, 0.632f);
    }

    private void testCheckPointReconstruction(String str, int i, boolean z, int i2, int i3, float f, float f2) {
        Frame parse_test_file = parse_test_file(str);
        Vec remove = parse_test_file.remove("economy");
        if (remove != null) {
            remove.remove();
        }
        DKV.put(parse_test_file);
        if (z) {
            parse_test_file.replace(i, parse_test_file.vec(i).toCategoricalVec()).remove();
            DKV.put(parse_test_file._key, parse_test_file);
        }
        DRFModel dRFModel = null;
        SharedTreeModel sharedTreeModel = null;
        SharedTreeModel sharedTreeModel2 = null;
        try {
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = parse_test_file._key;
            dRFParameters._response_column = parse_test_file.name(i);
            dRFParameters._ntrees = i2;
            dRFParameters._seed = 42L;
            dRFParameters._max_depth = 10;
            dRFParameters._score_each_iteration = true;
            dRFParameters._sample_rate = f;
            dRFModel = (DRFModel) new DRF(dRFParameters, Key.make("Initial model")).trainModel().get();
            DRFModel.DRFParameters dRFParameters2 = new DRFModel.DRFParameters();
            dRFParameters2._train = parse_test_file._key;
            dRFParameters2._response_column = parse_test_file.name(i);
            dRFParameters2._ntrees = i2 + i3;
            dRFParameters2._seed = 42L;
            dRFParameters2._checkpoint = dRFModel._key;
            dRFParameters2._score_each_iteration = true;
            dRFParameters2._max_depth = 10;
            dRFParameters2._sample_rate = f2;
            sharedTreeModel = (DRFModel) new DRF(dRFParameters2, Key.make("Model from checkpoint")).trainModel().get();
            DRFModel.DRFParameters dRFParameters3 = new DRFModel.DRFParameters();
            dRFParameters3._train = parse_test_file._key;
            dRFParameters3._response_column = parse_test_file.name(i);
            dRFParameters3._ntrees = i2 + i3;
            dRFParameters3._seed = 42L;
            dRFParameters3._score_each_iteration = true;
            dRFParameters3._max_depth = 10;
            dRFParameters3._sample_rate = f;
            sharedTreeModel2 = (DRFModel) new DRF(dRFParameters3, Key.make("Validation model")).trainModel().get();
            CompressedTree[][] trees = ModelSerializationTest.getTrees(sharedTreeModel);
            CompressedTree[][] trees2 = ModelSerializationTest.getTrees(sharedTreeModel2);
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", trees, trees2, true);
            for (int i4 = 0; i4 < trees.length; i4++) {
                for (int i5 = 0; i5 < trees[i4].length; i5++) {
                    if (trees[i4][i5] != null) {
                        Assert.assertNotEquals(trees[i4][i5]._key, trees2[i4][i5]._key);
                    }
                }
            }
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
        } catch (Throwable th) {
            if (parse_test_file != null) {
                parse_test_file.delete();
            }
            if (dRFModel != null) {
                dRFModel.delete();
            }
            if (sharedTreeModel != null) {
                sharedTreeModel.delete();
            }
            if (sharedTreeModel2 != null) {
                sharedTreeModel2.delete();
            }
            throw th;
        }
    }
}
