package hex.gam.MatrixFrameUtils;

import hex.Model;
import hex.gam.GAM;
import hex.gam.GAMModel;
import hex.gam.GamSplines.ThinPlateRegressionUtils;
import hex.glm.GLMModel;
import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import water.DKV;
import water.Key;
import water.MemoryManager;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/gam/MatrixFrameUtils/GamUtils.class */
public class GamUtils {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/gam/MatrixFrameUtils/GamUtils$AllocateType.class */
    public enum AllocateType {
        firstOneLess,
        sameOrig,
        bothOneLess,
        firstTwoLess
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[][], double[][][]] */
    public static double[][][] allocate3DArrayCS(int i, GAMModel.GAMParameters gAMParameters, AllocateType allocateType) {
        ?? r0 = new double[i];
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            if (gAMParameters._gam_columns_sorted[i3].length == 1) {
                int i4 = i2;
                i2++;
                r0[i4] = allocate2DArray(allocateType, gAMParameters._num_knots_sorted[i3]);
            }
        }
        return r0;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[][], double[][][]] */
    public static double[][][] allocate3DArray(int i, GAMModel.GAMParameters gAMParameters, AllocateType allocateType) {
        ?? r0 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            if (gAMParameters._bs_sorted[i2] != 2) {
                r0[i2] = allocate2DArray(allocateType, gAMParameters._num_knots_sorted[i2]);
            } else {
                r0[i2] = allocate2DArray(allocateType, (gAMParameters._num_knots_sorted[i2] + gAMParameters._spline_orders_sorted[i2]) - 2);
            }
        }
        return r0;
    }

    public static void removeCenteringIS(double[][][] dArr, GAMModel.GAMParameters gAMParameters) {
        int length = gAMParameters._bs_sorted.length;
        for (int i = 0; i < length; i++) {
            if (gAMParameters._bs_sorted[i] == 2) {
                dArr[i] = allocate2DArray(AllocateType.sameOrig, (gAMParameters._num_knots_sorted[i] + gAMParameters._spline_orders_sorted[i]) - 2);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[][], double[][][]] */
    public static double[][][] allocate3DArrayTP(int i, GAMModel.GAMParameters gAMParameters, int[] iArr, int[] iArr2) {
        ?? r0 = new double[i];
        int i2 = 0;
        int length = gAMParameters._gam_columns.length;
        for (int i3 = 0; i3 < length; i3++) {
            if (gAMParameters._bs_sorted[i3] == 1) {
                r0[i2] = MemoryManager.malloc8d(iArr[i2], iArr2[i2]);
                i2++;
            }
        }
        return r0;
    }

    public static double[][] allocate2DArray(AllocateType allocateType, int i) {
        double[][] malloc8d;
        switch (allocateType) {
            case firstOneLess:
                malloc8d = MemoryManager.malloc8d(i - 1, i);
                break;
            case sameOrig:
                malloc8d = MemoryManager.malloc8d(i, i);
                break;
            case bothOneLess:
                malloc8d = MemoryManager.malloc8d(i - 1, i - 1);
                break;
            case firstTwoLess:
                malloc8d = MemoryManager.malloc8d(i - 2, i);
                break;
            default:
                throw new IllegalArgumentException("fileMode can only be firstOneLess, sameOrig, bothOneLess or firstTwoLess.");
        }
        return malloc8d;
    }

    public static Integer[] sortCoeffMags(int i, final double[] dArr) {
        Integer[] numArr = new Integer[i];
        for (int i2 = 0; i2 < numArr.length; i2++) {
            numArr[i2] = Integer.valueOf(i2);
        }
        Arrays.sort(numArr, new Comparator<Integer>() { // from class: hex.gam.MatrixFrameUtils.GamUtils.1
            @Override // java.util.Comparator
            public int compare(Integer num, Integer num2) {
                if (dArr[num.intValue()] < dArr[num2.intValue()]) {
                    return 1;
                }
                return dArr[num.intValue()] > dArr[num2.intValue()] ? -1 : 0;
            }
        });
        return numArr;
    }

    public static boolean equalColNames(String[] strArr, String[] strArr2, String str) {
        boolean contains = ArrayUtils.contains(strArr, str);
        boolean contains2 = ArrayUtils.contains(strArr2, str);
        boolean z = strArr.length == strArr2.length;
        if (contains && !contains2) {
            z = strArr.length == strArr2.length + 1;
        } else if (!contains && contains2) {
            z = strArr.length + 1 == strArr2.length;
        }
        if (!z) {
            return z;
        }
        for (String str2 : strArr) {
            if (str2 != str && !ArrayUtils.contains(strArr2, str2)) {
                return false;
            }
        }
        return true;
    }

    public static void copy2DArray(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        for (int i = 0; i < length; i++) {
            System.arraycopy(dArr[i], 0, dArr2[i], 0, dArr[i].length);
        }
    }

    public static void copy2DArray(int[][] iArr, int[][] iArr2) {
        int length = iArr.length;
        for (int i = 0; i < length; i++) {
            System.arraycopy(iArr[i], 0, iArr2[i], 0, iArr[i].length);
        }
    }

    public static void copyCVGLMtoGAMModel(GAMModel gAMModel, GLMModel gLMModel, GAMModel.GAMParameters gAMParameters, String str) {
        ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_metrics = ((GLMModel.GLMOutput) gLMModel._output)._cross_validation_metrics;
        ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_metrics_summary = GAMModelUtils.copyTwoDimTable(((GLMModel.GLMOutput) gLMModel._output)._cross_validation_metrics_summary, "GLM cross-validation metrics summary");
        int length = ((GLMModel.GLMOutput) gLMModel._output)._cv_scoring_history.length;
        ((GAMModel.GAMModelOutput) gAMModel._output)._glm_cv_scoring_history = new TwoDimTable[length];
        if (gAMParameters._keep_cross_validation_predictions) {
            ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_predictions = new Key[length];
        }
        for (int i = 0; i < length; i++) {
            ((GAMModel.GAMModelOutput) gAMModel._output)._glm_cv_scoring_history[i] = GAMModelUtils.copyTwoDimTable(((GLMModel.GLMOutput) gLMModel._output)._cv_scoring_history[i], ((GLMModel.GLMOutput) gLMModel._output)._cv_scoring_history[i].getTableHeader());
            if (gAMParameters._keep_cross_validation_predictions) {
                Frame deepCopy = DKV.getGet(((GLMModel.GLMOutput) gLMModel._output)._cross_validation_predictions[i]).deepCopy(Key.make().toString());
                DKV.put(deepCopy);
                ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_predictions[i] = deepCopy.getKey();
            }
        }
        if (gAMParameters._keep_cross_validation_models) {
            ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_models = buildCVGamModels(gAMModel, gLMModel, gAMParameters, str);
        }
        if (gAMParameters._keep_cross_validation_predictions) {
            Frame deepCopy2 = DKV.getGet(((GLMModel.GLMOutput) gLMModel._output)._cross_validation_holdout_predictions_frame_id).deepCopy(Key.make().toString());
            DKV.put(deepCopy2);
            ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_holdout_predictions_frame_id = deepCopy2.getKey();
        }
        if (gAMParameters._keep_cross_validation_fold_assignment) {
            Frame deepCopy3 = DKV.getGet(((GLMModel.GLMOutput) gLMModel._output)._cross_validation_fold_assignment_frame_id).deepCopy(Key.make().toString());
            DKV.put(deepCopy3);
            ((GAMModel.GAMModelOutput) gAMModel._output)._cross_validation_fold_assignment_frame_id = deepCopy3.getKey();
        }
    }

    public static Key[] buildCVGamModels(GAMModel gAMModel, GLMModel gLMModel, GAMModel.GAMParameters gAMParameters, String str) {
        int length = ((GLMModel.GLMOutput) gLMModel._output)._cross_validation_models.length;
        Key[] keyArr = new Key[length];
        for (int i = 0; i < length; i++) {
            GLMModel get = DKV.getGet(((GLMModel.GLMOutput) gLMModel._output)._cross_validation_models[i]);
            GAMModel.GAMParameters makeGAMParameters = makeGAMParameters(gAMParameters);
            if (str != null) {
                if (makeGAMParameters._ignored_columns != null) {
                    ArrayList arrayList = new ArrayList(Arrays.asList(makeGAMParameters._ignored_columns));
                    arrayList.add(str);
                    makeGAMParameters._ignored_columns = (String[]) arrayList.toArray(new String[0]);
                } else {
                    makeGAMParameters._ignored_columns = new String[]{str};
                }
            }
            int i2 = makeGAMParameters._max_iterations;
            makeGAMParameters._max_iterations = 1;
            GAMModel gAMModel2 = new GAM(makeGAMParameters).trainModel().get();
            makeGAMParameters._max_iterations = i2;
            GAMModelUtils.copyGLMCoeffs(get, gAMModel2, makeGAMParameters, gAMModel._nclass);
            GAMModelUtils.copyGLMtoGAMModel(gAMModel2, get, gAMParameters, true);
            keyArr[i] = gAMModel2.getKey();
            DKV.put(gAMModel2);
        }
        return keyArr;
    }

    public static GAMModel.GAMParameters makeGAMParameters(GAMModel.GAMParameters gAMParameters) {
        GAMModel.GAMParameters gAMParameters2 = new GAMModel.GAMParameters();
        Field[] declaredFields = GAMModel.GAMParameters.class.getDeclaredFields();
        Field[] declaredFields2 = Model.Parameters.class.getDeclaredFields();
        setParamField(gAMParameters, gAMParameters2, false, declaredFields, Collections.emptyList());
        setParamField(gAMParameters, gAMParameters2, true, declaredFields2, Collections.emptyList());
        gAMParameters2._nfolds = 0;
        gAMParameters2._keep_cross_validation_predictions = false;
        gAMParameters2._keep_cross_validation_fold_assignment = false;
        gAMParameters2._keep_cross_validation_models = false;
        gAMParameters2._train = gAMParameters._train;
        return gAMParameters2;
    }

    /* JADX WARN: Code restructure failed: missing block: B:22:0x0040, code lost:
    
        if (r9.contains(r0.getName()) == false) goto L16;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static void setParamField(hex.Model.Parameters r5, hex.Model.Parameters r6, boolean r7, java.lang.reflect.Field[] r8, java.util.List<java.lang.String> r9) {
        /*
            r0 = r9
            if (r0 == 0) goto Lf
            r0 = r9
            int r0 = r0.size()
            if (r0 != 0) goto L13
        Lf:
            r0 = 1
            goto L14
        L13:
            r0 = 0
        L14:
            r11 = r0
            r0 = r8
            r12 = r0
            r0 = r12
            int r0 = r0.length
            r13 = r0
            r0 = 0
            r14 = r0
        L21:
            r0 = r14
            r1 = r13
            if (r0 >= r1) goto L80
            r0 = r12
            r1 = r14
            r0 = r0[r1]
            r15 = r0
            r0 = r11
            if (r0 != 0) goto L43
            r0 = r9
            r1 = r15
            java.lang.String r1 = r1.getName()     // Catch: java.lang.Throwable -> L78
            boolean r0 = r0.contains(r1)     // Catch: java.lang.Throwable -> L78
            if (r0 != 0) goto L75
        L43:
            r0 = r7
            if (r0 == 0) goto L5b
            r0 = r6
            java.lang.Class r0 = r0.getClass()     // Catch: java.lang.Throwable -> L78
            java.lang.Class r0 = r0.getSuperclass()     // Catch: java.lang.Throwable -> L78
            r1 = r15
            java.lang.String r1 = r1.getName()     // Catch: java.lang.Throwable -> L78
            java.lang.reflect.Field r0 = r0.getDeclaredField(r1)     // Catch: java.lang.Throwable -> L78
            r10 = r0
            goto L69
        L5b:
            r0 = r6
            java.lang.Class r0 = r0.getClass()     // Catch: java.lang.Throwable -> L78
            r1 = r15
            java.lang.String r1 = r1.getName()     // Catch: java.lang.Throwable -> L78
            java.lang.reflect.Field r0 = r0.getDeclaredField(r1)     // Catch: java.lang.Throwable -> L78
            r10 = r0
        L69:
            r0 = r10
            r1 = r6
            r2 = r15
            r3 = r5
            java.lang.Object r2 = r2.get(r3)     // Catch: java.lang.Throwable -> L78
            r0.set(r1, r2)     // Catch: java.lang.Throwable -> L78
        L75:
            goto L7a
        L78:
            r16 = move-exception
        L7a:
            int r14 = r14 + 1
            goto L21
        L80:
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: hex.gam.MatrixFrameUtils.GamUtils.setParamField(hex.Model$Parameters, hex.Model$Parameters, boolean, java.lang.reflect.Field[], java.util.List):void");
    }

    public static void keepFrameKeys(List<Key> list, Key<Frame>... keyArr) {
        for (Key<Frame> key : keyArr) {
            Frame get = DKV.getGet(key);
            if (get != null) {
                for (Vec vec : get.vecs()) {
                    list.add(vec._key);
                }
            }
        }
    }

    public static void setDefaultBSType(GAMModel.GAMParameters gAMParameters) {
        gAMParameters._bs = new int[gAMParameters._gam_columns.length];
        for (int i = 0; i < gAMParameters._bs.length; i++) {
            if (gAMParameters._gam_columns[i].length > 1) {
                gAMParameters._bs[i] = 1;
            } else {
                gAMParameters._bs[i] = 0;
            }
        }
    }

    public static void setThinPlateParameters(GAMModel.GAMParameters gAMParameters, int i) {
        int length = gAMParameters._gam_columns.length;
        gAMParameters._m = MemoryManager.malloc4(i);
        gAMParameters._M = MemoryManager.malloc4(i);
        int i2 = 0;
        for (int i3 = 0; i3 < length; i3++) {
            if (gAMParameters._bs[i3] == 1) {
                int length2 = gAMParameters._gam_columns[i3].length;
                gAMParameters._m[i2] = ThinPlateRegressionUtils.calculatem(length2);
                gAMParameters._M[i2] = ThinPlateRegressionUtils.calculateM(length2, gAMParameters._m[i2]);
                i2++;
            }
        }
    }

    public static void setGamPredSize(GAMModel.GAMParameters gAMParameters, int i) {
        int length = gAMParameters._gam_columns.length;
        int i2 = i;
        int i3 = 0;
        gAMParameters._gamPredSize = MemoryManager.malloc4(length);
        for (int i4 = 0; i4 < length; i4++) {
            if (gAMParameters._gam_columns[i4].length == 1) {
                int i5 = i3;
                i3++;
                gAMParameters._gamPredSize[i5] = 1;
            } else {
                int i6 = i2;
                i2++;
                gAMParameters._gamPredSize[i6] = gAMParameters._gam_columns[i4].length;
            }
        }
    }

    public static double[] generateKnotsOneColumn(Frame frame, int i) {
        double[] malloc8d = MemoryManager.malloc8d(i);
        try {
            Scope.enter();
            Frame frame2 = new Frame(frame);
            DKV.put(frame2);
            double[] malloc8d2 = MemoryManager.malloc8d(i);
            if (!$assertionsDisabled && i <= 1) {
                throw new AssertionError();
            }
            double d = 1.0d / (i - 1);
            for (int i2 = 0; i2 < i; i2++) {
                malloc8d2[i2] = i2 * d;
            }
            QuantileModel.QuantileParameters quantileParameters = new QuantileModel.QuantileParameters();
            quantileParameters._train = frame2._key;
            quantileParameters._probs = malloc8d2;
            QuantileModel quantileModel = new Quantile(quantileParameters).trainModel().get();
            DKV.remove(frame2._key);
            Scope.track_generic(quantileModel);
            System.arraycopy(quantileModel._output._quantiles[0], 0, malloc8d, 0, i);
            Scope.exit(new Key[0]);
            return malloc8d;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    public static Frame prepareGamVec(int i, GAMModel.GAMParameters gAMParameters, Frame frame) {
        Vec track = gAMParameters._weights_column == null ? Scope.track(Vec.makeOne(frame.numRows())) : frame.vec(gAMParameters._weights_column);
        Frame frame2 = new Frame(new Vec[0]);
        int length = gAMParameters._gam_columns_sorted[i].length;
        for (int i2 = 0; i2 < length; i2++) {
            frame2.add(gAMParameters._gam_columns_sorted[i][i2], frame.vec(gAMParameters._gam_columns_sorted[i][i2]));
        }
        frame2.add("weights_column", track);
        return frame2;
    }

    public static String[] generateGamColNames(int i, GAMModel.GAMParameters gAMParameters) {
        String[] strArr = gAMParameters._bs_sorted[i] == 0 ? new String[gAMParameters._num_knots_sorted[i]] : new String[(gAMParameters._num_knots_sorted[i] + gAMParameters._spline_orders_sorted[i]) - 2];
        String str = gAMParameters._gam_columns_sorted[i][0] + "_";
        String str2 = gAMParameters._bs_sorted[i] == 0 ? str + "cr_" : gAMParameters._bs_sorted[i] == 2 ? str + "is_" : str + "tp_";
        for (int i2 = 0; i2 < strArr.length; i2++) {
            strArr[i2] = str2 + i2;
        }
        return strArr;
    }

    public static String[] generateGamColNamesThinPlateKnots(int i, GAMModel.GAMParameters gAMParameters, int[][] iArr, String str) {
        int i2 = gAMParameters._num_knots_sorted[i];
        int length = iArr.length;
        String[] strArr = new String[i2 + length];
        for (int i3 = 0; i3 < i2; i3++) {
            strArr[i3] = str + i3;
        }
        for (int i4 = 0; i4 < length; i4++) {
            strArr[i4 + i2] = genPolyBasisNames(gAMParameters._gam_columns_sorted[i], iArr[i4]);
        }
        return strArr;
    }

    public static String genPolyBasisNames(String[] strArr, int[] iArr) {
        StringBuffer stringBuffer = new StringBuffer();
        int length = strArr.length;
        int i = length - 1;
        for (int i2 = 0; i2 < length; i2++) {
            stringBuffer.append(strArr[i2]);
            stringBuffer.append("_");
            stringBuffer.append(iArr[i2]);
            if (i2 < i) {
                stringBuffer.append("_");
            }
        }
        return stringBuffer.toString();
    }

    public static Frame buildGamFrame(GAMModel.GAMParameters gAMParameters, Frame frame, Key<Frame>[] keyArr, String str) {
        Vec remove = frame.remove(gAMParameters._response_column);
        List arrayList = gAMParameters._ignored_columns == null ? new ArrayList() : Arrays.asList(gAMParameters._ignored_columns);
        Vec remove2 = gAMParameters._offset_column != null ? frame.remove(gAMParameters._offset_column) : null;
        Vec remove3 = gAMParameters._weights_column != null ? frame.remove(gAMParameters._weights_column) : null;
        Vec remove4 = str != null ? frame.remove(str) : null;
        for (int i = 0; i < gAMParameters._gam_columns_sorted.length; i++) {
            Frame track = Scope.track(new Frame[]{(Frame) keyArr[i].get()});
            frame.add(track.names(), track.removeAll());
            if (arrayList.contains(gAMParameters._gam_columns_sorted[i])) {
                frame.remove(gAMParameters._gam_columns_sorted[i]);
            }
        }
        if (remove2 != null) {
            frame.add(gAMParameters._offset_column, remove2);
        }
        if (str != null) {
            frame.add(str, remove4);
        }
        if (remove3 != null) {
            frame.add(gAMParameters._weights_column, remove3);
        }
        if (remove != null) {
            frame.add(gAMParameters._response_column, remove);
        }
        return frame;
    }

    public static Frame concateGamVecs(Key<Frame>[] keyArr) {
        Frame frame = new Frame(Key.make());
        for (Key<Frame> key : keyArr) {
            Frame track = Scope.track(new Frame[]{(Frame) key.get()});
            frame.add(track.names(), track.removeAll());
        }
        return frame;
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [java.lang.String[], java.lang.String[][]] */
    public static void sortGAMParameters(GAMModel.GAMParameters gAMParameters, int i, int i2) {
        int length = gAMParameters._gam_columns.length;
        int i3 = 0;
        int i4 = i;
        int i5 = i + i2;
        gAMParameters._gam_columns_sorted = new String[length];
        gAMParameters._num_knots_sorted = MemoryManager.malloc4(length);
        gAMParameters._scale_sorted = MemoryManager.malloc8d(length);
        gAMParameters._bs_sorted = MemoryManager.malloc4(length);
        gAMParameters._gamPredSize = MemoryManager.malloc4(length);
        gAMParameters._spline_orders_sorted = MemoryManager.malloc4(length);
        for (int i6 = 0; i6 < length; i6++) {
            if (gAMParameters._bs[i6] == 0) {
                int i7 = i3;
                i3++;
                setGamParameters(gAMParameters, i6, i7);
            } else if (gAMParameters._bs[i6] == 2) {
                setGamParameters(gAMParameters, i6, i4);
                int i8 = i4;
                i4++;
                gAMParameters._spline_orders_sorted[i8] = gAMParameters._spline_orders[i6];
            } else {
                int i9 = i5;
                i5++;
                setGamParameters(gAMParameters, i6, i9);
            }
        }
    }

    public static void setGamParameters(GAMModel.GAMParameters gAMParameters, int i, int i2) {
        gAMParameters._gam_columns_sorted[i2] = (String[]) gAMParameters._gam_columns[i].clone();
        gAMParameters._num_knots_sorted[i2] = gAMParameters._num_knots[i];
        gAMParameters._scale_sorted[i2] = gAMParameters._scale[i];
        gAMParameters._gamPredSize[i2] = gAMParameters._gam_columns_sorted[i2].length;
        gAMParameters._bs_sorted[i2] = gAMParameters._bs[i];
    }

    public static void setDefaultScale(GAMModel.GAMParameters gAMParameters) {
        int length = gAMParameters._gam_columns.length;
        gAMParameters._scale = new double[length];
        for (int i = 0; i < length; i++) {
            gAMParameters._scale[i] = 1.0d;
        }
    }

    static {
        $assertionsDisabled = !GamUtils.class.desiredAssertionStatus();
    }
}
