package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelSelectionStrategies;
import ai.h2o.automl.ModelSelectionStrategy;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.Models;
import ai.h2o.automl.events.EventLogEntry;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.SequentialWalker;
import hex.tree.xgboost.XGBoostModel;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import water.Job;
import water.Key;

/* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps.class */
public class XGBoostSteps extends ModelingSteps {
    static final String NAME = Algo.XGBoost.name();
    private final ModelingStep[] defaults;
    private final ModelingStep[] grids;
    private final ModelingStep[] exploitation;

    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$DefaultXGBoostGridStep.class */
    static class DefaultXGBoostGridStep extends XGBoostGridStep {
        public DefaultXGBoostGridStep(String str, AutoML autoML) {
            super(str, autoML, false);
        }

        @Override // ai.h2o.automl.modeling.XGBoostSteps.XGBoostGridStep, ai.h2o.automl.ModelingStep.GridStep
        /* renamed from: prepareModelParameters, reason: merged with bridge method [inline-methods] */
        public XGBoostModel.XGBoostParameters mo37prepareModelParameters() {
            XGBoostModel.XGBoostParameters mo37prepareModelParameters = super.mo37prepareModelParameters();
            mo37prepareModelParameters._scale_pos_weight = new XGBoostModel.XGBoostParameters()._scale_pos_weight;
            return mo37prepareModelParameters;
        }

        @Override // ai.h2o.automl.ModelingStep.GridStep
        public Map<String, Object[]> prepareSearchParameters() {
            HashMap hashMap = new HashMap();
            if (this._emulateLightGBM) {
                hashMap.put("_max_leaves", new Integer[]{32, 1024, 32768, 1048576});
                hashMap.put("_max_depth", new Integer[]{10, 20, 50});
            } else {
                hashMap.put("_max_depth", new Integer[]{3, 6, 9, 12, 15});
                if (aml().getWeightsColumn() == null || aml().getWeightsColumn().isInt()) {
                    hashMap.put("_min_rows", new Double[]{Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(15.0d), Double.valueOf(20.0d)});
                } else {
                    hashMap.put("_min_rows", new Double[]{Double.valueOf(0.01d), Double.valueOf(0.1d), Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(15.0d), Double.valueOf(20.0d)});
                }
            }
            hashMap.put("_sample_rate", new Double[]{Double.valueOf(0.6d), Double.valueOf(0.8d), Double.valueOf(1.0d)});
            hashMap.put("_col_sample_rate", new Double[]{Double.valueOf(0.6d), Double.valueOf(0.8d), Double.valueOf(1.0d)});
            hashMap.put("_col_sample_rate_per_tree", new Double[]{Double.valueOf(0.7d), Double.valueOf(0.8d), Double.valueOf(0.9d), Double.valueOf(1.0d)});
            hashMap.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.dart});
            hashMap.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
            hashMap.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
            if (aml().getBuildSpec().build_control.balance_classes && aml().getDistributionFamily().equals(DistributionFamily.bernoulli)) {
                double[] classDistribution = aml().getClassDistribution();
                float f = (float) (classDistribution[0] / classDistribution[1]);
                float f2 = f < 1.0f ? 1.0f / f : f;
                hashMap.put("_scale_pos_weight", new Float[]{Float.valueOf(1.0f), Float.valueOf(f)});
                hashMap.put("_max_delta_step", new Float[]{Float.valueOf(0.0f), Float.valueOf(Math.min(5.0f, f2 / 2.0f)), Float.valueOf(Math.min(10.0f, f2))});
            }
            return hashMap;
        }
    }

    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$XGBoostExploitationStep.class */
    static abstract class XGBoostExploitationStep extends ModelingStep.SelectionStep<XGBoostModel> {
        boolean _emulateLightGBM;

        protected XGBoostModel getBestXGB() {
            return getBestXGBs(1).get(0);
        }

        protected List<XGBoostModel> getBestXGBs(int i) {
            ArrayList arrayList = new ArrayList();
            for (XGBoostModel xGBoostModel : getTrainedModels()) {
                if (xGBoostModel instanceof XGBoostModel) {
                    arrayList.add(xGBoostModel);
                }
                if (arrayList.size() == i) {
                    break;
                }
            }
            return arrayList;
        }

        @Override // ai.h2o.automl.ModelingStep
        public boolean canRun() {
            return super.canRun() && getBestXGBs(1).size() > 0;
        }

        public XGBoostExploitationStep(String str, AutoML autoML, boolean z) {
            super(XGBoostSteps.NAME, Algo.XGBoost, str, autoML);
            this._emulateLightGBM = z;
            if (autoML.getBuildSpec().build_models.exploitation_ratio > 0.0d) {
                this._ignoredConstraints = new AutoML.Constraint[]{AutoML.Constraint.MODEL_COUNT};
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$XGBoostGridStep.class */
    public static abstract class XGBoostGridStep extends ModelingStep.GridStep<XGBoostModel> {
        boolean _emulateLightGBM;

        public XGBoostGridStep(String str, AutoML autoML, boolean z) {
            super(XGBoostSteps.NAME, Algo.XGBoost, str, autoML);
            this._emulateLightGBM = z;
        }

        @Override // ai.h2o.automl.ModelingStep.GridStep
        /* renamed from: prepareModelParameters */
        public XGBoostModel.XGBoostParameters mo37prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(aml(), this._emulateLightGBM);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$XGBoostModelStep.class */
    public static abstract class XGBoostModelStep extends ModelingStep.ModelStep<XGBoostModel> {
        boolean _emulateLightGBM;

        XGBoostModelStep(String str, AutoML autoML, boolean z) {
            super(XGBoostSteps.NAME, Algo.XGBoost, str, autoML);
            this._emulateLightGBM = z;
        }

        @Override // ai.h2o.automl.ModelingStep.ModelStep
        /* renamed from: prepareModelParameters */
        public XGBoostModel.XGBoostParameters mo35prepareModelParameters() {
            XGBoostModel.XGBoostParameters prepareModelParameters = XGBoostSteps.prepareModelParameters(aml(), this._emulateLightGBM);
            if (aml().getBuildSpec().build_control.balance_classes && aml().getDistributionFamily().equals(DistributionFamily.bernoulli)) {
                double[] classDistribution = aml().getClassDistribution();
                prepareModelParameters._scale_pos_weight = (float) (classDistribution[0] / classDistribution[1]);
            }
            return prepareModelParameters;
        }
    }

    static XGBoostModel.XGBoostParameters prepareModelParameters(AutoML autoML, boolean z) {
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        if (z) {
            xGBoostParameters._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.hist;
            xGBoostParameters._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
        }
        xGBoostParameters._score_tree_interval = 5;
        xGBoostParameters._ntrees = 10000;
        return xGBoostParameters;
    }

    public XGBoostSteps(AutoML autoML) {
        super(autoML);
        this.defaults = new XGBoostModelStep[]{new XGBoostModelStep("def_1", aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.1
            @Override // ai.h2o.automl.modeling.XGBoostSteps.XGBoostModelStep, ai.h2o.automl.ModelingStep.ModelStep
            /* renamed from: prepareModelParameters, reason: merged with bridge method [inline-methods] */
            public XGBoostModel.XGBoostParameters mo35prepareModelParameters() {
                XGBoostModel.XGBoostParameters mo35prepareModelParameters = super.mo35prepareModelParameters();
                mo35prepareModelParameters._max_depth = 10;
                mo35prepareModelParameters._min_rows = 5.0d;
                mo35prepareModelParameters._sample_rate = 0.6d;
                mo35prepareModelParameters._col_sample_rate = 0.8d;
                mo35prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    mo35prepareModelParameters._max_leaves = 1 << mo35prepareModelParameters._max_depth;
                    mo35prepareModelParameters._max_depth *= 2;
                }
                return mo35prepareModelParameters;
            }
        }, new XGBoostModelStep("def_2", aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.2
            @Override // ai.h2o.automl.modeling.XGBoostSteps.XGBoostModelStep, ai.h2o.automl.ModelingStep.ModelStep
            /* renamed from: prepareModelParameters */
            public XGBoostModel.XGBoostParameters mo35prepareModelParameters() {
                XGBoostModel.XGBoostParameters mo35prepareModelParameters = super.mo35prepareModelParameters();
                mo35prepareModelParameters._max_depth = 15;
                mo35prepareModelParameters._min_rows = 10.0d;
                mo35prepareModelParameters._sample_rate = 0.6d;
                mo35prepareModelParameters._col_sample_rate = 0.8d;
                mo35prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    mo35prepareModelParameters._max_leaves = 1 << mo35prepareModelParameters._max_depth;
                    mo35prepareModelParameters._max_depth *= 2;
                }
                return mo35prepareModelParameters;
            }
        }, new XGBoostModelStep("def_3", aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.3
            @Override // ai.h2o.automl.modeling.XGBoostSteps.XGBoostModelStep, ai.h2o.automl.ModelingStep.ModelStep
            /* renamed from: prepareModelParameters */
            public XGBoostModel.XGBoostParameters mo35prepareModelParameters() {
                XGBoostModel.XGBoostParameters mo35prepareModelParameters = super.mo35prepareModelParameters();
                mo35prepareModelParameters._max_depth = 5;
                mo35prepareModelParameters._min_rows = 3.0d;
                mo35prepareModelParameters._sample_rate = 0.8d;
                mo35prepareModelParameters._col_sample_rate = 0.8d;
                mo35prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    mo35prepareModelParameters._max_leaves = 1 << mo35prepareModelParameters._max_depth;
                    mo35prepareModelParameters._max_depth *= 2;
                }
                return mo35prepareModelParameters;
            }
        }};
        this.grids = new XGBoostGridStep[]{new DefaultXGBoostGridStep("grid_1", aml())};
        this.exploitation = new ModelingStep[]{new XGBoostExploitationStep("lr_annealing", aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.4
            Key<Models> resultKey = null;

            @Override // ai.h2o.automl.ModelingStep.SelectionStep
            protected Job<Models> startTraining(Key key, double d) {
                this.resultKey = key;
                XGBoostModel bestXGB = getBestXGB();
                aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Retraining best XGBoost with learning rate annealing: " + bestXGB._key);
                XGBoostModel.XGBoostParameters clone = bestXGB._input_parms.clone();
                clone._max_runtime_secs = 0.0d;
                clone._learn_rate_annealing = 0.99d;
                initTimeConstraints(clone, d);
                setStoppingCriteria(clone, new XGBoostModel.XGBoostParameters());
                return asModelsJob(startModel(Key.make(key + "_model"), clone), key);
            }

            @Override // ai.h2o.automl.ModelingStep.SelectionStep
            protected ModelSelectionStrategy getSelectionStrategy() {
                return (keyArr, keyArr2) -> {
                    return new ModelSelectionStrategies.KeepBestN(1, () -> {
                        return makeTmpLeaderboard(Objects.toString(this.resultKey, this._provider + "_" + this._id));
                    }).select(new Key[]{getBestXGB()._key}, keyArr2);
                };
            }
        }, new XGBoostExploitationStep("lr_search", aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.5
            Key resultKey = null;

            @Override // ai.h2o.automl.ModelingStep.SelectionStep
            protected ModelSelectionStrategy getSelectionStrategy() {
                return (keyArr, keyArr2) -> {
                    return new ModelSelectionStrategies.KeepBestN(1, () -> {
                        return makeTmpLeaderboard(Objects.toString(this.resultKey, this._provider + "_" + this._id));
                    }).select(new Key[]{getBestXGB()._key}, keyArr2);
                };
            }

            /* JADX WARN: Type inference failed for: r0v21, types: [java.lang.Object[], java.lang.Object[][]] */
            @Override // ai.h2o.automl.ModelingStep.SelectionStep
            protected Job<Models> startTraining(Key key, double d) {
                this.resultKey = key;
                XGBoostModel xGBoostModel = getBestXGBs(1).get(0);
                aml().eventLog().info(EventLogEntry.Stage.ModelSelection, "Applying learning rate search on best XGBoost: " + xGBoostModel._key);
                XGBoostModel.XGBoostParameters clone = xGBoostModel._input_parms.clone();
                XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
                clone._max_runtime_secs = 0.0d;
                initTimeConstraints(clone, 0.0d);
                setStoppingCriteria(clone, xGBoostParameters);
                int i = clone._score_tree_interval;
                ?? r0 = {new Object[]{"_learn_rate", "_score_tree_interval"}, new Object[]{Double.valueOf(0.5d), Integer.valueOf(i)}, new Object[]{Double.valueOf(0.2d), Integer.valueOf(2 * i)}, new Object[]{Double.valueOf(0.1d), Integer.valueOf(3 * i)}, new Object[]{Double.valueOf(0.05d), Integer.valueOf(4 * i)}, new Object[]{Double.valueOf(0.02d), Integer.valueOf(5 * i)}, new Object[]{Double.valueOf(0.01d), Integer.valueOf(6 * i)}, new Object[]{Double.valueOf(0.005d), Integer.valueOf(7 * i)}, new Object[]{Double.valueOf(0.002d), Integer.valueOf(8 * i)}, new Object[]{Double.valueOf(0.001d), Integer.valueOf(9 * i)}, new Object[]{Double.valueOf(5.0E-4d), Integer.valueOf(10 * i)}};
                aml().eventLog().info(EventLogEntry.Stage.ModelTraining, "AutoML: starting " + this.resultKey + " model training").setNamedValue("start_" + this._provider + "_" + this._id, new Date(), EventLogEntry.epochFormat.get());
                return asModelsJob(GridSearch.startGridSearch(Key.make(key + "_grid"), new SequentialWalker(clone, (Object[][]) r0, new GridSearch.SimpleParametersBuilderFactory(), new HyperSpaceSearchCriteria.SequentialSearchCriteria(HyperSpaceSearchCriteria.StoppingCriteria.create().maxRuntimeSecs((int) d).stoppingMetric(clone._stopping_metric).stoppingRounds(3).stoppingTolerance(clone._stopping_tolerance).build())), 1), key);
            }
        }};
    }

    @Override // ai.h2o.automl.ModelingSteps
    public String getProvider() {
        return NAME;
    }

    @Override // ai.h2o.automl.ModelingSteps
    protected ModelingStep[] getDefaultModels() {
        return this.defaults;
    }

    @Override // ai.h2o.automl.ModelingSteps
    protected ModelingStep[] getGrids() {
        return this.grids;
    }

    @Override // ai.h2o.automl.ModelingSteps
    protected ModelingStep[] getOptionals() {
        return this.exploitation;
    }
}
