package ai.h2o.automl;

import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.dummy.DummyBuilder;
import ai.h2o.automl.dummy.DummyModel;
import ai.h2o.automl.dummy.DummyStepsProvider;
import hex.Model;
import hex.grid.Grid;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.runner.CloudSize;
import water.runner.H2ORunner;

@CloudSize(1)
@RunWith(H2ORunner.class)
/* loaded from: input_file:ai/h2o/automl/ModelingStepTest.class */
public class ModelingStepTest {
    private List<Keyed> toDelete = new ArrayList();
    private AutoML aml;
    private Frame fr;

    /* loaded from: input_file:ai/h2o/automl/ModelingStepTest$DummyGridStep.class */
    private static class DummyGridStep extends ModelingStep.GridStep<DummyModel> {
        public DummyGridStep(IAlgo iAlgo, String str, int i, AutoML autoML) {
            super(iAlgo, str, i, autoML);
        }

        protected Job<Grid> startJob() {
            DummyModel.DummyModelParameters dummyModelParameters = new DummyModel.DummyModelParameters();
            HashMap hashMap = new HashMap();
            hashMap.put("_tag", new String[]{"one", "two", "three"});
            return hyperparameterSearch(dummyModelParameters, hashMap);
        }
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStepTest$DummyModelStep.class */
    private static class DummyModelStep extends ModelingStep.ModelStep<DummyModel> {
        public DummyModelStep(IAlgo iAlgo, String str, int i, AutoML autoML) {
            super(iAlgo, str, i, autoML);
        }

        protected Job<DummyModel> startJob() {
            return trainModel(new DummyModel.DummyModelParameters());
        }
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStepTest$DummySelectionStep.class */
    private static class DummySelectionStep extends ModelingStep.SelectionStep<DummyModel> {
        boolean _useSearch;

        public DummySelectionStep(IAlgo iAlgo, String str, int i, AutoML autoML, boolean z) {
            super(iAlgo, str, i, autoML);
            this._useSearch = z;
        }

        protected Job<Models> startTraining(Key key, double d) {
            Job startModel;
            DummyModel.DummyModelParameters dummyModelParameters = new DummyModel.DummyModelParameters();
            setCommonModelBuilderParams(dummyModelParameters);
            ((Model.Parameters) dummyModelParameters)._max_runtime_secs = d;
            if (this._useSearch) {
                HashMap hashMap = new HashMap();
                hashMap.put("_tag", new String[]{"uno", "due", "tre", "quattro"});
                startModel = startSearch(Key.make(key + "_expsearch"), dummyModelParameters, hashMap, new HyperSpaceSearchCriteria.CartesianSearchCriteria());
            } else {
                startModel = startModel(Key.make(key + "_expmodel"), dummyModelParameters);
            }
            return asModelsJob(startModel, key);
        }

        protected ModelSelectionStrategy getSelectionStrategy() {
            return new ModelSelectionStrategies.KeepBestN(10, () -> {
                return makeTmpLeaderboard("for_selection");
            });
        }
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStepTest$TestingModelSteps.class */
    private class TestingModelSteps extends DummyStepsProvider.DummyModelSteps {
        public TestingModelSteps(AutoML autoML) {
            super(autoML);
            this.defaultModels = new ModelingStep[]{new DummyModelStep(DummyBuilder.algo, "dummy_model", 10, aml())};
            this.grids = new ModelingStep[]{new DummyGridStep(DummyBuilder.algo, "dummy_grid", 50, aml())};
            this.exploitation = new ModelingStep[]{new DummySelectionStep(DummyBuilder.algo, "dummy_exploitation_single", 10, aml(), false), new DummySelectionStep(DummyBuilder.algo, "dummy_exploitation_multi", 10, aml(), true)};
        }
    }

    @Before
    public void setup() {
        DummyStepsProvider dummyStepsProvider = new DummyStepsProvider();
        dummyStepsProvider.modelStepsFactory = autoML -> {
            return new TestingModelSteps(autoML);
        };
        ModelingStepsRegistry.registerProvider(dummyStepsProvider);
        this.fr = new Frame(Key.make("dummy_fr"), new String[]{"A", "B", "target"}, new Vec[]{TestUtil.ivec(new int[]{1, 2, 3, 4, 5}), TestUtil.ivec(new int[]{1, 2, 3, 4, 5}), TestUtil.cvec(new String[]{"foo", "foo", "foo", "bar", "bar"})});
        DKV.put(this.fr);
        this.toDelete.add(this.fr);
        AutoMLBuildSpec autoMLBuildSpec = new AutoMLBuildSpec();
        autoMLBuildSpec.input_spec.training_frame = this.fr._key;
        autoMLBuildSpec.input_spec.response_column = "target";
        autoMLBuildSpec.build_models.modeling_plan = new StepDefinition[]{new StepDefinition("dummy")};
        autoMLBuildSpec.build_models.exploitation_ratio = 0.5d;
        this.aml = new AutoML((Key) null, new Date(), autoMLBuildSpec);
        DKV.put(this.aml);
        this.toDelete.add(this.aml);
    }

    @After
    public void cleanup() {
        this.toDelete.forEach((v0) -> {
            v0.remove();
        });
    }

    @Test
    public void testModelStep() {
        Keyed keyed = (DummyModel) ((ModelingStep) Arrays.stream(this.aml.getExecutionPlan()).filter(modelingStep -> {
            return "dummy_model".equals(modelingStep._id);
        }).findFirst().get()).startJob().get();
        this.toDelete.add(keyed);
        Assert.assertNotNull(keyed);
        Assert.assertEquals(this.aml.getBuildSpec().input_spec.response_column, ((DummyModel.DummyModelParameters) ((DummyModel) keyed)._parms)._response_column);
        Assert.assertEquals(this.aml.getBuildSpec().build_control.nfolds, ((DummyModel.DummyModelParameters) ((DummyModel) keyed)._parms)._nfolds);
        Assert.assertTrue(((DummyModel.DummyModelParameters) ((DummyModel) keyed)._parms)._max_runtime_secs > 0.0d);
        Assert.assertEquals(this.aml.getBuildSpec().build_control.stopping_criteria.stopping_metric(), ((DummyModel.DummyModelParameters) ((DummyModel) keyed)._parms)._stopping_metric);
    }

    @Test
    public void testGridStep() {
        Keyed keyed = (Grid) ((ModelingStep) Arrays.stream(this.aml.getExecutionPlan()).filter(modelingStep -> {
            return "dummy_grid".equals(modelingStep._id);
        }).findFirst().get()).startJob().get();
        this.toDelete.add(keyed);
        Assert.assertEquals(3L, keyed.getModelCount());
        for (Model model : keyed.getModels()) {
            Assert.assertEquals(this.aml.getBuildSpec().input_spec.response_column, model._parms._response_column);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.nfolds, model._parms._nfolds);
            Assert.assertTrue(model._parms._max_runtime_secs > 0.0d);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.stopping_criteria.stopping_metric(), model._parms._stopping_metric);
        }
    }

    @Test
    public void testSelectionStepSingleModel() {
        Keyed keyed = (Models) ((ModelingStep) Arrays.stream(this.aml.getExecutionPlan()).filter(modelingStep -> {
            return "dummy_exploitation_single".equals(modelingStep._id);
        }).findFirst().get()).startJob().get();
        this.toDelete.add(keyed);
        Assert.assertEquals(1L, keyed.getModelCount());
        for (Model model : keyed.getModels()) {
            Assert.assertEquals(this.aml.getBuildSpec().input_spec.response_column, model._parms._response_column);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.nfolds, model._parms._nfolds);
            Assert.assertTrue(model._parms._max_runtime_secs > 0.0d);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.stopping_criteria.stopping_metric(), model._parms._stopping_metric);
        }
    }

    @Test
    public void testSelectionStepMultipleModels() {
        Keyed keyed = (Models) ((ModelingStep) Arrays.stream(this.aml.getExecutionPlan()).filter(modelingStep -> {
            return "dummy_exploitation_multi".equals(modelingStep._id);
        }).findFirst().get()).startJob().get();
        this.toDelete.add(keyed);
        Assert.assertEquals(4L, keyed.getModelCount());
        for (Model model : keyed.getModels()) {
            Assert.assertEquals(this.aml.getBuildSpec().input_spec.response_column, model._parms._response_column);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.nfolds, model._parms._nfolds);
            Assert.assertTrue(model._parms._max_runtime_secs > 0.0d);
            Assert.assertEquals(this.aml.getBuildSpec().build_control.stopping_criteria.stopping_metric(), model._parms._stopping_metric);
        }
    }
}
