package hex.glrm;

import hex.DataInfo;
import hex.genmodel.algos.glrm.GlrmInitialization;
import hex.genmodel.algos.glrm.GlrmLoss;
import hex.glrm.GLRMModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/glrm/GLRMGridTest.class */
public class GLRMGridTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testMultipleGridInvocation() {
        Grid grid = null;
        Frame frame = null;
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.glrm.GLRMGridTest.1
                {
                    put("_k", new Integer[]{2, 4});
                    put("_transform", new DataInfo.TransformType[]{DataInfo.TransformType.NONE, DataInfo.TransformType.DEMEAN});
                }
            };
            String[] strArr = (String[]) hashMap.keySet().toArray(new String[hashMap.size()]);
            Arrays.sort(strArr);
            int crossProductSize = ArrayUtils.crossProductSize(hashMap);
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._seed = 4224L;
            gLRMParameters._loss = GlrmLoss.Absolute;
            gLRMParameters._init = GlrmInitialization.SVD;
            Key[] keyArr = new Key[2];
            Key make = Key.make("GLRM_grid_iris" + Key.rand());
            for (int i = 0; i < 2; i++) {
                grid = (Grid) GridSearch.startGridSearch(make, gLRMParameters, hashMap).get();
                keyArr[i] = grid.getModelKeys();
                Assert.assertEquals("Size of grid should match to size of hyper space", crossProductSize, grid.getModelCount() + grid.getFailures().getFailureCount());
                String[] hyperNames = grid.getHyperNames();
                Arrays.sort(hyperNames);
                Assert.assertArrayEquals("Hyper parameters names should match!", strArr, hyperNames);
            }
            Assert.assertArrayEquals("The model keys should be same between two iterations!", keyArr[0], keyArr[1]);
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }

    @Test
    public void testGridAppend() {
        Grid grid = null;
        Frame frame = null;
        try {
            frame = parse_test_file("smalldata/iris/iris_wheader.csv");
            HashMap<String, Object[]> hashMap = new HashMap<String, Object[]>() { // from class: hex.glrm.GLRMGridTest.2
                {
                    put("_k", new Integer[]{2, 4});
                    put("_transform", new DataInfo.TransformType[]{DataInfo.TransformType.NONE, DataInfo.TransformType.DEMEAN});
                }
            };
            String[] strArr = (String[]) hashMap.keySet().toArray(new String[hashMap.size()]);
            Arrays.sort(strArr);
            int crossProductSize = ArrayUtils.crossProductSize(hashMap);
            GLRMModel.GLRMParameters gLRMParameters = new GLRMModel.GLRMParameters();
            gLRMParameters._train = frame._key;
            gLRMParameters._seed = 4224L;
            gLRMParameters._loss = GlrmLoss.Absolute;
            gLRMParameters._init = GlrmInitialization.SVD;
            Key make = Key.make("GLRM_grid_iris" + Key.rand());
            Grid grid2 = GridSearch.startGridSearch(make, gLRMParameters, hashMap).get();
            Assert.assertEquals("Size of grid should match to size of hyper space", crossProductSize, grid2.getModelCount() + grid2.getFailures().getFailureCount());
            String[] hyperNames = grid2.getHyperNames();
            Arrays.sort(hyperNames);
            Assert.assertArrayEquals("Hyper parameters names should match!", strArr, hyperNames);
            hashMap.put("_k", new Integer[]{3});
            String[] strArr2 = (String[]) hashMap.keySet().toArray(new String[hashMap.size()]);
            Arrays.sort(strArr2);
            int crossProductSize2 = ArrayUtils.crossProductSize(hashMap);
            Assert.assertArrayEquals("Names of hyperspaces should be same!", strArr, strArr2);
            grid = (Grid) GridSearch.startGridSearch(make, gLRMParameters, hashMap).get();
            Assert.assertEquals("Size of grid should match to size of hyper space", crossProductSize + crossProductSize2, grid.getModelCount() + grid.getFailures().getFailureCount());
            String[] hyperNames2 = grid.getHyperNames();
            Arrays.sort(hyperNames2);
            Assert.assertArrayEquals("Hyper parameters names should match!", strArr2, hyperNames2);
            HashSet hashSet = new HashSet(grid.getModelCount());
            for (Key key : grid.getModelKeys()) {
                hashSet.add(key.toString());
            }
            Assert.assertEquals("Model names should be unique!", grid.getModelCount(), hashSet.size());
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            throw th;
        }
    }
}
