package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.tree.xgboost.XGBoostModel;
import java.util.HashMap;
import water.Job;

/* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps.class */
public class XGBoostSteps extends ModelingSteps {
    private ModelingStep[] defaults;
    private ModelingStep[] grids;

    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$XGBoostGridStep.class */
    static abstract class XGBoostGridStep extends ModelingStep.GridStep<XGBoostModel> {
        boolean _emulateLightGBM;

        public XGBoostGridStep(String str, int i, AutoML autoML, boolean z) {
            super(Algo.XGBoost, str, i, autoML);
            this._emulateLightGBM = z;
        }

        XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(aml(), this._emulateLightGBM);
        }
    }

    /* loaded from: input_file:ai/h2o/automl/modeling/XGBoostSteps$XGBoostModelStep.class */
    static abstract class XGBoostModelStep extends ModelingStep.ModelStep<XGBoostModel> {
        boolean _emulateLightGBM;

        XGBoostModelStep(String str, int i, AutoML autoML, boolean z) {
            super(Algo.XGBoost, str, i, autoML);
            this._emulateLightGBM = z;
        }

        XGBoostModel.XGBoostParameters prepareModelParameters() {
            return XGBoostSteps.prepareModelParameters(aml(), this._emulateLightGBM);
        }
    }

    static XGBoostModel.XGBoostParameters prepareModelParameters(AutoML autoML, boolean z) {
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        if (z) {
            xGBoostParameters._tree_method = XGBoostModel.XGBoostParameters.TreeMethod.hist;
            xGBoostParameters._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
        }
        xGBoostParameters._distribution = (!autoML.getResponseColumn().isBinary() || autoML.getResponseColumn().isNumeric()) ? autoML.getResponseColumn().isCategorical() ? DistributionFamily.multinomial : DistributionFamily.AUTO : DistributionFamily.bernoulli;
        xGBoostParameters._score_tree_interval = 5;
        xGBoostParameters._stopping_rounds = 5;
        xGBoostParameters._ntrees = 10000;
        xGBoostParameters._learn_rate = 0.05d;
        return xGBoostParameters;
    }

    public XGBoostSteps(AutoML autoML) {
        super(autoML);
        this.defaults = new XGBoostModelStep[]{new XGBoostModelStep("def_1", 10, aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.1
            @Override // ai.h2o.automl.ModelingStep.ModelStep, ai.h2o.automl.ModelingStep
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters prepareModelParameters = prepareModelParameters();
                prepareModelParameters._max_depth = 10;
                prepareModelParameters._min_rows = 5.0d;
                prepareModelParameters._sample_rate = 0.6d;
                prepareModelParameters._col_sample_rate = 0.8d;
                prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    prepareModelParameters._max_leaves = 1 << prepareModelParameters._max_depth;
                    prepareModelParameters._max_depth *= 2;
                    prepareModelParameters._min_sum_hessian_in_leaf = (float) prepareModelParameters._min_rows;
                }
                return trainModel(prepareModelParameters);
            }
        }, new XGBoostModelStep("def_2", 10, aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.2
            @Override // ai.h2o.automl.ModelingStep.ModelStep, ai.h2o.automl.ModelingStep
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters prepareModelParameters = prepareModelParameters();
                prepareModelParameters._max_depth = 20;
                prepareModelParameters._min_rows = 10.0d;
                prepareModelParameters._sample_rate = 0.6d;
                prepareModelParameters._col_sample_rate = 0.8d;
                prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    prepareModelParameters._max_leaves = 1 << prepareModelParameters._max_depth;
                    prepareModelParameters._max_depth *= 2;
                    prepareModelParameters._min_sum_hessian_in_leaf = (float) prepareModelParameters._min_rows;
                }
                return trainModel(prepareModelParameters);
            }
        }, new XGBoostModelStep("def_3", 10, aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.3
            @Override // ai.h2o.automl.ModelingStep.ModelStep, ai.h2o.automl.ModelingStep
            protected Job<XGBoostModel> startJob() {
                XGBoostModel.XGBoostParameters prepareModelParameters = prepareModelParameters();
                prepareModelParameters._max_depth = 5;
                prepareModelParameters._min_rows = 3.0d;
                prepareModelParameters._sample_rate = 0.8d;
                prepareModelParameters._col_sample_rate = 0.8d;
                prepareModelParameters._col_sample_rate_per_tree = 0.8d;
                if (this._emulateLightGBM) {
                    prepareModelParameters._max_leaves = 1 << prepareModelParameters._max_depth;
                    prepareModelParameters._max_depth *= 2;
                    prepareModelParameters._min_sum_hessian_in_leaf = (float) prepareModelParameters._min_rows;
                }
                return trainModel(prepareModelParameters);
            }
        }};
        this.grids = new XGBoostGridStep[]{new XGBoostGridStep("grid_1", 100, aml(), false) { // from class: ai.h2o.automl.modeling.XGBoostSteps.4
            @Override // ai.h2o.automl.ModelingStep.GridStep, ai.h2o.automl.ModelingStep
            protected Job<Grid> startJob() {
                XGBoostModel.XGBoostParameters prepareModelParameters = prepareModelParameters();
                HashMap hashMap = new HashMap();
                if (this._emulateLightGBM) {
                    hashMap.put("_max_leaves", new Integer[]{32, 1024, 32768, 1048576});
                    hashMap.put("_max_depth", new Integer[]{10, 20, 50});
                    hashMap.put("_min_sum_hessian_in_leaf", new Double[]{Double.valueOf(0.01d), Double.valueOf(0.1d), Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(15.0d), Double.valueOf(20.0d)});
                } else {
                    hashMap.put("_max_depth", new Integer[]{5, 10, 15, 20});
                    hashMap.put("_min_rows", new Double[]{Double.valueOf(0.01d), Double.valueOf(0.1d), Double.valueOf(1.0d), Double.valueOf(3.0d), Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(15.0d), Double.valueOf(20.0d)});
                }
                hashMap.put("_sample_rate", new Double[]{Double.valueOf(0.6d), Double.valueOf(0.8d), Double.valueOf(1.0d)});
                hashMap.put("_col_sample_rate", new Double[]{Double.valueOf(0.6d), Double.valueOf(0.8d), Double.valueOf(1.0d)});
                hashMap.put("_col_sample_rate_per_tree", new Double[]{Double.valueOf(0.7d), Double.valueOf(0.8d), Double.valueOf(0.9d), Double.valueOf(1.0d)});
                hashMap.put("_booster", new XGBoostModel.XGBoostParameters.Booster[]{XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.gbtree, XGBoostModel.XGBoostParameters.Booster.dart});
                hashMap.put("_reg_lambda", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(1.0f), Float.valueOf(10.0f), Float.valueOf(100.0f)});
                hashMap.put("_reg_alpha", new Float[]{Float.valueOf(0.001f), Float.valueOf(0.01f), Float.valueOf(0.1f), Float.valueOf(0.5f), Float.valueOf(1.0f)});
                return hyperparameterSearch(prepareModelParameters, hashMap);
            }
        }};
    }

    @Override // ai.h2o.automl.ModelingSteps
    protected ModelingStep[] getDefaultModels() {
        return this.defaults;
    }

    @Override // ai.h2o.automl.ModelingSteps
    protected ModelingStep[] getGrids() {
        return this.grids;
    }
}
