package ai.h2o.mojos.runtime.transforms;

import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.util.MojoTransformBuilderUtils;

/* loaded from: input_file:ai/h2o/mojos/runtime/transforms/MojoTransformKMeansBuilder.class */
public class MojoTransformKMeansBuilder extends MojoTransform {
    private final float[][] centroids32;
    private final double[][] centroids64;
    private final OutputType outputType;
    private int n;
    private int k;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:ai/h2o/mojos/runtime/transforms/MojoTransformKMeansBuilder$OutputType.class */
    public enum OutputType {
        LABELS,
        DISTANCES
    }

    private MojoTransformKMeansBuilder(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2, float[][] fArr, double[][] dArr, OutputType outputType) {
        super(iArr, iArr2);
        this.n = -1;
        this.k = -1;
        if (!$assertionsDisabled && outputType == null) {
            throw new AssertionError();
        }
        this.centroids32 = fArr;
        this.centroids64 = dArr;
        this.outputType = outputType;
    }

    public MojoTransformKMeansBuilder(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2, float[][] fArr, OutputType outputType) {
        this(mojoFrameMeta, iArr, iArr2, fArr, (double[][]) null, outputType);
        if (!$assertionsDisabled && fArr == null) {
            throw new AssertionError();
        }
        this.k = fArr.length;
        this.n = fArr[0].length;
        if (!$assertionsDisabled && this.n != this.iindices.length) {
            throw new AssertionError();
        }
        MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.iindices, 64, "Input columns must be of Float32 type");
        switch (outputType) {
            case LABELS:
                if (!$assertionsDisabled && iArr2.length != 1) {
                    throw new AssertionError();
                }
                MojoTransformBuilderUtils.assertType(mojoFrameMeta.getColumnType(this.oindices[0]), 16, "Output column must be of Int32 type");
                break;
            case DISTANCES:
                MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.oindices, 64, "Output column must be of Float32 type");
                break;
        }
        for (float[] fArr2 : fArr) {
            if (!$assertionsDisabled && fArr2.length != iArr.length) {
                throw new AssertionError();
            }
        }
    }

    public MojoTransformKMeansBuilder(MojoFrameMeta mojoFrameMeta, int[] iArr, int[] iArr2, double[][] dArr, OutputType outputType) {
        this(mojoFrameMeta, iArr, iArr2, (float[][]) null, dArr, outputType);
        if (!$assertionsDisabled && dArr == null) {
            throw new AssertionError();
        }
        this.k = dArr.length;
        this.n = dArr[0].length;
        if (!$assertionsDisabled && this.n != this.iindices.length) {
            throw new AssertionError();
        }
        MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.iindices, 128, "Input columns must be of Float64 type");
        switch (outputType) {
            case LABELS:
                if (!$assertionsDisabled && iArr2.length != 1) {
                    throw new AssertionError();
                }
                MojoTransformBuilderUtils.assertType(mojoFrameMeta.getColumnType(this.oindices[0]), 16, "Output column must be of Int32 type");
                break;
            case DISTANCES:
                MojoTransformBuilderUtils.assertTypes(mojoFrameMeta, this.oindices, 128, "Output column must be of Float64 type");
                break;
        }
        for (double[] dArr2 : dArr) {
            if (!$assertionsDisabled && dArr2.length != iArr.length) {
                throw new AssertionError();
            }
        }
    }

    @Override // ai.h2o.mojos.runtime.transforms.MojoTransform
    public void transform(MojoFrame mojoFrame) {
        if (this.centroids32 == null || this.centroids64 != null) {
            switch (this.outputType) {
                case LABELS:
                    kmeansLabelsFloat64(mojoFrame);
                    return;
                case DISTANCES:
                    kmeansDistancesFloat64(mojoFrame);
                    return;
                default:
                    throw new UnsupportedOperationException(this.outputType.toString());
            }
        }
        switch (this.outputType) {
            case LABELS:
                kmeansLabelsFloat32(mojoFrame);
                return;
            case DISTANCES:
                kmeansDistancesFloat32(mojoFrame);
                return;
            default:
                throw new UnsupportedOperationException(this.outputType.toString());
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void kmeansLabelsFloat32(MojoFrame mojoFrame) {
        float[] fArr = new float[this.n];
        for (int i = 0; i < this.n; i++) {
            fArr[i] = (float[]) mojoFrame.getColumnData(this.iindices[i]);
        }
        int[] iArr = (int[]) mojoFrame.getColumnData(this.oindices[0]);
        int nrows = mojoFrame.getNrows();
        for (int i2 = 0; i2 < nrows; i2++) {
            float f = Float.POSITIVE_INFINITY;
            int i3 = Integer.MIN_VALUE;
            int i4 = 0;
            while (true) {
                if (i4 < this.k) {
                    float f2 = 0.0f;
                    float f3 = 0.0f;
                    float f4 = 0.0f;
                    boolean z = true;
                    for (int i5 = 0; i5 < this.n; i5++) {
                        char c = fArr[i5][i2];
                        float f5 = this.centroids32[i4][i5];
                        z = z && c == f5;
                        f2 += c * c;
                        f3 += c * f5;
                        f4 += f5 * f5;
                    }
                    if (z) {
                        i3 = i4;
                        break;
                    }
                    float f6 = (f3 * (-2.0f)) + f2 + f4;
                    if (f6 < 0.0f) {
                        f6 = 0.0f;
                    }
                    if (f6 < f) {
                        f = f6;
                        i3 = i4;
                    }
                    i4++;
                }
            }
            iArr[i2] = i3;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void kmeansLabelsFloat64(MojoFrame mojoFrame) {
        double[] dArr = new double[this.n];
        for (int i = 0; i < this.n; i++) {
            dArr[i] = (double[]) mojoFrame.getColumnData(this.iindices[i]);
        }
        int[] iArr = (int[]) mojoFrame.getColumnData(this.oindices[0]);
        int nrows = mojoFrame.getNrows();
        for (int i2 = 0; i2 < nrows; i2++) {
            double d = Double.POSITIVE_INFINITY;
            int i3 = Integer.MIN_VALUE;
            int i4 = 0;
            while (true) {
                if (i4 < this.k) {
                    double d2 = 0.0d;
                    double d3 = 0.0d;
                    double d4 = 0.0d;
                    boolean z = true;
                    for (int i5 = 0; i5 < this.n; i5++) {
                        long j = dArr[i5][i2];
                        double d5 = this.centroids64[i4][i5];
                        z = z && j == d5;
                        d2 += j * j;
                        d3 += j * d5;
                        d4 += d5 * d5;
                    }
                    if (z) {
                        i3 = i4;
                        break;
                    }
                    double d6 = (d3 * (-2.0d)) + d2 + d4;
                    if (d6 < 0.0d) {
                        d6 = 0.0d;
                    }
                    if (d6 < d) {
                        d = d6;
                        i3 = i4;
                    }
                    i4++;
                }
            }
            iArr[i2] = i3;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void kmeansDistancesFloat32(MojoFrame mojoFrame) {
        float[] fArr = new float[this.n];
        float[] fArr2 = new float[this.k];
        for (int i = 0; i < this.n; i++) {
            fArr[i] = (float[]) mojoFrame.getColumnData(this.iindices[i]);
        }
        for (int i2 = 0; i2 < this.k; i2++) {
            fArr2[i2] = (float[]) mojoFrame.getColumnData(this.oindices[i2]);
        }
        int nrows = mojoFrame.getNrows();
        for (int i3 = 0; i3 < nrows; i3++) {
            for (int i4 = 0; i4 < this.k; i4++) {
                float f = 0.0f;
                float f2 = 0.0f;
                float f3 = 0.0f;
                for (int i5 = 0; i5 < this.n; i5++) {
                    char c = fArr[i5][i3];
                    float f4 = this.centroids32[i4][i5];
                    f += c * c;
                    f2 += c * f4;
                    f3 += f4 * f4;
                }
                float f5 = (f - (2.0f * f2)) + f3;
                if (f5 < 0.0f) {
                    f5 = 0.0f;
                }
                fArr2[i4][i3] = (float) Math.sqrt(f5);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void kmeansDistancesFloat64(MojoFrame mojoFrame) {
        double[] dArr = new double[this.n];
        double[] dArr2 = new double[this.k];
        for (int i = 0; i < this.n; i++) {
            dArr[i] = (double[]) mojoFrame.getColumnData(this.iindices[i]);
        }
        for (int i2 = 0; i2 < this.k; i2++) {
            dArr2[i2] = (double[]) mojoFrame.getColumnData(this.oindices[i2]);
        }
        int nrows = mojoFrame.getNrows();
        for (int i3 = 0; i3 < nrows; i3++) {
            for (int i4 = 0; i4 < this.k; i4++) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i5 = 0; i5 < this.n; i5++) {
                    long j = dArr[i5][i3];
                    double d4 = this.centroids64[i4][i5];
                    d += j * j;
                    d2 += j * d4;
                    d3 += d4 * d4;
                }
                double d5 = ((-2.0d) * d2) + d + d3;
                if (d5 < 0.0d) {
                    d5 = 0.0d;
                }
                dArr2[i4][i3] = Math.sqrt(d5);
            }
        }
    }

    static {
        $assertionsDisabled = !MojoTransformKMeansBuilder.class.desiredAssertionStatus();
    }
}
