package ai.h2o.automl;

import hex.Model;
import java.util.Date;
import java.util.Random;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:ai/h2o/automl/AutoMLTest.class */
public class AutoMLTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void test_basic_automl_behaviour_using_cv() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(3);
            autoMLBuildSpec.build_control.keep_cross_validation_models = false;
            autoMLBuildSpec.build_control.keep_cross_validation_predictions = false;
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            int i = 0;
            int i2 = 0;
            for (Key key : autoML.leaderboard().getModelKeys()) {
                if (key.toString().startsWith("StackedEnsemble")) {
                    i++;
                } else {
                    i2++;
                }
            }
            Assert.assertEquals("wrong amount of standard models", 3L, i2);
            Assert.assertEquals("wrong amount of SE models", 2L, i);
            Assert.assertEquals(5L, autoML.leaderboard().getModelCount());
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_automl_with_cv_disabled() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(3);
            autoMLBuildSpec.build_control.nfolds = 0;
            autoMLBuildSpec.build_control.keep_cross_validation_models = false;
            autoMLBuildSpec.build_control.keep_cross_validation_predictions = false;
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            int i = 0;
            int i2 = 0;
            for (Key key : autoML.leaderboard().getModelKeys()) {
                if (key.toString().startsWith("StackedEnsemble")) {
                    i++;
                } else {
                    i2++;
                }
            }
            Assert.assertEquals("wrong amount of standard models", 3L, i2);
            Assert.assertEquals("no Stacked Ensemble expected if cross-validation is disabled", 0L, i);
            Assert.assertEquals(3L, autoML.leaderboard().getModelCount());
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_automl_basic_behaviour_on_timeout() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.build_control.stopping_criteria.set_max_runtime_secs(new Random().nextInt(30));
            autoMLBuildSpec.build_control.keep_cross_validation_models = false;
            autoMLBuildSpec.build_control.keep_cross_validation_predictions = false;
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_automl_basic_behaviour_on_grid_timeout() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.build_models.exclude_algos = new Algo[]{Algo.DeepLearning, Algo.DRF, Algo.GLM};
            autoMLBuildSpec.build_control.stopping_criteria.set_max_runtime_secs(8.0d);
            autoMLBuildSpec.build_control.keep_cross_validation_models = false;
            autoMLBuildSpec.build_control.keep_cross_validation_predictions = false;
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void KeepCrossValidationFoldAssignmentEnabledTest() {
        AutoML autoML = null;
        Frame frame = null;
        Model model = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_max_runtime_secs(30.0d);
            autoMLBuildSpec.build_control.keep_cross_validation_fold_assignment = true;
            autoML = AutoML.makeAutoML(Key.make(), new Date(), autoMLBuildSpec);
            AutoML.startAutoML(autoML);
            autoML.get();
            model = autoML.leader();
            TestCase.assertTrue(model != null && model._parms._keep_cross_validation_fold_assignment);
            TestCase.assertNotNull(model._output._cross_validation_fold_assignment_frame_id);
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (model != null) {
                DKV.getGet(model._output._cross_validation_fold_assignment_frame_id).delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (model != null) {
                DKV.getGet(model._output._cross_validation_fold_assignment_frame_id).delete();
            }
            throw th;
        }
    }

    @Test
    public void KeepCrossValidationFoldAssignmentDisabledTest() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "IsDepDelayed";
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_max_runtime_secs(30.0d);
            autoMLBuildSpec.build_control.keep_cross_validation_fold_assignment = false;
            autoML = AutoML.makeAutoML(Key.make(), new Date(), autoMLBuildSpec);
            AutoML.startAutoML(autoML);
            autoML.get();
            Model leader = autoML.leader();
            TestCase.assertTrue((leader == null || leader._parms._keep_cross_validation_fold_assignment) ? false : true);
            TestCase.assertNull(leader._output._cross_validation_fold_assignment_frame_id);
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void testWorkPlan() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.response_column = "IsDepDelayed";
            autoML = new AutoML(Key.make(), new Date(), autoMLBuildSpec);
            Assert.assertEquals(autoML.planWork().remainingWork(), 380);
            autoMLBuildSpec.build_models.exclude_algos = new Algo[]{Algo.DeepLearning, Algo.XGBoost};
            Assert.assertEquals(autoML.planWork().remainingWork(), 380 - 200);
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    public void test_training_frame_partition_when_cv_disabled_and_validation_frame_missing() {
        AutoML autoML = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            frame2 = parse_test_file("./smalldata/logreg/prostate_test.csv");
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.validation_frame = null;
            autoMLBuildSpec.input_spec.leaderboard_frame = frame2._key;
            autoMLBuildSpec.build_control.nfolds = 0;
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_seed(1L);
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            Assert.assertEquals(0.9d, autoML.getTrainingFrame().numRows() / frame.numRows(), 0.01d);
            Assert.assertEquals(0.1d, autoML.getValidationFrame().numRows() / frame.numRows(), 0.01d);
            Assert.assertEquals(frame2.numRows(), autoML.getLeaderboardFrame().numRows());
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void test_training_frame_partition_when_cv_disabled_and_leaderboard_frame_missing() {
        AutoML autoML = null;
        Frame frame = null;
        Frame frame2 = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            frame2 = parse_test_file("./smalldata/logreg/prostate_test.csv");
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.validation_frame = frame2._key;
            autoMLBuildSpec.input_spec.leaderboard_frame = null;
            autoMLBuildSpec.build_control.nfolds = 0;
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_seed(1L);
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            Assert.assertEquals(0.9d, autoML.getTrainingFrame().numRows() / frame.numRows(), 0.01d);
            Assert.assertEquals(frame2.numRows(), autoML.getValidationFrame().numRows());
            Assert.assertEquals(0.1d, autoML.getLeaderboardFrame().numRows() / frame.numRows(), 0.01d);
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            throw th;
        }
    }

    @Test
    public void test_training_frame_partition_when_cv_disabled_and_both_validation_and_leaderboard_frames_missing() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.validation_frame = null;
            autoMLBuildSpec.input_spec.leaderboard_frame = null;
            autoMLBuildSpec.build_control.nfolds = 0;
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_seed(1L);
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            Assert.assertEquals(0.8d, autoML.getTrainingFrame().numRows() / frame.numRows(), 0.01d);
            Assert.assertEquals(0.1d, autoML.getValidationFrame().numRows() / frame.numRows(), 0.01d);
            Assert.assertEquals(0.1d, autoML.getLeaderboardFrame().numRows() / frame.numRows(), 0.01d);
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }

    @Test
    public void test_training_frame_not_partitioned_when_cv_enabled() {
        AutoML autoML = null;
        Frame frame = null;
        try {
            AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            autoMLBuildSpec.input_spec.response_column = "CAPSULE";
            autoMLBuildSpec.input_spec.training_frame = frame._key;
            autoMLBuildSpec.input_spec.validation_frame = null;
            autoMLBuildSpec.input_spec.leaderboard_frame = null;
            autoMLBuildSpec.build_control.stopping_criteria.set_max_models(1);
            autoMLBuildSpec.build_control.stopping_criteria.set_seed(1L);
            autoML = AutoML.startAutoML(autoMLBuildSpec);
            autoML.get();
            Assert.assertEquals(frame.numRows(), autoML.getTrainingFrame().numRows());
            TestCase.assertNull(autoML.getValidationFrame());
            TestCase.assertNull(autoML.getLeaderboardFrame());
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
        } catch (Throwable th) {
            if (autoML != null) {
                autoML.deleteWithChildren();
            }
            if (frame != null) {
                frame.remove();
            }
            throw th;
        }
    }
}
