package hex.deeplearning;

import au.com.bytecode.opencsv.CSVWriter;
import hex.DataInfo;
import hex.Model;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.Storage;
import hex.genmodel.utils.DistributionFamily;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.AutoBuffer;
import water.DKV;
import water.H2O;
import water.H2ONode;
import water.Iced;
import water.IcedUtils;
import water.Key;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.FrameUtils;
import water.util.Log;
import water.util.MathUtils;
import water.util.PrettyPrint;
import water.util.RandomBase;
import water.util.RandomUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/deeplearning/DeepLearningModelInfo.class */
public final class DeepLearningModelInfo extends Iced<DeepLearningModelInfo> {
    public TwoDimTable summaryTable;
    public DataInfo data_info;
    private Storage.DenseRowMatrix[] dense_row_weights;
    private Storage.DenseVector[] biases;
    private Storage.DenseVector[] avg_activations;
    private Storage.DenseRowMatrix[] dense_row_weights_momenta;
    private Storage.DenseVector[] biases_momenta;
    private Storage.DenseRowMatrix[] dense_row_ada_dx_g;
    private Storage.DenseVector[] biases_ada_dx_g;
    private boolean[] _saw_missing_cats;
    public DeepLearningModel.DeepLearningParameters parameters;
    Key<Model> _model_id;
    private double[] mean_rate;
    private double[] rms_rate;
    private double[] mean_bias;
    private double[] rms_bias;
    private double[] mean_weight;
    public double[] rms_weight;
    public double[] mean_a;
    private volatile boolean unstable;
    private long processed_global;
    private long processed_local;
    int[] units;
    final boolean _classification;
    final Frame _train;
    final Frame _valid;
    public static GradientCheck gradientCheck;
    public static GradientCheck gradientCheckBias;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/deeplearning/DeepLearningModelInfo$GradientCheck.class */
    public static class GradientCheck {
        int layer;
        int row;
        int col;
        double gradient = CMAESOptimizer.DEFAULT_STOPFITNESS;

        GradientCheck(int i, int i2, int i3) {
            this.layer = i;
            this.row = i2;
            this.col = i3;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public void apply(int i, int i2, int i3, double d) {
            if (i2 == this.row && i3 == this.col && i == this.layer) {
                this.gradient += d;
            }
        }
    }

    public DataInfo data_info() {
        return this.data_info;
    }

    public long size() {
        long j = 0;
        for (Storage.DenseRowMatrix denseRowMatrix : this.dense_row_weights) {
            if (denseRowMatrix != null) {
                j += denseRowMatrix.size();
            }
        }
        for (int i = 0; i < this.biases.length; i++) {
            j += r0[i].size();
        }
        return j;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void checkMissingCats(int[] iArr) {
        if (iArr == null || this._saw_missing_cats == null) {
            return;
        }
        for (int i = 0; i < iArr.length; i++) {
            if (!$assertionsDisabled && !this.data_info._catMissing[i]) {
                throw new AssertionError();
            }
            if (!this._saw_missing_cats[i]) {
                this._saw_missing_cats[i] = iArr[i] == this.data_info._catOffsets[i + 1] - 1;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean has_momenta() {
        return (get_params()._momentum_start == CMAESOptimizer.DEFAULT_STOPFITNESS && get_params()._momentum_stable == CMAESOptimizer.DEFAULT_STOPFITNESS) ? false : true;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean adaDelta() {
        return get_params()._adaptive_rate;
    }

    public final Storage.DenseRowMatrix get_weights(int i) {
        return this.dense_row_weights[i];
    }

    public final Storage.DenseVector get_biases(int i) {
        return this.biases[i];
    }

    public final Storage.DenseRowMatrix get_weights_momenta(int i) {
        return this.dense_row_weights_momenta[i];
    }

    public final Storage.DenseVector get_biases_momenta(int i) {
        return this.biases_momenta[i];
    }

    public final Storage.DenseRowMatrix get_ada_dx_g(int i) {
        return this.dense_row_ada_dx_g[i];
    }

    public final Storage.DenseVector get_biases_ada_dx_g(int i) {
        return this.biases_ada_dx_g[i];
    }

    public final Storage.DenseVector get_avg_activations(int i) {
        return this.avg_activations[i];
    }

    public final DeepLearningModel.DeepLearningParameters get_params() {
        return this.parameters;
    }

    public final void set_params(DeepLearningModel.DeepLearningParameters deepLearningParameters, Key<Model> key) {
        this.parameters = (DeepLearningModel.DeepLearningParameters) deepLearningParameters.m1445clone();
        this._model_id = key;
    }

    public boolean isUnstable() {
        return this.unstable;
    }

    public void setUnstable() {
        if (!this.unstable) {
            computeStats();
        }
        this.unstable = true;
    }

    public synchronized long get_processed_global() {
        return this.processed_global;
    }

    public synchronized void set_processed_global(long j) {
        this.processed_global = j;
    }

    public synchronized void add_processed_global(long j) {
        this.processed_global += j;
    }

    public synchronized long get_processed_local() {
        return this.processed_local;
    }

    public synchronized void set_processed_local(long j) {
        this.processed_local = j;
    }

    public synchronized void add_processed_local(long j) {
        this.processed_local += j;
    }

    public synchronized long get_processed_total() {
        return this.processed_global + this.processed_local;
    }

    private DeepLearningModelInfo() {
        this.unstable = false;
        this._classification = false;
        this._valid = null;
        this._train = null;
    }

    public DeepLearningModelInfo(DeepLearningModel.DeepLearningParameters deepLearningParameters, Key key, DataInfo dataInfo, int i, Frame frame, Frame frame2) {
        this.unstable = false;
        this._classification = i > 1;
        this._train = frame;
        this._valid = frame2;
        this.data_info = dataInfo;
        this.parameters = (DeepLearningModel.DeepLearningParameters) deepLearningParameters.m1445clone();
        this._model_id = key;
        DeepLearningModel.DeepLearningParameters.Sanity.modifyParms(this.parameters, this.parameters, i);
        int fullN = dataInfo.fullN();
        int cardinality = get_params()._autoencoder ? fullN : (!this._classification || this.parameters._distribution == DistributionFamily.modified_huber) ? 1 : frame.vec(this.parameters._response_column).cardinality();
        if (!get_params()._autoencoder && !$assertionsDisabled && cardinality != i && this.parameters._distribution != DistributionFamily.modified_huber) {
            throw new AssertionError();
        }
        this._saw_missing_cats = dataInfo._cats > 0 ? new boolean[this.data_info._cats] : null;
        if (!$assertionsDisabled && fullN <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && cardinality <= 0) {
            throw new AssertionError();
        }
        if (has_momenta() && adaDelta()) {
            throw new IllegalArgumentException("Cannot have non-zero momentum and adaptive rate at the same time.");
        }
        int length = get_params()._hidden.length;
        this.units = new int[length + 2];
        if (get_params()._max_categorical_features <= Integer.MAX_VALUE - dataInfo._nums) {
            this.units[0] = Math.min(dataInfo._nums + get_params()._max_categorical_features, fullN);
        } else {
            this.units[0] = fullN;
        }
        System.arraycopy(get_params()._hidden, 0, this.units, 1, length);
        this.units[length + 1] = cardinality;
        boolean z = ((long) this.units[0]) > 1000;
        boolean z2 = ((long) this.units[0]) > 100000;
        if (z) {
            dataInfo._adaptedFrame.domains();
            if (z2) {
                Log.warn("===================================================================================================================================");
                Object[] objArr = new Object[1];
                objArr[0] = fullN + " input features" + (dataInfo._cats > 0 ? " (after categorical one-hot encoding)" : "") + ". Can be slow and require a lot of memory.";
                Log.warn(objArr);
            }
            FrameUtils.printTopCategoricalLevels(dataInfo._adaptedFrame, z2, 10);
            if (z2) {
                Log.warn("Suggestions:");
                Log.warn(" *) Limit the size of the first hidden layer");
                if (dataInfo._cats > 0) {
                    Log.warn(" *) Limit the total number of one-hot encoded features by setting 'categorical_encoding=\"enum_limited\"'");
                    Log.warn(" *) Limit the total number of one-hot encoded features with the parameter 'max_categorical_features' (experimental)");
                    Log.warn(" *) Run h2o.interaction(...,pairwise=F) on high-cardinality categorical columns to limit the factor count, see http://docs.h2o.ai/h2o/latest-stable/h2o-docs/data-science/deep-learning.html#faq");
                }
                Log.warn("===================================================================================================================================");
            }
        }
        int[] iArr = new int[length + 1];
        for (int i2 = 0; i2 < length; i2++) {
            iArr[i2] = (get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) ? 2 : 1;
        }
        iArr[length] = 1;
        this.dense_row_weights = new Storage.DenseRowMatrix[length + 1];
        this.dense_row_weights[0] = new Storage.DenseRowMatrix(iArr[0] * this.units[1], this.units[0]);
        for (int i3 = 1; i3 <= length; i3++) {
            this.dense_row_weights[i3] = new Storage.DenseRowMatrix(iArr[i3] * this.units[i3 + 1], this.units[i3]);
        }
        this.biases = new Storage.DenseVector[length + 1];
        for (int i4 = 0; i4 <= length; i4++) {
            this.biases[i4] = new Storage.DenseVector(iArr[i4] * this.units[i4 + 1]);
        }
        if (get_params()._autoencoder && get_params()._sparsity_beta > CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this.avg_activations = new Storage.DenseVector[length];
            this.mean_a = new double[length];
            for (int i5 = 0; i5 < length; i5++) {
                this.avg_activations[i5] = new Storage.DenseVector(iArr[i5] * this.units[i5 + 1]);
            }
        }
        allocateHelperArrays();
        this.mean_rate = new double[this.units.length - 1];
        this.rms_rate = new double[this.units.length - 1];
        this.mean_bias = new double[this.units.length - 1];
        this.rms_bias = new double[this.units.length - 1];
        this.mean_weight = new double[this.units.length - 1];
        this.rms_weight = new double[this.units.length - 1];
    }

    void allocateHelperArrays() {
        int[] iArr = new int[this.units.length - 1];
        for (int i = 0; i < this.units.length - 1; i++) {
            iArr[i] = (get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) ? 2 : 1;
        }
        iArr[this.units.length - 2] = 1;
        if (has_momenta()) {
            this.dense_row_weights_momenta = new Storage.DenseRowMatrix[this.dense_row_weights.length];
            if (this.dense_row_weights[0] != null) {
                this.dense_row_weights_momenta[0] = new Storage.DenseRowMatrix(iArr[0] * this.units[1], this.units[0]);
            }
            for (int i2 = 1; i2 < this.dense_row_weights_momenta.length; i2++) {
                this.dense_row_weights_momenta[i2] = new Storage.DenseRowMatrix(iArr[i2] * this.units[i2 + 1], this.units[i2]);
            }
            this.biases_momenta = new Storage.DenseVector[this.biases.length];
            for (int i3 = 0; i3 < this.biases_momenta.length; i3++) {
                this.biases_momenta[i3] = new Storage.DenseVector(iArr[i3] * this.units[i3 + 1]);
            }
            return;
        }
        if (adaDelta()) {
            this.dense_row_ada_dx_g = new Storage.DenseRowMatrix[this.dense_row_weights.length];
            this.dense_row_ada_dx_g[0] = new Storage.DenseRowMatrix(iArr[0] * 2 * this.units[1], this.units[0]);
            for (int i4 = 1; i4 < this.dense_row_ada_dx_g.length; i4++) {
                this.dense_row_ada_dx_g[i4] = new Storage.DenseRowMatrix(iArr[i4] * this.units[i4 + 1], 2 * this.units[i4]);
            }
            this.biases_ada_dx_g = new Storage.DenseVector[this.biases.length];
            for (int i5 = 0; i5 < this.biases_ada_dx_g.length; i5++) {
                this.biases_ada_dx_g[i5] = new Storage.DenseVector(iArr[i5] * 2 * this.units[i5 + 1]);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TwoDimTable createSummaryTable() {
        computeStats();
        Neurons[] makeNeuronsForTesting = DeepLearningTask.makeNeuronsForTesting(this);
        TwoDimTable twoDimTable = new TwoDimTable("Status of Neuron Layers", (!get_params()._autoencoder ? "predicting " + this.data_info._adaptedFrame.lastVecName() + ", " : "") + (get_params()._autoencoder ? "auto-encoder" : this._classification ? this.units[this.units.length - 1] + "-class classification" : "regression") + ", " + get_params()._distribution + " distribution, " + get_params()._loss + " loss, " + String.format("%,d", Long.valueOf(size())) + " weights/biases, " + PrettyPrint.bytes(new AutoBuffer().put(this).buf().length) + ", " + String.format("%,d", Long.valueOf(get_processed_global())) + " training samples, mini-batch size " + String.format("%,d", Integer.valueOf(get_params()._mini_batch_size)), new String[makeNeuronsForTesting.length], new String[]{"Layer", "Units", "Type", "Dropout", "L1", "L2", "Mean Rate", "Rate RMS", "Momentum", "Mean Weight", "Weight RMS", "Mean Bias", "Bias RMS"}, new String[]{"int", "int", "string", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double"}, new String[]{"%d", "%d", "%s", "%2.2f %%", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f"}, "");
        for (int i = 0; i < makeNeuronsForTesting.length; i++) {
            twoDimTable.set(i, 0, Integer.valueOf(i + 1));
            twoDimTable.set(i, 1, Integer.valueOf(makeNeuronsForTesting[i].units));
            twoDimTable.set(i, 2, makeNeuronsForTesting[i].getClass().getSimpleName());
            if (i == 0) {
                twoDimTable.set(i, 3, Double.valueOf(makeNeuronsForTesting[i].params._input_dropout_ratio * 100.0d));
            } else {
                if (i < makeNeuronsForTesting.length - 1) {
                    if (makeNeuronsForTesting[i].params._hidden_dropout_ratios == null) {
                        twoDimTable.set(i, 3, 0);
                    } else {
                        twoDimTable.set(i, 3, Double.valueOf(makeNeuronsForTesting[i].params._hidden_dropout_ratios[i - 1] * 100.0d));
                    }
                }
                twoDimTable.set(i, 4, Double.valueOf(makeNeuronsForTesting[i].params._l1));
                twoDimTable.set(i, 5, Double.valueOf(makeNeuronsForTesting[i].params._l2));
                twoDimTable.set(i, 6, Double.valueOf(get_params()._adaptive_rate ? this.mean_rate[i - 1] : makeNeuronsForTesting[i].rate(get_processed_total())));
                twoDimTable.set(i, 7, Double.valueOf(get_params()._adaptive_rate ? this.rms_rate[i - 1] : CMAESOptimizer.DEFAULT_STOPFITNESS));
                twoDimTable.set(i, 8, Float.valueOf(get_params()._adaptive_rate ? 0.0f : makeNeuronsForTesting[i].momentum(get_processed_total())));
                twoDimTable.set(i, 9, Double.valueOf(this.mean_weight[i - 1]));
                twoDimTable.set(i, 10, Double.valueOf(this.rms_weight[i - 1]));
                twoDimTable.set(i, 11, Double.valueOf(this.mean_bias[i - 1]));
                twoDimTable.set(i, 12, Double.valueOf(this.rms_bias[i - 1]));
            }
        }
        this.summaryTable = twoDimTable;
        return this.summaryTable;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (!get_params()._quiet_mode) {
            if (get_params()._sparsity_beta > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                for (int i = 0; i < get_params()._hidden.length; i++) {
                    sb.append("Average activation in hidden layer ").append(i).append(" is  ").append(this.mean_a[i]).append(" \n");
                }
            }
            createSummaryTable();
            sb.append(this.summaryTable.toString(1));
        }
        return sb.toString();
    }

    public String toStringAll() {
        StringBuilder sb = new StringBuilder();
        sb.append(toString());
        for (int i = 0; i < this.units.length - 1; i++) {
            sb.append("\nweights[").append(i).append("][]=").append(Arrays.toString(get_weights(i).raw()));
        }
        for (int i2 = 0; i2 < this.units.length - 1; i2++) {
            sb.append("\nbiases[").append(i2).append("][]=").append(Arrays.toString(get_biases(i2).raw()));
        }
        if (has_momenta()) {
            for (int i3 = 0; i3 < this.units.length - 1; i3++) {
                sb.append("\nweights_momenta[").append(i3).append("][]=").append(Arrays.toString(get_weights_momenta(i3).raw()));
            }
        }
        if (this.biases_momenta != null) {
            for (int i4 = 0; i4 < this.units.length - 1; i4++) {
                sb.append("\nbiases_momenta[").append(i4).append("][]=").append(Arrays.toString(this.biases_momenta[i4].raw()));
            }
        }
        sb.append("\nunits[]=").append(Arrays.toString(this.units));
        sb.append("\nprocessed global: ").append(get_processed_global());
        sb.append("\nprocessed local:  ").append(get_processed_local());
        sb.append("\nprocessed total:  ").append(get_processed_total());
        sb.append(CSVWriter.DEFAULT_LINE_END);
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initializeMembers(Key<Frame>[] keyArr, Key<Frame>[] keyArr2) {
        randomizeWeights();
        int i = 0;
        while (i < get_params()._hidden.length) {
            if (get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.Rectifier || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.RectifierWithDropout || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.Maxout || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout) {
                Arrays.fill(this.biases[i].raw(), i == 0 ? 0.5d : 1.0d);
            } else if (get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.Tanh || get_params()._activation == DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout) {
                Arrays.fill(this.biases[i].raw(), CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
            i++;
        }
        Arrays.fill(this.biases[this.biases.length - 1].raw(), CMAESOptimizer.DEFAULT_STOPFITNESS);
        if (keyArr == null && keyArr2 == null) {
            Log.info("Created random initial model state.");
            return;
        }
        Log.info("Initializing initial model state from user-given weights/biases.");
        for (int i2 = 0; i2 < get_params()._hidden.length + 1; i2++) {
            if (keyArr[i2] == null) {
                Log.info("No user-given weight matrix given for weights #" + (i2 + 1) + ". Initializing those weights randomly.");
            } else if (keyArr2[i2] == null) {
                Log.info("No user-given bias vector given for biases #" + (i2 + 1) + ". Initializing those biases randomly.");
            } else {
                Frame frame = keyArr[i2].get();
                if (frame == null) {
                    throw new IllegalArgumentException("User-given weight matrix for weights #" + (i2 + 1) + " '" + keyArr[i2].toString() + "' not found. Initializing those weights randomly.");
                }
                if (frame.numRows() != get_weights(i2).rows() || frame.numCols() != get_weights(i2).cols()) {
                    throw new IllegalArgumentException("Dimensionality mismatch: initial_weights matrix #" + i2 + " should have " + get_weights(i2).rows() + " rows and " + get_weights(i2).cols() + " columns, but has " + frame.numRows() + " rows and " + frame.numCols() + " columns.");
                }
                Frame frame2 = keyArr2[i2].get();
                if (frame2 == null) {
                    throw new IllegalArgumentException("User-given bias vector for biases #" + (i2 + 1) + " '" + keyArr2[i2].toString() + "' not found. Initializing those biases randomly.");
                }
                if (frame2.numRows() != get_biases(i2).size() || frame2.numCols() != 1) {
                    throw new IllegalArgumentException("Dimensionality mismatch: initial_biases vector #" + i2 + " should have " + get_biases(i2).size() + " rows and 1 column, but has " + frame2.numRows() + " rows and " + frame2.numCols() + " column(s).");
                }
                for (int i3 = 0; i3 < frame.numCols(); i3++) {
                    for (int i4 = 0; i4 < frame.numRows(); i4++) {
                        get_weights(i2).set(i4, i3, (float) frame.vec(i3).at(i4));
                    }
                }
                for (int i5 = 0; i5 < frame.numRows(); i5++) {
                    get_biases(i2).set(i5, (float) frame2.vec(0).at(i5));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void initializeFromPretrainedModel(DeepLearningModelInfo deepLearningModelInfo) {
        if (!$assertionsDisabled && !deepLearningModelInfo.parameters._autoencoder) {
            throw new AssertionError();
        }
        randomizeWeights();
        for (int i = 0; i < this.dense_row_weights.length - 1; i++) {
            if (get_weights(i).rows() != deepLearningModelInfo.get_weights(i).rows()) {
                throw new IllegalArgumentException("Mismatch between weights in pretrained model and this model: rows in layer " + i + ": " + deepLearningModelInfo.get_weights(i).rows() + " vs " + get_weights(i).rows() + ". Enable ignored_const_cols for both models and/or check categorical levels for consistency.");
            }
            if (get_weights(i).cols() != deepLearningModelInfo.get_weights(i).cols()) {
                throw new IllegalArgumentException("Mismatch between weights in pretrained model and this model: cols in layer " + i + ": " + deepLearningModelInfo.get_weights(i).cols() + " vs " + get_weights(i).cols() + ". Enable ignored_const_cols for both models and/or check categorical levels for consistency.");
            }
            for (int i2 = 0; i2 < get_weights(i).rows(); i2++) {
                for (int i3 = 0; i3 < get_weights(i).cols(); i3++) {
                    get_weights(i).set(i2, i3, deepLearningModelInfo.get_weights(i).get(i2, i3));
                }
            }
        }
        for (int i4 = 0; i4 < get_params()._hidden.length; i4++) {
            for (int i5 = 0; i5 < this.biases[i4].raw().length; i5++) {
                this.biases[i4].set(i5, deepLearningModelInfo.biases[i4].get(i5));
            }
        }
        Arrays.fill(this.biases[this.biases.length - 1].raw(), CMAESOptimizer.DEFAULT_STOPFITNESS);
    }

    public void add(DeepLearningModelInfo deepLearningModelInfo) {
        for (int i = 0; i < this.dense_row_weights.length; i++) {
            ArrayUtils.add(get_weights(i).raw(), deepLearningModelInfo.get_weights(i).raw());
        }
        for (int i2 = 0; i2 < this.biases.length; i2++) {
            ArrayUtils.add(this.biases[i2].raw(), deepLearningModelInfo.biases[i2].raw());
        }
        if (this.avg_activations != null) {
            for (int i3 = 0; i3 < this.avg_activations.length; i3++) {
                ArrayUtils.add(this.avg_activations[i3].raw(), deepLearningModelInfo.biases[i3].raw());
            }
        }
        if (has_momenta()) {
            if (!$assertionsDisabled && !deepLearningModelInfo.has_momenta()) {
                throw new AssertionError();
            }
            for (int i4 = 0; i4 < this.dense_row_weights_momenta.length; i4++) {
                ArrayUtils.add(get_weights_momenta(i4).raw(), deepLearningModelInfo.get_weights_momenta(i4).raw());
            }
            for (int i5 = 0; i5 < this.biases_momenta.length; i5++) {
                ArrayUtils.add(this.biases_momenta[i5].raw(), deepLearningModelInfo.biases_momenta[i5].raw());
            }
        }
        if (adaDelta()) {
            if (!$assertionsDisabled && !deepLearningModelInfo.adaDelta()) {
                throw new AssertionError();
            }
            for (int i6 = 0; i6 < this.dense_row_ada_dx_g.length; i6++) {
                ArrayUtils.add(get_ada_dx_g(i6).raw(), deepLearningModelInfo.get_ada_dx_g(i6).raw());
            }
        }
        add_processed_local(deepLearningModelInfo.get_processed_local());
    }

    protected void mult(double d) {
        div(1.0d / d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void div(double d) {
        for (int i = 0; i < this.dense_row_weights.length; i++) {
            ArrayUtils.div(get_weights(i).raw(), (float) d);
        }
        for (Storage.DenseVector denseVector : this.biases) {
            ArrayUtils.div(denseVector.raw(), d);
        }
        if (this.avg_activations != null) {
            for (Storage.DenseVector denseVector2 : this.avg_activations) {
                ArrayUtils.div(denseVector2.raw(), d);
            }
        }
        if (has_momenta()) {
            for (int i2 = 0; i2 < this.dense_row_weights_momenta.length; i2++) {
                ArrayUtils.div(get_weights_momenta(i2).raw(), (float) d);
            }
            for (Storage.DenseVector denseVector3 : this.biases_momenta) {
                ArrayUtils.div(denseVector3.raw(), d);
            }
        }
        if (adaDelta()) {
            for (int i3 = 0; i3 < this.dense_row_ada_dx_g.length; i3++) {
                ArrayUtils.div(get_ada_dx_g(i3).raw(), (float) d);
            }
        }
    }

    double uniformDist(Random random, double d, double d2) {
        return d + (random.nextFloat() * (d2 - d));
    }

    private void randomizeWeights() {
        for (int i = 0; i < this.dense_row_weights.length; i++) {
            RandomBase rng = RandomUtils.getRNG(get_params()._seed + 195911405 + i + 1);
            double sqrt = Math.sqrt(6.0d / (this.units[i] + this.units[i + 1]));
            for (int i2 = 0; i2 < get_weights(i).rows(); i2++) {
                for (int i3 = 0; i3 < get_weights(i).cols(); i3++) {
                    if (get_params()._initial_weight_distribution == DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.UniformAdaptive) {
                        if (i == this.dense_row_weights.length - 1 && this._classification) {
                            get_weights(i).set(i2, i3, (float) (4.0d * uniformDist(rng, -sqrt, sqrt)));
                        } else {
                            get_weights(i).set(i2, i3, (float) uniformDist(rng, -sqrt, sqrt));
                        }
                    } else if (get_params()._initial_weight_distribution == DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.Uniform) {
                        get_weights(i).set(i2, i3, (float) uniformDist(rng, -get_params()._initial_weight_scale, get_params()._initial_weight_scale));
                    } else if (get_params()._initial_weight_distribution == DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.Normal) {
                        get_weights(i).set(i2, i3, (float) (rng.nextGaussian() * get_params()._initial_weight_scale));
                    }
                }
            }
        }
    }

    public float[] computeVariableImportances() {
        float[] fArr = new float[this.units[0]];
        Arrays.fill(fArr, 0.0f);
        if (this.units.length == 2) {
            for (int i = 0; i < this.units[0]; i++) {
                for (int i2 = 0; i2 < this.units[1]; i2++) {
                    int i3 = i;
                    fArr[i3] = fArr[i3] + Math.abs(get_weights(0).get(i2, i));
                }
            }
        } else {
            float[][] fArr2 = new float[this.units[0]][this.units[2]];
            float[] fArr3 = new float[this.units[1]];
            float[] fArr4 = new float[this.units[2]];
            for (float[] fArr5 : fArr2) {
                Arrays.fill(fArr5, 0.0f);
            }
            Arrays.fill(fArr3, 0.0f);
            Arrays.fill(fArr4, 0.0f);
            for (int i4 = 0; i4 < this.units[1]; i4++) {
                for (int i5 = 0; i5 < this.units[0]; i5++) {
                    int i6 = i4;
                    fArr3[i6] = fArr3[i6] + Math.abs(get_weights(0).get(i4, i5));
                }
            }
            for (int i7 = 0; i7 < this.units[2]; i7++) {
                for (int i8 = 0; i8 < this.units[1]; i8++) {
                    int i9 = i7;
                    fArr4[i9] = fArr4[i9] + Math.abs(get_weights(1).get(i7, i8));
                }
            }
            for (int i10 = 0; i10 < this.units[0]; i10++) {
                for (int i11 = 0; i11 < this.units[2]; i11++) {
                    for (int i12 = 0; i12 < this.units[1]; i12++) {
                        float f = get_weights(0).get(i12, i10);
                        float f2 = get_weights(1).get(i11, i12);
                        float[] fArr6 = fArr2[i10];
                        int i13 = i11;
                        fArr6[i13] = fArr6[i13] + (((Math.abs(f) / fArr3[i12]) * Math.abs(f2)) / fArr4[i11]);
                    }
                }
            }
            for (int i14 = 0; i14 < this.units[2]; i14++) {
                float f3 = 0.0f;
                for (int i15 = 0; i15 < this.units[0]; i15++) {
                    f3 += fArr2[i15][i14];
                }
                for (int i16 = 0; i16 < this.units[0]; i16++) {
                    float[] fArr7 = fArr2[i16];
                    int i17 = i14;
                    fArr7[i17] = fArr7[i17] / f3;
                }
            }
            for (int i18 = 0; i18 < this.units[0]; i18++) {
                fArr[i18] = ArrayUtils.sum(fArr2[i18]);
            }
        }
        ArrayUtils.div(fArr, ArrayUtils.maxValue(fArr));
        if (this._saw_missing_cats != null) {
            for (int i19 = 0; i19 < this._saw_missing_cats.length; i19++) {
                if (!$assertionsDisabled && !this.data_info._catMissing[i19]) {
                    throw new AssertionError();
                }
                if (!this._saw_missing_cats[i19]) {
                    fArr[this.data_info._catOffsets[i19 + 1] - 1] = 0.0f;
                }
            }
        }
        return fArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void computeStats() {
        float[][] fArr = get_params()._adaptive_rate ? new float[this.units.length - 1] : (float[][]) null;
        if (get_params()._autoencoder && get_params()._sparsity_beta > CMAESOptimizer.DEFAULT_STOPFITNESS) {
            for (int i = 0; i < get_params()._hidden.length; i++) {
                this.mean_a[i] = 0.0d;
                for (int i2 = 0; i2 < this.avg_activations[i].size(); i2++) {
                    double[] dArr = this.mean_a;
                    int i3 = i;
                    dArr[i3] = dArr[i3] + this.avg_activations[i].get(i2);
                }
                double[] dArr2 = this.mean_a;
                int i4 = i;
                dArr2[i4] = dArr2[i4] / this.avg_activations[i].size();
            }
        }
        for (int i5 = 0; i5 < this.units.length - 1; i5++) {
            this.rms_rate[i5] = 0.0d;
            this.mean_rate[i5] = 0.0d;
            this.rms_bias[i5] = 0.0d;
            this.mean_bias[i5] = 0.0d;
            this.rms_weight[i5] = 0.0d;
            this.mean_weight[i5] = 0.0d;
            for (int i6 = 0; i6 < this.biases[i5].size(); i6++) {
                double[] dArr3 = this.mean_bias;
                int i7 = i5;
                dArr3[i7] = dArr3[i7] + this.biases[i5].get(i6);
            }
            if (fArr != null) {
                fArr[i5] = new float[get_weights(i5).raw().length];
            }
            for (int i8 = 0; i8 < get_weights(i5).raw().length; i8++) {
                double[] dArr4 = this.mean_weight;
                int i9 = i5;
                dArr4[i9] = dArr4[i9] + get_weights(i5).raw()[i8];
                if (fArr != null) {
                    fArr[i5][i8] = MathUtils.approxSqrt(get_ada_dx_g(i5).raw()[2 * i8] + ((float) get_params()._epsilon)) * MathUtils.approxInvSqrt(get_ada_dx_g(i5).raw()[(2 * i8) + 1] + ((float) get_params()._epsilon));
                    double[] dArr5 = this.mean_rate;
                    int i10 = i5;
                    dArr5[i10] = dArr5[i10] + fArr[i5][i8];
                }
            }
            double[] dArr6 = this.mean_bias;
            int i11 = i5;
            dArr6[i11] = dArr6[i11] / this.biases[i5].size();
            double[] dArr7 = this.mean_weight;
            int i12 = i5;
            dArr7[i12] = dArr7[i12] / get_weights(i5).size();
            if (fArr != null) {
                double[] dArr8 = this.mean_rate;
                int i13 = i5;
                dArr8[i13] = dArr8[i13] / fArr[i5].length;
            }
            for (int i14 = 0; i14 < this.biases[i5].size(); i14++) {
                double d = this.biases[i5].get(i14) - this.mean_bias[i5];
                double[] dArr9 = this.rms_bias;
                int i15 = i5;
                dArr9[i15] = dArr9[i15] + (d * d);
            }
            for (int i16 = 0; i16 < get_weights(i5).size(); i16++) {
                double d2 = get_weights(i5).raw()[i16] - this.mean_weight[i5];
                double[] dArr10 = this.rms_weight;
                int i17 = i5;
                dArr10[i17] = dArr10[i17] + (d2 * d2);
                if (fArr != null) {
                    double d3 = fArr[i5][i16] - this.mean_rate[i5];
                    double[] dArr11 = this.rms_rate;
                    int i18 = i5;
                    dArr11[i18] = dArr11[i18] + (d3 * d3);
                }
            }
            this.rms_bias[i5] = MathUtils.approxSqrt(this.rms_bias[i5] / this.biases[i5].size());
            this.rms_weight[i5] = MathUtils.approxSqrt(this.rms_weight[i5] / get_weights(i5).size());
            if (fArr != null) {
                this.rms_rate[i5] = MathUtils.approxSqrt(this.rms_rate[i5] / fArr[i5].length);
            }
            this.unstable |= Double.isNaN(this.mean_bias[i5]) || Double.isNaN(this.rms_bias[i5]) || Double.isNaN(this.mean_weight[i5]) || Double.isNaN(this.rms_weight[i5]) || Math.abs(this.mean_weight[i5]) > 1.0E10d || this.rms_weight[i5] > 1.0E10d || Math.abs(this.mean_bias[i5]) > 100000.0d || this.rms_bias[i5] > 100000.0d;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long checksum_impl() {
        computeStats();
        Random random = new Random(-557122629L);
        double longBitsToDouble = Double.longBitsToDouble(get_params()._seed) + (size() * get_processed_total());
        for (double d : this.mean_bias) {
            longBitsToDouble += random.nextDouble() * (d + 123.23d);
        }
        for (double d2 : this.rms_bias) {
            longBitsToDouble += random.nextDouble() * (d2 + 123.23d);
        }
        for (double d3 : this.mean_weight) {
            longBitsToDouble += random.nextDouble() * (d3 + 123.23d);
        }
        for (double d4 : this.rms_weight) {
            longBitsToDouble += random.nextDouble() * (d4 + 123.23d);
        }
        for (double d5 : this.mean_rate) {
            longBitsToDouble += random.nextDouble() * (d5 + 123.23d);
        }
        for (double d6 : this.rms_rate) {
            longBitsToDouble += random.nextDouble() * (d6 + 123.23d);
        }
        return Double.doubleToRawLongBits(longBitsToDouble);
    }

    public static DeepLearningModelInfo timeAverage(DeepLearningModelInfo deepLearningModelInfo) {
        float f = (float) deepLearningModelInfo.get_params()._elastic_averaging_moving_rate;
        if (!$assertionsDisabled && (f <= 0.0f || f > 1.0f)) {
            throw new AssertionError();
        }
        DeepLearningModelInfo deepLearningModelInfo2 = (DeepLearningModelInfo) DKV.getGet(deepLearningModelInfo.elasticAverageModelInfoKey());
        if (deepLearningModelInfo2 == null || f == 1.0f) {
            deepLearningModelInfo2 = (DeepLearningModelInfo) IcedUtils.deepCopy(deepLearningModelInfo);
        } else {
            deepLearningModelInfo.mult(f);
            deepLearningModelInfo2.mult(1.0f - f);
            deepLearningModelInfo2.add(deepLearningModelInfo);
            deepLearningModelInfo2.set_processed_global(deepLearningModelInfo.get_processed_global());
        }
        deepLearningModelInfo2.set_processed_local(0L);
        DKV.put(deepLearningModelInfo2.elasticAverageModelInfoKey(), deepLearningModelInfo2);
        return deepLearningModelInfo2;
    }

    public Key localModelInfoKey(H2ONode h2ONode) {
        return Key.make(this._model_id + ".node" + h2ONode.index(), (byte) 31, true, h2ONode);
    }

    public Key elasticAverageModelInfoKey() {
        return Key.make(this._model_id + ".elasticaverage", (byte) 31, true, H2O.CLOUD._memary[0]);
    }

    static {
        $assertionsDisabled = !DeepLearningModelInfo.class.desiredAssertionStatus();
        gradientCheck = null;
        gradientCheckBias = null;
    }
}
