package ai.h2o.automl;

import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.StepResultState;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLog;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.PreprocessingStep;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelContainer;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import hex.leaderboard.Leaderboard;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Predicate;
import jsr166y.CountedCompleter;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.util.ArrayUtils;
import water.util.Countdown;
import water.util.EnumUtils;
import water.util.Log;

/* loaded from: input_file:ai/h2o/automl/ModelingStep.class */
public abstract class ModelingStep<M extends Model> extends Iced<ModelingStep> {
    static Predicate<WorkAllocations.Work> isDefaultModel;
    static Predicate<WorkAllocations.Work> isExplorationWork;
    static Predicate<WorkAllocations.Work> isExploitationWork;
    private final transient AutoML _aml;
    protected final IAlgo _algo;
    protected final String _provider;
    protected final String _id;
    protected int _weight;
    protected int _priorityGroup;
    protected String _description;
    protected WorkAllocations.Work _work;
    StepDefinition _fromDef;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected AutoML.Constraint[] _ignoredConstraints = new AutoML.Constraint[0];
    private final transient List<Consumer<Job>> _onDone = new ArrayList();
    final transient Predicate<WorkAllocations.Work> _isSamePriorityGroup = work -> {
        return work._priorityGroup == this._priorityGroup;
    };

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$DynamicStep.class */
    public static abstract class DynamicStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_DYNAMIC_TRAINING_WEIGHT = 20;
        public static final int DEFAULT_DYNAMIC_GROUP = 100;
        private transient Collection<ModelingStep> _subSteps;

        /* loaded from: input_file:ai/h2o/automl/ModelingStep$DynamicStep$VirtualAlgo.class */
        public static class VirtualAlgo implements IAlgo {
            @Override // ai.h2o.automl.IAlgo
            public String name() {
                return "virtual";
            }
        }

        public DynamicStep(String str, String str2, AutoML autoML) {
            this(str, str2, 100, 20, autoML);
        }

        public DynamicStep(String str, String str2, int i, int i2, AutoML autoML) {
            super(str, new VirtualAlgo(), str2, i, i2, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        public boolean canRun() {
            return false;
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Job<M> startJob() {
            return null;
        }

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.Dynamic;
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Key<Models> makeKey(String str, boolean z) {
            return aml().makeKey(str, "decision", z);
        }

        private void initSubSteps() {
            if (this._subSteps == null) {
                this._subSteps = prepareModelingSteps();
            }
        }

        @Override // ai.h2o.automl.ModelingStep
        public Iterator<? extends ModelingStep> iterateSubSteps() {
            initSubSteps();
            return this._subSteps.iterator();
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Optional<? extends ModelingStep> getSubStep(String str) {
            initSubSteps();
            return this._subSteps.stream().filter(modelingStep -> {
                return modelingStep._id.equals(str);
            }).findFirst();
        }

        protected abstract Collection<ModelingStep> prepareModelingSteps();
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$GridStep.class */
    public static abstract class GridStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_GRID_TRAINING_WEIGHT = 30;
        public static final int DEFAULT_GRID_GROUP = 2;
        protected static final int GRID_STOPPING_ROUND_FACTOR = 2;

        public GridStep(String str, IAlgo iAlgo, String str2, AutoML autoML) {
            this(str, iAlgo, str2, 2, 30, autoML);
        }

        public GridStep(String str, IAlgo iAlgo, String str2, int i, int i2, AutoML autoML) {
            super(str, iAlgo, str2, i, i2, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.HyperparamSearch;
        }

        @Override // ai.h2o.automl.ModelingStep
        public boolean isResumable() {
            return true;
        }

        public abstract Model.Parameters prepareModelParameters();

        public abstract Map<String, Object[]> prepareSearchParameters();

        @Override // ai.h2o.automl.ModelingStep
        protected Job<Grid> startJob() {
            return hyperparameterSearch(prepareModelParameters(), prepareSearchParameters());
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Key<Grid> makeKey(String str, boolean z) {
            return aml().makeKey(str, "grid", z);
        }

        protected Job<Grid> hyperparameterSearch(Model.Parameters parameters, Map<String, Object[]> map) {
            return hyperparameterSearch(null, parameters, map);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<Grid> hyperparameterSearch(Key<Grid> key, Model.Parameters parameters, Map<String, Object[]> map) {
            try {
                Model.Parameters parameters2 = (Model.Parameters) parameters.getClass().newInstance();
                initTimeConstraints(parameters, CMAESOptimizer.DEFAULT_STOPFITNESS);
                setCommonModelBuilderParams(parameters);
                setStoppingCriteria(parameters, parameters2);
                setCustomParams(parameters);
                setDistributionParameters(parameters);
                HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria = (HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria) aml().getBuildSpec().build_control.stopping_criteria.getSearchCriteria().m2203clone();
                setSearchCriteria(randomDiscreteValueSearchCriteria, parameters);
                if (null == key) {
                    key = makeKey(this._provider, true);
                }
                aml().trackKeys(key);
                Log.debug("Hyperparameter search: " + this._provider + ", time remaining (ms): " + aml().timeRemainingMs());
                aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, randomDiscreteValueSearchCriteria.max_runtime_secs() == CMAESOptimizer.DEFAULT_STOPFITNESS ? "No time limitation for " + key : "Time assigned for " + key + ": " + randomDiscreteValueSearchCriteria.max_runtime_secs() + "s");
                return startSearch(key, parameters, map, randomDiscreteValueSearchCriteria);
            } catch (Exception e) {
                aml().eventLog().error(EventLogEntry.Stage.ModelTraining, "Internal error doing hyperparameter search");
                throw new H2OIllegalArgumentException("Hyperparameter search can't create a new instance of Model.Parameters subclass: " + parameters.getClass());
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void setSearchCriteria(HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria randomDiscreteValueSearchCriteria, Model.Parameters parameters) {
            WorkAllocations.Work allocatedWork = getAllocatedWork();
            double timeRemainingMs = limitModelTrainingTime() ? (((float) aml().timeRemainingMs()) * getWorkAllocations().remainingWorkRatio(allocatedWork, this._isSamePriorityGroup)) / 1000.0d : CMAESOptimizer.DEFAULT_STOPFITNESS;
            int ceil = (int) Math.ceil(aml().remainingModels() * getWorkAllocations().remainingWorkRatio(allocatedWork, isExplorationWork.and(work -> {
                return work._algo != Algo.StackedEnsemble;
            })));
            randomDiscreteValueSearchCriteria.set_max_runtime_secs(randomDiscreteValueSearchCriteria.max_runtime_secs() == CMAESOptimizer.DEFAULT_STOPFITNESS ? timeRemainingMs : Math.min(randomDiscreteValueSearchCriteria.max_runtime_secs(), timeRemainingMs));
            randomDiscreteValueSearchCriteria.set_max_models(randomDiscreteValueSearchCriteria.max_models() == 0 ? ceil : Math.min(randomDiscreteValueSearchCriteria.max_models(), ceil));
            randomDiscreteValueSearchCriteria.set_stopping_rounds(parameters._stopping_rounds * 2);
        }
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$ModelStep.class */
    public static abstract class ModelStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_MODEL_TRAINING_WEIGHT = 10;
        public static final int DEFAULT_MODEL_GROUP = 1;

        public ModelStep(String str, IAlgo iAlgo, String str2, AutoML autoML) {
            this(str, iAlgo, str2, 1, 10, autoML);
        }

        public ModelStep(String str, IAlgo iAlgo, String str2, int i, int i2, AutoML autoML) {
            super(str, iAlgo, str2, i, i2, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.ModelBuild;
        }

        public abstract Model.Parameters prepareModelParameters();

        @Override // ai.h2o.automl.ModelingStep
        protected Job<M> startJob() {
            return trainModel(prepareModelParameters());
        }

        protected Job<M> trainModel(Model.Parameters parameters) {
            return trainModel(null, parameters);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<M> trainModel(Key<M> key, Model.Parameters parameters) {
            String algoName = ModelBuilder.algoName(this._algo.urlName());
            if (null == key) {
                key = makeKey(algoName, true);
            }
            P p = ModelBuilder.make(this._algo.urlName(), null, null)._parms;
            initTimeConstraints(parameters, CMAESOptimizer.DEFAULT_STOPFITNESS);
            setCommonModelBuilderParams(parameters);
            setSeed(parameters, p, SeedPolicy.Incremental);
            setStoppingCriteria(parameters, p);
            setCustomParams(parameters);
            setDistributionParameters(parameters);
            if (limitModelTrainingTime()) {
                double timeRemainingMs = (((float) aml().timeRemainingMs()) * getWorkAllocations().remainingWorkRatio(getAllocatedWork(), this._isSamePriorityGroup)) / 1000.0d;
                parameters._max_runtime_secs = parameters._max_runtime_secs == CMAESOptimizer.DEFAULT_STOPFITNESS ? timeRemainingMs : Math.min(parameters._max_runtime_secs, timeRemainingMs);
            } else {
                parameters._max_runtime_secs = CMAESOptimizer.DEFAULT_STOPFITNESS;
            }
            Log.debug("Training model: " + algoName + ", time remaining (ms): " + aml().timeRemainingMs());
            aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, parameters._max_runtime_secs == CMAESOptimizer.DEFAULT_STOPFITNESS ? "No time limitation for " + key : "Time assigned for " + key + ": " + parameters._max_runtime_secs + "s");
            return startModel(key, parameters);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:ai/h2o/automl/ModelingStep$SeedPolicy.class */
    public enum SeedPolicy {
        None,
        Global,
        Incremental
    }

    /* loaded from: input_file:ai/h2o/automl/ModelingStep$SelectionStep.class */
    public static abstract class SelectionStep<M extends Model> extends ModelingStep<M> {
        public static final int DEFAULT_SELECTION_TRAINING_WEIGHT = 20;
        public static final int DEFAULT_SELECTION_GROUP = 3;

        public SelectionStep(String str, IAlgo iAlgo, String str2, AutoML autoML) {
            this(str, iAlgo, str2, 3, 20, autoML);
        }

        public SelectionStep(String str, IAlgo iAlgo, String str2, int i, int i2, AutoML autoML) {
            super(str, iAlgo, str2, i, i2, autoML);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected WorkAllocations.JobType getJobType() {
            return WorkAllocations.JobType.Selection;
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Key<Models> makeKey(String str, boolean z) {
            return aml().makeKey(str, "selection", z);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public ModelSelectionStrategies.LeaderboardHolder makeLeaderboard(String str, final EventLog eventLog) {
            Leaderboard leaderboard = aml().leaderboard();
            final EventLog orMake = eventLog == null ? EventLog.getOrMake(Key.make(str)) : eventLog;
            final Leaderboard orMake2 = Leaderboard.getOrMake(str, orMake.asLogger(EventLogEntry.Stage.ModelTraining), leaderboard.leaderboardFrame(), leaderboard.getSortMetric());
            return new ModelSelectionStrategies.LeaderboardHolder() { // from class: ai.h2o.automl.ModelingStep.SelectionStep.1
                @Override // ai.h2o.automl.ModelSelectionStrategies.LeaderboardHolder
                public Leaderboard get() {
                    return orMake2;
                }

                @Override // ai.h2o.automl.ModelSelectionStrategies.LeaderboardHolder
                public void cleanup() {
                    orMake2.removeModels(orMake2.getModelKeys(), false);
                    orMake2.remove(false);
                    if (eventLog == null) {
                        orMake.remove();
                    }
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public ModelSelectionStrategies.LeaderboardHolder makeTmpLeaderboard(String str) {
            return makeLeaderboard("tmp_" + str, null);
        }

        @Override // ai.h2o.automl.ModelingStep
        protected Job<Models> startJob() {
            final Key<Model>[] trainedModelsKeys = getTrainedModelsKeys();
            final Key<Models> makeKey = makeKey(this._provider + "_" + this._id, false);
            aml().trackKeys(makeKey);
            final Job job = new Job(makeKey, Models.class.getName(), this._description);
            WorkAllocations.Work allocatedWork = getAllocatedWork();
            final double timeRemainingMs = limitModelTrainingTime() ? (((float) aml().timeRemainingMs()) * getWorkAllocations().remainingWorkRatio(allocatedWork)) / 1000.0d : CMAESOptimizer.DEFAULT_STOPFITNESS;
            aml().eventLog().debug(EventLogEntry.Stage.ModelTraining, timeRemainingMs == CMAESOptimizer.DEFAULT_STOPFITNESS ? "No time limitation for " + makeKey : "Time assigned for " + makeKey + ": " + timeRemainingMs + "s");
            return job.start(new H2O.H2OCountedCompleter() { // from class: ai.h2o.automl.ModelingStep.SelectionStep.2
                final Models result;
                final Key<Models> selectionKey;
                final EventLog selectionEventLog;
                final ModelSelectionStrategies.LeaderboardHolder selectionLeaderboard;

                {
                    this.result = new Models(makeKey, Model.class, job);
                    this.selectionKey = Key.make(makeKey + "_select");
                    this.selectionEventLog = EventLog.getOrMake(this.selectionKey);
                    this.selectionLeaderboard = SelectionStep.this.makeLeaderboard(this.selectionKey.toString(), this.selectionEventLog);
                    this.result.delete_and_lock(job);
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // water.H2O.H2OCountedCompleter
                public void compute2() {
                    ModelSelectionStrategy.Selection selection = null;
                    try {
                        ModelingStepsExecutor modelingStepsExecutor = new ModelingStepsExecutor(this.selectionLeaderboard.get(), this.selectionEventLog, Countdown.fromSeconds(timeRemainingMs));
                        modelingStepsExecutor.start();
                        StepResultState monitor = modelingStepsExecutor.monitor(SelectionStep.this.startTraining(this.selectionKey, timeRemainingMs), SelectionStep.this, job);
                        if (monitor.is(StepResultState.ResultStatus.success)) {
                            Log.debug("Selection leaderboard " + this.selectionLeaderboard.get()._key, this.selectionLeaderboard.get().toLogString());
                            selection = SelectionStep.this.getSelectionStrategy().select(trainedModelsKeys, this.selectionLeaderboard.get().getModelKeys());
                            Leaderboard leaderboard = SelectionStep.this.aml().leaderboard();
                            Log.debug("Selection result for job " + makeKey, ToStringBuilder.reflectionToString(selection));
                            leaderboard.removeModels(selection._remove, false);
                            SelectionStep.this.aml().trackKeys(selection._remove);
                            leaderboard.addModels(selection._add);
                        } else {
                            if (monitor.is(StepResultState.ResultStatus.failed)) {
                                throw ((RuntimeException) monitor.error());
                            }
                            if (monitor.is(StepResultState.ResultStatus.cancelled)) {
                                throw new Job.JobCancelledException();
                            }
                        }
                        this.result.unlock(job);
                        if (selection != null) {
                            this.result.addModels(selection._add);
                        }
                        tryComplete();
                    } catch (Throwable th) {
                        this.result.unlock(job);
                        if (0 != 0) {
                            this.result.addModels(selection._add);
                        }
                        throw th;
                    }
                }

                @Override // jsr166y.CountedCompleter
                public void onCompletion(CountedCompleter countedCompleter) {
                    Keyed.remove(this.selectionKey, new Futures(), false);
                    this.selectionLeaderboard.get().removeModels(trainedModelsKeys, false);
                    this.selectionLeaderboard.get().removeModels((Key[]) Arrays.stream(this.selectionLeaderboard.get().getModelKeys()).filter(key -> {
                        return !ArrayUtils.contains(this.result.getModelKeys(), key);
                    }).toArray(i -> {
                        return new Key[i];
                    }), true);
                    this.selectionLeaderboard.cleanup();
                    if (!SelectionStep.this.aml().eventLog()._key.equals(this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    super.onCompletion(countedCompleter);
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // jsr166y.CountedCompleter
                public boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
                    this.result.unlock(job._key, false);
                    Keyed.remove(this.selectionKey);
                    this.selectionLeaderboard.get().remove();
                    if (!SelectionStep.this.aml().eventLog()._key.equals(this.selectionEventLog._key)) {
                        this.selectionEventLog.remove();
                    }
                    return super.onExceptionalCompletion(th, countedCompleter);
                }
            }, allocatedWork._weight, timeRemainingMs);
        }

        protected abstract Job<Models> startTraining(Key<Models> key, double d);

        protected abstract ModelSelectionStrategy getSelectionStrategy();

        /* JADX INFO: Access modifiers changed from: protected */
        public Job<Models> asModelsJob(final Job job, final Key<Models> key) {
            final Job job2 = new Job(key, Models.class.getName(), job._description);
            return job2.start(new H2O.H2OCountedCompleter() { // from class: ai.h2o.automl.ModelingStep.SelectionStep.3
                final Models models;

                {
                    this.models = new Models(key, Model.class, job2);
                    this.models.delete_and_lock(job2);
                }

                /* JADX WARN: Multi-variable type inference failed */
                @Override // water.H2O.H2OCountedCompleter
                public void compute2() {
                    ModelingStepsExecutor.ensureStopRequestPropagated(job, job2, 1000);
                    Keyed keyed = job.get();
                    this.models.unlock(job2);
                    if (keyed instanceof Model) {
                        this.models.addModel(keyed.getKey());
                    } else if (keyed instanceof ModelContainer) {
                        this.models.addModels(((ModelContainer) keyed).getModelKeys());
                        keyed.remove(false);
                    } else if (keyed != 0 || !job2.stop_requested()) {
                        throw new H2OIllegalArgumentException("Can only convert jobs producing a single Model or ModelContainer.");
                    }
                    tryComplete();
                }
            }, job._work, job._max_runtime_msecs);
        }
    }

    protected <MP extends Model.Parameters> Job<Grid> startSearch(Key<Grid> key, MP mp, Map<String, Object[]> map, HyperSpaceSearchCriteria hyperSpaceSearchCriteria) {
        if (!$assertionsDisabled && key == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && mp == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && map.size() <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && hyperSpaceSearchCriteria == null) {
            throw new AssertionError();
        }
        applyPreprocessing(mp);
        aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " hyperparameter search").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        return GridSearch.create(key, HyperSpaceWalker.BaseWalker.WalkerFactory.create(mp, map, new GridSearch.SimpleParametersBuilderFactory(), hyperSpaceSearchCriteria)).withParallelism(1).withMaxConsecutiveFailures(aml()._maxConsecutiveModelFailures).start();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public <MP extends Model.Parameters> Job<M> startModel(Key<M> key, MP mp) {
        if (!$assertionsDisabled && key == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && mp == 0) {
            throw new AssertionError();
        }
        Job job = new Job(key, ModelBuilder.javaName(this._algo.urlName()), this._description);
        applyPreprocessing(mp);
        ModelBuilder make = ModelBuilder.make(this._algo.urlName(), job, key);
        make._parms = mp;
        aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + key + " model training").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
        make.init(false);
        if (make._messages.length > 0) {
            for (ModelBuilder.ValidationMessage validationMessage : make._messages) {
                if (validationMessage.log_level() == 2) {
                    aml().eventLog().warn(EventLogEntry.Stage.ModelTraining, validationMessage.field() + " param, " + validationMessage.message());
                } else if (validationMessage.log_level() == 1) {
                    aml().eventLog().error(EventLogEntry.Stage.ModelTraining, validationMessage.field() + " param, " + validationMessage.message());
                }
            }
        }
        return make.trainModelOnH2ONode();
    }

    private boolean validParameters(Model.Parameters parameters, String[] strArr) {
        try {
            Model.Parameters clone = parameters.m2203clone();
            setCommonModelBuilderParams(clone);
            ModelBuilder make = ModelBuilder.make(clone);
            make.init(false);
            return Arrays.stream(strArr).allMatch(str -> {
                return make.getMessagesByFieldAndSeverity(str, (byte) 1).length == 0;
            });
        } catch (H2OIllegalArgumentException e) {
            return false;
        }
    }

    protected void setDistributionParameters(Model.Parameters parameters) {
        switch (aml().getDistributionFamily()) {
            case custom:
                parameters._custom_distribution_func = aml().getBuildSpec().build_control.custom_distribution_func;
                break;
            case huber:
                parameters._huber_alpha = aml().getBuildSpec().build_control.huber_alpha;
                break;
            case tweedie:
                parameters._tweedie_power = aml().getBuildSpec().build_control.tweedie_power;
                break;
            case quantile:
                parameters._quantile_alpha = aml().getBuildSpec().build_control.quantile_alpha;
                break;
        }
        try {
            parameters.setDistributionFamily(aml().getDistributionFamily());
        } catch (H2OIllegalArgumentException e) {
            parameters.setDistributionFamily(DistributionFamily.AUTO);
        }
        if (!validParameters(parameters, new String[]{"_distribution", "_family"})) {
            parameters.setDistributionFamily(DistributionFamily.AUTO);
        }
        if (aml().getDistributionFamily().equals(parameters.getDistributionFamily())) {
            return;
        }
        aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "Algo " + parameters.algoName() + " doesn't support " + this._aml.getDistributionFamily().name() + " distribution. Using AUTO distribution instead.");
    }

    protected ModelingStep(String str, IAlgo iAlgo, String str2, int i, int i2, AutoML autoML) {
        if (!$assertionsDisabled && i < 0) {
            throw new AssertionError();
        }
        this._provider = str;
        this._algo = iAlgo;
        this._id = str2;
        this._priorityGroup = i;
        this._weight = i2;
        this._aml = autoML;
        this._description = str + " " + str2;
    }

    public String getProvider() {
        return this._provider;
    }

    public String getId() {
        return this._id;
    }

    public String getGlobalId() {
        return this._provider + ":" + this._id;
    }

    public IAlgo getAlgo() {
        return this._algo;
    }

    public int getWeight() {
        return this._weight;
    }

    public int getPriorityGroup() {
        return this._priorityGroup;
    }

    public boolean isResumable() {
        return false;
    }

    public boolean ignores(AutoML.Constraint constraint) {
        return ArrayUtils.contains(this._ignoredConstraints, constraint);
    }

    public boolean limitModelTrainingTime() {
        return !ignores(AutoML.Constraint.TIMEOUT) && aml().getBuildSpec().build_control.stopping_criteria.max_models() == 0;
    }

    public boolean canRun() {
        WorkAllocations.Work allocatedWork = getAllocatedWork();
        return allocatedWork != null && allocatedWork._weight > 0;
    }

    public Job run() {
        Job startJob = startJob();
        if (startJob != null && startJob._result != null) {
            register(startJob._result);
            if (isResumable()) {
                aml().session().addResumableKey(startJob._result);
            }
        }
        return startJob;
    }

    public Iterator<? extends ModelingStep> iterateSubSteps() {
        return Collections.emptyIterator();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Optional<? extends ModelingStep> getSubStep(String str) {
        return Optional.empty();
    }

    protected abstract WorkAllocations.JobType getJobType();

    protected abstract Job startJob();

    /* JADX INFO: Access modifiers changed from: protected */
    public void onDone(Job job) {
        Iterator<Consumer<Job>> it = this._onDone.iterator();
        while (it.hasNext()) {
            it.next().accept(job);
        }
        this._onDone.clear();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void register(Key key) {
        aml().session().registerKeySource(key, this);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AutoML aml() {
        return this._aml;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public WorkAllocations.Work getAllocatedWork() {
        if (this._work == null) {
            this._work = getWorkAllocations().getAllocation(this._id, this._algo);
        }
        return this._work;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public WorkAllocations.Work makeWork() {
        return new WorkAllocations.Work(getId(), getAlgo(), getJobType(), getPriorityGroup(), getWeight());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Key makeKey(String str, boolean z) {
        return aml().makeKey(str, null, z);
    }

    protected WorkAllocations getWorkAllocations() {
        return aml()._workAllocations;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Model[] getTrainedModels() {
        return aml().leaderboard().getModels();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Key<Model>[] getTrainedModelsKeys() {
        return aml().leaderboard().getModelKeys();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isCVEnabled() {
        return aml().isCVEnabled();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        ModelingStep modelingStep = (ModelingStep) obj;
        return this._provider.equals(modelingStep._provider) && this._id.equals(modelingStep._id);
    }

    public int hashCode() {
        return Objects.hash(this._provider, this._id);
    }

    protected void setCommonModelBuilderParams(Model.Parameters parameters) {
        parameters._train = aml()._trainingFrame._key;
        if (null != aml()._validationFrame) {
            parameters._valid = aml()._validationFrame._key;
        }
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        parameters._response_column = buildSpec.input_spec.response_column;
        parameters._ignored_columns = buildSpec.input_spec.ignored_columns;
        setCrossValidationParams(parameters);
        setWeightingParams(parameters);
        setClassBalancingParams(parameters);
        parameters._custom_metric_func = buildSpec.build_control.custom_metric_func;
        parameters._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
        parameters._keep_cross_validation_fold_assignment = buildSpec.build_control.nfolds != 0 && buildSpec.build_control.keep_cross_validation_fold_assignment;
        parameters._export_checkpoints_dir = buildSpec.build_control.export_checkpoints_dir;
        parameters._main_model_time_budget_factor = 2.0d;
    }

    protected void setCrossValidationParams(Model.Parameters parameters) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        parameters._keep_cross_validation_predictions = aml().getBlendingFrame() == null || buildSpec.build_control.keep_cross_validation_predictions;
        parameters._fold_column = buildSpec.input_spec.fold_column;
        if (buildSpec.input_spec.fold_column == null) {
            parameters._nfolds = buildSpec.build_control.nfolds;
            if (buildSpec.build_control.nfolds > 1) {
                parameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            }
        }
    }

    protected void setWeightingParams(Model.Parameters parameters) {
        parameters._weights_column = aml().getBuildSpec().input_spec.weights_column;
    }

    protected void setClassBalancingParams(Model.Parameters parameters) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        if (buildSpec.build_control.balance_classes) {
            parameters._balance_classes = buildSpec.build_control.balance_classes;
            parameters._class_sampling_factors = buildSpec.build_control.class_sampling_factors;
            parameters._max_after_balance_size = buildSpec.build_control.max_after_balance_size;
        }
    }

    protected void setCustomParams(Model.Parameters parameters) {
        AutoMLBuildSpec.AutoMLCustomParameters autoMLCustomParameters = aml().getBuildSpec().build_models.algo_parameters;
        if (autoMLCustomParameters == null) {
            return;
        }
        autoMLCustomParameters.applyCustomParameters(this._algo, parameters);
    }

    protected void applyPreprocessing(Model.Parameters parameters) {
        if (aml().getPreprocessing() == null) {
            return;
        }
        for (PreprocessingStep preprocessingStep : aml().getPreprocessing()) {
            PreprocessingStep.Completer apply = preprocessingStep.apply(parameters, getPreprocessingConfig());
            this._onDone.add(job -> {
                apply.run();
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PreprocessingConfig getPreprocessingConfig() {
        return new PreprocessingConfig();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setStoppingCriteria(Model.Parameters parameters, Model.Parameters parameters2) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        if (parameters._stopping_metric == parameters2._stopping_metric) {
            parameters._stopping_metric = buildSpec.build_control.stopping_criteria.stopping_metric();
        }
        if (parameters._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
            parameters._stopping_metric = aml().getResponseColumn().cardinality() == -1 ? ScoreKeeper.StoppingMetric.deviance : ScoreKeeper.StoppingMetric.logloss;
        }
        if (parameters._stopping_rounds == parameters2._stopping_rounds) {
            parameters._stopping_rounds = buildSpec.build_control.stopping_criteria.stopping_rounds();
        }
        if (parameters._stopping_tolerance == parameters2._stopping_tolerance) {
            parameters._stopping_tolerance = buildSpec.build_control.stopping_criteria.stopping_tolerance();
        }
    }

    protected void setSeed(Model.Parameters parameters, Model.Parameters parameters2, SeedPolicy seedPolicy) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        if (parameters._seed == parameters2._seed) {
            switch (seedPolicy) {
                case Global:
                    parameters._seed = buildSpec.build_control.stopping_criteria.seed();
                    return;
                case Incremental:
                    parameters._seed = this._aml._incrementalSeed.get() == parameters2._seed ? parameters2._seed : this._aml._incrementalSeed.getAndIncrement();
                    return;
                default:
                    return;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initTimeConstraints(Model.Parameters parameters, double d) {
        AutoMLBuildSpec buildSpec = aml().getBuildSpec();
        if (parameters._max_runtime_secs == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            double max_runtime_secs_per_model = buildSpec.build_control.stopping_criteria.max_runtime_secs_per_model();
            parameters._max_runtime_secs = d <= CMAESOptimizer.DEFAULT_STOPFITNESS ? max_runtime_secs_per_model : Math.min(max_runtime_secs_per_model, d);
        }
    }

    private String getSortMetric() {
        Leaderboard leaderboard = aml().leaderboard();
        if (leaderboard == null) {
            return null;
        }
        return leaderboard.getSortMetric();
    }

    private static ScoreKeeper.StoppingMetric metricValueOf(String str) {
        if (str == null) {
            return ScoreKeeper.StoppingMetric.AUTO;
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -1322110077:
                if (str.equals("mean_residual_deviance")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return ScoreKeeper.StoppingMetric.deviance;
            default:
                try {
                    return (ScoreKeeper.StoppingMetric) EnumUtils.valueOf(ScoreKeeper.StoppingMetric.class, str);
                } catch (IllegalArgumentException e) {
                    return ScoreKeeper.StoppingMetric.AUTO;
                }
        }
    }

    static {
        $assertionsDisabled = !ModelingStep.class.desiredAssertionStatus();
        isDefaultModel = work -> {
            return work._type == WorkAllocations.JobType.ModelBuild;
        };
        isExplorationWork = work2 -> {
            return work2._type == WorkAllocations.JobType.ModelBuild || work2._type == WorkAllocations.JobType.HyperparamSearch;
        };
        isExploitationWork = work3 -> {
            return work3._type == WorkAllocations.JobType.Selection;
        };
    }
}
