package hex.deeplearning;

import hex.DataInfo;
import hex.DistributionFactory;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderListener;
import hex.ModelCategory;
import hex.ScoreKeeper;
import hex.ToEigenVec;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMTask;
import hex.util.EffectiveParametersUtils;
import hex.util.LinearAlgebraUtils;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import jsr166y.CountedCompleter;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.H2O;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Lockable;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.init.Linpack;
import water.init.NetworkTest;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MRUtils;
import water.util.PrettyPrint;

/* loaded from: input_file:hex/deeplearning/DeepLearning.class */
public class DeepLearning extends ModelBuilder<DeepLearningModel, DeepLearningModel.DeepLearningParameters, DeepLearningModel.DeepLearningModelOutput> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/deeplearning/DeepLearning$DeepLearningDriver.class */
    public class DeepLearningDriver extends ModelBuilder<DeepLearningModel, DeepLearningModel.DeepLearningParameters, DeepLearningModel.DeepLearningModelOutput>.Driver {
        static final /* synthetic */ boolean $assertionsDisabled;

        public DeepLearningDriver() {
            super();
        }

        @Override // hex.ModelBuilder.Driver
        public void computeImpl() {
            DeepLearning.this.init(true);
            if (Model.evaluateAutoModelParameters()) {
                initActualParamValues();
            }
            Model.Parameters clone = ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms).m2168clone();
            if (DeepLearning.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(DeepLearning.this);
            }
            buildModel();
            checkNonAutoParmsNotChanged(clone, DeepLearning.this._parms);
        }

        public void checkNonAutoParmsNotChanged(Model.Parameters parameters, Model.Parameters parameters2) {
            try {
                for (Field field : parameters.getClass().getFields()) {
                    field.getType();
                    Object obj = field.get(parameters);
                    if (obj != null && !"AUTO".equalsIgnoreCase(obj.toString())) {
                        Object obj2 = field.get(parameters2);
                        if (!$assertionsDisabled && !obj.toString().equalsIgnoreCase(obj2.toString())) {
                            throw new AssertionError("Found non-AUTO value in _parms which has changed during DL model training");
                        }
                    }
                }
            } catch (IllegalAccessException e) {
                throw new RuntimeException("Error while checking param changes during DL model training", e);
            }
        }

        public final void buildModel() {
            DeepLearningModel deepLearningModel;
            Lockable lockable = null;
            ArrayList arrayList = new ArrayList();
            if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._checkpoint == null) {
                deepLearningModel = new DeepLearningModel(DeepLearning.this.dest(), (DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, new DeepLearningModel.DeepLearningModelOutput(DeepLearning.this), DeepLearning.this._train, DeepLearning.this._valid, DeepLearning.this.nclasses());
                if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._pretrained_autoencoder != null) {
                    DeepLearningModel deepLearningModel2 = (DeepLearningModel) DKV.getGet(((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._pretrained_autoencoder);
                    if (deepLearningModel2 == null) {
                        throw new H2OIllegalArgumentException("The pretrained model '" + ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._pretrained_autoencoder + "' cannot be found.");
                    }
                    if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._autoencoder || !((DeepLearningModel.DeepLearningParameters) deepLearningModel2._parms)._autoencoder) {
                        throw new H2OIllegalArgumentException("The pretrained model must be unsupervised (an autoencoder), and the model to be trained must be supervised.");
                    }
                    Log.info("Loading model parameters of input and hidden layers from the pretrained autoencoder model.");
                    deepLearningModel.model_info().initializeFromPretrainedModel(deepLearningModel2.model_info());
                } else {
                    deepLearningModel.model_info().initializeMembers(((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._initial_weights, ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._initial_biases);
                }
            } else {
                DeepLearningModel deepLearningModel3 = (DeepLearningModel) DKV.getGet(((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._checkpoint);
                if (deepLearningModel3 == null) {
                    throw new IllegalArgumentException("Checkpoint not found.");
                }
                Log.info("Resuming from checkpoint.");
                DeepLearning.this._job.update(0L, "Resuming from checkpoint");
                if (DeepLearning.this.isClassifier() != ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel3._output).isClassifier()) {
                    throw new H2OIllegalArgumentException("Response type must be the same as for the checkpointed model.");
                }
                if (DeepLearning.this.isSupervised() != ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel3._output).isSupervised()) {
                    throw new H2OIllegalArgumentException("Model type must be the same as for the checkpointed model.");
                }
                DeepLearningModel.DeepLearningParameters.Sanity.checkIfParameterChangeAllowed((DeepLearningModel.DeepLearningParameters) deepLearningModel3._input_parms, (DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms);
                try {
                    try {
                        for (String str : deepLearningModel3.adaptTestForTrain(DeepLearning.this._train, true, false)) {
                            Log.warn(str);
                        }
                        for (String str2 : deepLearningModel3.adaptTestForTrain(DeepLearning.this._valid, true, false)) {
                            Log.warn(str2);
                        }
                        DataInfo makeDataInfo = DeepLearning.makeDataInfo(DeepLearning.this._train, DeepLearning.this._valid, (DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, DeepLearning.this.nclasses());
                        DKV.put(makeDataInfo);
                        arrayList.add(makeDataInfo._key);
                        deepLearningModel = new DeepLearningModel(DeepLearning.this.dest(), (DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, deepLearningModel3, false, makeDataInfo);
                        deepLearningModel.write_lock(DeepLearning.this._job);
                        if (!Arrays.equals(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output)._names, ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel3._output)._names)) {
                            throw new H2OIllegalArgumentException("The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols).");
                        }
                        if (!Arrays.deepEquals(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output)._domains, ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel3._output)._domains)) {
                            throw new H2OIllegalArgumentException("Categorical factor levels of the training data must be the same as for the checkpointed model.");
                        }
                        if (makeDataInfo.fullN() != deepLearningModel3.model_info().data_info().fullN()) {
                            throw new H2OIllegalArgumentException("Total number of predictors is different than for the checkpointed model.");
                        }
                        if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._epochs <= deepLearningModel3.epoch_counter) {
                            throw new H2OIllegalArgumentException("Total number of epochs must be larger than the number of epochs already trained for the checkpointed model (" + deepLearningModel3.epoch_counter + ").");
                        }
                        DeepLearningModel.DeepLearningParameters deepLearningParameters = deepLearningModel.model_info().get_params();
                        if (!$assertionsDisabled && deepLearningParameters == deepLearningModel3.model_info().get_params()) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && deepLearningParameters == DeepLearning.this._parms) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && deepLearningParameters == deepLearningModel3._parms) {
                            throw new AssertionError();
                        }
                        DeepLearningModel.DeepLearningParameters.Sanity.updateParametersDuringCheckpointRestart((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, (DeepLearningModel.DeepLearningParameters) deepLearningModel3._parms, false, false);
                        DeepLearningModel.DeepLearningParameters.Sanity.updateParametersDuringCheckpointRestart((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, deepLearningModel.model_info().get_params(), true, true);
                        DeepLearningModel.DeepLearningParameters.Sanity.modifyParms((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms, deepLearningModel.model_info().get_params(), DeepLearning.this.nclasses());
                        Log.info("Continuing training after " + String.format("%.3f", Double.valueOf(deepLearningModel3.epoch_counter)) + " epochs from the checkpointed model.");
                        deepLearningModel.update(DeepLearning.this._job);
                        if (deepLearningModel != null) {
                            deepLearningModel.unlock(DeepLearning.this._job);
                        }
                    } catch (H2OIllegalArgumentException e) {
                        if (0 != 0) {
                            lockable.unlock(DeepLearning.this._job);
                            lockable.delete();
                        }
                        throw e;
                    }
                } catch (Throwable th) {
                    if (0 != 0) {
                        lockable.unlock(DeepLearning.this._job);
                    }
                    throw th;
                }
            }
            DistributionFamily distributionFamily = deepLearningModel.model_info().get_params()._distribution;
            if (Model.evaluateAutoModelParameters() && ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._distribution == DistributionFamily.AUTO) {
                ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._distribution = distributionFamily;
                ((DeepLearningModel.DeepLearningParameters) deepLearningModel._parms)._distribution = distributionFamily;
            }
            trainModel(deepLearningModel);
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                DKV.remove((Key) it.next());
            }
            ArrayList arrayList2 = new ArrayList();
            try {
                if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._export_weights_and_biases && ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).weights != null && ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).biases != null) {
                    for (Key key : Arrays.asList(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).weights)) {
                        arrayList2.add(key);
                        for (Vec vec : ((Frame) DKV.getGet(key)).vecs()) {
                            arrayList2.add(vec._key);
                        }
                    }
                    for (Key key2 : Arrays.asList(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).biases)) {
                        arrayList2.add(key2);
                        for (Vec vec2 : ((Frame) DKV.getGet(key2)).vecs()) {
                            arrayList2.add(vec2._key);
                        }
                    }
                }
                Scope.exit((Key[]) arrayList2.toArray(new Key[arrayList2.size()]));
            } catch (Throwable th2) {
                Scope.exit((Key[]) arrayList2.toArray(new Key[arrayList2.size()]));
                throw th2;
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public final DeepLearningModel trainModel(DeepLearningModel deepLearningModel) {
            DeepLearningModel deepLearningModel2;
            Frame frame = null;
            if (deepLearningModel == null) {
                try {
                    deepLearningModel = (DeepLearningModel) DKV.get(DeepLearning.this.dest()).get();
                } catch (Throwable th) {
                    if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                        Log.info("==============================================================================================================================================================================");
                        if (DeepLearning.this.stop_requested()) {
                            if (DeepLearning.this.timeout()) {
                                DeepLearning.this.warn("_max_runtime_secs", "Deep Learning model training was interrupted due to timeout.  Increase _max_runtime_secs or set it to 0 to disable it.");
                            }
                            Log.info("Deep Learning model training was interrupted.");
                        } else {
                            Log.info("Finished training the Deep Learning model.");
                            if (deepLearningModel != null) {
                                Log.info(deepLearningModel);
                            }
                        }
                        Log.info("==============================================================================================================================================================================");
                    }
                    if (deepLearningModel != null) {
                        deepLearningModel.deleteElasticAverageModels();
                        deepLearningModel.unlock(DeepLearning.this._job);
                        if (deepLearningModel.actual_best_model_key != null) {
                            if (!$assertionsDisabled && deepLearningModel.actual_best_model_key == deepLearningModel._key) {
                                throw new AssertionError();
                            }
                            DKV.remove(deepLearningModel.actual_best_model_key);
                        }
                    }
                    throw th;
                }
            }
            Object[] objArr = new Object[1];
            objArr[0] = "Model category: " + (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._autoencoder ? "Auto-Encoder" : DeepLearning.this.isClassifier() ? "Classification" : "Regression");
            Log.info(objArr);
            Log.info("Number of model parameters (weights/biases): " + String.format("%,d", Long.valueOf(deepLearningModel.model_info().size())));
            deepLearningModel.write_lock(DeepLearning.this._job);
            DeepLearning.this._job.update(0L, "Setting up training data...");
            DeepLearningModel.DeepLearningParameters deepLearningParameters = deepLearningModel.model_info().get_params();
            Frame frame2 = new Frame(deepLearningParameters._train, DeepLearning.this._train.names(), DeepLearning.this._train.vecs());
            Frame frame3 = DeepLearning.this._valid != null ? new Frame(deepLearningParameters._valid, DeepLearning.this._valid.names(), DeepLearning.this._valid.vecs()) : null;
            Frame frame4 = frame2;
            if (((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).isClassifier() && deepLearningParameters._balance_classes) {
                DeepLearning.this._job.update(0L, "Balancing class distribution of training data...");
                float[] fArr = new float[frame4.lastVec().domain().length];
                if (deepLearningParameters._class_sampling_factors != null) {
                    if (deepLearningParameters._class_sampling_factors.length != frame4.lastVec().domain().length) {
                        throw new IllegalArgumentException("class_sampling_factors must have " + frame4.lastVec().domain().length + " elements");
                    }
                    fArr = (float[]) deepLearningParameters._class_sampling_factors.clone();
                }
                frame4 = MRUtils.sampleFrameStratified(frame4, frame4.lastVec(), frame4.vec(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).weightsName()), fArr, deepLearningParameters._max_after_balance_size * ((float) frame4.numRows()), deepLearningParameters._seed, true, false);
                Vec lastVec = frame4.lastVec();
                Vec vec = frame4.vec(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).weightsName());
                MRUtils.ClassDist classDist = new MRUtils.ClassDist(lastVec);
                ((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output)._modelClassDist = DeepLearning.this._weights != null ? classDist.doAll(lastVec, vec).relDist() : classDist.doAll(lastVec).relDist();
            }
            deepLearningModel.training_rows = frame4.numRows();
            if (DeepLearning.this._weights != null && DeepLearning.this._weights.min() == CMAESOptimizer.DEFAULT_STOPFITNESS && DeepLearning.this._weights.max() == 1.0d && DeepLearning.this._weights.isInt()) {
                deepLearningModel.training_rows = Math.round(frame4.numRows() * DeepLearning.this._weights.mean());
                Log.warn("Not counting " + (frame4.numRows() - deepLearningModel.training_rows) + " rows with weight=0 towards an epoch.");
            }
            Log.info("One epoch corresponds to " + deepLearningModel.training_rows + " training data rows.");
            Frame sampleFrame = MRUtils.sampleFrame(frame4, deepLearningParameters._score_training_samples, deepLearningParameters._seed);
            if (sampleFrame != frame4) {
                Scope.track(sampleFrame);
            }
            if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                Log.info("Number of chunks of the training data: " + frame4.anyVec().nChunks());
            }
            if (frame3 != null) {
                deepLearningModel.validation_rows = frame3.numRows();
                if (((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).isClassifier() && deepLearningParameters._balance_classes && deepLearningParameters._score_validation_sampling == DeepLearningModel.DeepLearningParameters.ClassSamplingMethod.Stratified) {
                    DeepLearning.this._job.update(0L, "Sampling validation data (stratified)...");
                    frame = MRUtils.sampleFrameStratified(frame3, frame3.lastVec(), frame3.vec(((DeepLearningModel.DeepLearningModelOutput) deepLearningModel._output).weightsName()), (float[]) null, deepLearningParameters._score_validation_samples > 0 ? deepLearningParameters._score_validation_samples : frame3.numRows(), deepLearningParameters._seed + 1, false, false);
                } else {
                    DeepLearning.this._job.update(0L, "Sampling validation data...");
                    frame = MRUtils.sampleFrame(frame3, deepLearningParameters._score_validation_samples, deepLearningParameters._seed + 1);
                    if (frame != frame3) {
                        Scope.track(frame);
                    }
                }
                if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                    Log.info("Number of chunks of the validation data: " + frame.anyVec().nChunks());
                }
            }
            deepLearningModel.actual_train_samples_per_iteration = DeepLearning.computeTrainSamplesPerIteration(deepLearningParameters, deepLearningModel.training_rows, deepLearningModel);
            if (deepLearningParameters._replicate_training_data) {
                if (deepLearningModel.actual_train_samples_per_iteration == deepLearningModel.training_rows * (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) && !deepLearningParameters._shuffle_training_data && H2O.CLOUD.size() > 1 && !deepLearningParameters._reproducible) {
                    if (!deepLearningParameters._quiet_mode) {
                        Log.info("Enabling training data shuffling, because all nodes train on the full dataset (replicated training data).");
                    }
                    deepLearningParameters._shuffle_training_data = true;
                }
            }
            if (!deepLearningParameters._shuffle_training_data && deepLearningModel.actual_train_samples_per_iteration == deepLearningModel.training_rows && frame4.anyVec().nChunks() == 1) {
                if (!deepLearningParameters._quiet_mode) {
                    Log.info("Enabling training data shuffling to avoid training rows in the same order over and over (no Hogwild since there's only 1 chunk).");
                }
                deepLearningParameters._shuffle_training_data = true;
            }
            long currentTimeMillis = System.currentTimeMillis();
            deepLearningModel._timeLastIterationEnter = currentTimeMillis;
            if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._autoencoder) {
                DeepLearning.this._job.update(0L, "Scoring null model of autoencoder...");
                if (!deepLearningParameters._quiet_mode) {
                    Log.info("Scoring the null model of the autoencoder.");
                }
                deepLearningModel.doScoring(sampleFrame, frame, DeepLearning.this._job._key, 0, false);
            }
            deepLearningModel.update(DeepLearning.this._job);
            deepLearningModel.total_setup_time_ms += currentTimeMillis - DeepLearning.this._job.start_time();
            Log.info("Total setup time: " + PrettyPrint.msecs(deepLearningModel.total_setup_time_ms, true));
            Log.info("Starting to train the Deep Learning model.");
            DeepLearning.this._job.update(0L, "Training...");
            while (true) {
                deepLearningModel.iterations++;
                deepLearningModel.set_model_info(deepLearningParameters._epochs == CMAESOptimizer.DEFAULT_STOPFITNESS ? deepLearningModel.model_info() : (H2O.CLOUD.size() <= 1 || !deepLearningParameters._replicate_training_data) ? ((DeepLearningTask) new DeepLearningTask(DeepLearning.this._job._key, deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel), deepLearningModel.iterations).doAll(frame4)).model_info() : deepLearningParameters._single_node_mode ? new DeepLearningTask2(DeepLearning.this._job._key, frame4, deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel), deepLearningModel.iterations).doAll(Key.make(H2O.SELF)).model_info() : new DeepLearningTask2(DeepLearning.this._job._key, frame4, deepLearningModel.model_info(), rowFraction(frame4, deepLearningParameters, deepLearningModel), deepLearningModel.iterations).doAllNodes().model_info());
                if (DeepLearning.this.stop_requested() && !DeepLearning.this.timeout()) {
                    throw new Job.JobCancelledException();
                }
                if (!deepLearningModel.doScoring(sampleFrame, frame, DeepLearning.this._job._key, deepLearningModel.iterations, false)) {
                    break;
                }
                if (DeepLearning.this.timeout()) {
                    DeepLearning.this._job.update((long) (deepLearningParameters._epochs * frame4.numRows()));
                    break;
                }
            }
            if (!DeepLearning.this.stop_requested() && ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._overwrite_with_best_model && deepLearningModel.actual_best_model_key != null && ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._nfolds == 0 && (deepLearningModel2 = (DeepLearningModel) DKV.getGet(deepLearningModel.actual_best_model_key)) != null && deepLearningModel2.loss() < deepLearningModel.loss() && Arrays.equals(deepLearningModel2.model_info().units, deepLearningModel.model_info().units)) {
                if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                    Log.info("Setting the model to be the best model so far (based on scoring history).");
                    Log.info("Best model's loss: " + deepLearningModel2.loss() + " vs this model's loss (before overwriting it with the best model): " + deepLearningModel.loss());
                }
                DeepLearningModelInfo deepLearningModelInfo = (DeepLearningModelInfo) IcedUtils.deepCopy(deepLearningModel2.model_info());
                deepLearningModelInfo.set_processed_global(deepLearningModel.model_info().get_processed_global());
                deepLearningModelInfo.set_processed_local(deepLearningModel.model_info().get_processed_local());
                DeepLearningModel.DeepLearningParameters deepLearningParameters2 = deepLearningModel.model_info().get_params();
                deepLearningModel.set_model_info(deepLearningModelInfo);
                deepLearningModel.model_info().parameters = deepLearningParameters2;
                deepLearningModel.update(DeepLearning.this._job);
                deepLearningModel.doScoring(sampleFrame, frame, DeepLearning.this._job._key, deepLearningModel.iterations, true);
                if (deepLearningModel2.loss() != deepLearningModel.loss()) {
                    if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                        Log.info("Best model's loss: " + deepLearningModel2.loss() + " vs this model's loss (after overwriting it with the best model) : " + deepLearningModel.loss());
                    }
                    Log.warn("Even though the model was reset to the previous best model, we observe different scoring results. Most likely, the data set has changed during a checkpoint restart. If so, please compare the metrics to observe your data shift.");
                }
            }
            deepLearningModel.model_info().data_info().coefNames();
            if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._quiet_mode) {
                Log.info("==============================================================================================================================================================================");
                if (DeepLearning.this.stop_requested()) {
                    if (DeepLearning.this.timeout()) {
                        DeepLearning.this.warn("_max_runtime_secs", "Deep Learning model training was interrupted due to timeout.  Increase _max_runtime_secs or set it to 0 to disable it.");
                    }
                    Log.info("Deep Learning model training was interrupted.");
                } else {
                    Log.info("Finished training the Deep Learning model.");
                    if (deepLearningModel != null) {
                        Log.info(deepLearningModel);
                    }
                }
                Log.info("==============================================================================================================================================================================");
            }
            if (deepLearningModel != null) {
                deepLearningModel.deleteElasticAverageModels();
                deepLearningModel.unlock(DeepLearning.this._job);
                if (deepLearningModel.actual_best_model_key != null) {
                    if (!$assertionsDisabled && deepLearningModel.actual_best_model_key == deepLearningModel._key) {
                        throw new AssertionError();
                    }
                    DKV.remove(deepLearningModel.actual_best_model_key);
                }
            }
            return deepLearningModel;
        }

        public void initActualParamValues() {
            if (!((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._autoencoder) {
                EffectiveParametersUtils.initStoppingMetric(DeepLearning.this._parms, DeepLearning.this.isClassifier());
            } else if (((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
                ((DeepLearningModel.DeepLearningParameters) DeepLearning.this._parms)._stopping_metric = ScoreKeeper.StoppingMetric.MSE;
            }
            EffectiveParametersUtils.initCategoricalEncoding(DeepLearning.this._parms, Model.Parameters.CategoricalEncodingScheme.OneHotInternal);
        }

        private float computeRowUsageFraction(long j, long j2, boolean z) {
            float f = ((float) j2) / ((float) j);
            if (z) {
                f /= H2O.CLOUD.size();
            }
            if ($assertionsDisabled || f > 0.0f) {
                return f;
            }
            throw new AssertionError();
        }

        private float rowFraction(Frame frame, DeepLearningModel.DeepLearningParameters deepLearningParameters, DeepLearningModel deepLearningModel) {
            return computeRowUsageFraction(frame.numRows(), deepLearningModel.actual_train_samples_per_iteration, deepLearningParameters._replicate_training_data);
        }

        @Override // hex.ModelBuilder.Driver, jsr166y.CountedCompleter
        public /* bridge */ /* synthetic */ boolean onExceptionalCompletion(Throwable th, CountedCompleter countedCompleter) {
            return super.onExceptionalCompletion(th, countedCompleter);
        }

        @Override // hex.ModelBuilder.Driver, jsr166y.CountedCompleter
        public /* bridge */ /* synthetic */ void onCompletion(CountedCompleter countedCompleter) {
            super.onCompletion(countedCompleter);
        }

        @Override // hex.ModelBuilder.Driver, water.H2O.H2OCountedCompleter
        public /* bridge */ /* synthetic */ void compute2() {
            super.compute2();
        }

        @Override // hex.ModelBuilder.Driver
        public /* bridge */ /* synthetic */ void setCallback(ModelBuilderListener modelBuilderListener) {
            super.setCallback(modelBuilderListener);
        }

        static {
            $assertionsDisabled = !DeepLearning.class.desiredAssertionStatus();
        }
    }

    public DeepLearning(DeepLearningModel.DeepLearningParameters deepLearningParameters) {
        super(deepLearningParameters);
        init(false);
    }

    public DeepLearning(DeepLearningModel.DeepLearningParameters deepLearningParameters, Key<DeepLearningModel> key) {
        super(deepLearningParameters, key);
        init(false);
    }

    public DeepLearning(boolean z) {
        super(new DeepLearningModel.DeepLearningParameters(), z);
    }

    @Override // hex.ModelBuilder
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial, ModelCategory.AutoEncoder};
    }

    @Override // hex.ModelBuilder
    public boolean havePojo() {
        return true;
    }

    @Override // hex.ModelBuilder
    public boolean haveMojo() {
        return true;
    }

    @Override // hex.ModelBuilder
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    @Override // hex.ModelBuilder
    public boolean isSupervised() {
        return !((DeepLearningModel.DeepLearningParameters) this._parms)._autoencoder;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.ModelBuilder
    public DeepLearningDriver trainModelImpl() {
        return new DeepLearningDriver();
    }

    @Override // hex.ModelBuilder
    public void init(boolean z) {
        super.init(z);
        ((DeepLearningModel.DeepLearningParameters) this._parms).validate(this, z);
        this._orig_projection_array = LinearAlgebraUtils.toEigenProjectionArray(this._origTrain, this._train, z);
        if (!ArrayUtils.contains(new DistributionFamily[]{DistributionFamily.AUTO, DistributionFamily.bernoulli, DistributionFamily.multinomial, DistributionFamily.gaussian, DistributionFamily.poisson, DistributionFamily.gamma, DistributionFamily.laplace, DistributionFamily.quantile, DistributionFamily.huber, DistributionFamily.tweedie}, ((DeepLearningModel.DeepLearningParameters) this._parms)._distribution)) {
            error("_distribution", ((DeepLearningModel.DeepLearningParameters) this._parms)._distribution.name() + " is not supported for DeepLearning in current H2O.");
        }
        if (z && error_count() == 0) {
            checkMemoryFootPrint();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DataInfo makeDataInfo(Frame frame, Frame frame2, DeepLearningModel.DeepLearningParameters deepLearningParameters, int i) {
        DataInfo dataInfo = new DataInfo(frame, frame2, deepLearningParameters._autoencoder ? 0 : 1, deepLearningParameters._autoencoder || deepLearningParameters._use_all_factor_levels, deepLearningParameters._standardize ? deepLearningParameters._autoencoder ? DataInfo.TransformType.NORMALIZE : deepLearningParameters._sparse ? DataInfo.TransformType.DESCALE : DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, (!deepLearningParameters._standardize || frame.lastVec().isCategorical()) ? DataInfo.TransformType.NONE : (DistributionFactory.getDistribution(deepLearningParameters).link(0.782347234d) > 0.782347234d ? 1 : (DistributionFactory.getDistribution(deepLearningParameters).link(0.782347234d) == 0.782347234d ? 0 : -1)) == 0 ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, deepLearningParameters._missing_values_handling == DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip, false, true, deepLearningParameters._weights_column != null, deepLearningParameters._offset_column != null, deepLearningParameters._fold_column != null);
        GLMTask.YMUTask doAll = new GLMTask.YMUTask(dataInfo, i, !deepLearningParameters._autoencoder && i == 1, deepLearningParameters._missing_values_handling == DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip, !deepLearningParameters._autoencoder, true).doAll(dataInfo._adaptedFrame);
        if (doAll.wsum() == CMAESOptimizer.DEFAULT_STOPFITNESS && deepLearningParameters._missing_values_handling == DeepLearningModel.DeepLearningParameters.MissingValuesHandling.Skip) {
            throw new H2OIllegalArgumentException("No rows left in the dataset after filtering out rows with missing values. Ignore columns with many NAs or set missing_values_handling to 'MeanImputation'.");
        }
        if (deepLearningParameters._weights_column != null && deepLearningParameters._offset_column != null) {
            Log.warn("Combination of offset and weights can lead to slight differences because Rollupstats aren't weighted - need to re-calculate weighted mean/sigma of the response including offset terms.");
        }
        if (deepLearningParameters._weights_column != null && deepLearningParameters._offset_column == null) {
            dataInfo.updateWeightedSigmaAndMean(doAll.predictorSDs(), doAll.predictorMeans());
            if (i == 1) {
                dataInfo.updateWeightedSigmaAndMeanForResponse(doAll.responseSDs(), doAll.responseMeans());
            }
        }
        return dataInfo;
    }

    @Override // hex.ModelBuilder
    protected void checkMemoryFootPrint_impl() {
        long j;
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._checkpoint != null) {
            return;
        }
        long numColsExp = LinearAlgebraUtils.numColsExp(this._train, true) - (((DeepLearningModel.DeepLearningParameters) this._parms)._autoencoder ? 0 : this._train.lastVec().cardinality());
        String[][] domains = this._train.domains();
        int i = 0;
        while (true) {
            if (i >= this._train.numCols() - (((DeepLearningModel.DeepLearningParameters) this._parms)._autoencoder ? 0 : 1)) {
                break;
            }
            if (domains[i] != null) {
                numColsExp++;
            }
            i++;
        }
        long abs = ((DeepLearningModel.DeepLearningParameters) this._parms)._autoencoder ? numColsExp : Math.abs(this._train.lastVec().cardinality());
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._hidden.length == 0) {
            j = 0 + (numColsExp * abs);
        } else {
            long j2 = 0 + (numColsExp * ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden[0]);
            int i2 = 1;
            while (i2 < ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden.length) {
                j2 += ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden[i2 - 1] * ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden[i2];
                i2++;
            }
            long j3 = j2 + (((DeepLearningModel.DeepLearningParameters) this._parms)._hidden[i2 - 1] * abs);
            for (int i3 = 0; i3 < ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden.length; i3++) {
                j3 += ((DeepLearningModel.DeepLearningParameters) this._parms)._hidden[i3];
            }
            j = j3 + abs;
        }
        if (j > 1.0E8d) {
            error("_hidden", "Model is too large: " + j + " parameters. Try reducing the number of neurons in the hidden layers (or reduce the number of categorical factors).");
        }
    }

    @Override // hex.ModelBuilder
    public void cv_computeAndSetOptimalParameters(ModelBuilder<DeepLearningModel, DeepLearningModel.DeepLearningParameters, DeepLearningModel.DeepLearningModelOutput>[] modelBuilderArr) {
        ((DeepLearningModel.DeepLearningParameters) this._parms)._overwrite_with_best_model = false;
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._stopping_rounds == 0 && ((DeepLearningModel.DeepLearningParameters) this._parms)._max_runtime_secs == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return;
        }
        ((DeepLearningModel.DeepLearningParameters) this._parms)._stopping_rounds = 0;
        setMaxRuntimeSecsForMainModel();
        double d = 0.0d;
        for (ModelBuilder<DeepLearningModel, DeepLearningModel.DeepLearningParameters, DeepLearningModel.DeepLearningModelOutput> modelBuilder : modelBuilderArr) {
            d += ((DeepLearningModel) DKV.getGet(modelBuilder.dest())).last_scored().epoch_counter;
        }
        ((DeepLearningModel.DeepLearningParameters) this._parms)._epochs = d / modelBuilderArr.length;
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._quiet_mode) {
            return;
        }
        warn("_epochs", "Setting optimal _epochs to " + ((DeepLearningModel.DeepLearningParameters) this._parms)._epochs + " for cross-validation main model based on early stopping of cross-validation models.");
        warn("_stopping_rounds", "Disabling convergence-based early stopping for cross-validation main model.");
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._main_model_time_budget_factor == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            warn("_max_runtime_secs", "Disabling maximum allowed runtime for cross-validation main model.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.ModelBuilder
    public Frame rebalance(Frame frame, boolean z, String str) {
        if (frame == null) {
            return null;
        }
        if (!((DeepLearningModel.DeepLearningParameters) this._parms)._force_load_balance && !((DeepLearningModel.DeepLearningParameters) this._parms)._reproducible) {
            return frame;
        }
        int nChunks = frame.anyVec().nChunks();
        this._job.update(0L, "Load balancing " + str.substring(str.length() - 5) + " data...");
        int desiredChunks = desiredChunks(frame, z);
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._reproducible) {
            if (!$assertionsDisabled && desiredChunks != 1) {
                throw new AssertionError();
            }
            if (!((DeepLearningModel.DeepLearningParameters) this._parms)._quiet_mode) {
                Log.warn("Reproducibility enforced - using only 1 thread - can be slow.");
            }
            if (nChunks == 1) {
                return frame;
            }
        } else if (nChunks >= desiredChunks) {
            if (!((DeepLearningModel.DeepLearningParameters) this._parms)._quiet_mode) {
                Log.info("Dataset already contains " + nChunks + " chunks. No need to rebalance.");
            }
            return frame;
        }
        if (!((DeepLearningModel.DeepLearningParameters) this._parms)._quiet_mode) {
            Log.info("Rebalancing " + str.substring(str.length() - 5) + " dataset into " + desiredChunks + " chunks.");
        }
        Key make = Key.make(str + ".chks" + desiredChunks);
        ((RebalanceDataSet) H2O.submitTask(new RebalanceDataSet(frame, make, desiredChunks))).join();
        Frame frame2 = (Frame) DKV.get(make).get();
        Scope.track(frame2);
        return frame2;
    }

    @Override // hex.ModelBuilder
    protected int desiredChunks(Frame frame, boolean z) {
        if (((DeepLearningModel.DeepLearningParameters) this._parms)._reproducible) {
            return 1;
        }
        return (int) Math.min(4 * H2O.NUMCPUS * (z ? 1 : H2O.CLOUD.size()), frame.numRows());
    }

    static long computeTrainSamplesPerIteration(DeepLearningModel.DeepLearningParameters deepLearningParameters, long j, DeepLearningModel deepLearningModel) {
        long j2;
        long j3 = deepLearningParameters._train_samples_per_iteration;
        if (!$assertionsDisabled && j3 != 0 && j3 != -1 && j3 != -2 && j3 < 1) {
            throw new AssertionError();
        }
        if (j3 == 0 || (!deepLearningParameters._replicate_training_data && j3 == -1)) {
            j2 = j;
            if (!deepLearningParameters._quiet_mode) {
                Log.info("Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to one epoch: #rows (" + j2 + ").");
            }
        } else if (j3 == -1) {
            j2 = (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) * j;
            if (!deepLearningParameters._quiet_mode) {
                Log.info("Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to #nodes x #rows (" + j2 + ").");
            }
        } else if (j3 == -2) {
            double d = 0.0d;
            for (int i = 0; i < H2O.CLOUD._memary.length; i++) {
                d += r0[i]._heartbeat._gflops;
            }
            if (deepLearningParameters._single_node_mode) {
                d /= H2O.CLOUD.size();
            }
            if (Double.isNaN(d)) {
                d = Linpack.run(H2O.SELF._heartbeat._cpus_allowed) * (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size());
            }
            if (!$assertionsDisabled && Double.isNaN(d)) {
                throw new AssertionError();
            }
            long size = deepLearningModel.model_info().size();
            int[] iArr = new int[2];
            iArr[0] = 1;
            iArr[1] = ((long) ((int) (size * 4))) == size * 4 ? (int) (size * 4) : Integer.MAX_VALUE;
            double[] dArr = new double[iArr.length];
            new NetworkTest.NetworkTester(iArr, (double[][]) null, dArr, ((double) size) > 1000000.0d ? 1 : 5, false, true).compute2();
            int floor = (deepLearningParameters._single_node_mode || H2O.CLOUD.size() == 1) ? 1 : 2 * ((int) Math.floor(Math.log(H2O.CLOUD.size()) / Math.log(2.0d)));
            double d2 = 50.0d;
            if (deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) {
                d2 = 50.0d * 8.0d;
            } else if (deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.Tanh || deepLearningParameters._activation == DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout) {
                d2 = 50.0d * 5.0d;
            }
            double d3 = (deepLearningParameters._single_node_mode || H2O.CLOUD.size() == 1) ? 0.001d : deepLearningParameters._target_ratio_comm_to_comp;
            deepLearningModel.time_for_communication_us = (H2O.CLOUD.size() == 1 ? 10000.0d : 100000.0d) + (floor * dArr[1]);
            double d4 = ((((d2 * size) + (10000 * deepLearningModel.model_info().units[0])) / (d * 1.0E9d)) / H2O.SELF._heartbeat._cpus_allowed) * 1000000.0d;
            if (!$assertionsDisabled && Double.isNaN(d4)) {
                throw new AssertionError();
            }
            long min = Math.min((long) (((deepLearningModel.time_for_communication_us / d3) - deepLearningModel.time_for_communication_us) / d4), (deepLearningParameters._single_node_mode ? 1 : H2O.CLOUD.size()) * j * 10);
            if (min > j && Math.abs(min % j) / j < 0.2d) {
                min -= min % j;
            }
            long min2 = Math.min(min, (long) ((deepLearningParameters._epochs * j) / 10.0d));
            if (H2O.CLOUD.size() == 1 || deepLearningParameters._single_node_mode) {
                min2 = Math.min(min2, 10 * ((int) (1000000.0d / d4)));
            }
            j2 = Math.min(100000 * H2O.CLOUD.size(), Math.max(1L, min2));
            if (!deepLearningParameters._quiet_mode) {
                Log.info("Auto-tuning parameter 'train_samples_per_iteration':");
                Log.info("Estimated compute power : " + (Math.round(d * 100.0d) / 100) + " GFlops");
                Log.info("Estimated time for comm : " + PrettyPrint.usecs((long) deepLearningModel.time_for_communication_us));
                Object[] objArr = new Object[1];
                objArr[0] = "Estimated time per row  : " + (((long) d4) > 0 ? PrettyPrint.usecs((long) d4) : d4 + " usecs");
                Log.info(objArr);
                Log.info("Estimated training speed: " + ((int) (1000000.0d / d4)) + " rows/sec");
                Log.info("Setting train_samples_per_iteration (" + deepLearningParameters._train_samples_per_iteration + ") to auto-tuned value: " + j2);
            }
        } else {
            j2 = Math.max(1L, Math.min(j3, (long) (deepLearningParameters._epochs * j)));
        }
        if ($assertionsDisabled || !(j2 == 0 || j2 == -1 || j2 == -2 || j2 < 1)) {
            return j2;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !DeepLearning.class.desiredAssertionStatus();
    }
}
