/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.Model;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Job;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;

public class DRFGridTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        DRFGridTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = DRFGridTest.parse_test_file((String)"smalldata/junit/cars.csv");
            fr.remove("name").remove();
            old = fr.remove("cylinders");
            fr.add("cylinders", old.toCategoricalVec());
            DKV.put((Keyed)fr);
            final Double[] legalSampleRateOpts = new Double[]{0.5};
            final Double[] illegalSampleRateOpts = new Double[]{2.0};
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{2, 4});
                    this.put("_max_depth", new Integer[]{10, 20});
                    this.put("_mtries", new Integer[]{-1, 4});
                    this.put("_sample_rate", ArrayUtils.join((Object[])legalSampleRateOpts, (Object[])illegalSampleRateOpts));
                }
            };
            Object[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
            Arrays.sort(hyperParamNames);
            int hyperSpaceSize = ArrayUtils.crossProductSize((Map)hyperParms);
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = fr._key;
            params._response_column = "cylinders";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals((String)"Size of grid should match to size of hyper space", (long)hyperSpaceSize, (long)(grid.getModelCount() + failures.getFailureCount()));
            Object[] gridHyperNames = grid.getHyperNames();
            Arrays.sort(gridHyperNames);
            Assert.assertArrayEquals((String)"Hyper parameters names should match!", (Object[])hyperParamNames, (Object[])gridHyperNames);
            Model[] ms = grid.getModels();
            Map<String, Set<Object>> usedModelParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Model m : ms) {
                DRFModel drf = (DRFModel)m;
                System.out.println(((DRFModel.DRFOutput)drf._output)._scored_train[((DRFModel.DRFOutput)drf._output)._ntrees]._mse + " " + Arrays.deepToString((Object[])ArrayUtils.zip((Object[])grid.getHyperNames(), (Object[])grid.getHyperValues(drf._parms))));
                GridTestUtils.extractParams(usedModelParams, drf._parms, (String[])hyperParamNames);
            }
            hyperParms.put("_sample_rate", legalSampleRateOpts);
            GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", (Map<String, Object[]>)hyperParms, usedModelParams);
            Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap((String[])hyperParamNames);
            for (Model.Parameters failedParams : failures.getFailedParameters()) {
                GridTestUtils.extractParams(failedHyperParams, failedParams, (String[])hyperParamNames);
            }
            hyperParms.put("_sample_rate", illegalSampleRateOpts);
            GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", (Map<String, Object[]>)hyperParms, failedHyperParams);
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testDuplicatesCarsGrid() {
        Grid grid = null;
        Frame fr = null;
        Vec old = null;
        try {
            fr = DRFGridTest.parse_test_file((String)"smalldata/junit/cars_20mpg.csv");
            fr.remove("name").remove();
            old = fr.remove("economy");
            fr.add("economy", old);
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{5, 5});
                    this.put("_max_depth", new Integer[]{2, 2});
                    this.put("_mtries", new Integer[]{-1, -1});
                    this.put("_sample_rate", new Double[]{0.1, 0.1});
                }
            };
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = fr._key;
            params._response_column = "economy";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)hyperParms);
            grid = (Grid)gs.get();
            Model[] models = grid.getModels();
            Assert.assertTrue((String)"Number of returned models has to be > 0", (models.length > 0 ? 1 : 0) != 0);
            Key modelKey = models[0]._key;
            for (Model m : models) {
                Assert.assertTrue((String)"Number of constructed models has to be equal to 1", (modelKey == m._key ? 1 : 0) != 0);
            }
        }
        finally {
            if (old != null) {
                old.remove();
            }
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testRandomCarsGrid() {
        Grid grid = null;
        DRFModel drfRebuilt = null;
        Frame fr = null;
        try {
            fr = DRFGridTest.parse_test_file((String)"smalldata/junit/cars.csv");
            fr.remove("name").remove();
            Vec old = fr.remove("economy (mpg)");
            fr.add("economy (mpg)", old);
            DKV.put((Keyed)fr);
            HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>();
            long seed = System.nanoTime();
            Random rng = new Random(seed);
            Integer ntreesDim = rng.nextInt(3) + 1;
            Integer maxDepthDim = rng.nextInt(3) + 1;
            Integer mtriesDim = rng.nextInt(3) + 1;
            Integer sampleRateDim = rng.nextInt(3) + 1;
            Integer[] ntreesArr = ArrayUtils.interval((Integer)1, (Integer)15);
            ArrayList<Integer> ntreesList = new ArrayList<Integer>(Arrays.asList(ntreesArr));
            Collections.shuffle(ntreesList);
            Object[] ntreesSpace = new Integer[ntreesDim.intValue()];
            for (int i = 0; i < ntreesDim; ++i) {
                ntreesSpace[i] = ntreesList.get(i);
            }
            Integer[] maxDepthArr = ArrayUtils.interval((Integer)1, (Integer)10);
            ArrayList<Integer> maxDepthList = new ArrayList<Integer>(Arrays.asList(maxDepthArr));
            Collections.shuffle(maxDepthList);
            Object[] maxDepthSpace = new Integer[maxDepthDim.intValue()];
            for (int i = 0; i < maxDepthDim; ++i) {
                maxDepthSpace[i] = maxDepthList.get(i);
            }
            Integer[] mtriesArr = ArrayUtils.interval((Integer)1, (Integer)5);
            ArrayList<Integer> mtriesList = new ArrayList<Integer>(Arrays.asList(mtriesArr));
            Collections.shuffle(mtriesList);
            Object[] mtriesSpace = new Integer[mtriesDim.intValue()];
            for (int i = 0; i < mtriesDim; ++i) {
                mtriesSpace[i] = mtriesList.get(i);
            }
            Double[] sampleRateArr = ArrayUtils.interval((Double)0.01, (Double)0.99, (Double)0.01);
            ArrayList<Double> sampleRateList = new ArrayList<Double>(Arrays.asList(sampleRateArr));
            Collections.shuffle(sampleRateList);
            Object[] sampleRateSpace = new Double[sampleRateDim.intValue()];
            for (int i = 0; i < sampleRateDim; ++i) {
                sampleRateSpace[i] = sampleRateList.get(i);
            }
            hyperParms.put("_ntrees", ntreesSpace);
            hyperParms.put("_max_depth", maxDepthSpace);
            hyperParms.put("_mtries", mtriesSpace);
            hyperParms.put("_sample_rate", sampleRateSpace);
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = fr._key;
            params._response_column = "economy (mpg)";
            Job gs = GridSearch.startGridSearch(null, (Model.Parameters)params, hyperParms);
            grid = (Grid)gs.get();
            System.out.println("Test seed: " + seed);
            System.out.println("ntrees search space: " + Arrays.toString(ntreesSpace));
            System.out.println("max_depth search space: " + Arrays.toString(maxDepthSpace));
            System.out.println("mtries search space: " + Arrays.toString(mtriesSpace));
            System.out.println("sample_rate search space: " + Arrays.toString(sampleRateSpace));
            Model[] ms = grid.getModels();
            int numModels = ms.length;
            System.out.println("Grid consists of " + numModels + " models");
            Grid.SearchFailure failures = grid.getFailures();
            Assert.assertEquals((String)"Number of models should match hyper space size", (long)numModels, (long)(ntreesDim * maxDepthDim * sampleRateDim * mtriesDim + failures.getFailureCount()));
            HashMap<String, Object[]> randomHyperParms = new HashMap<String, Object[]>();
            Object ntreeVal = ntreesSpace[rng.nextInt(ntreesSpace.length)];
            randomHyperParms.put("_ntrees", new Integer[]{ntreeVal});
            Object maxDepthVal = maxDepthSpace[rng.nextInt(maxDepthSpace.length)];
            randomHyperParms.put("_max_depth", maxDepthSpace);
            Object mtriesVal = mtriesSpace[rng.nextInt(mtriesSpace.length)];
            randomHyperParms.put("_max_depth", mtriesSpace);
            Object sampleRateVal = sampleRateSpace[rng.nextInt(sampleRateSpace.length)];
            randomHyperParms.put("_sample_rate", sampleRateSpace);
            params._ntrees = (Integer)ntreeVal;
            params._max_depth = (Integer)maxDepthVal;
            params._mtries = (Integer)mtriesVal;
            drfRebuilt = (DRFModel)new DRF(params).trainModel().get();
            double rebuiltMSE = ((DRFModel.DRFOutput)drfRebuilt._output)._scored_train[((DRFModel.DRFOutput)drfRebuilt._output)._ntrees]._mse;
            System.out.println("The rebuilt model's MSE: " + rebuiltMSE);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (grid != null) {
                grid.remove();
            }
            if (drfRebuilt != null) {
                drfRebuilt.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCollisionOfDRFParamsChecksum() {
        Frame fr = null;
        try {
            fr = DRFGridTest.parse_test_file((String)"smalldata/junit/cars.csv");
            fr.remove("name").remove();
            Vec old = fr.remove("economy (mpg)");
            fr.add("economy (mpg)", old);
            DKV.put((Keyed)fr);
            DRFModel.DRFParameters params1 = new DRFModel.DRFParameters();
            params1._train = fr._key;
            params1._response_column = "economy (mpg)";
            params1._seed = -4522296119273841674L;
            params1._mtries = 3;
            params1._max_depth = 15;
            params1._ntrees = 9;
            params1._sample_rate = 0.6499996781349182;
            DRFModel.DRFParameters params2 = new DRFModel.DRFParameters();
            params2._train = fr._key;
            params2._response_column = "economy (mpg)";
            params2._seed = -4522296119273841674L;
            params2._mtries = 1;
            params2._max_depth = 1;
            params2._ntrees = 13;
            params2._sample_rate = 0.6499996781349182;
            long csum1 = params1.checksum();
            long csum2 = params2.checksum();
            Assert.assertNotEquals((String)"Checksums shoudl be different", (long)csum1, (long)csum2);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
        }
    }
}

