package hex.glrm;

import hex.CreateFrame;
import hex.DataInfo;
import hex.ModelMetrics;
import hex.SplitFrame;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.genmodel.algos.glrm.GlrmRegularizer;
import hex.glrm.GLRMModel;
import hex.pca.PCA;
import hex.pca.PCAModel;
import hex.pca.PCAWideDataSetsTests;
import java.io.FileInputStream;
import java.util.Arrays;
import java.util.Random;
import java.util.TreeMap;
import java.util.concurrent.ExecutionException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.UploadFileVec;
import water.fvec.Vec;
import water.parser.ParseDataset;
import water.rapids.Rapids;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.FrameUtils;
import water.util.Log;

/* loaded from: input_file:hex/glrm/GLRMTest.class */
public class GLRMTest extends TestUtil {
    public final double TOLERANCE = 1.0E-6d;
    static final /* synthetic */ boolean $assertionsDisabled;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    public double errStddev(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            double d2 = dArr[i] - dArr2[i];
            d += d2 * d2;
        }
        return d;
    }

    public double errEigvec(double[][] dArr, double[][] dArr2) {
        return errEigvec(dArr, dArr2, 1.0E-6d);
    }

    public double errEigvec(double[][] dArr, double[][] dArr2, double d) {
        double d2 = 0.0d;
        for (int i = 0; i < dArr2[0].length; i++) {
            boolean z = Math.abs(dArr[0][i] - dArr2[0][i]) > d;
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                double d3 = dArr[i2][i] - (z ? -dArr2[i2][i] : dArr2[i2][i]);
                d2 += d3 * d3;
            }
        }
        return d2;
    }

    public static void checkLossbyCol(GLRMModel.GLRMParameters gLRMParameters, GLRMModel gLRMModel) {
        int i = gLRMModel._output._ncats;
        GlrmLoss[] glrmLossArr = gLRMModel._output._lossFunc;
        if (!$assertionsDisabled && (i < 0 || i > glrmLossArr.length)) {
            throw new AssertionError();
        }
        if (null == gLRMParameters._loss_by_col || null == gLRMParameters._loss_by_col_idx) {
            return;
        }
        Assert.assertEquals(gLRMParameters._loss_by_col.length, gLRMParameters._loss_by_col_idx.length);
        int[] iArr = new int[gLRMParameters._loss_by_col_idx.length];
        for (int i2 = 0; i2 < gLRMParameters._loss_by_col_idx.length; i2++) {
            int i3 = -1;
            int i4 = 0;
            while (true) {
                if (i4 >= gLRMModel._output._permutation.length) {
                    break;
                }
                if (gLRMModel._output._permutation[i4] == gLRMParameters._loss_by_col_idx[i2]) {
                    i3 = i4;
                    break;
                }
                i4++;
            }
            iArr[i2] = i3;
        }
        Arrays.sort(iArr);
        for (int i5 = 0; i5 < i; i5++) {
            int binarySearch = Arrays.binarySearch(iArr, i5);
            Assert.assertEquals(binarySearch >= 0 ? gLRMParameters._loss_by_col[binarySearch] : gLRMParameters._multi_loss, glrmLossArr[i5]);
        }
        for (int i6 = i; i6 < glrmLossArr.length; i6++) {
            int binarySearch2 = Arrays.binarySearch(iArr, i6);
            Assert.assertEquals(binarySearch2 >= 0 ? gLRMParameters._loss_by_col[binarySearch2] : gLRMParameters._loss, glrmLossArr[i6]);
        }
    }

    @Test
    @Ignore
    public void testSubset() throws InterruptedException, ExecutionException {
        RuntimeException runtimeException;
        GLRMModel gLRMModel = null;
        try {
            UploadFileVec.readPut("train", new FileInputStream(FileUtils.getFile("bigdata/laptop/census/ACS_13_5YR_DP02_cleaned.zip")), new UploadFileVec.ReadPutStats());
        } catch (Exception e) {
            e.printStackTrace();
        }
        ParseDataset.parse(Key.make("train_parsed"), new Key[]{Key.make("train")});
        Frame get = DKV.getGet("train_parsed");
        try {
            try {
                Log.info(new Object[]{"num chunks: ", Integer.valueOf(get.anyVec().nChunks())});
                Frame frame = new Frame(Key.make("acs_zcta_fr"), new String[]{"name"}, new Vec[]{get.vec(0).toCategoricalVec()});
                DKV.put(frame);
                get.remove(0).remove();
                DKV.put(get);
                GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
                gLRMParameters._train = get._key;
                gLRMParameters._gamma_x = 0.25d;
                gLRMParameters._gamma_y = 0.5d;
                gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
                gLRMParameters._regularization_y = GlrmRegularizer.L1;
                gLRMParameters._k = 10;
                gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
                gLRMParameters._max_iterations = 1;
                gLRMParameters._loss = GlrmLoss.Quadratic;
                try {
                    try {
                        Scope.enter();
                        GLRMModel gLRMModel2 = new GLRM(gLRMParameters).trainModel().get();
                        Rapids.exec("(tmp= py_4 (rows (cols_py " + gLRMModel2._output._representation_key + " [0 1]) (tmp= py_3 (| (| (| (| (| (== (tmp= py_2 " + frame._key + ") \"10065\") (== py_2 \"11219\")) (== py_2 \"66753\")) (== py_2 \"84104\")) (== py_2 \"94086\")) (== py_2 \"95014\")))))");
                        frame.delete();
                        Scope.exit(new Key[0]);
                        if (get != null) {
                            get.delete();
                        }
                        if (gLRMModel2 != null) {
                            gLRMModel2.delete();
                        }
                    } catch (Throwable th) {
                        frame.delete();
                        Scope.exit(new Key[0]);
                        throw th;
                    }
                } finally {
                }
            } catch (Throwable th2) {
                if (get != null) {
                    get.delete();
                }
                if (0 != 0) {
                    gLRMModel.delete();
                }
                throw th2;
            }
        } finally {
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test
    public void testArrests() throws InterruptedException, ExecutionException {
        Frame frame = ArrayUtils.frame(ard(new double[]{ard(new double[]{1.24256408d, 0.7828393d, -0.5209066d, -0.003416473d}), ard(new double[]{0.50786248d, 1.1068225d, -1.2117642d, 2.484202941d}), ard(new double[]{0.07163341d, 1.4788032d, 0.9989801d, 1.042878388d})}));
        GLRMModel gLRMModel = null;
        Frame frame2 = null;
        try {
            frame2 = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame2._key;
            gLRMParameters._gamma_y = 0.5d;
            gLRMParameters._gamma_x = 0.5d;
            gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
            gLRMParameters._regularization_y = GlrmRegularizer.Quadratic;
            gLRMParameters._k = 3;
            gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
            gLRMParameters._init = GlrmInitialization.User;
            gLRMParameters._recover_svd = false;
            gLRMParameters._user_y = frame._key;
            gLRMParameters._seed = 1234L;
            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
            gLRMModel.score(frame2).delete();
            ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame2);
            Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
            frame.delete();
            if (frame2 != null) {
                frame2.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
        } catch (Throwable th) {
            frame.delete();
            if (frame2 != null) {
                frame2.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testBenignSVD() throws InterruptedException, ExecutionException {
        GLRMModel gLRMModel = null;
        Frame frame = null;
        try {
            frame = parse_test_file(Key.make("benign.hex"), "smalldata/logreg/benign.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._k = 10;
            gLRMParameters._gamma_y = 0.25d;
            gLRMParameters._gamma_x = 0.25d;
            gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
            gLRMParameters._regularization_y = GlrmRegularizer.Quadratic;
            gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
            gLRMParameters._init = GlrmInitialization.SVD;
            gLRMParameters._min_step_size = 1.0E-5d;
            gLRMParameters._recover_svd = false;
            gLRMParameters._max_iterations = 2000;
            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
            gLRMModel.score(frame).delete();
            ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame);
            Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    @Test
    public void testArrestsSVD() throws InterruptedException, ExecutionException {
        Frame frame = ArrayUtils.frame(ard(new double[]{ard(new double[]{1.24256408d, 0.7828393d, -0.5209066d, -0.003416473d}), ard(new double[]{0.50786248d, 1.1068225d, -1.2117642d, 2.484202941d}), ard(new double[]{0.07163341d, 1.4788032d, 0.9989801d, 1.042878388d}), ard(new double[]{0.23234938d, 0.230868d, -1.0735927d, -0.184916602d})}));
        double[] dArr = {11.024148d, 6.964086d, 4.179904d, 2.915146d};
        ard(new double[]{ard(new double[]{-0.5358995d, 0.4181809d, -0.3412327d, 0.6492278d}), ard(new double[]{-0.5831836d, 0.1879856d, -0.2681484d, -0.74340748d}), ard(new double[]{-0.2781909d, -0.8728062d, -0.3780158d, 0.13387773d}), ard(new double[]{-0.5434321d, -0.1673186d, 0.8177779d, 0.08902432d})});
        GLRMModel gLRMModel = null;
        Frame frame2 = null;
        try {
            frame2 = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame2._key;
            gLRMParameters._k = 4;
            gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
            gLRMParameters._init = GlrmInitialization.User;
            gLRMParameters._user_y = frame._key;
            gLRMParameters._max_iterations = 1000;
            gLRMParameters._min_step_size = 1.0E-8d;
            gLRMParameters._recover_svd = true;
            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
            gLRMModel.score(frame2).delete();
            ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame2);
            Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
            Assert.assertEquals(gLRMModel._output._objective, fromDKV._numerr, 1.0E-6d);
            frame.delete();
            if (frame2 != null) {
                frame2.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
        } catch (Throwable th) {
            frame.delete();
            if (frame2 != null) {
                frame2.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testArrestsPlusPlus() throws InterruptedException, ExecutionException {
        GLRMModel gLRMModel = null;
        Frame frame = null;
        try {
            frame = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._k = 4;
            gLRMParameters._loss = GlrmLoss.Huber;
            gLRMParameters._regularization_x = GlrmRegularizer.NonNegative;
            gLRMParameters._regularization_y = GlrmRegularizer.NonNegative;
            gLRMParameters._gamma_y = 1.0d;
            gLRMParameters._gamma_x = 1.0d;
            gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
            gLRMParameters._init = GlrmInitialization.PlusPlus;
            gLRMParameters._max_iterations = 100;
            gLRMParameters._min_step_size = 1.0E-8d;
            gLRMParameters._recover_svd = true;
            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    @Test
    public void testArrestsMissing() throws InterruptedException, ExecutionException {
        double[] dArr = {11.024148d, 6.964086d, 4.179904d, 2.915146d};
        double[][] ard = ard(new double[]{ard(new double[]{-0.5358995d, 0.4181809d, -0.3412327d, 0.6492278d}), ard(new double[]{-0.5831836d, 0.1879856d, -0.2681484d, -0.74340748d}), ard(new double[]{-0.2781909d, -0.8728062d, -0.3780158d, 0.13387773d}), ard(new double[]{-0.5434321d, -0.1673186d, 0.8177779d, 0.08902432d})});
        Frame frame = null;
        GLRMModel gLRMModel = null;
        TreeMap treeMap = new TreeMap();
        TreeMap treeMap2 = new TreeMap();
        StringBuilder sb = new StringBuilder();
        for (double d : new double[]{0.0d, 0.1d, 0.25d, 0.5d, 0.75d, 0.9d}) {
            try {
                Scope.enter();
                frame = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
                if (d > 0.0d) {
                    Frame frame2 = new Frame(Key.make(), frame.names(), frame.vecs());
                    DKV.put(frame2._key, frame2);
                    new FrameUtils.MissingInserter(frame2._key, 1234L, d).execImpl().get();
                    DKV.remove(frame2._key);
                }
                GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
                gLRMParameters._train = frame._key;
                gLRMParameters._k = frame.numCols();
                gLRMParameters._loss = GlrmLoss.Quadratic;
                gLRMParameters._regularization_x = GlrmRegularizer.None;
                gLRMParameters._regularization_y = GlrmRegularizer.None;
                gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
                gLRMParameters._init = GlrmInitialization.PlusPlus;
                gLRMParameters._max_iterations = 1000;
                gLRMParameters._seed = 1234L;
                gLRMParameters._recover_svd = true;
                gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                Log.info(new Object[]{(100.0d * d) + "% missing values: Objective = " + gLRMModel._output._objective});
                double errStddev = errStddev(dArr, gLRMModel._output._singular_vals) / gLRMParameters._k;
                double errEigvec = errEigvec(ard, gLRMModel._output._eigenvectors_raw) / gLRMParameters._k;
                Log.info(new Object[]{"Avg SSE in Std Dev = " + errStddev + "\tAvg SSE in Eigenvectors = " + errEigvec});
                treeMap.put(Double.valueOf(d), Double.valueOf(errStddev));
                treeMap2.put(Double.valueOf(d), Double.valueOf(errEigvec));
                gLRMModel.score(frame).delete();
                ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame);
                Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
                Assert.assertEquals(gLRMModel._output._objective, fromDKV._numerr, 1.0E-6d);
                Scope.exit(new Key[0]);
                if (frame != null) {
                    frame.delete();
                }
                if (gLRMModel != null) {
                    gLRMModel.delete();
                }
            } catch (Throwable th) {
                if (frame != null) {
                    frame.delete();
                }
                if (gLRMModel != null) {
                    gLRMModel.delete();
                }
                throw th;
            }
        }
        sb.append("\nMissing Fraction --> Avg SSE in Std Dev\n");
        for (String str : Arrays.toString(treeMap.entrySet().toArray()).split(",")) {
            sb.append(str.replace("=", " --> ")).append("\n");
        }
        sb.append("\n");
        sb.append("Missing Fraction --> Avg SSE in Eigenvectors\n");
        for (String str2 : Arrays.toString(treeMap2.entrySet().toArray()).split(",")) {
            sb.append(str2.replace("=", " --> ")).append("\n");
        }
        Log.info(new Object[]{sb.toString()});
    }

    @Test
    public void testSetColumnLoss() throws InterruptedException, ExecutionException {
        GLRMModel gLRMModel = null;
        Frame frame = null;
        try {
            frame = parse_test_file(Key.make("benign.hex"), "smalldata/logreg/benign.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._k = 12;
            gLRMParameters._loss = GlrmLoss.Quadratic;
            gLRMParameters._loss_by_col = new GlrmLoss[]{GlrmLoss.Absolute, GlrmLoss.Huber};
            gLRMParameters._loss_by_col_idx = new int[]{2, 5};
            gLRMParameters._transform = DataInfo.TransformType.STANDARDIZE;
            gLRMParameters._init = GlrmInitialization.PlusPlus;
            gLRMParameters._min_step_size = 1.0E-5d;
            gLRMParameters._recover_svd = false;
            gLRMParameters._max_iterations = 2000;
            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
            checkLossbyCol(gLRMParameters, gLRMModel);
            gLRMModel.score(frame).delete();
            ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame);
            Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (gLRMModel != null) {
                gLRMModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void testGLRMPredMojo() {
        try {
            Scope.enter();
            CreateFrame createFrame = new CreateFrame();
            Random random = new Random();
            int nextInt = random.nextInt(10000) + 50000;
            int nextInt2 = random.nextInt(17) + 3;
            createFrame.rows = nextInt;
            createFrame.cols = nextInt2;
            createFrame.binary_fraction = 0.0d;
            createFrame.string_fraction = 0.0d;
            createFrame.time_fraction = 0.0d;
            createFrame.has_response = false;
            createFrame.positive_response = true;
            createFrame.missing_fraction = 0.1d;
            createFrame.seed = System.currentTimeMillis();
            System.out.println("Createframe parameters: rows: " + nextInt + " cols:" + nextInt2 + " seed: " + createFrame.seed);
            SplitFrame splitFrame = new SplitFrame(Scope.track(new Frame[]{(Frame) createFrame.execImpl().get()}), new double[]{1.0d - 0.2d, 0.2d}, new Key[]{Key.make("train.hex"), Key.make("test.hex")});
            splitFrame.exec().get();
            Key[] keyArr = splitFrame._destination_frames;
            Frame frame = DKV.get(keyArr[0]).get();
            Frame frame2 = DKV.get(keyArr[1]).get();
            Scope.track(new Frame[]{frame});
            Scope.track(new Frame[]{frame2});
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._k = 3;
            gLRMParameters._loss = GlrmLoss.Quadratic;
            gLRMParameters._init = GlrmInitialization.SVD;
            gLRMParameters._max_iterations = 10;
            gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
            gLRMParameters._gamma_x = 0.0d;
            gLRMParameters._gamma_y = 0.0d;
            gLRMParameters._seed = createFrame.seed;
            GLRMModel gLRMModel = new GLRM(gLRMParameters).trainModel().get();
            Scope.track_generic(gLRMModel);
            Frame frame3 = DKV.get(gLRMModel.gen_representation_key(frame)).get();
            Scope.track(new Frame[]{frame3});
            Frame score = gLRMModel.score(frame);
            Scope.track(new Frame[]{score});
            Assert.assertEquals(score.numRows(), frame3.numRows());
            Frame score2 = gLRMModel.score(frame2);
            Scope.track(new Frame[]{score2});
            Frame frame4 = DKV.get(gLRMModel.gen_representation_key(frame2)).get();
            Scope.track(new Frame[]{frame4});
            Assert.assertEquals(score2.numRows(), frame4.numRows());
            Assert.assertTrue(gLRMModel.testJavaScoring(frame2, score2, 1.0E-6d, 1.0E-6d, 1.0d));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    @Test
    public void testRegularizers() throws InterruptedException, ExecutionException {
        Frame frame = ArrayUtils.frame(ard(new double[]{ard(new double[]{13.2d, 236.0d, 58.0d, 21.2d}), ard(new double[]{10.0d, 263.0d, 48.0d, 44.5d}), ard(new double[]{8.1d, 294.0d, 80.0d, 31.0d}), ard(new double[]{8.8d, 190.0d, 50.0d, 19.5d})}));
        GLRMModel gLRMModel = null;
        Frame frame2 = null;
        try {
            Scope.enter();
            frame2 = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame2._key;
            gLRMParameters._k = 4;
            gLRMParameters._init = GlrmInitialization.User;
            gLRMParameters._user_y = frame._key;
            gLRMParameters._transform = DataInfo.TransformType.NONE;
            gLRMParameters._recover_svd = false;
            gLRMParameters._max_iterations = 1000;
            gLRMParameters._seed = 1234L;
            Log.info(new Object[]{"\nNon-negative matrix factorization"});
            gLRMParameters._gamma_y = 1.0d;
            gLRMParameters._gamma_x = 1.0d;
            gLRMParameters._regularization_x = GlrmRegularizer.NonNegative;
            gLRMParameters._regularization_y = GlrmRegularizer.NonNegative;
            try {
                gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
                Log.info(new Object[]{"Archetypes:\n" + gLRMModel._output._archetypes.toString()});
                gLRMModel.score(frame2).delete();
                ModelMetricsGLRM fromDKV = ModelMetrics.getFromDKV(gLRMModel, frame2);
                Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV._numerr + "\tCategorical Misclassification Error = " + fromDKV._caterr});
                if (gLRMModel != null) {
                    gLRMModel.delete();
                }
                Log.info(new Object[]{"\nOrthogonal non-negative matrix factorization"});
                gLRMParameters._gamma_y = 1.0d;
                gLRMParameters._gamma_x = 1.0d;
                gLRMParameters._regularization_x = GlrmRegularizer.OneSparse;
                gLRMParameters._regularization_y = GlrmRegularizer.NonNegative;
                try {
                    gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                    Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
                    Log.info(new Object[]{"Archetypes:\n" + gLRMModel._output._archetypes.toString()});
                    gLRMModel.score(frame2).delete();
                    ModelMetricsGLRM fromDKV2 = ModelMetrics.getFromDKV(gLRMModel, frame2);
                    Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV2._numerr + "\tCategorical Misclassification Error = " + fromDKV2._caterr});
                    if (gLRMModel != null) {
                        gLRMModel.delete();
                    }
                    Log.info(new Object[]{"\nQuadratic clustering (k-means)"});
                    gLRMParameters._gamma_x = 1.0d;
                    gLRMParameters._gamma_y = 0.0d;
                    gLRMParameters._regularization_x = GlrmRegularizer.UnitOneSparse;
                    gLRMParameters._regularization_y = GlrmRegularizer.None;
                    try {
                        gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                        Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
                        Log.info(new Object[]{"Archetypes:\n" + gLRMModel._output._archetypes.toString()});
                        gLRMModel.score(frame2).delete();
                        ModelMetricsGLRM fromDKV3 = ModelMetrics.getFromDKV(gLRMModel, frame2);
                        Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV3._numerr + "\tCategorical Misclassification Error = " + fromDKV3._caterr});
                        if (gLRMModel != null) {
                            gLRMModel.delete();
                        }
                        Log.info(new Object[]{"\nQuadratic mixture (soft k-means)"});
                        gLRMParameters._gamma_x = 1.0d;
                        gLRMParameters._gamma_y = 0.0d;
                        gLRMParameters._regularization_x = GlrmRegularizer.UnitOneSparse;
                        gLRMParameters._regularization_y = GlrmRegularizer.None;
                        try {
                            gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                            Log.info(new Object[]{"Iteration " + gLRMModel._output._iterations + ": Objective value = " + gLRMModel._output._objective});
                            Log.info(new Object[]{"Archetypes:\n" + gLRMModel._output._archetypes.toString()});
                            gLRMModel.score(frame2).delete();
                            ModelMetricsGLRM fromDKV4 = ModelMetrics.getFromDKV(gLRMModel, frame2);
                            Log.info(new Object[]{"Numeric Sum of Squared Error = " + fromDKV4._numerr + "\tCategorical Misclassification Error = " + fromDKV4._caterr});
                            if (gLRMModel != null) {
                                gLRMModel.delete();
                            }
                            frame.delete();
                            if (frame2 != null) {
                                frame2.delete();
                            }
                            Scope.exit(new Key[0]);
                        } finally {
                            if (gLRMModel != null) {
                                gLRMModel.delete();
                            }
                        }
                    } finally {
                        if (gLRMModel != null) {
                            gLRMModel.delete();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Throwable th) {
            frame.delete();
            if (frame2 != null) {
                frame2.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v8, types: [double[], double[][]] */
    @Test
    public void testArrestsVarianceMetrics() throws InterruptedException, ExecutionException {
        double[] dArr = {83.7324d, 14.212402d, 6.489426d, 2.48279d};
        double[][] ard = ard(new double[]{ard(new double[]{0.04170432d, -0.04482166d, 0.07989066d, -0.99492173d}), ard(new double[]{0.99522128d, -0.05876003d, -0.06756974d, 0.0389383d}), ard(new double[]{0.04633575d, 0.97685748d, -0.20054629d, -0.05816914d}), ard(new double[]{0.0751555d, 0.20071807d, 0.97408059d, 0.07232502d})});
        double[] dArr2 = {1.5748783d, 0.9948694d, 0.5971291d, 0.4164494d};
        double[][] ard2 = ard(new double[]{ard(new double[]{-0.5358995d, 0.4181809d, -0.3412327d, 0.6492278d}), ard(new double[]{-0.5831836d, 0.1879856d, -0.2681484d, -0.74340748d}), ard(new double[]{-0.2781909d, -0.8728062d, -0.3780158d, 0.13387773d}), ard(new double[]{-0.5434321d, -0.1673186d, 0.8177779d, 0.08902432d})});
        Frame frame = null;
        PCAModel pCAModel = null;
        GLRMModel gLRMModel = null;
        try {
            frame = parse_test_file(Key.make("arrests.hex"), "smalldata/pca_test/USArrests.csv");
            for (DataInfo.TransformType transformType : new DataInfo.TransformType[]{DataInfo.TransformType.DEMEAN, DataInfo.TransformType.STANDARDIZE}) {
                try {
                    PCAModel.PCAParameters pCAParameters = new PCAModel.PCAParameters();
                    pCAParameters._train = frame._key;
                    pCAParameters._k = 4;
                    pCAParameters._transform = transformType;
                    pCAParameters._max_iterations = 1000;
                    pCAParameters._pca_method = PCAModel.PCAParameters.Method.Power;
                    pCAModel = (PCAModel) new PCA(pCAParameters).trainModel().get();
                    GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
                    gLRMParameters._train = frame._key;
                    gLRMParameters._k = 4;
                    gLRMParameters._transform = transformType;
                    gLRMParameters._loss = GlrmLoss.Quadratic;
                    gLRMParameters._init = GlrmInitialization.SVD;
                    gLRMParameters._max_iterations = 2000;
                    gLRMParameters._gamma_x = 0.0d;
                    gLRMParameters._gamma_y = 0.0d;
                    gLRMParameters._recover_svd = true;
                    gLRMModel = (GLRMModel) new GLRM(gLRMParameters).trainModel().get();
                    if (!$assertionsDisabled && gLRMModel == null) {
                        throw new AssertionError();
                    }
                    pCAModel._output._importance.getCellValues();
                    gLRMModel._output._importance.getCellValues();
                    if (transformType == DataInfo.TransformType.DEMEAN) {
                        TestUtil.checkStddev(dArr, pCAModel._output._std_deviation, 1.0E-6d);
                        TestUtil.checkEigvec(ard, pCAModel._output._eigenvectors, 1.0E-6d);
                    } else if (transformType == DataInfo.TransformType.STANDARDIZE) {
                        TestUtil.checkStddev(dArr2, pCAModel._output._std_deviation, 1.0E-6d);
                        TestUtil.checkEigvec(ard2, pCAModel._output._eigenvectors, 1.0E-6d);
                    }
                    TestUtil.checkIcedArrays(pCAModel._output._importance.getCellValues(), gLRMModel._output._importance.getCellValues(), 1.0E-6d);
                    if (pCAModel != null) {
                        pCAModel.delete();
                    }
                    if (gLRMModel != null) {
                        gLRMModel.delete();
                    }
                } catch (Throwable th) {
                    if (pCAModel != null) {
                        pCAModel.delete();
                    }
                    if (gLRMModel != null) {
                        gLRMModel.delete();
                    }
                    throw th;
                }
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th2) {
            if (frame != null) {
                frame.delete();
            }
            throw th2;
        }
    }

    @Test
    public void testWideDataSetGLRMCat() throws InterruptedException, ExecutionException {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file(Key.make("Prostrate_CAT"), PCAWideDataSetsTests._prostateDataset);
            parse_test_file.vec(0).setNA(0L);
            parse_test_file.vec(3).setNA(10L);
            parse_test_file.vec(5).setNA(20L);
            Scope.track(new Frame[]{parse_test_file});
            DKV.put(parse_test_file);
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = parse_test_file._key;
            gLRMParameters._k = 3;
            gLRMParameters._transform = DataInfo.TransformType.DEMEAN;
            gLRMParameters._seed = 12345L;
            gLRMParameters._gamma_x = 1.0d;
            gLRMParameters._gamma_y = 0.5d;
            if (!Arrays.asList(parse_test_file.typesStr()).contains("Enum")) {
                gLRMParameters._init = GlrmInitialization.SVD;
            }
            gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
            gLRMParameters._regularization_y = GlrmRegularizer.Quadratic;
            gLRMParameters._recover_svd = true;
            GLRMModel gLRMModel = new GLRM(gLRMParameters).trainModel().get();
            Frame score = gLRMModel.score(parse_test_file);
            Scope.track(new Frame[]{score});
            Scope.track_generic(gLRMModel);
            GLRM glrm = new GLRM(gLRMParameters);
            glrm.setWideDataset(true);
            GLRMModel gLRMModel2 = glrm.trainModel().get();
            Frame score2 = gLRMModel2.score(parse_test_file);
            Scope.track(new Frame[]{score2});
            Scope.track_generic(gLRMModel2);
            TestUtil.checkStddev(gLRMModel2._output._std_deviation, gLRMModel._output._std_deviation, 1.0E-6d);
            Assert.assertTrue(Arrays.equals(TestUtil.checkEigvec(gLRMModel._output._archetypes_raw._archetypes, gLRMModel2._output._archetypes_raw._archetypes, 1.0E-6d), TestUtil.checkEigvec(gLRMModel2._output._eigenvectors, gLRMModel._output._eigenvectors, 1.0E-6d)));
            Assert.assertTrue(TestUtil.isIdenticalUpToRelTolerance(score2, score, 1.0E-6d));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testWideDataSetGLRMDec() throws InterruptedException, ExecutionException {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file(Key.make("deacathlon"), "smalldata/pca_test/decathlon.csv");
            Scope.track(new Frame[]{parse_test_file});
            DKV.put(parse_test_file);
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = parse_test_file._key;
            gLRMParameters._k = 3;
            gLRMParameters._transform = DataInfo.TransformType.NONE;
            gLRMParameters._seed = 12345L;
            gLRMParameters._gamma_x = 1.0d;
            gLRMParameters._gamma_y = 0.5d;
            if (!Arrays.asList(parse_test_file.typesStr()).contains("Enum")) {
                gLRMParameters._init = GlrmInitialization.SVD;
            }
            gLRMParameters._regularization_x = GlrmRegularizer.Quadratic;
            gLRMParameters._regularization_y = GlrmRegularizer.Quadratic;
            gLRMParameters._recover_svd = true;
            GLRMModel gLRMModel = new GLRM(gLRMParameters).trainModel().get();
            Frame score = gLRMModel.score(parse_test_file);
            Scope.track(new Frame[]{score});
            Scope.track_generic(gLRMModel);
            GLRM glrm = new GLRM(gLRMParameters);
            glrm.setWideDataset(true);
            GLRMModel gLRMModel2 = glrm.trainModel().get();
            Frame score2 = gLRMModel2.score(parse_test_file);
            Scope.track(new Frame[]{score2});
            Scope.track_generic(gLRMModel2);
            TestUtil.checkStddev(gLRMModel2._output._std_deviation, gLRMModel._output._std_deviation, 1.0E-6d);
            Assert.assertTrue(Arrays.equals(TestUtil.checkEigvec(gLRMModel._output._archetypes_raw._archetypes, gLRMModel2._output._archetypes_raw._archetypes, 1.0E-6d), TestUtil.checkEigvec(gLRMModel2._output._eigenvectors, gLRMModel._output._eigenvectors, 1.0E-6d)));
            Assert.assertTrue(TestUtil.isIdenticalUpToRelTolerance(score2, score, 1.0E-6d));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    static {
        $assertionsDisabled = !GLRMTest.class.desiredAssertionStatus();
    }
}
