package ai.h2o.mojos.runtime.transforms;

import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.utils.SB;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.Map;

/* loaded from: input_file:ai/h2o/mojos/runtime/transforms/o.class */
public class o extends MojoTransformBuilder implements InterfaceC0022c {
    private final Object[] a;
    private final Object[] b;
    private final int[][] c;
    private final int[] d;
    private double e;
    private double f;
    private double g;
    private double h;
    private a i;
    private int j;
    private int k;
    private final String[] l;
    private boolean m;
    private final MojoColumn.Type n;
    private static /* synthetic */ boolean o;

    /* loaded from: input_file:ai/h2o/mojos/runtime/transforms/o$a.class */
    public enum a {
        REGRESSION,
        BINOMIAL,
        MULTINOMIAL
    }

    /* loaded from: input_file:ai/h2o/mojos/runtime/transforms/o$b.class */
    static class b extends K implements InterfaceC0022c {
        private final M[] a;
        private final M b;
        private final a c;
        private final Map<M, int[]> d;

        b(MojoColumn.Type[] typeArr, MojoColumn.Type type, a aVar) {
            this.a = new M[typeArr.length];
            M[] values = M.values();
            int[] iArr = new int[values.length];
            this.d = new EnumMap(M.class);
            for (int i = 0; i < typeArr.length; i++) {
                M a = M.a(typeArr[i]);
                this.a[i] = a;
                int ordinal = a.ordinal();
                iArr[ordinal] = iArr[ordinal] + 1;
            }
            for (M m : values) {
                this.d.put(m, new int[iArr[m.ordinal()]]);
            }
            Arrays.fill(iArr, 0);
            for (int i2 = 0; i2 < this.a.length; i2++) {
                M m2 = this.a[i2];
                int ordinal2 = m2.ordinal();
                this.d.get(m2)[iArr[ordinal2]] = i2;
                iArr[ordinal2] = iArr[ordinal2] + 1;
            }
            this.b = M.a(type);
            this.c = aVar;
        }

        @Override // ai.h2o.mojos.runtime.transforms.K
        final String a() {
            SB sb = new SB();
            sb.p("ai.h2o.mojos.runtime.transforms.MojoTransformFTRL_");
            int i = 0;
            M m = null;
            if (this.a.length > 0) {
                sb.p(this.a[0] + "_");
                m = this.a[0];
            }
            for (M m2 : this.a) {
                if (m2 == m) {
                    i++;
                } else {
                    if (i > 1) {
                        sb.p(i).p("_");
                    }
                    sb.p(m2 + "_");
                    i = 1;
                }
                m = m2;
            }
            if (i > 1) {
                sb.p(i).p("_");
            }
            sb.p("_" + this.b + "_");
            sb.p(this.c.name());
            return sb.toString();
        }

        @Override // ai.h2o.mojos.runtime.transforms.K
        final String[] b() {
            String str = this.b.i;
            int i = 0;
            for (M m : M.values()) {
                if (this.d.get(m).length > 1) {
                    i++;
                }
            }
            String[] strArr = new String[i + 14];
            strArr[0] = "int[] inputIndices";
            strArr[1] = "int[] outputIndices";
            strArr[2] = str + " alpha";
            strArr[3] = str + " beta";
            strArr[4] = str + " lambda1";
            strArr[5] = str + " lambda2";
            strArr[6] = "int numBins";
            strArr[7] = "int numFeatures";
            strArr[8] = "ai.h2o.mojos.runtime.utils.HashUtils hasher";
            strArr[9] = "final " + str + "[][] z";
            strArr[10] = "final " + str + "[][] n";
            strArr[11] = "final int[][] interactions";
            strArr[12] = "final int[] dataLabelIDs";
            strArr[13] = "final long[] featureHashes";
            int i2 = 14;
            for (M m2 : M.values()) {
                if (this.d.get(m2).length > 1) {
                    strArr[i2] = "final int[] " + a(m2);
                    i2++;
                }
            }
            return strArr;
        }

        @Override // ai.h2o.mojos.runtime.transforms.K
        final String a(String str) {
            String str2 = this.b.i;
            SB sb = new SB();
            sb.p("  public ").p(str).p("(int[] inputIndices, int[] outputIndices, double alpha, double beta, double lambda1, double lambda2, int numBins, int mantissaNBits, ").p(str2).p("[][] z, ").p(str2).p("[][] n, int[][] interactions, int[] dataLabelIDs, String[] featureNames) {").nl().p("    this.inputIndices = inputIndices;").nl().p("    this.outputIndices = outputIndices;").nl().p("    this.alpha = (").p(str2).p(") alpha;").nl().p("    this.beta = (").p(str2).p(") beta;").nl().p("    this.lambda1 = (").p(str2).p(") lambda1;").nl().p("    this.lambda2 = (").p(str2).p(")lambda2;").nl().p("    this.numBins = numBins;").nl().p("    this.hasher = new ai.h2o.mojos.runtime.utils.HashUtils(52 - mantissaNBits);").nl().p("    this.z = z;").nl().p("    this.n = n;").nl().p("    this.interactions = interactions;").nl().p("    this.dataLabelIDs = dataLabelIDs;").nl().p("    this.featureHashes = new long[featureNames.length];").nl().p("    for (int i = 0; i < this.featureHashes.length; i += 1) {").nl().p("      this.featureHashes[i] = this.hasher.hash(featureNames[i]);").nl().p("    }").nl().p("    this.numFeatures = this.featureHashes.length + this.interactions.length;").nl();
            for (M m : M.values()) {
                int[] iArr = this.d.get(m);
                if (iArr.length > 1) {
                    String str3 = "{";
                    sb.p("    this.").p(a(m)).p(" = new int[] ");
                    for (int i : iArr) {
                        sb.p(str3).p(i);
                        str3 = ", ";
                    }
                    sb.p("};").nl();
                }
            }
            sb.p("  }").nl();
            return sb.toString();
        }

        @Override // ai.h2o.mojos.runtime.transforms.K
        final String b(String str) {
            String str2 = this.b.i;
            String str3 = str2 + "[]";
            SB nl = new SB().p("  public void ").p(str).p("(MojoFrame frame) {").nl().p("    ").p(str3 + "[]").p(" outputs = new ").p(str2).p("[outputIndices.length][];").nl().p("    for (int i = 0; i < outputIndices.length; i += 1) {").nl().p("      outputs[i] = (").p(str3).p(") frame.getColumnData(outputIndices[i]);").nl().p("    }").nl();
            a(nl, 4);
            if (this.c == a.BINOMIAL) {
                nl.p("    byte binomialIdx = (byte) 0;").nl();
            }
            nl.p("    int nrows = frame.getNrows();").nl().p("    ").p(str2).p(" inverseAlpha = (").p(str2).p(") (1 / alpha);").nl().p("    ").p(str2).p(" rr = (").p(str2).p(") (beta * inverseAlpha + lambda2);").nl().p("    for (int i = 0; i < nrows; i += 1) {").nl().p("      long[] tmp = new long[numFeatures];").nl().p("      long hash;").nl();
            b(nl, 6);
            nl.p("      for (int k = 0; k < interactions.length; k += 1) {").nl().p("        int[] interaction = interactions[k];").nl().p("        int idx = inputIndices.length + k;").nl().p("        for (int z = 0; z < interaction.length; z += 1) {").nl().p("          tmp[idx] += tmp[interaction[z]];").nl().p("        }").nl().p("        tmp[idx] %= numBins;").nl().p("      }").nl().p("      for (int k = 0; k < dataLabelIDs.length; k += 1) {").nl().p("        int labelID = dataLabelIDs[k];").nl();
            if (this.c == a.BINOMIAL) {
                nl.p("        if (labelID == 1) {").nl().p("          binomialIdx = (byte) (k == 0 ? 0 : 1);").nl().p("          continue;").nl().p("        }").nl();
            }
            nl.p("        ").p(str2).p(" wTx = 0;").nl().p("        for (int j = 0; j < tmp.length; j += 1) {").nl().p("          int idx = (int) tmp[j];").nl().p("          wTx -= (").p(str2).p(") Math.copySign(Math.max(Math.abs(z[labelID][idx]) - lambda1, (").p(str2).p(") 0.0) / (n[labelID][idx] * inverseAlpha + rr), z[labelID][idx]);").nl().p("        }").nl();
            switch (this.c) {
                case BINOMIAL:
                    nl.p("        wTx = 1 / (1 + (").p(str2).p(") Math.exp((double) -wTx));").nl();
                    break;
                case MULTINOMIAL:
                    nl.p("        if (dataLabelIDs.length > 2) {").nl().p("          wTx = (").p(str2).p(") Math.exp((double) wTx);").nl().p("        } else {").nl().p("          wTx = 1 / (1 + (").p(str2).p(") Math.exp((double) -wTx));").nl().p("        }").nl();
                    break;
            }
            nl.p("        outputs[k][i] = wTx;").nl().p("      }").nl().p("    }").nl();
            if (this.c == a.BINOMIAL) {
                nl.p("    byte tmp = (byte) (binomialIdx == 0 ? 1 : 0);").nl().p("    for (int i = 0; i < nrows; i += 1) {").nl().p("      outputs[binomialIdx][i] = 1 - outputs[tmp][i];").nl().p("    }").nl();
            }
            nl.p("    if (dataLabelIDs.length > 2) {").nl().p("      for (int i = 0; i < nrows; i += 1) {").nl().p("        ").p(str2).p(" sum = 0;").nl().p("        for(int c = 0; c < outputs.length; c += 1) {").nl().p("          sum += outputs[c][i];").nl().p("        }").nl().p("        for(int c = 0; c < outputs.length; c += 1) {").nl().p("          outputs[c][i] /= sum;").nl().p("        }").nl().p("      }").nl().p("    }").nl().p("  }").nl();
            return nl.toString();
        }

        private static String a(M m) {
            return "iindices" + m.h;
        }

        private static String a(M m, int i) {
            return i > 1 ? "inputs" + m.h : "input" + m.h;
        }

        private SB a(SB sb, int i) {
            String replace = new String(new char[4]).replace((char) 0, ' ');
            for (M m : M.values()) {
                int[] iArr = this.d.get(m);
                String a = a(m, iArr.length);
                if (iArr.length > 1) {
                    String a2 = a(m);
                    sb.p(replace).p(m.i).p("[][] ").p(a).p(" = new ").p(m.i).p("[").p(a2).p(".length][];").nl().p(replace).p("for (int i = 0; i < ").p(a2).p(".length; i += 1) {").nl().p(replace).p("  ").p(a).p("[i] = (").p(m.i).p("[]) frame.getColumnData(inputIndices[").p(a2).p("[i]]);").nl().p(replace).p("}").nl();
                } else if (iArr.length == 1) {
                    sb.p(replace).p(m.i).p("[] ").p(a).p(" = (").p(m.i).p("[]) frame.getColumnData(inputIndices[").p(iArr[0]).p("]);").nl();
                }
            }
            return sb;
        }

        private SB b(SB sb, int i) {
            String replace = new String(new char[6]).replace((char) 0, ' ');
            for (M m : M.values()) {
                int[] iArr = this.d.get(m);
                String a = a(m, iArr.length);
                if (iArr.length > 1) {
                    String a2 = a(m);
                    sb.p(replace).p("for (int x = 0; x < ").p(a).p(".length; x += 1) {").nl().p(replace).p("  hash = (").p(m.a(a + "[x][i]")).p(" ? ").p(M.Int64Gen.k).p(" : hasher.hash(").p(a).p("[x][i])) + featureHashes[").p(a2).p("[x]];").nl().p(replace).p("  tmp[").p(a2).p("[x]] = hash < 0 ? (((((hash >>> 1) % numBins) << 1) + (hash & 1L)) % numBins) : (hash % numBins);").nl().p(replace).p("}").nl();
                } else if (iArr.length == 1) {
                    sb.p(replace).p("hash = (").p(m.a(a + "[i]")).p(" ? ").p(M.Int64Gen.k).p(" : hasher.hash(").p(a).p("[i])) + featureHashes[").p(iArr[0]).p("];").nl().p(replace).p("tmp[").p(iArr[0]).p("] = hash < 0 ? (((((hash >>> 1) % numBins) << 1) + (hash & 1L)) % numBins) : (hash % numBins);").nl();
                }
            }
            return sb;
        }
    }

    private o(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2, double d, double d2, double d3, double d4, a aVar, int i, int i2, Object[] objArr, Object[] objArr2, MojoColumn.Type type, int[][] iArr3, int[] iArr4, String[] strArr) {
        super(mojoFrameMeta, iArr, iArr2);
        this.d = iArr4;
        this.l = strArr;
        this.c = iArr3;
        this.e = d;
        this.f = d2;
        this.g = d3;
        this.h = d4;
        this.i = aVar;
        this.j = i;
        this.k = i2;
        this.m = false;
        this.n = type;
        if (!o && this.d == null) {
            throw new AssertionError("Label IDs array cannot be null");
        }
        if (!o && this.l == null) {
            throw new AssertionError("Feature names array cannot be null");
        }
        if (!o && this.l.length != iArr.length) {
            throw new AssertionError("Feature names array length doesn't match number of input columns");
        }
        if (!o && this.c == null) {
            throw new AssertionError("Interactions array cannot be null");
        }
        for (int[] iArr5 : this.c) {
            if (!o && iArr5 == null) {
                throw new AssertionError("Null interaction found in interactions array");
            }
            if (!o && iArr5.length <= 0) {
                throw new AssertionError("Interaction array must have at least one element");
            }
        }
        if (!o && objArr == null) {
            throw new AssertionError("z array cannot be null");
        }
        this.a = new Object[objArr.length];
        System.arraycopy(objArr, 0, this.a, 0, this.a.length);
        if (!o && objArr2 == null) {
            throw new AssertionError("n array cannot be null");
        }
        this.b = new Object[objArr2.length];
        System.arraycopy(objArr2, 0, this.b, 0, this.b.length);
        if (!o && i2 <= 0) {
            throw new AssertionError("Number of bins must be positive");
        }
        if (!o && i < 0) {
            throw new AssertionError("Number of mantissa bits must be non-negative");
        }
        if (!o && this.i != a.REGRESSION && this.i != a.BINOMIAL && this.i != a.MULTINOMIAL) {
            throw new AssertionError("Model type must be one of {REGRESSION, BINOMIAL, MULTINOMIAL}");
        }
    }

    public o(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2, double d, double d2, double d3, double d4, a aVar, int i, int i2, double[][] dArr, double[][] dArr2, int[][] iArr3, int[] iArr4, String[] strArr) {
        this(mojoFrameMeta, iArr, iArr2, d, d2, d3, d4, aVar, i, i2, dArr, dArr2, MojoColumn.Type.Float64, iArr3, iArr4, strArr);
        if (dArr.length > 0) {
            if (!o && dArr[0] == null) {
                throw new AssertionError("Null column found in z array");
            }
            int length = dArr[0].length;
            for (double[] dArr3 : dArr) {
                if (!o && dArr3 == null) {
                    throw new AssertionError("Null column found in z array");
                }
                if (!o && dArr3.length != length) {
                    throw new AssertionError("Column lengths in z array do not match");
                }
            }
        }
        if (dArr2.length > 0) {
            if (!o && dArr2[0] == null) {
                throw new AssertionError("Null column found in n array");
            }
            int length2 = dArr2[0].length;
            for (double[] dArr4 : dArr2) {
                if (!o && dArr4 == null) {
                    throw new AssertionError("Null column found in n array");
                }
                if (!o && dArr4.length != length2) {
                    throw new AssertionError("Column lengths in n array do not match");
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public MojoTransform build() {
        MojoColumn.Type[] inputTypes = getInputTypes();
        MojoColumn.Type[] outputTypes = getOutputTypes();
        ai.h2o.mojos.runtime.transforms.a.a.a(inputTypes, 498, "Time64 not allowed as an input type");
        ai.h2o.mojos.runtime.transforms.a.a.a(outputTypes, 193, "Output columns must be of the same type");
        MojoColumn.Type type = outputTypes.length > 0 ? outputTypes[0] : MojoColumn.Type.Float32;
        Object obj = null;
        Object obj2 = null;
        if (type == MojoColumn.Type.Float32) {
            float[] fArr = new float[this.a.length];
            float[] fArr2 = new float[this.b.length];
            if (!this.m) {
                if (this.n == MojoColumn.Type.Float64) {
                    for (int i = 0; i < this.a.length; i++) {
                        double[] dArr = (double[]) this.a[i];
                        float[] fArr3 = new float[dArr.length];
                        for (int i2 = 0; i2 < dArr.length; i2++) {
                            fArr3[i2] = (float) dArr[i2];
                        }
                        this.a[i] = fArr3;
                    }
                    for (int i3 = 0; i3 < this.b.length; i3++) {
                        double[] dArr2 = (double[]) this.b[i3];
                        float[] fArr4 = new float[dArr2.length];
                        for (int i4 = 0; i4 < dArr2.length; i4++) {
                            fArr4[i4] = (float) Math.sqrt((float) dArr2[i4]);
                        }
                        this.b[i3] = fArr4;
                    }
                } else if (this.n == MojoColumn.Type.Float32) {
                    for (int i5 = 0; i5 < this.b.length; i5++) {
                        float[] fArr5 = (float[]) this.b[i5];
                        float[] fArr6 = new float[fArr5.length];
                        for (int i6 = 0; i6 < fArr5.length; i6++) {
                            fArr6[i6] = (float) Math.sqrt(fArr5[i6]);
                        }
                        this.b[i5] = fArr6;
                    }
                }
                this.m = true;
            }
            for (int i7 = 0; i7 < this.a.length; i7++) {
                fArr[i7] = (float[]) this.a[i7];
            }
            for (int i8 = 0; i8 < this.b.length; i8++) {
                fArr2[i8] = (float[]) this.b[i8];
            }
            obj = fArr;
            obj2 = fArr2;
        } else if (type == MojoColumn.Type.Float64) {
            double[] dArr3 = new double[this.a.length];
            double[] dArr4 = new double[this.b.length];
            if (!this.m) {
                if (this.n == MojoColumn.Type.Float64) {
                    for (int i9 = 0; i9 < this.b.length; i9++) {
                        double[] dArr5 = (double[]) this.b[i9];
                        double[] dArr6 = new double[dArr5.length];
                        for (int i10 = 0; i10 < dArr5.length; i10++) {
                            dArr6[i10] = Math.sqrt(dArr5[i10]);
                        }
                        this.b[i9] = dArr6;
                    }
                } else if (this.n == MojoColumn.Type.Float32) {
                    for (int i11 = 0; i11 < this.a.length; i11++) {
                        float[] fArr7 = (float[]) this.a[i11];
                        double[] dArr7 = new double[fArr7.length];
                        for (int i12 = 0; i12 < fArr7.length; i12++) {
                            dArr7[i12] = fArr7[i12];
                        }
                        this.a[i11] = dArr7;
                    }
                    for (int i13 = 0; i13 < this.b.length; i13++) {
                        float[] fArr8 = (float[]) this.b[i13];
                        double[] dArr8 = new double[fArr8.length];
                        for (int i14 = 0; i14 < fArr8.length; i14++) {
                            dArr8[i14] = Math.sqrt(fArr8[i14]);
                        }
                        this.b[i13] = dArr8;
                    }
                }
                this.m = true;
            }
            for (int i15 = 0; i15 < this.a.length; i15++) {
                dArr3[i15] = (double[]) this.a[i15];
            }
            for (int i16 = 0; i16 < this.b.length; i16++) {
                dArr4[i16] = (double[]) this.b[i16];
            }
            obj = dArr3;
            obj2 = dArr4;
        }
        return new b(inputTypes, type, this.i).a(this.iindices, this.oindices, Double.valueOf(this.e), Double.valueOf(this.f), Double.valueOf(this.g), Double.valueOf(this.h), Integer.valueOf(this.k), Integer.valueOf(this.j), obj, obj2, this.c, this.d, this.l);
    }

    static {
        o = !o.class.desiredAssertionStatus();
    }
}
