package hex.deepwater;

import hex.Model;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import java.io.File;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.Arrays;
import javax.imageio.ImageIO;
import water.H2O;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/deepwater/DeepWaterParameters.class */
public class DeepWaterParameters extends Model.Parameters {
    public String _network_definition_file;
    public String _network_parameters_file;
    public String _export_native_parameters_prefix;
    public String _mean_image_file;
    public double _clip_gradient = 10.0d;
    public boolean _gpu = true;
    public int[] _device_id = {0};
    public Network _network = Network.auto;
    public Backend _backend = Backend.mxnet;
    public ProblemType _problem_type = ProblemType.auto;
    public int[] _image_shape = {0, 0};
    public int _channels = 3;
    public boolean _overwrite_with_best_model = true;
    public boolean _autoencoder = false;
    public boolean _sparse = false;
    public boolean _use_all_factor_levels = true;
    public MissingValuesHandling _missing_values_handling = MissingValuesHandling.MeanImputation;
    public boolean _standardize = true;
    public double _epochs = 10.0d;
    public Activation _activation = null;
    public int[] _hidden = null;
    public double _input_dropout_ratio = 0.0d;
    public double[] _hidden_dropout_ratios = null;
    public long _train_samples_per_iteration = -2;
    public double _target_ratio_comm_to_comp = 0.05d;
    public double _learning_rate = 0.005d;
    public double _learning_rate_annealing = 1.0E-6d;
    public double _momentum_start = 0.9d;
    public double _momentum_ramp = 10000.0d;
    public double _momentum_stable = 0.99d;
    public double _score_interval = 5.0d;
    public long _score_training_samples = 10000;
    public long _score_validation_samples = 0;
    public double _score_duty_cycle = 0.1d;
    public boolean _quiet_mode = false;
    public boolean _replicate_training_data = true;
    public boolean _single_node_mode = false;
    public boolean _shuffle_training_data = true;
    public int _mini_batch_size = 32;
    protected boolean _cache_data = true;

    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$Activation.class */
    public enum Activation {
        Rectifier,
        Tanh
    }

    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$Backend.class */
    public enum Backend {
        unknown,
        mxnet,
        caffe,
        tensorflow,
        xgrpc
    }

    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$MissingValuesHandling.class */
    public enum MissingValuesHandling {
        MeanImputation,
        Skip
    }

    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$Network.class */
    public enum Network {
        auto,
        user,
        lenet,
        alexnet,
        vgg,
        googlenet,
        inception_bn,
        resnet
    }

    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$ProblemType.class */
    public enum ProblemType {
        auto,
        image,
        text,
        dataset
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/deepwater/DeepWaterParameters$Sanity.class */
    public static class Sanity {
        private static final transient String[] cp_modifiable = {"_seed", "_checkpoint", "_epochs", "_score_interval", "_train_samples_per_iteration", "_target_ratio_comm_to_comp", "_score_duty_cycle", "_score_training_samples", "_score_validation_samples", "_score_validation_sampling", "_classification_stop", "_regression_stop", "_stopping_rounds", "_stopping_metric", "_quiet_mode", "_max_confusion_matrix_size", "_max_hit_ratio_k", "_diagnostics", "_variable_importances", "_replicate_training_data", "_shuffle_training_data", "_single_node_mode", "_overwrite_with_best_model", "_mini_batch_size", "_network_parameters_file", "_clip_gradient", "_learning_rate", "_learning_rate_annealing", "_gpu", "_sparse", "_device_id", "_input_dropout_ratio", "_hidden_dropout_ratios", "_cache_data", "_export_native_parameters_prefix", "_image_shape"};
        private static final transient String[] cp_not_modifiable = {"_drop_na20_cols", "_missing_values_handling", "_response_column", "_activation", "_use_all_factor_levels", "_problem_type", "_channels", "_standardize", "_autoencoder", "_network", "_backend", "_momentum_start", "_momentum_ramp", "_momentum_stable", "_ignore_const_cols", "_max_categorical_features", "_nfolds", "_distribution", "_network_definition_file", "_mean_image_file"};

        Sanity() {
        }

        static void checkCompleteness() {
            for (Field field : DeepWaterParameters.class.getDeclaredFields()) {
                if (!ArrayUtils.contains(cp_not_modifiable, field.getName()) && !ArrayUtils.contains(cp_modifiable, field.getName()) && !field.getName().equals("_hidden") && !field.getName().equals("_ignored_columns") && !field.getName().equals("$jacocoData")) {
                    throw H2O.unimpl("Please add " + field.getName() + " to either cp_modifiable or cp_not_modifiable");
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static void checkIfParameterChangeAllowed(DeepWaterParameters deepWaterParameters, DeepWaterParameters deepWaterParameters2) {
            checkCompleteness();
            if (deepWaterParameters2._nfolds != 0) {
                throw new UnsupportedOperationException("nfolds must be 0: Cross-validation is not supported during checkpoint restarts.");
            }
            if ((deepWaterParameters2._valid == null) != (deepWaterParameters._valid == null)) {
                throw new H2OIllegalArgumentException("Presence of validation dataset must agree with the checkpointed model.");
            }
            if (!deepWaterParameters2._autoencoder && (deepWaterParameters2._response_column == null || !deepWaterParameters2._response_column.equals(deepWaterParameters._response_column))) {
                throw new H2OIllegalArgumentException("Response column (" + deepWaterParameters2._response_column + ") is not the same as for the checkpointed model: " + deepWaterParameters._response_column);
            }
            if (!Arrays.equals(deepWaterParameters2._ignored_columns, deepWaterParameters._ignored_columns)) {
                throw new H2OIllegalArgumentException("Ignored columns must be the same as for the checkpointed model.");
            }
            loop0: for (Field field : deepWaterParameters.getClass().getFields()) {
                if (ArrayUtils.contains(cp_not_modifiable, field.getName())) {
                    for (Field field2 : deepWaterParameters2.getClass().getFields()) {
                        if (field.equals(field2)) {
                            try {
                                if ((field2.get(deepWaterParameters2) == null || field.get(deepWaterParameters) == null || !field.get(deepWaterParameters).toString().equals(field2.get(deepWaterParameters2).toString())) && (field.get(deepWaterParameters) != null || field2.get(deepWaterParameters2) != null)) {
                                    throw new H2OIllegalArgumentException("Cannot change parameter: '" + field.getName() + "': " + field.get(deepWaterParameters) + " -> " + field2.get(deepWaterParameters2));
                                    break loop0;
                                }
                            } catch (IllegalAccessException e) {
                                e.printStackTrace();
                            }
                        }
                    }
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static void updateParametersDuringCheckpointRestart(DeepWaterParameters deepWaterParameters, DeepWaterParameters deepWaterParameters2, boolean z, boolean z2) {
            for (Field field : deepWaterParameters2.getClass().getFields()) {
                if (ArrayUtils.contains(cp_modifiable, field.getName())) {
                    for (Field field2 : deepWaterParameters.getClass().getFields()) {
                        if (field.equals(field2)) {
                            try {
                                if (field2.get(deepWaterParameters) == null || field.get(deepWaterParameters2) == null || !field.get(deepWaterParameters2).toString().equals(field2.get(deepWaterParameters).toString())) {
                                    if (field.get(deepWaterParameters2) != null || field2.get(deepWaterParameters) != null) {
                                        if (!deepWaterParameters2._quiet_mode && !z2) {
                                            Log.info(new Object[]{"Applying user-requested modification of '" + field.getName() + "': " + field.get(deepWaterParameters2) + " -> " + field2.get(deepWaterParameters)});
                                        }
                                        if (z) {
                                            field.set(deepWaterParameters2, field2.get(deepWaterParameters));
                                        }
                                    }
                                }
                            } catch (IllegalAccessException e) {
                                e.printStackTrace();
                            }
                        }
                    }
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static void modifyParms(DeepWaterParameters deepWaterParameters, DeepWaterParameters deepWaterParameters2, int i) {
            if (H2O.CLOUD.size() == 1 && deepWaterParameters._replicate_training_data) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_replicate_training_data: Disabling replicate_training_data on 1 node."});
                }
                deepWaterParameters2._replicate_training_data = false;
            }
            if (deepWaterParameters._distribution == DistributionFamily.AUTO) {
                if (i > 1) {
                    deepWaterParameters2._distribution = i == 2 ? DistributionFamily.bernoulli : DistributionFamily.multinomial;
                } else {
                    deepWaterParameters2._distribution = DistributionFamily.gaussian;
                }
            }
            if (deepWaterParameters._single_node_mode && (H2O.CLOUD.size() == 1 || !deepWaterParameters._replicate_training_data)) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_single_node_mode: Disabling single_node_mode (only for multi-node operation with replicated training data)."});
                }
                deepWaterParameters2._single_node_mode = false;
            }
            if (deepWaterParameters._overwrite_with_best_model && deepWaterParameters._nfolds != 0) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_overwrite_with_best_model: Disabling overwrite_with_best_model in combination with n-fold cross-validation."});
                }
                deepWaterParameters2._overwrite_with_best_model = false;
            }
            if (deepWaterParameters._problem_type == ProblemType.auto) {
                deepWaterParameters2._problem_type = deepWaterParameters.guessProblemType();
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_problem_type: Automatically selecting problem_type: " + deepWaterParameters2._problem_type.toString()});
                }
            }
            if (deepWaterParameters._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.AUTO) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_categorical_encoding: Automatically enabling OneHotInternal categorical encoding."});
                }
                deepWaterParameters2._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotInternal;
            }
            if (deepWaterParameters._nfolds != 0 && deepWaterParameters._overwrite_with_best_model) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_overwrite_with_best_model: Automatically disabling overwrite_with_best_model, since the final model is the only scored model with n-fold cross-validation."});
                }
                deepWaterParameters2._overwrite_with_best_model = false;
            }
            if (deepWaterParameters._network == Network.auto || deepWaterParameters._network == null) {
                if (deepWaterParameters._network_definition_file == null || deepWaterParameters._network_definition_file.equals("")) {
                    if (deepWaterParameters2._problem_type == ProblemType.image) {
                        deepWaterParameters2._network = Network.inception_bn;
                    }
                    if (deepWaterParameters2._problem_type == ProblemType.text || deepWaterParameters2._problem_type == ProblemType.dataset) {
                        deepWaterParameters2._network = null;
                        if (deepWaterParameters._hidden == null) {
                            deepWaterParameters2._hidden = new int[]{200, 200};
                            deepWaterParameters2._activation = Activation.Rectifier;
                            deepWaterParameters2._hidden_dropout_ratios = new double[deepWaterParameters2._hidden.length];
                        }
                    }
                    if (!deepWaterParameters._quiet_mode && deepWaterParameters2._network != null && deepWaterParameters2._network != Network.user) {
                        Log.info(new Object[]{"_network: Using " + deepWaterParameters2._network + " model by default."});
                    }
                } else {
                    if (!deepWaterParameters._quiet_mode) {
                        Log.info(new Object[]{"_network_definition_file: Automatically setting network type to 'user', since a network definition file was provided."});
                    }
                    deepWaterParameters2._network = Network.user;
                }
            }
            if (deepWaterParameters._autoencoder && deepWaterParameters._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
                if (!deepWaterParameters._quiet_mode) {
                    Log.info(new Object[]{"_stopping_metric: Automatically setting stopping_metric to MSE for autoencoder."});
                }
                deepWaterParameters2._stopping_metric = ScoreKeeper.StoppingMetric.MSE;
            }
            if (deepWaterParameters2._hidden != null) {
                if (deepWaterParameters2._hidden_dropout_ratios == null) {
                    if (!deepWaterParameters._quiet_mode) {
                        Log.info(new Object[]{"_hidden_dropout_ratios: Automatically setting hidden_dropout_ratios to 0 for all layers."});
                    }
                    deepWaterParameters2._hidden_dropout_ratios = new double[deepWaterParameters2._hidden.length];
                }
                if (deepWaterParameters2._activation == null) {
                    deepWaterParameters2._activation = Activation.Rectifier;
                    if (!deepWaterParameters._quiet_mode) {
                        Log.info(new Object[]{"_activation: Automatically setting activation to " + deepWaterParameters2._activation + " for all layers."});
                    }
                }
                if (deepWaterParameters._quiet_mode) {
                    return;
                }
                Log.info(new Object[]{"Hidden layers: " + Arrays.toString(deepWaterParameters2._hidden)});
                Log.info(new Object[]{"Activation function: " + deepWaterParameters2._activation});
                Log.info(new Object[]{"Input dropout ratio: " + deepWaterParameters2._input_dropout_ratio});
                Log.info(new Object[]{"Hidden layer dropout ratio: " + Arrays.toString(deepWaterParameters2._hidden_dropout_ratios)});
            }
        }
    }

    public String algoName() {
        return "DeepWater";
    }

    public String fullName() {
        return "Deep Water";
    }

    public String javaName() {
        return DeepWaterModel.class.getName();
    }

    protected double defaultStoppingTolerance() {
        return 0.0d;
    }

    public DeepWaterParameters() {
        this._stopping_rounds = 5;
    }

    public long progressUnits() {
        if (train() == null) {
            return 1L;
        }
        return (long) Math.ceil(this._epochs * train().numRows());
    }

    public float learningRate(double d) {
        return (float) (this._learning_rate / (1.0d + (this._learning_rate_annealing * d)));
    }

    public final float momentum(double d) {
        double d2 = this._momentum_start;
        if (this._momentum_ramp > 0.0d) {
            d2 = d >= this._momentum_ramp ? this._momentum_stable : d2 + (((this._momentum_stable - this._momentum_start) * d) / this._momentum_ramp);
        }
        return (float) d2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void validate(DeepWater deepWater, boolean z) {
        boolean isClassifier = (z || deepWater.nclasses() != 0) ? deepWater.isClassifier() : this._distribution == DistributionFamily.bernoulli || this._distribution == DistributionFamily.bernoulli;
        if (this._mini_batch_size < 1) {
            deepWater.error("_mini_batch_size", "Mini-batch size must be >= 1");
        }
        if (this._weights_column != null && z) {
            Vec vec = train().vec(this._weights_column);
            if (!vec.isInt() || vec.max() > 1.0d || vec.min() < 0.0d) {
                deepWater.error("_weights_column", "only supporting weights of 0 or 1 right now");
            }
        }
        if (this._clip_gradient <= 0.0d) {
            deepWater.error("_clip_gradient", "Clip gradient must be >= 0");
        }
        if (this._hidden != null && this._network_definition_file != null && !this._network_definition_file.isEmpty()) {
            deepWater.error("_hidden", "Cannot provide hidden layers and a network definition file at the same time.");
        }
        if (this._activation != null && this._network_definition_file != null && !this._network_definition_file.isEmpty()) {
            deepWater.error("_activation", "Cannot provide activation functions and a network definition file at the same time.");
        }
        if (this._problem_type == ProblemType.image) {
            if (this._image_shape.length != 2) {
                deepWater.error("_image_shape", "image_shape must have 2 dimensions (width, height)");
            }
            if (this._image_shape[0] < 0) {
                deepWater.error("_image_shape", "image_shape[0] must be >=1 or automatic (0).");
            }
            if (this._image_shape[1] < 0) {
                deepWater.error("_image_shape", "image_shape[1] must be >=1 or automatic (0).");
            }
            if (this._channels != 1 && this._channels != 3) {
                deepWater.error("_channels", "channels must be either 1 or 3.");
            }
        } else if (this._problem_type != ProblemType.auto) {
            deepWater.warn("_image_shape", "image shape is ignored, only used for image_classification");
            deepWater.warn("_channels", "channels shape is ignored, only used for image_classification");
            deepWater.warn("_mean_image_file", "mean_image_file shape is ignored, only used for image_classification");
        }
        if (this._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Enum) {
            deepWater.error("_categorical_encoding", "categorical encoding scheme cannot be Enum: the neural network must have numeric columns as input.");
        }
        if (this._autoencoder) {
            deepWater.error("_autoencoder", "Autoencoder is not supported right now.");
        }
        if (this._network == Network.user) {
            if (this._network_definition_file == null || this._network_definition_file.isEmpty()) {
                deepWater.error("_network_definition_file", "network_definition_file must be provided if the network is user-specified.");
            } else if (!new File(this._network_definition_file).exists()) {
                deepWater.error("_network_definition_file", "network_definition_file " + this._network_definition_file + " not found.");
            }
        } else if (this._network_definition_file != null && !this._network_definition_file.isEmpty() && this._network != Network.auto) {
            deepWater.error("_network_definition_file", "network_definition_file cannot be provided if a pre-defined network is chosen.");
        }
        if (this._network_parameters_file != null && !this._network_parameters_file.isEmpty() && !new File(this._network_parameters_file).exists()) {
            deepWater.error("_network_parameters_file", "network_parameters_file " + this._network_parameters_file + " not found.");
        }
        if (this._checkpoint != null) {
            DeepWaterModel deepWaterModel = this._checkpoint.get();
            if (deepWaterModel == null) {
                deepWater.error("_width", "Invalid checkpoint provided: width mismatch.");
            }
            if (!Arrays.equals(this._image_shape, deepWaterModel.get_params()._image_shape)) {
                deepWater.error("_width", "Invalid checkpoint provided: width mismatch.");
            }
        }
        if (!this._autoencoder) {
            if (isClassifier) {
                deepWater.hide("_regression_stop", "regression_stop is used only with regression.");
            } else {
                deepWater.hide("_classification_stop", "classification_stop is used only with classification.");
            }
            if ((!isClassifier && this._valid != null) || this._valid == null) {
                deepWater.hide("_score_validation_sampling", "score_validation_sampling requires classification and a validation frame.");
            }
        } else if (this._nfolds > 1) {
            deepWater.error("_nfolds", "N-fold cross-validation is not supported for Autoencoder.");
        }
        if (H2O.CLOUD.size() == 1 && this._replicate_training_data) {
            deepWater.hide("_replicate_training_data", "replicate_training_data is only valid with cloud size greater than 1.");
        }
        if (this._single_node_mode && (H2O.CLOUD.size() == 1 || !this._replicate_training_data)) {
            deepWater.hide("_single_node_mode", "single_node_mode is only used with multi-node operation with replicated training data.");
        }
        if (H2O.ARGS.client && this._single_node_mode) {
            deepWater.error("_single_node_mode", "Cannot run on a single node in client mode");
        }
        if (this._autoencoder) {
            deepWater.hide("_use_all_factor_levels", "use_all_factor_levels is mandatory in combination with autoencoder.");
        }
        if (this._nfolds != 0) {
            deepWater.hide("_overwrite_with_best_model", "overwrite_with_best_model is unsupported in combination with n-fold cross-validation.");
        }
        if (z) {
            deepWater.checkDistributions();
        }
        if (this._score_training_samples < 0) {
            deepWater.error("_score_training_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
        }
        if (this._score_validation_samples < 0) {
            deepWater.error("_score_validation_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
        }
        if (isClassifier && deepWater.hasOffsetCol()) {
            deepWater.error("_offset_column", "Offset is only supported for regression.");
        }
        if (z) {
            if (!isClassifier && this._balance_classes) {
                deepWater.error("_balance_classes", "balance_classes requires classification.");
            }
            if (this._class_sampling_factors != null && !this._balance_classes) {
                deepWater.error("_class_sampling_factors", "class_sampling_factors requires balance_classes to be enabled.");
            }
            if (this._replicate_training_data && null != train() && train().byteSize() > (0.9d * H2O.CLOUD.free_mem()) / H2O.CLOUD.size() && H2O.CLOUD.size() > 1) {
                deepWater.error("_replicate_training_data", "Compressed training dataset takes more than 90% of avg. free available memory per node (" + ((0.9d * H2O.CLOUD.free_mem()) / H2O.CLOUD.size()) + "), cannot run with replicate_training_data.");
            }
        }
        if (!this._autoencoder || this._stopping_metric == ScoreKeeper.StoppingMetric.AUTO || this._stopping_metric == ScoreKeeper.StoppingMetric.MSE) {
            return;
        }
        deepWater.error("_stopping_metric", "Stopping metric must either be AUTO or MSE for autoencoder.");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ProblemType guessProblemType() {
        if (this._problem_type != ProblemType.auto) {
            return this._problem_type;
        }
        boolean z = false;
        boolean z2 = false;
        String str = null;
        Vec vec = train().vec(0);
        if (vec.isString() || vec.isCategorical()) {
            str = vec.atStr(new BufferedString(), 0L).toString();
            try {
                ImageIO.read(new File(str));
                z = true;
            } catch (Throwable th) {
            }
            try {
                ImageIO.read(new URL(str));
                z = true;
            } catch (Throwable th2) {
            }
        }
        if (str != null) {
            if (!z && (str.endsWith(".jpg") || str.endsWith(".png") || str.endsWith(".tif"))) {
                z = true;
                Log.warn(new Object[]{"Cannot read first image at " + str + " - Check data."});
            } else if (vec.isString() && train().numCols() <= 4) {
                z2 = true;
            }
        }
        return z ? ProblemType.image : z2 ? ProblemType.text : ProblemType.dataset;
    }
}
