/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.Model;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.util.Log;

@RunWith(value=Parameterized.class)
public class GBMEncodingTest
extends TestUtil {
    @Parameterized.Parameter
    public Model.Parameters.CategoricalEncodingScheme encoding;

    @BeforeClass
    public static void stall() {
        GBMEncodingTest.stall_till_cloudsize((int)1);
    }

    @Parameterized.Parameters
    public static Iterable<?> data() {
        return Arrays.asList(Model.Parameters.CategoricalEncodingScheme.values());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testGBM_BasicCategoricalEncoding() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A"})).withDataForCol(1, GBMEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V"})).build();
            String target = "Response";
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._seed = 1L;
            parms._train = train._key;
            parms._response_column = target;
            parms._ntrees = 1;
            parms._max_depth = 1;
            parms._learn_rate = 1.0;
            parms._min_rows = 1.0;
            parms._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                parms._max_categorical_levels = 2;
            }
            GBM job = new GBM(parms);
            GBMModel gbm = (GBMModel)job.trainModel().get();
            Scope.track_generic((Keyed)gbm);
            Frame trainPreds = gbm.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            GBMEncodingTest.assertStringVecEquals((Vec)train.vec(target), (Vec)trainPreds.vec(0));
            Frame test = new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"A"})).build();
            Frame testPreds = gbm.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testGBM_CategoricalEncodingWithUnseenCategories() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "Response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, GBMEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            String target = "Response";
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._seed = 1L;
            parms._train = train._key;
            parms._response_column = target;
            parms._ntrees = 1;
            parms._max_depth = 3;
            parms._learn_rate = 1.0;
            parms._min_rows = 1.0;
            parms._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                parms._max_categorical_levels = 2;
            }
            GBM job = new GBM(parms);
            GBMModel gbm = (GBMModel)job.trainModel().get();
            Scope.track_generic((Keyed)gbm);
            Frame trainPreds = gbm.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            Frame test = new TestFrameBuilder().withName("testEncoding").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"A", "D", "E"})).build();
            Frame testPreds = gbm.score(test);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testGBM_CategoricalEncodingWithPredictionsOnFeaturesSubset() {
        if (this.encoding == Model.Parameters.CategoricalEncodingScheme.OneHotInternal) {
            return;
        }
        Log.info((Object[])new Object[]{"Using encoding " + this.encoding});
        try {
            Scope.enter();
            Frame train = new TestFrameBuilder().withName("trainEncoding").withColNames(new String[]{"ColA", "ColB", "Response"}).withVecTypes(new byte[]{4, 3, 4}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"B", "B", "A", "A", "A", "B", "A"})).withDataForCol(1, GBMEncodingTest.ar((long[])new long[]{2L, 2L, 1L, 1L, 1L, 2L, 1L})).withDataForCol(2, GBMEncodingTest.ar((String[])new String[]{"C", "C", "V", "V", "V", "C", "V"})).build();
            String target = "Response";
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._seed = 1L;
            parms._train = train._key;
            parms._response_column = target;
            parms._ntrees = 1;
            parms._max_depth = 3;
            parms._learn_rate = 1.0;
            parms._min_rows = 1.0;
            parms._categorical_encoding = this.encoding;
            if (this.encoding == Model.Parameters.CategoricalEncodingScheme.EnumLimited) {
                parms._max_categorical_levels = 2;
            }
            GBM job = new GBM(parms);
            GBMModel gbm = (GBMModel)job.trainModel().get();
            Scope.track_generic((Keyed)gbm);
            Frame trainPreds = gbm.score(train);
            Scope.track((Frame[])new Frame[]{trainPreds});
            Frame test_cat = new TestFrameBuilder().withName("testEncodingCat").withColNames(new String[]{"ColA", "ColZ"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, GBMEncodingTest.ar((String[])new String[]{"A"})).withDataForCol(1, GBMEncodingTest.ard((double[])new double[]{0.0})).build();
            Frame testPreds = gbm.score(test_cat);
            Scope.track((Frame[])new Frame[]{testPreds});
            Assert.assertEquals((Object)"V", (Object)testPreds.vec(0).stringAt(0L));
            Frame test_num = new TestFrameBuilder().withName("testEncodingNum").withColNames(new String[]{"ColB"}).withVecTypes(new byte[]{3}).withDataForCol(0, GBMEncodingTest.ar((long[])new long[]{1L})).build();
            Frame testPreds2 = gbm.score(test_num);
            Scope.track((Frame[])new Frame[]{testPreds2});
            Assert.assertEquals((Object)"V", (Object)testPreds2.vec(0).stringAt(0L));
            Frame test_no_common = new TestFrameBuilder().withName("testEncodingNoCommon").withColNames(new String[]{"ColZ"}).withVecTypes(new byte[]{3}).withDataForCol(0, GBMEncodingTest.ar((long[])new long[]{1L})).build();
            try {
                Scope.track((Frame[])new Frame[]{gbm.score(test_no_common)});
                Assert.fail((String)"Should have thrown IllegalArgumentException");
            }
            catch (IllegalArgumentException e) {
                Assert.assertTrue((String)("Expected exception due to no column in common with training data, but got: " + e.getMessage()), (boolean)e.getMessage().contains("no columns in common"));
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }
}

