/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Model;
import hex.genmodel.utils.DistributionFamily;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;

public class GBMGridTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        GBMGridTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = GBMGridTest.parse_test_file((String)"smalldata/junit/cars.csv");
            fr.remove("name").remove();
            old = fr.remove("cylinders");
            fr.add("cylinders", old.toCategoricalVec());
            DKV.put((Keyed)fr);
            final Double[] legalLearnRateOpts = new Double[]{0.01, 0.1, 0.3};
            final Double[] illegalLearnRateOpts = new Double[]{-1.0};
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{1, 2});
                    this.put("_distribution", new DistributionFamily[]{DistributionFamily.multinomial});
                    this.put("_max_depth", new Integer[]{1, 2, 5});
                    this.put("_learn_rate", ArrayUtils.join((Object[])legalLearnRateOpts, (Object[])illegalLearnRateOpts));
                }
            };
            Object[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
            Arrays.sort(hyperParamNames);
            int hyperSpaceSize = ArrayUtils.crossProductSize((Map)hyperParms);
            GBMModel.GBMParameters params = new GBMModel.GBMParameters();
            params._train = fr._key;
            params._response_column = "cylinders";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals((String)"Size of grid (models+failures) should match to size of hyper space", (long)hyperSpaceSize, (long)(grid.getModelCount() + failures.getFailureCount()));
            Object[] gridHyperNames = grid.getHyperNames();
            Arrays.sort(gridHyperNames);
            Assert.assertArrayEquals((String)"Hyper parameters names should match!", (Object[])hyperParamNames, (Object[])gridHyperNames);
            Key[] mKeys = grid.getModelKeys();
            Map<String, Set<Object>> usedHyperParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Key mKey : mKeys) {
                GBMModel gbm = (GBMModel)mKey.get();
                System.out.println(((GBMModel.GBMOutput)gbm._output)._scored_train[((GBMModel.GBMOutput)gbm._output)._ntrees]._mse + " " + Arrays.deepToString((Object[])ArrayUtils.zip((Object[])grid.getHyperNames(), (Object[])grid.getHyperValues(gbm._parms))));
                GridTestUtils.extractParams(usedHyperParams, gbm._parms, (String[])hyperParamNames);
            }
            hyperParms.put("_learn_rate", legalLearnRateOpts);
            GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", (Map<String, Object[]>)hyperParms, usedHyperParams);
            Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Model.Parameters failedParams : failures.getFailedParameters()) {
                GridTestUtils.extractParams(failedHyperParams, failedParams, (String[])hyperParamNames);
            }
            hyperParms.put("_learn_rate", illegalLearnRateOpts);
            GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", (Map<String, Object[]>)hyperParms, failedHyperParams);
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDuplicatesCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = GBMGridTest.parse_test_file((String)"smalldata/junit/cars_20mpg.csv");
            fr.remove("name").remove();
            old = fr.remove("economy");
            fr.add("economy", old);
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>(){
                {
                    this.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
                    this.put("_ntrees", new Integer[]{5, 5});
                    this.put("_max_depth", new Integer[]{2, 2});
                    this.put("_learn_rate", new Double[]{0.1, 0.1});
                }
            };
            GBMModel.GBMParameters params = new GBMModel.GBMParameters();
            params._train = fr._key;
            params._response_column = "economy";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            Model[] models = grid.getModels();
            Assert.assertTrue((String)"Number of returned models has to be > 0", (models.length > 0 ? 1 : 0) != 0);
            Key modelKey = models[0]._key;
            for (Model m : models) {
                Assert.assertTrue((String)"Number of constructed models has to be equal to 1", (modelKey == m._key ? 1 : 0) != 0);
            }
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testGridAccumulation() {
        Grid grid = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = GBMGridTest.parse_test_file((String)"smalldata/junit/cars_20mpg.csv");
            fr.remove("name").remove();
            old = fr.remove("economy");
            fr.add("economy", old);
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>(){
                {
                    this.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
                    this.put("_ntrees", new Integer[]{2});
                    this.put("_max_depth", new Integer[]{2});
                    this.put("_learn_rate", new Double[]{0.1});
                }
            };
            GBMModel.GBMParameters params = new GBMModel.GBMParameters();
            params._train = fr._key;
            params._response_column = "economy";
            Key accumulating_grid = Key.make((String)"accumulating_grid");
            Job gs = null;
            gs = GridSearch.startGridSearch((Key)accumulating_grid, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            gs = GridSearch.startGridSearch((Key)accumulating_grid, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            Model[] models = grid.getModels();
            Assert.assertTrue((String)"Number of returned models has to be 1", (models.length == 1 ? 1 : 0) != 0);
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testRandomCarsGrid() {
        Grid grid = null;
        GBMModel gbmRebuilt = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = GBMGridTest.parse_test_file((String)"smalldata/junit/cars.csv");
            fr.remove("name").remove();
            old = fr.remove("economy (mpg)");
            fr.add("economy (mpg)", old);
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>();
            hyperParms.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
            Random rng = new Random();
            Integer ntreesDim = rng.nextInt(4) + 1;
            Integer maxDepthDim = rng.nextInt(4) + 1;
            Integer learnRateDim = rng.nextInt(4) + 1;
            Integer[] ntreesArr = ArrayUtils.interval((Integer)1, (Integer)25);
            ArrayList<Integer> ntreesList = new ArrayList<Integer>(Arrays.asList(ntreesArr));
            Collections.shuffle(ntreesList);
            Object[] ntreesSpace = new Integer[ntreesDim.intValue()];
            for (int i = 0; i < ntreesDim; ++i) {
                ntreesSpace[i] = ntreesList.get(i);
            }
            Integer[] maxDepthArr = ArrayUtils.interval((Integer)1, (Integer)10);
            ArrayList<Integer> maxDepthList = new ArrayList<Integer>(Arrays.asList(maxDepthArr));
            Collections.shuffle(maxDepthList);
            Object[] maxDepthSpace = new Integer[maxDepthDim.intValue()];
            for (int i = 0; i < maxDepthDim; ++i) {
                maxDepthSpace[i] = maxDepthList.get(i);
            }
            Double[] learnRateArr = ArrayUtils.interval((Double)0.01, (Double)1.0, (Double)0.01);
            ArrayList<Double> learnRateList = new ArrayList<Double>(Arrays.asList(learnRateArr));
            Collections.shuffle(learnRateList);
            Object[] learnRateSpace = new Double[learnRateDim.intValue()];
            for (int i = 0; i < learnRateDim; ++i) {
                learnRateSpace[i] = learnRateList.get(i);
            }
            hyperParms.put("_ntrees", ntreesSpace);
            hyperParms.put("_max_depth", maxDepthSpace);
            hyperParms.put("_learn_rate", learnRateSpace);
            GBMModel.GBMParameters params = new GBMModel.GBMParameters();
            params._train = fr._key;
            params._response_column = "economy (mpg)";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            System.out.println("ntrees search space: " + Arrays.toString(ntreesSpace));
            System.out.println("max_depth search space: " + Arrays.toString(maxDepthSpace));
            System.out.println("learn_rate search space: " + Arrays.toString(learnRateSpace));
            Model[] ms = grid.getModels();
            Integer numModels = ms.length;
            System.out.println("Grid consists of " + numModels + " models");
            Assert.assertTrue((numModels == ntreesDim * maxDepthDim * learnRateDim ? 1 : 0) != 0);
            HashMap<String, Object[]> randomHyperParms = new HashMap<String, Object[]>();
            randomHyperParms.put("_distribution", new DistributionFamily[]{DistributionFamily.gaussian});
            Object ntreeVal = ntreesSpace[rng.nextInt(ntreesSpace.length)];
            randomHyperParms.put("_ntrees", new Integer[]{ntreeVal});
            Object maxDepthVal = maxDepthSpace[rng.nextInt(maxDepthSpace.length)];
            randomHyperParms.put("_max_depth", maxDepthSpace);
            Object learnRateVal = learnRateSpace[rng.nextInt(learnRateSpace.length)];
            randomHyperParms.put("_learn_rate", learnRateSpace);
            params._distribution = DistributionFamily.gaussian;
            params._ntrees = (Integer)ntreeVal;
            params._max_depth = (Integer)maxDepthVal;
            params._learn_rate = (Double)learnRateVal;
            GBM gbm = new GBM(params);
            gbmRebuilt = (GBMModel)gbm.trainModel().get();
            Assert.assertTrue((boolean)gbm.isStopped());
            double rebuiltMSE = ((GBMModel.GBMOutput)gbmRebuilt._output)._scored_train[((GBMModel.GBMOutput)gbmRebuilt._output)._ntrees]._mse;
            System.out.println("The rebuilt model's MSE: " + rebuiltMSE);
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (gbmRebuilt != null) {
                gbmRebuilt.remove();
            }
        }
    }
}

