package ai.h2o.automl;

import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.leaderboard.Leaderboard;
import hex.Model;
import hex.ModelBuilder;
import hex.ScoreKeeper;
import hex.ensemble.StackedEnsembleModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import java.util.Map;
import water.Iced;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.util.EnumUtils;
import water.util.Log;

/* loaded from: input_file:ai/h2o/automl/ModelingStep.class */
public abstract class ModelingStep<M extends Model> extends Iced<ModelingStep> {
    private transient AutoML _aml;
    protected final Algo _algo;
    protected final String _id;
    protected int _weight;
    protected boolean _ignoreConstraints;
    protected String _description;
    StepDefinition _fromDef;

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$GridStep.class */
    public static abstract class GridStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_GRID_TRAINING_WEIGHT = 20;

        public GridStep(Algo algo, String str, int i, AutoML autoML) {
            super(algo, str, i, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected abstract Job<Grid> startJob();

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.HyperparamSearch, this._weight);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.Work getAllocatedWork() {
            return getWorkAllocations().getAllocation(this._id, this._algo);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Key<Grid> makeKey(String str, boolean z) {
            return aml().gridKey(str, z);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<Grid> hyperparameterSearch(Model.Parameters parameters, Map<String, Object[]> map) {
            return hyperparameterSearch(null, parameters, map);
        }

        protected Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters parameters, Map<String, Object[]> map) {
            try {
                Model.Parameters parameters2 = (Model.Parameters) parameters.getClass().newInstance();
                setCommonModelBuilderParams(parameters);
                setStoppingCriteria(parameters, parameters2, false);
                setCustomParams(parameters);
                HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria clone = aml().getBuildSpec().build_control.stopping_criteria.getSearchCriteria().clone();
                WorkAllocations.Work allocatedWork = getAllocatedWork();
                double timeRemainingMs = (((float) aml().timeRemainingMs()) * getWorkAllocations().remainingWorkRatio(allocatedWork)) / 1000.0d;
                int ceil = (int) Math.ceil(aml().remainingModels() * getWorkAllocations().remainingWorkRatio(allocatedWork, work -> {
                    return work._algo != Algo.StackedEnsemble;
                }));
                if (clone.max_runtime_secs() == 0.0d) {
                    clone.set_max_runtime_secs(timeRemainingMs);
                } else {
                    clone.set_max_runtime_secs(Math.min(clone.max_runtime_secs(), timeRemainingMs));
                }
                if (clone.max_models() == 0) {
                    clone.set_max_models(ceil);
                } else {
                    clone.set_max_models(Math.min(clone.max_models(), ceil));
                }
                if (null == key) {
                    key = makeKey(this._algo.name(), true);
                }
                aml().addGridKey(key);
                aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " hyperparameter search");
                Log.debug(new Object[]{"Hyperparameter search: " + this._algo.name() + ", time remaining (ms): " + aml().timeRemainingMs()});
                return GridSearch.startGridSearch(key, parameters, map, new GridSearch.SimpleParametersBuilderFactory(), clone, 1);
            } catch (Exception e) {
                aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Internal error doing hyperparameter search");
                throw new H2OIllegalArgumentException("Hyperparameter search can't create a new instance of Model.Parameters subclass: " + parameters.getClass());
            }
        }
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$ModelStep.class */
    public static abstract class ModelStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_MODEL_TRAINING_WEIGHT = 10;

        public ModelStep(Algo algo, String str, int i, AutoML autoML) {
            super(algo, str, i, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected abstract Job<M> startJob();

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.Work makeWork() {
            return new WorkAllocations.Work(this._id, this._algo, WorkAllocations.JobType.ModelBuild, this._weight);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.h2o.automl.ModelingStep
        public WorkAllocations.Work getAllocatedWork() {
            return getWorkAllocations().getAllocation(this._id, this._algo);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.h2o.automl.ModelingStep
        public Key<M> makeKey(String str, boolean z) {
            return (Key<M>) aml().modelKey(str, z);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<M> trainModel(Model.Parameters parameters) {
            return trainModel(null, parameters);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<M> trainModel(Key<M> key, Model.Parameters parameters) {
            String algoName = ModelBuilder.algoName(this._algo.urlName());
            if (null == key) {
                key = makeKey(algoName, true);
            }
            ModelBuilder make = ModelBuilder.make(this._algo.urlName(), new Job(key, ModelBuilder.javaName(this._algo.urlName()), this._description), key);
            Model.Parameters parameters2 = make._parms;
            make._parms = parameters;
            setCommonModelBuilderParams(make._parms);
            setStoppingCriteria(make._parms, parameters2, true);
            setCustomParams(make._parms);
            if (this._ignoreConstraints) {
                make._parms._max_runtime_secs = 0.0d;
            } else if (make._parms._max_runtime_secs == 0.0d) {
                make._parms._max_runtime_secs = aml().timeRemainingMs() / 1000.0d;
            } else {
                make._parms._max_runtime_secs = Math.min(make._parms._max_runtime_secs, aml().timeRemainingMs() / 1000.0d);
            }
            aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " model training");
            make.init(false);
            Log.debug(new Object[]{"Training model: " + algoName + ", time remaining (ms): " + aml().timeRemainingMs()});
            try {
                return make.trainModelOnH2ONode();
            } catch (H2OIllegalArgumentException e) {
                aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, "Skipping training of model " + key + " due to exception: " + e);
                return null;
            }
        }
    }

    protected ModelingStep(Algo algo, String str, int i, AutoML autoML) {
        this._algo = algo;
        this._id = str;
        this._weight = i;
        this._aml = autoML;
        this._description = algo.name() + " " + str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract WorkAllocations.Work getAllocatedWork();

    protected abstract Key makeKey(String str, boolean z);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract WorkAllocations.Work makeWork();

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract Job startJob();

    /* JADX INFO: Access modifiers changed from: protected */
    public AutoML aml() {
        return this._aml;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean canRun() {
        return getAllocatedWork() != null;
    }

    protected WorkAllocations getWorkAllocations() {
        return aml()._workAllocations;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Model[] getTrainedModels() {
        return aml().leaderboard().getModels();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isCVEnabled() {
        return aml().isCVEnabled();
    }

    void setCommonModelBuilderParams(Model.Parameters parameters) {
        parameters._train = aml()._trainingFrame._key;
        if (null != aml()._validationFrame) {
            parameters._valid = aml()._validationFrame._key;
        }
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        parameters._response_column = buildSpec.input_spec.response_column;
        parameters._ignored_columns = buildSpec.input_spec.ignored_columns;
        if (!(parameters instanceof StackedEnsembleModel.StackedEnsembleParameters)) {
            parameters._keep_cross_validation_predictions = aml().getBlendingFrame() == null ? true : buildSpec.build_control.keep_cross_validation_predictions;
            parameters._fold_column = buildSpec.input_spec.fold_column;
            parameters._weights_column = buildSpec.input_spec.weights_column;
            if (buildSpec.input_spec.fold_column == null) {
                parameters._nfolds = buildSpec.build_control.nfolds;
                if (buildSpec.build_control.nfolds > 1) {
                    parameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
                }
            }
            if (buildSpec.build_control.balance_classes) {
                parameters._balance_classes = buildSpec.build_control.balance_classes;
                parameters._class_sampling_factors = buildSpec.build_control.class_sampling_factors;
                parameters._max_after_balance_size = buildSpec.build_control.max_after_balance_size;
            }
        }
        parameters._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
        parameters._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
        parameters._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;
    }

    void setCustomParams(Model.Parameters parameters) {
        AutoMLBuildSpec.AutoMLCustomParameters autoMLCustomParameters = aml().getBuildSpec().build_models.algo_parameters;
        if (autoMLCustomParameters == null) {
            return;
        }
        autoMLCustomParameters.applyCustomParameters(this._algo, parameters);
    }

    void setStoppingCriteria(Model.Parameters parameters, Model.Parameters parameters2, boolean z) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        parameters._max_runtime_secs = buildSpec.build_control.stopping_criteria.max_runtime_secs_per_model();
        if (z && parameters._seed == parameters2._seed && buildSpec.build_control.stopping_criteria.seed() != -1) {
            parameters._seed = buildSpec.build_control.stopping_criteria.seed() + aml().individualModelsTrained.getAndIncrement();
        }
        if (parameters._stopping_metric == parameters2._stopping_metric) {
            parameters._stopping_metric = buildSpec.build_control.stopping_criteria.stopping_metric();
        }
        if (parameters._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
            String sortMetric = getSortMetric();
            parameters._stopping_metric = sortMetric == null ? ScoreKeeper.StoppingMetric.AUTO : sortMetric.equals("auc") ? ScoreKeeper.StoppingMetric.logloss : metricValueOf(sortMetric);
        }
        if (parameters._stopping_rounds == parameters2._stopping_rounds) {
            parameters._stopping_rounds = buildSpec.build_control.stopping_criteria.stopping_rounds();
        }
        if (parameters._stopping_tolerance == parameters2._stopping_tolerance) {
            parameters._stopping_tolerance = buildSpec.build_control.stopping_criteria.stopping_tolerance();
        }
    }

    private String getSortMetric() {
        Leaderboard leaderboard = aml().leaderboard();
        if (leaderboard == null) {
            return null;
        }
        return leaderboard.getSortMetric();
    }

    private static ScoreKeeper.StoppingMetric metricValueOf(String str) {
        if (str == null) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -1322110077:
                if (str.equals("mean_residual_deviance")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return ScoreKeeper.StoppingMetric.deviance;
            default:
                try {
                    return EnumUtils.valueOf(ScoreKeeper.StoppingMetric.class, str);
                } catch (IllegalArgumentException e) {
                    return ScoreKeeper.StoppingMetric.AUTO;
                }
        }
    }
}
