package hex.genmodel.algos.deeplearning;

import hex.a;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.utils.DistributionFamily;
import java.io.Serializable;

/* loaded from: input_file:hex/genmodel/algos/deeplearning/DeeplearningMojoModel.class */
public class DeeplearningMojoModel extends MojoModel {
    public int r;
    public int s;
    public int t;
    public int[] u;
    public double[] v;
    public double[] w;
    public double[] x;
    public double[] y;
    public boolean z;
    public String A;
    private String[] H;
    public boolean B;
    public int[] C;
    public double[] D;
    public StoreWeightsBias[] E;
    public int[] F;
    private int I;
    public DistributionFamily G;
    private double[] J;
    private int[] K;
    private static /* synthetic */ boolean L;

    /* loaded from: input_file:hex/genmodel/algos/deeplearning/DeeplearningMojoModel$StoreWeightsBias.class */
    public static class StoreWeightsBias implements Serializable {

        /* renamed from: a, reason: collision with root package name */
        float[] f970a;

        /* renamed from: b, reason: collision with root package name */
        double[] f971b;

        /* JADX INFO: Access modifiers changed from: package-private */
        public StoreWeightsBias(float[] fArr, double[] dArr) {
            this.f970a = fArr;
            this.f971b = dArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeeplearningMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
        this.J = new double[this.s];
        this.K = new int[this.t];
    }

    public final void l() {
        this.I = this.C.length - 1;
        this.H = new String[this.I];
        int i = this.I - 1;
        for (int i2 = 0; i2 < i; i2++) {
            this.H[i2] = this.A;
        }
        this.H[i] = j() ? this.A : i() ? "Softmax" : "Linear";
        this.J = new double[this.s];
        this.K = new int[this.t];
    }

    @Override // hex.genmodel.GenModel
    public final double[] a(double[] dArr, double d, double[] dArr2) {
        if (!L && dArr == null) {
            throw new AssertionError("doubles are null");
        }
        double[] dArr3 = new double[this.C[0]];
        a(dArr, dArr3, this.J, this.K, this.s, this.t, this.u, this.v, this.w, this.z, true);
        for (int i = 0; i < this.I; i++) {
            NeuralNetwork neuralNetwork = new NeuralNetwork(this.H[i], this.D[i], this.E[i], dArr3, this.C[i + 1]);
            dArr3 = NeuralNetwork.a(neuralNetwork.f972a).a(neuralNetwork.f974c == 1 ? neuralNetwork.a() : neuralNetwork.b(), neuralNetwork.f973b, neuralNetwork.f974c);
        }
        if (j() || L || this.j == dArr3.length) {
            return b(dArr3, dArr2, dArr);
        }
        throw new AssertionError("nclasses " + this.j + " neuronsOutput.length " + dArr3.length);
    }

    private double[] b(double[] dArr, double[] dArr2, double[] dArr3) {
        if (j()) {
            if (this.v == null || this.v.length <= 0) {
                for (int i = 0; i < dArr.length; i++) {
                    dArr2[i] = dArr[i];
                }
            } else {
                int length = dArr.length - this.s;
                for (int i2 = 0; i2 < length; i2++) {
                    dArr2[i2] = dArr[i2];
                }
                for (int i3 = 0; i3 < this.s; i3++) {
                    int i4 = length + i3;
                    dArr2[i4] = (dArr[i4] / this.v[i3]) + this.w[i3];
                }
            }
        } else if (this.G == DistributionFamily.modified_huber) {
            dArr2[0] = -1.0d;
            dArr2[2] = a(this.G, dArr2[0]);
            dArr2[1] = 1.0d - dArr2[2];
        } else if (!i()) {
            if (this.x != null) {
                dArr2[0] = (dArr[0] / this.x[0]) + this.y[0];
            } else {
                dArr2[0] = dArr[0];
            }
            dArr2[0] = a(this.G, dArr2[0]);
            if (Double.isNaN(dArr2[0])) {
                throw new RuntimeException("Predicted regression target NaN!");
            }
        } else {
            if (!L && dArr2.length != dArr.length + 1) {
                throw new AssertionError();
            }
            for (int i5 = 0; i5 < dArr2.length - 1; i5++) {
                dArr2[i5 + 1] = dArr[i5];
                if (Double.isNaN(dArr2[i5 + 1])) {
                    throw new RuntimeException("Predicted class probability NaN!");
                }
            }
            if (this.k) {
                GenModel.a(dArr2, this.m, this.n);
            }
            dArr2[0] = GenModel.a(dArr2, this.m, dArr3, this.l);
        }
        return dArr2;
    }

    private static double a(DistributionFamily distributionFamily, double d) {
        switch (distributionFamily) {
            case bernoulli:
            case quasibinomial:
            case modified_huber:
            case ordinal:
                return 1.0d / (1.0d + Math.min(1.0E19d, Math.exp(-d)));
            case multinomial:
            case poisson:
            case gamma:
            case tweedie:
                return Math.min(1.0E19d, Math.exp(d));
            default:
                return d;
        }
    }

    @Override // hex.genmodel.GenModel
    public final double[] a(double[] dArr, double[] dArr2) {
        return a(dArr, 0.0d, dArr2);
    }

    @Override // hex.genmodel.GenModel
    public final int a(a aVar) {
        if (aVar == a.AutoEncoder) {
            return this.C[0];
        }
        if (i()) {
            return c() + 1;
        }
        return 2;
    }

    public final double b(double[] dArr, double[] dArr2) {
        if (!L && (dArr == null || dArr.length <= 0 || dArr2 == null || dArr2.length <= 0)) {
            throw new AssertionError();
        }
        if (!L && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        int length = dArr.length - this.s;
        double d = 0.0d;
        int i = 0;
        while (i < dArr.length) {
            d += Math.pow((dArr2[i] - dArr[i]) * ((this.v == null || this.v.length <= 0 || this.s <= 0 || i < length) ? 1.0d : this.v[i - length]), 2.0d);
            i++;
        }
        return d / dArr.length;
    }

    static {
        L = !DeeplearningMojoModel.class.desiredAssertionStatus();
    }
}
