package hex.genmodel.algos.deeplearning;

import hex.ModelCategory;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.utils.DistributionFamily;
import java.io.Serializable;

/* loaded from: input_file:hex/genmodel/algos/deeplearning/DeeplearningMojoModel.class */
public class DeeplearningMojoModel extends MojoModel {
    public int _mini_batch_size;
    public int _nums;
    public int _cats;
    public int[] _catoffsets;
    public double[] _normmul;
    public double[] _normsub;
    public double[] _normrespmul;
    public double[] _normrespsub;
    public boolean _use_all_factor_levels;
    public String _activation;
    public String[] _allActivations;
    public boolean _imputeMeans;
    public int[] _units;
    public double[] _all_drop_out_ratios;
    public StoreWeightsBias[] _weightsAndBias;
    public int[] _catNAFill;
    public int _numLayers;
    public DistributionFamily _family;
    protected String _genmodel_encoding;
    protected String[] _orig_names;
    protected String[][] _orig_domain_values;
    protected double[] _orig_projection_array;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/genmodel/algos/deeplearning/DeeplearningMojoModel$StoreWeightsBias.class */
    public static class StoreWeightsBias implements Serializable {
        float[] _wValues;
        double[] _bValues;

        /* JADX INFO: Access modifiers changed from: package-private */
        public StoreWeightsBias(float[] fArr, double[] dArr) {
            this._wValues = fArr;
            this._bValues = dArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeeplearningMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    public void init() {
        this._numLayers = this._units.length - 1;
        this._allActivations = new String[this._numLayers];
        int i = this._numLayers - 1;
        for (int i2 = 0; i2 < i; i2++) {
            this._allActivations[i2] = this._activation;
        }
        this._allActivations[i] = isAutoEncoder() ? this._activation : isClassifier() ? "Softmax" : "Linear";
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError("doubles are null");
        }
        double[] dArr3 = new double[this._units[0]];
        setInput(dArr, dArr3, new double[this._nums], new int[this._cats], this._nums, this._cats, this._catoffsets, this._normmul, this._normsub, this._use_all_factor_levels, true);
        for (int i = 0; i < this._numLayers; i++) {
            dArr3 = new NeuralNetwork(this._allActivations[i], this._all_drop_out_ratios[i], this._weightsAndBias[i], dArr3, this._units[i + 1]).fprop1Layer();
        }
        if (isAutoEncoder() || $assertionsDisabled || this._nclasses == dArr3.length) {
            return modifyOutputs(dArr3, dArr2, dArr);
        }
        throw new AssertionError("nclasses " + this._nclasses + " neuronsOutput.length " + dArr3.length);
    }

    public double[] modifyOutputs(double[] dArr, double[] dArr2, double[] dArr3) {
        if (isAutoEncoder()) {
            if (this._normmul == null || this._normmul.length <= 0) {
                for (int i = 0; i < dArr.length; i++) {
                    dArr2[i] = dArr[i];
                }
            } else {
                int length = dArr.length - this._nums;
                for (int i2 = 0; i2 < length; i2++) {
                    dArr2[i2] = dArr[i2];
                }
                for (int i3 = 0; i3 < this._nums; i3++) {
                    int i4 = length + i3;
                    dArr2[i4] = (dArr[i4] / this._normmul[i3]) + this._normsub[i3];
                }
            }
        } else if (this._family == DistributionFamily.modified_huber) {
            dArr2[0] = -1.0d;
            dArr2[2] = linkInv(this._family, dArr2[0]);
            dArr2[1] = 1.0d - dArr2[2];
        } else if (!isClassifier()) {
            if (this._normrespmul != null) {
                dArr2[0] = (dArr[0] / this._normrespmul[0]) + this._normrespsub[0];
            } else {
                dArr2[0] = dArr[0];
            }
            dArr2[0] = linkInv(this._family, dArr2[0]);
            if (Double.isNaN(dArr2[0])) {
                throw new RuntimeException("Predicted regression target NaN!");
            }
        } else {
            if (!$assertionsDisabled && dArr2.length != dArr.length + 1) {
                throw new AssertionError();
            }
            for (int i5 = 0; i5 < dArr2.length - 1; i5++) {
                dArr2[i5 + 1] = dArr[i5];
                if (Double.isNaN(dArr2[i5 + 1])) {
                    throw new RuntimeException("Predicted class probability NaN!");
                }
            }
            if (this._balanceClasses) {
                GenModel.correctProbabilities(dArr2, this._priorClassDistrib, this._modelClassDistrib);
            }
            dArr2[0] = GenModel.getPrediction(dArr2, this._priorClassDistrib, dArr3, this._defaultThreshold);
        }
        return dArr2;
    }

    private double linkInv(DistributionFamily distributionFamily, double d) {
        switch (distributionFamily) {
            case bernoulli:
            case quasibinomial:
            case modified_huber:
            case ordinal:
                return 1.0d / (1.0d + Math.min(1.0E19d, Math.exp(-d)));
            case multinomial:
            case poisson:
            case gamma:
            case tweedie:
                return Math.min(1.0E19d, Math.exp(d));
            default:
                return d;
        }
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        return score0(dArr, 0.0d, dArr2);
    }

    @Override // hex.genmodel.GenModel
    public int getPredsSize(ModelCategory modelCategory) {
        if (modelCategory == ModelCategory.AutoEncoder) {
            return this._units[0];
        }
        if (isClassifier()) {
            return nclasses() + 1;
        }
        return 2;
    }

    public double calculateReconstructionErrorPerRowData(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && (dArr == null || dArr.length <= 0 || dArr2 == null || dArr2.length <= 0)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        int length = dArr.length - this._nums;
        double d = 0.0d;
        int i = 0;
        while (i < dArr.length) {
            d += Math.pow((dArr2[i] - dArr[i]) * ((this._normmul == null || this._normmul.length <= 0 || this._nums <= 0 || i < length) ? 1.0d : this._normmul[i - length]), 2.0d);
            i++;
        }
        return d / dArr.length;
    }

    @Override // hex.genmodel.GenModel, water.genmodel.IGeneratedModel
    public CategoricalEncoding getCategoricalEncoding() {
        String str = this._genmodel_encoding;
        boolean z = -1;
        switch (str.hashCode()) {
            case -1991381482:
                if (str.equals("SortByResponse")) {
                    z = true;
                    break;
                }
                break;
            case -1930758232:
                if (str.equals("LabelEncoder")) {
                    z = 5;
                    break;
                }
                break;
            case 2020783:
                if (str.equals("AUTO")) {
                    z = false;
                    break;
                }
                break;
            case 66953228:
                if (str.equals("Eigen")) {
                    z = 4;
                    break;
                }
                break;
            case 1507527684:
                if (str.equals("OneHotInternal")) {
                    z = 2;
                    break;
                }
                break;
            case 1989867553:
                if (str.equals("Binary")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
            case true:
                return CategoricalEncoding.AUTO;
            case true:
                return CategoricalEncoding.Binary;
            case true:
                return CategoricalEncoding.Eigen;
            case true:
                return CategoricalEncoding.LabelEncoder;
            default:
                return null;
        }
    }

    @Override // hex.genmodel.GenModel, water.genmodel.IGeneratedModel
    public String[] getOrigNames() {
        return this._orig_names;
    }

    @Override // hex.genmodel.GenModel, water.genmodel.IGeneratedModel
    public double[] getOrigProjectionArray() {
        return this._orig_projection_array;
    }

    @Override // hex.genmodel.GenModel, water.genmodel.IGeneratedModel
    public String[][] getOrigDomainValues() {
        return this._orig_domain_values;
    }

    static {
        $assertionsDisabled = !DeeplearningMojoModel.class.desiredAssertionStatus();
    }
}
