/*
 * Decompiled with CFR 0.152.
 */
package hex.kmeans;

import hex.Model;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.kmeans.KMeans;
import hex.kmeans.KMeansModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;

public class KMeansGridTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        KMeansGridTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testIrisGrid() {
        Grid grid = null;
        Frame fr = null;
        try {
            Model[] ms;
            fr = KMeansGridTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>();
            Object[] legalKOpts = new Integer[]{1, 2, 3, 4, 5};
            Object[] illegalKOpts = new Integer[]{0};
            hyperParms.put("_k", ArrayUtils.join((Object[])legalKOpts, (Object[])illegalKOpts));
            hyperParms.put("_init", new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.PlusPlus, KMeans.Initialization.Furthest});
            hyperParms.put("_seed", new Long[]{1L, 123456789L, 987654321L});
            Object[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
            Arrays.sort(hyperParamNames);
            int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
            KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
            params._train = fr._key;
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals((String)"Size of grid should match to size of hyper space", (long)hyperSpaceSize, (long)(grid.getModelCount() + failures.getFailureCount()));
            Object[] gridHyperNames = grid.getHyperNames();
            Arrays.sort(gridHyperNames);
            Assert.assertArrayEquals((String)"Hyper parameters names should match!", (Object[])hyperParamNames, (Object[])gridHyperNames);
            Map<String, Set<Object>> usedModelParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Model m : ms = grid.getModels()) {
                KMeansModel kmm = (KMeansModel)m;
                System.out.println(((KMeansModel.KMeansOutput)kmm._output)._tot_withinss + " " + Arrays.deepToString((Object[])ArrayUtils.zip((Object[])grid.getHyperNames(), (Object[])grid.getHyperValues(kmm._parms))));
                GridTestUtils.extractParams(usedModelParams, kmm._parms, (String[])hyperParamNames);
            }
            hyperParms.put("_k", legalKOpts);
            GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hyperParms, usedModelParams);
            Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Model.Parameters failedParams : failures.getFailedParameters()) {
                GridTestUtils.extractParams(failedHyperParams, (KMeansModel.KMeansParameters)failedParams, (String[])hyperParamNames);
            }
            hyperParms.put("_k", illegalKOpts);
            GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hyperParms, failedHyperParams);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDuplicatesCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        try {
            fr = KMeansGridTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            fr.remove("class").remove();
            DKV.put((Keyed)fr);
            HashMap<String, Number[]> hyperParms = new HashMap<String, Number[]>();
            hyperParms.put("_k", new Integer[]{3, 3, 3});
            hyperParms.put("_init", (Number[])new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.Random, KMeans.Initialization.Random});
            hyperParms.put("_seed", new Long[]{123456789L, 123456789L, 123456789L});
            KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
            params._train = fr._key;
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            Model[] models = grid.getModels();
            Assert.assertTrue((String)"Number of returned models has to be > 0", (models.length > 0 ? 1 : 0) != 0);
            Key modelKey = models[0]._key;
            for (Model m : models) {
                Assert.assertTrue((String)"Number of constructed models has to be equal to 1", (modelKey == m._key ? 1 : 0) != 0);
            }
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testUserPointsCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        Frame init = ArrayUtils.frame((double[][])KMeansGridTest.ard((double[][])new double[][]{KMeansGridTest.ard((double[])new double[]{5.0, 3.4, 1.5, 0.2}), KMeansGridTest.ard((double[])new double[]{7.0, 3.2, 4.7, 1.4}), KMeansGridTest.ard((double[])new double[]{6.5, 3.0, 5.8, 2.2})}));
        try {
            fr = KMeansGridTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            fr.remove("class").remove();
            DKV.put((Keyed)fr);
            HashMap<String, Number[]> hyperParms = new HashMap<String, Number[]>();
            hyperParms.put("_k", new Integer[]{3});
            hyperParms.put("_init", (Number[])new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.PlusPlus, KMeans.Initialization.User, KMeans.Initialization.Furthest});
            hyperParms.put("_seed", new Long[]{123456789L});
            KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
            params._train = fr._key;
            params._user_points = init._key;
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            Integer numModels = grid.getModels().length;
            System.out.println("Grid consists of " + numModels + " models");
            Assert.assertTrue((numModels == 4 ? 1 : 0) != 0);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (init != null) {
                init.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Ignore(value="PUBDEV-1675")
    public void testRandomCarsGrid() {
        Grid grid = null;
        KMeansModel kmRebuilt = null;
        Frame fr = null;
        Frame init = ArrayUtils.frame((double[][])KMeansGridTest.ard((double[][])new double[][]{KMeansGridTest.ard((double[])new double[]{5.0, 3.4, 1.5, 0.2}), KMeansGridTest.ard((double[])new double[]{7.0, 3.2, 4.7, 1.4}), KMeansGridTest.ard((double[])new double[]{6.5, 3.0, 5.8, 2.2})}));
        try {
            fr = KMeansGridTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            fr.remove("class").remove();
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>();
            Random rng = new Random();
            Integer kDim = rng.nextInt(4) + 1;
            Integer initDim = rng.nextInt(4) + 1;
            Integer seedDim = rng.nextInt(4) + 1;
            Integer standardizeDim = rng.nextInt(2) + 1;
            Integer[] kArr = new Integer[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50};
            ArrayList<Integer> kList = new ArrayList<Integer>(Arrays.asList(kArr));
            Collections.shuffle(kList);
            Object[] kSpace = new Integer[kDim.intValue()];
            for (int i = 0; i < kDim; ++i) {
                kSpace[i] = kList.get(i);
            }
            KMeans.Initialization[] initArr = new KMeans.Initialization[]{KMeans.Initialization.Random, KMeans.Initialization.User, KMeans.Initialization.PlusPlus, KMeans.Initialization.Furthest};
            ArrayList<KMeans.Initialization> initList = new ArrayList<KMeans.Initialization>(Arrays.asList(initArr));
            Collections.shuffle(initList);
            Object[] initSpace = new KMeans.Initialization[initDim.intValue()];
            for (int i = 0; i < initDim; ++i) {
                initSpace[i] = initList.get(i);
            }
            Long[] seedArr = new Long[]{0L, 1L, 123456789L, 987654321L};
            ArrayList<Long> seedList = new ArrayList<Long>(Arrays.asList(seedArr));
            Collections.shuffle(seedList);
            Object[] seedSpace = new Long[seedDim.intValue()];
            for (int i = 0; i < seedDim; ++i) {
                seedSpace[i] = seedList.get(i);
            }
            Integer[] standardizeArr = new Integer[]{1, 0};
            ArrayList<Integer> standardizeList = new ArrayList<Integer>(Arrays.asList(standardizeArr));
            Collections.shuffle(standardizeList);
            Object[] standardizeSpace = new Integer[standardizeDim.intValue()];
            for (int i = 0; i < standardizeDim; ++i) {
                standardizeSpace[i] = standardizeList.get(i);
            }
            hyperParms.put("_k", kSpace);
            hyperParms.put("_init", initSpace);
            hyperParms.put("_seed", seedSpace);
            hyperParms.put("_standardize", standardizeSpace);
            System.out.println("k search space: " + Arrays.toString(kSpace));
            System.out.println("max_depth search space: " + Arrays.toString(initSpace));
            System.out.println("seed search space: " + Arrays.toString(seedSpace));
            System.out.println("sample_rate search space: " + Arrays.toString(standardizeSpace));
            KMeansModel.KMeansParameters params = new KMeansModel.KMeansParameters();
            params._train = fr._key;
            if (Arrays.asList(initSpace).contains(KMeans.Initialization.User)) {
                params._user_points = init._key;
            }
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            Model[] ms = grid.getModels();
            Integer numModels = ms.length;
            System.out.println("Grid consists of " + numModels + " models");
            Assert.assertTrue((numModels == kDim * initDim * standardizeDim * seedDim ? 1 : 0) != 0);
            HashMap<String, Object[]> randomHyperParms = new HashMap<String, Object[]>();
            Object kVal = kSpace[rng.nextInt(kSpace.length)];
            randomHyperParms.put("_k", new Integer[]{kVal});
            Object initVal = initSpace[rng.nextInt(initSpace.length)];
            randomHyperParms.put("_init", initSpace);
            Object seedVal = seedSpace[rng.nextInt(seedSpace.length)];
            randomHyperParms.put("_seed", seedSpace);
            Object standardizeVal = standardizeSpace[rng.nextInt(standardizeSpace.length)];
            randomHyperParms.put("_standardize", standardizeSpace);
            params._k = (Integer)kVal;
            params._init = initVal;
            params._seed = (Long)seedVal;
            params._standardize = (Integer)standardizeVal == 1;
            kmRebuilt = (KMeansModel)new KMeans(params).trainModel().get();
            double rebuiltBetweenss = ((KMeansModel.KMeansOutput)kmRebuilt._output)._betweenss;
            System.out.println("The rebuilt model's betweenss: " + rebuiltBetweenss);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (kmRebuilt != null) {
                kmRebuilt.remove();
            }
            if (init != null) {
                init.remove();
            }
        }
    }
}

