/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Model;
import hex.ensemble.StackedEnsemble;
import hex.ensemble.StackedEnsembleModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.tree.drf.DRFModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Job;
import water.Key;
import water.Lockable;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

@RunWith(value=Parameterized.class)
public class StackedEnsembleEncodingTest
extends TestUtil {
    @Parameterized.Parameter
    public Model.Parameters.CategoricalEncodingScheme encoding;

    @BeforeClass
    public static void stall() {
        StackedEnsembleEncodingTest.stall_till_cloudsize((int)1);
    }

    @Parameterized.Parameters
    public static Iterable<?> data() {
        return Arrays.asList(Model.Parameters.CategoricalEncodingScheme.values());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSE_BasicCategoricalEncoding() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        ArrayList<Object> deletables = new ArrayList<Object>();
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, StackedEnsembleEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            String string = "Response";
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = train._key;
            params._response_column = string;
            params._sample_rate = 1.0;
            params._min_rows = 1.0;
            params._seed = 1L;
            params._nfolds = 2;
            params._keep_cross_validation_models = false;
            params._keep_cross_validation_predictions = true;
            params._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                params._max_categorical_levels = 2;
            }
            Job gridSearch = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{1, 2});
                    this.put("_max_depth", new Integer[]{2, 3});
                }
            });
            Grid grid = (Grid)gridSearch.get();
            deletables.add(grid);
            Model[] gridModels = grid.getModels();
            deletables.addAll(Arrays.asList(gridModels));
            Assert.assertEquals((long)4L, (long)gridModels.length);
            StackedEnsembleModel.StackedEnsembleParameters seParams = new StackedEnsembleModel.StackedEnsembleParameters();
            seParams._train = train._key;
            seParams._response_column = string;
            seParams._base_models = (Key[])ArrayUtils.append((Object[])grid.getModelKeys(), (Object[])new Key[0]);
            seParams._seed = 1L;
            StackedEnsembleModel se = (StackedEnsembleModel)new StackedEnsemble(seParams).trainModel().get();
            deletables.add(se);
            Frame trainPreds = se.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            StackedEnsembleEncodingTest.assertStringVecEquals((Vec)train.vec(string), (Vec)trainPreds.vec(0));
            Frame test = new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"A", "B"})).build();
            for (Model model : gridModels) {
                Frame testPreds = model.score(test);
                Scope.track((Frame[])new Frame[]{testPreds});
                Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            }
            Frame testPreds = se.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            Assert.assertEquals((Object)"C", (Object)testPreds.vec(0).stringAt(1L));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            for (Lockable lockable : deletables) {
                if (lockable instanceof Model) {
                    ((Model)lockable).deleteCrossValidationPreds();
                }
                lockable.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSE_CategoricalEncodingWithUnseenCategories() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        ArrayList<Object> deletables = new ArrayList<Object>();
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A", "B", "A", "E"})).withDataForCol(1, StackedEnsembleEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V", "C", "V", "V"})).build();
            String string = "Response";
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = train._key;
            params._response_column = string;
            params._sample_rate = 1.0;
            params._min_rows = 1.0;
            params._seed = 1L;
            params._nfolds = 2;
            params._keep_cross_validation_models = false;
            params._keep_cross_validation_predictions = true;
            params._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                params._max_categorical_levels = 2;
            }
            Job gridSearch = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{1, 2});
                    this.put("_max_depth", new Integer[]{2, 3});
                }
            });
            Grid grid = (Grid)gridSearch.get();
            deletables.add(grid);
            Model[] gridModels = grid.getModels();
            deletables.addAll(Arrays.asList(gridModels));
            Assert.assertEquals((long)4L, (long)gridModels.length);
            StackedEnsembleModel.StackedEnsembleParameters seParams = new StackedEnsembleModel.StackedEnsembleParameters();
            seParams._train = train._key;
            seParams._response_column = string;
            seParams._base_models = (Key[])ArrayUtils.append((Object[])grid.getModelKeys(), (Object[])new Key[0]);
            seParams._seed = 1L;
            StackedEnsembleModel se = (StackedEnsembleModel)new StackedEnsemble(seParams).trainModel().get();
            deletables.add(se);
            Frame trainPreds = se.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            Frame test = new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"A", "D", "E"})).build();
            for (Model model : gridModels) {
                Frame testPreds = model.score(test);
                Scope.track((Frame[])new Frame[]{testPreds});
                Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            }
            Frame testPreds = se.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
        }
        catch (IllegalArgumentException illegalArgumentException) {
            Scope.exit((Key[])new Key[0]);
            for (Lockable lockable : deletables) {
                if (lockable instanceof Model) {
                    ((Model)lockable).deleteCrossValidationPreds();
                }
                lockable.delete();
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            for (Lockable lockable : deletables) {
                if (lockable instanceof Model) {
                    ((Model)lockable).deleteCrossValidationPreds();
                }
                lockable.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSE_CategoricalEncodingWithPredictionsOnFeaturesSubset() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        ArrayList<Object> deletables = new ArrayList<Object>();
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "ColB", "Response"}).withVecTypes(new byte[]{4, 3, 4}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, StackedEnsembleEncodingTest.ar((long[])new long[]{2L, 2L, 1L, 1L, 1L, 2L, 1L})).withDataForCol(2, StackedEnsembleEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            String string = "Response";
            DRFModel.DRFParameters params = new DRFModel.DRFParameters();
            params._train = train._key;
            params._response_column = string;
            params._min_rows = 1.0;
            params._sample_rate = 1.0;
            params._col_sample_rate_per_tree = 0.5;
            params._seed = 1L;
            params._nfolds = 2;
            params._keep_cross_validation_models = false;
            params._keep_cross_validation_predictions = true;
            params._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                params._max_categorical_levels = 2;
            }
            Job gridSearch = GridSearch.startGridSearch(null, (Model.Parameters)params, (Map)new HashMap<String, Object[]>(){
                {
                    this.put("_ntrees", new Integer[]{1, 2});
                    this.put("_max_depth", new Integer[]{2, 3});
                }
            });
            Grid grid = (Grid)gridSearch.get();
            deletables.add(grid);
            Model[] gridModels = grid.getModels();
            deletables.addAll(Arrays.asList(gridModels));
            Assert.assertEquals((long)4L, (long)gridModels.length);
            StackedEnsembleModel.StackedEnsembleParameters seParams = new StackedEnsembleModel.StackedEnsembleParameters();
            seParams._train = train._key;
            seParams._response_column = string;
            seParams._base_models = (Key[])ArrayUtils.append((Object[])grid.getModelKeys(), (Object[])new Key[0]);
            seParams._seed = 1L;
            StackedEnsembleModel se = (StackedEnsembleModel)new StackedEnsemble(seParams).trainModel().get();
            deletables.add(se);
            Frame trainPreds = se.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            Frame test_cat = new TestFrameBuilder().withName("testEncodingCat").withColNames(new String[]{"ColA", "ColZ"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, StackedEnsembleEncodingTest.ar((String[])new String[]{"A"})).withDataForCol(1, StackedEnsembleEncodingTest.ard((double[])new double[]{0.0})).build();
            for (Model model : gridModels) {
                Frame testPreds = model.score(test_cat);
                Scope.track((Frame[])new Frame[]{testPreds});
                Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            }
            Frame testPreds = se.score(test_cat);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            Frame test_num = new TestFrameBuilder().withName("testEncodingNum").withColNames(new String[]{"ColB"}).withVecTypes(new byte[]{3}).withDataForCol(0, StackedEnsembleEncodingTest.ar((long[])new long[]{1L})).build();
            for (Model model : gridModels) {
                Frame testPreds2 = model.score(test_num);
                Scope.track((Frame[])new Frame[]{testPreds2});
                Assert.assertEquals((Object)"V", (Object)testPreds2.vec(0).stringAt(0L));
            }
            Frame testPreds2 = se.score(test_num);
            Scope.track((Frame[])new Frame[]{testPreds2});
            Assert.assertEquals((Object)"V", (Object)testPreds2.vec(0).stringAt(0L));
            Frame test_no_common = new TestFrameBuilder().withName("testEncodingNoCommon").withColNames(new String[]{"ColZ"}).withVecTypes(new byte[]{3}).withDataForCol(0, StackedEnsembleEncodingTest.ar((long[])new long[]{1L})).build();
            for (Model model : gridModels) {
                try {
                    Scope.track((Frame[])new Frame[]{model.score(test_no_common)});
                    Assert.fail((String)"Should have thrown IllegalArgumentException");
                }
                catch (IllegalArgumentException e) {
                    Assert.assertTrue((String)("Expected exception due to no column in common with training data, but got: " + e.getMessage()), (boolean)e.getMessage().contains("no columns in common"));
                }
            }
            try {
                Scope.track((Frame[])new Frame[]{se.score(test_no_common)});
                Assert.fail((String)"Should have thrown IllegalArgumentException");
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((String)("Expected exception due to no column in common with training data, but got: " + e.getMessage()), (boolean)e.getMessage().contains("no columns in common"));
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
            for (Lockable lockable : deletables) {
                if (lockable instanceof Model) {
                    ((Model)lockable).deleteCrossValidationPreds();
                }
                lockable.delete();
            }
        }
    }
}

