package ai.h2o.automl;

import hex.Model;
import hex.ScoreKeeper;
import hex.deeplearning.DeepLearningModel;
import hex.ensemble.StackedEnsembleModel;
import hex.glm.GLMModel;
import hex.grid.HyperSpaceSearchCriteria;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBMModel;
import hex.tree.xgboost.XGBoostModel;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.stream.Stream;
import water.Iced;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OIllegalValueException;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.PojoUtils;

/* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec.class */
public class AutoMLBuildSpec extends Iced {
    private static final DateFormat projectTimeStampFormat = new SimpleDateFormat("yyyyMMdd_HmmssSSS");
    public final AutoMLBuildControl build_control = new AutoMLBuildControl();
    public final AutoMLInput input_spec = new AutoMLInput();
    public final AutoMLBuildModels build_models = new AutoMLBuildModels();

    /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLBuildControl.class */
    public static final class AutoMLBuildControl extends Iced {
        public float[] class_sampling_factors;
        public String project_name = null;
        public boolean balance_classes = false;
        public float max_after_balance_size = 5.0f;
        public int nfolds = 5;
        public boolean keep_cross_validation_predictions = false;
        public boolean keep_cross_validation_models = false;
        public boolean keep_cross_validation_fold_assignment = false;
        public String export_checkpoints_dir = null;
        public final AutoMLStoppingCriteria stopping_criteria = new AutoMLStoppingCriteria();

        public AutoMLBuildControl() {
            this.stopping_criteria.set_max_models(0);
            this.stopping_criteria.set_max_runtime_secs(0.0d);
            this.stopping_criteria.set_max_runtime_secs_per_model(0.0d);
            this.stopping_criteria.set_stopping_rounds(3);
            this.stopping_criteria.set_stopping_tolerance(0.001d);
            this.stopping_criteria.set_stopping_metric(ScoreKeeper.StoppingMetric.AUTO);
        }
    }

    /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLBuildModels.class */
    public static final class AutoMLBuildModels extends Iced {
        public Algo[] exclude_algos;
        public Algo[] include_algos;
        public StepDefinition[] modeling_plan;
        public AutoMLCustomParameters algo_parameters = new AutoMLCustomParameters();
    }

    /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLCustomParameters.class */
    public static final class AutoMLCustomParameters extends Iced {
        static final String ALGO_PARAMS_ALL_ENABLED = "sys.ai.h2o.automl.algo_parameters.all.enabled";
        private static final String[] ALLOWED_PARAMETERS = {"monotone_constraints"};
        private static final String ROOT_PARAM = "algo_parameters";
        private final IcedHashMap<String, String[]> _algoParameterNames = new IcedHashMap<>();
        private final IcedHashMap<String, Model.Parameters> _algoParameters = new IcedHashMap<>();

        /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLCustomParameters$AutoMLCustomParameter.class */
        public static final class AutoMLCustomParameter<V> extends Iced {
            private Algo _algo;
            private String _name;
            private V _value;

            private AutoMLCustomParameter(String str, V v) {
                this._name = str;
                this._value = v;
            }

            private AutoMLCustomParameter(Algo algo, String str, V v) {
                this._algo = algo;
                this._name = str;
                this._value = v;
            }
        }

        /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLCustomParameters$Builder.class */
        public static final class Builder {
            private final transient List<AutoMLCustomParameter> _anyAlgoParams = new ArrayList();
            private final transient List<AutoMLCustomParameter> _specificAlgoParams = new ArrayList();

            public <V> Builder add(String str, V v) {
                assertParameterAllowed(str);
                this._anyAlgoParams.add(new AutoMLCustomParameter(str, v));
                return this;
            }

            public <V> Builder add(Algo algo, String str, V v) {
                assertParameterAllowed(str);
                this._specificAlgoParams.add(new AutoMLCustomParameter(algo, str, v));
                return this;
            }

            public AutoMLCustomParameters build() {
                AutoMLCustomParameters autoMLCustomParameters = new AutoMLCustomParameters();
                for (AutoMLCustomParameter autoMLCustomParameter : this._anyAlgoParams) {
                    if (!autoMLCustomParameters.addParameter(autoMLCustomParameter._name, autoMLCustomParameter._value)) {
                        throw new H2OIllegalValueException(autoMLCustomParameter._name, AutoMLCustomParameters.ROOT_PARAM, autoMLCustomParameter._value);
                    }
                }
                for (AutoMLCustomParameter autoMLCustomParameter2 : this._specificAlgoParams) {
                    if (!autoMLCustomParameters.addParameter(autoMLCustomParameter2._algo, autoMLCustomParameter2._name, autoMLCustomParameter2._value)) {
                        throw new H2OIllegalValueException(autoMLCustomParameter2._name, AutoMLCustomParameters.ROOT_PARAM, autoMLCustomParameter2._value);
                    }
                }
                return autoMLCustomParameters;
            }

            private void assertParameterAllowed(String str) {
                if (!Boolean.parseBoolean(System.getProperty(AutoMLCustomParameters.ALGO_PARAMS_ALL_ENABLED, "false")) && !ArrayUtils.contains(AutoMLCustomParameters.ALLOWED_PARAMETERS, str)) {
                    throw new H2OIllegalValueException(AutoMLCustomParameters.ROOT_PARAM, str);
                }
            }
        }

        public static Builder create() {
            return new Builder();
        }

        public boolean hasCustomParams(Algo algo) {
            return this._algoParameterNames.get(algo.name()) != null;
        }

        public boolean hasCustomParam(Algo algo, String str) {
            return ArrayUtils.contains((Object[]) this._algoParameterNames.get(algo.name()), str);
        }

        public void applyCustomParameters(Algo algo, Model.Parameters parameters) {
            if (hasCustomParams(algo)) {
                PojoUtils.copyProperties(parameters, getCustomizedDefaults(algo), PojoUtils.FieldNaming.CONSISTENT, (String[]) null, (String[]) Stream.of((Object[]) getCustomParameterNames(algo)).map(str -> {
                    return "_" + str;
                }).toArray(i -> {
                    return new String[i];
                }));
            }
        }

        String[] getCustomParameterNames(Algo algo) {
            return (String[]) this._algoParameterNames.get(algo.name());
        }

        Model.Parameters getCustomizedDefaults(Algo algo) {
            if (!this._algoParameters.containsKey(algo.name())) {
                this._algoParameters.put(algo.name(), defaultParameters(algo));
            }
            return (Model.Parameters) this._algoParameters.get(algo.name());
        }

        private Model.Parameters defaultParameters(Algo algo) {
            switch (algo) {
                case DeepLearning:
                    return new DeepLearningModel.DeepLearningParameters();
                case DRF:
                    return new DRFModel.DRFParameters();
                case GBM:
                    return new GBMModel.GBMParameters();
                case GLM:
                    return new GLMModel.GLMParameters();
                case StackedEnsemble:
                    return new StackedEnsembleModel.StackedEnsembleParameters();
                case XGBoost:
                    return new XGBoostModel.XGBoostParameters();
                default:
                    throw new H2OIllegalArgumentException("Custom parameters are not supported for " + algo.name() + ".");
            }
        }

        private void addParameterName(Algo algo, String str) {
            if (!this._algoParameterNames.containsKey(algo.name())) {
                this._algoParameterNames.put(algo.name(), new String[]{str});
                return;
            }
            String[] strArr = (String[]) this._algoParameterNames.get(algo.name());
            if (ArrayUtils.contains(strArr, str)) {
                return;
            }
            this._algoParameterNames.put(algo.name(), ArrayUtils.append(strArr, new String[]{str}));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public <V> boolean addParameter(String str, V v) {
            boolean z = false;
            for (Algo algo : Algo.values()) {
                z |= addParameter(algo, str, v);
            }
            return z;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public <V> boolean addParameter(Algo algo, String str, V v) {
            Model.Parameters customizedDefaults = getCustomizedDefaults(algo);
            try {
                if (setField(customizedDefaults, str, v, PojoUtils.FieldNaming.DEST_HAS_UNDERSCORES) || setField(customizedDefaults, str, v, PojoUtils.FieldNaming.CONSISTENT)) {
                    addParameterName(algo, str);
                    return true;
                }
                Log.debug(new Object[]{"Could not set custom param " + str + " for algo " + algo});
                return false;
            } catch (IllegalArgumentException e) {
                throw new H2OIllegalValueException(str, ROOT_PARAM, v);
            }
        }

        private <D, V> boolean setField(D d, String str, V v, PojoUtils.FieldNaming fieldNaming) {
            try {
                PojoUtils.setField(d, str, v, fieldNaming);
                return true;
            } catch (IllegalArgumentException e) {
                try {
                    PojoUtils.getFieldValue(d, str, fieldNaming);
                    throw e;
                } catch (IllegalArgumentException e2) {
                    return false;
                }
            }
        }
    }

    /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLInput.class */
    public static final class AutoMLInput extends Iced {
        public Key<Frame> training_frame;
        public Key<Frame> validation_frame;
        public Key<Frame> blending_frame;
        public Key<Frame> leaderboard_frame;
        public String response_column;
        public String fold_column;
        public String weights_column;
        public String[] ignored_columns;
        public String sort_metric = ScoreKeeper.StoppingMetric.AUTO.name();
    }

    /* loaded from: input_file:ai/h2o/automl/AutoMLBuildSpec$AutoMLStoppingCriteria.class */
    public static final class AutoMLStoppingCriteria extends Iced {
        public static final int AUTO_STOPPING_TOLERANCE = -1;
        private final HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria _searchCriteria = new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria();
        private double _max_runtime_secs_per_model = 0.0d;

        public double max_runtime_secs_per_model() {
            return this._max_runtime_secs_per_model;
        }

        public void set_max_runtime_secs_per_model(double d) {
            this._max_runtime_secs_per_model = d;
        }

        public long seed() {
            return this._searchCriteria.seed();
        }

        public int max_models() {
            return this._searchCriteria.max_models();
        }

        public double max_runtime_secs() {
            return this._searchCriteria.max_runtime_secs();
        }

        public int stopping_rounds() {
            return this._searchCriteria.stopping_rounds();
        }

        public ScoreKeeper.StoppingMetric stopping_metric() {
            return this._searchCriteria.stopping_metric();
        }

        public double stopping_tolerance() {
            return this._searchCriteria.stopping_tolerance();
        }

        public void set_seed(long j) {
            this._searchCriteria.set_seed(j);
        }

        public void set_max_models(int i) {
            this._searchCriteria.set_max_models(i);
        }

        public void set_max_runtime_secs(double d) {
            this._searchCriteria.set_max_runtime_secs(d);
        }

        public void set_stopping_rounds(int i) {
            this._searchCriteria.set_stopping_rounds(i);
        }

        public void set_stopping_metric(ScoreKeeper.StoppingMetric stoppingMetric) {
            this._searchCriteria.set_stopping_metric(stoppingMetric);
        }

        public void set_stopping_tolerance(double d) {
            this._searchCriteria.set_stopping_tolerance(d);
        }

        public void set_default_stopping_tolerance_for_frame(Frame frame) {
            this._searchCriteria.set_default_stopping_tolerance_for_frame(frame);
        }

        public static double default_stopping_tolerance_for_frame(Frame frame) {
            return HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria.default_stopping_tolerance_for_frame(frame);
        }

        public HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria getSearchCriteria() {
            return this._searchCriteria;
        }

        public AutoMLStoppingCriteria() {
            set_max_models(0);
            set_max_runtime_secs(3600.0d);
            set_max_runtime_secs_per_model(0.0d);
            set_stopping_rounds(3);
            set_stopping_tolerance(0.001d);
            set_stopping_metric(ScoreKeeper.StoppingMetric.AUTO);
        }
    }

    public String project() {
        if (this.build_control.project_name == null) {
            this.build_control.project_name = "AutoML_" + projectTimeStampFormat.format(new Date());
        }
        return this.build_control.project_name;
    }
}
