package hex.ensemble;

import hex.GLMHelper;
import hex.Model;
import hex.ModelMetrics;
import hex.SplitFrame;
import hex.ensemble.Metalearner;
import hex.ensemble.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.splitframe.ShuffleSplitFrame;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import water.DKV;
import water.H2O;
import water.Job;
import water.Key;
import water.MRTask;
import water.Scope;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:hex/ensemble/StackedEnsembleTest.class */
public class StackedEnsembleTest extends TestUtil {

    @Rule
    public ExpectedException expectedException = ExpectedException.none();
    static final String[] ignored_aircols = {"DepTime", "ArrTime", "AirTime", "ArrDelay", "DepDelay", "TaxiIn", "TaxiOut", "Cancelled", "CancellationCode", "Diverted", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsDepDelayed"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/ensemble/StackedEnsembleTest$PrepData.class */
    public abstract class PrepData {
        private PrepData() {
        }

        abstract int prep(Frame frame);
    }

    /* loaded from: input_file:hex/ensemble/StackedEnsembleTest$Pubdev6157MRTask.class */
    public static class Pubdev6157MRTask extends MRTask {
        private final int nclasses;

        Pubdev6157MRTask(int i) {
            this.nclasses = i;
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            Random random = new Random();
            NewChunk newChunk = newChunkArr[0];
            NewChunk newChunk2 = newChunkArr[1];
            for (int i = 0; i < chunkArr[0]._len; i++) {
                long start = chunkArr[0].start() + i;
                newChunk.addNum(random.nextDouble());
                newChunk2.addNum(start % 2 == 0 ? this.nclasses - 1 : start % this.nclasses);
            }
        }
    }

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testBasicEnsembleAUTOMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.1
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.AUTO, false);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.2
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.AUTO, false);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.3
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.AUTO, false);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.4
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.AUTO, false);
    }

    @Test
    public void testBasicEnsembleGBMMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.5
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.gbm, false);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.6
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.gbm, false);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.7
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.gbm, false);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.8
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.gbm, false);
    }

    @Test
    public void testBasicEnsembleDRFMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.9
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.drf, false);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.10
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.drf, false);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.11
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.drf, false);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.12
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.drf, false);
    }

    @Test
    public void testBasicEnsembleDeepLearningMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.13
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.deeplearning, false);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.14
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.deeplearning, false);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.15
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.deeplearning, false);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.16
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.deeplearning, false);
    }

    @Test
    public void testBasicEnsembleGLMMetalearner() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.17
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/junit/test_tree_minmax.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.18
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("response");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/logreg/prostate.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.19
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.20
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/gbm_test/alphabet_cattest.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.21
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("y");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.22
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/logreg/prostate.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.23
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("ID").remove();
                return frame.find("RACE");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.24
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("cylinders");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.glm, false);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.25
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.glm, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [java.lang.String[], java.lang.String[][]] */
    @Test
    public void testPubDev6157() {
        try {
            Scope.enter();
            Vec makeConN = Vec.makeConN(100000L, H2O.ARGS.nthreads * 4);
            Scope.track(makeConN);
            byte[] bArr = {3, 4};
            ?? r0 = new String[bArr.length];
            r0[r0.length - 1] = new String[4];
            for (int i = 0; i < 4; i++) {
                r0[r0.length - 1][i] = "Level" + i;
            }
            final Frame outputFrame = new Pubdev6157MRTask(4).doAll(bArr, new Vec[]{makeConN}).outputFrame(Key.make(), (String[]) null, (String[][]) r0);
            Scope.track(new Frame[]{outputFrame});
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._train = outputFrame._key;
            gLMParameters._response_column = outputFrame.lastVecName();
            gLMParameters._family = GLMModel.GLMParameters.Family.multinomial;
            gLMParameters._max_iterations = 1;
            gLMParameters._seed = 42L;
            gLMParameters._auto_rebalance = false;
            final GLMModel gLMModel = new GLM(gLMParameters).trainModelOnH2ONode().get();
            Scope.track_generic(gLMModel);
            Assert.assertNotNull(gLMModel);
            final Job job = new Job(Key.make(), gLMParameters.javaName(), gLMParameters.algoName());
            job.start(new H2O.H2OCountedCompleter() { // from class: hex.ensemble.StackedEnsembleTest.26
                public void compute2() {
                    GLMHelper.runBigScore(gLMModel, outputFrame, false, false, job);
                    tryComplete();
                }
            }, 1L).get();
            basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.27
                @Override // hex.ensemble.StackedEnsembleTest.PrepData
                int prep(Frame frame) {
                    return frame.find("class");
                }
            }, false, DistributionFamily.multinomial, Metalearner.Algorithm.glm, false);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testBlending() {
        basicEnsemble("./smalldata/junit/cars.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.28
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                frame.remove("name").remove();
                return frame.find("economy (mpg)") ^ (-1);
            }
        }, false, DistributionFamily.gaussian, Metalearner.Algorithm.AUTO, true);
        basicEnsemble("./smalldata/airlines/allyears2k_headers.zip", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.29
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                for (String str : StackedEnsembleTest.ignored_aircols) {
                    frame.remove(str).remove();
                }
                return frame.find("IsArrDelayed");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.AUTO, true);
        basicEnsemble("./smalldata/iris/iris_wheader.csv", null, new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.30
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("class");
            }
        }, false, DistributionFamily.multinomial, Metalearner.Algorithm.AUTO, true);
        basicEnsemble("./smalldata/logreg/prostate_train.csv", "./smalldata/logreg/prostate_test.csv", new PrepData() { // from class: hex.ensemble.StackedEnsembleTest.31
            @Override // hex.ensemble.StackedEnsembleTest.PrepData
            int prep(Frame frame) {
                return frame.find("CAPSULE");
            }
        }, false, DistributionFamily.bernoulli, Metalearner.Algorithm.AUTO, true);
    }

    @Test
    public void testBaseModelPredictionsCaching() {
        Grid grid = null;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/junit/cars.csv");
            Frame[] shuffleSplitFrame = ShuffleSplitFrame.shuffleSplitFrame(parse_test_file, new Key[]{Key.make(parse_test_file._key + "_train"), Key.make(parse_test_file._key + "_blending"), Key.make(parse_test_file._key + "_valid")}, new double[]{0.5d, 0.3d, 0.2d}, 24576L);
            parse_test_file.remove();
            Frame frame = shuffleSplitFrame[0];
            Frame frame2 = shuffleSplitFrame[1];
            Frame frame3 = shuffleSplitFrame[2];
            arrayList2.addAll(Arrays.asList(frame, frame2, frame3));
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._distribution = DistributionFamily.gaussian;
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "economy (mpg)";
            gBMParameters._seed = 24576L;
            grid = (Grid) GridSearch.startGridSearch((Key) null, gBMParameters, new HashMap<String, Object[]>() { // from class: hex.ensemble.StackedEnsembleTest.32
                {
                    put("_ntrees", new Integer[]{3, 5});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.2d)});
                }
            }).get();
            Model[] models = grid.getModels();
            Assert.assertEquals(4L, models.length);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._distribution = DistributionFamily.bernoulli;
            stackedEnsembleParameters._train = frame._key;
            stackedEnsembleParameters._blending = frame2._key;
            stackedEnsembleParameters._response_column = "economy (mpg)";
            stackedEnsembleParameters._base_models = grid.getModelKeys();
            stackedEnsembleParameters._seed = 24576L;
            stackedEnsembleParameters._keep_base_model_predictions = false;
            StackedEnsembleModel stackedEnsembleModel = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel);
            Assert.assertNull(stackedEnsembleModel._output._base_model_predictions_keys);
            stackedEnsembleParameters._keep_base_model_predictions = true;
            StackedEnsembleModel stackedEnsembleModel2 = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel2);
            Assert.assertNotNull(stackedEnsembleModel2._output._base_model_predictions_keys);
            Assert.assertEquals(models.length, stackedEnsembleModel2._output._base_model_predictions_keys.length);
            Key[] keyArr = stackedEnsembleModel2._output._base_model_predictions_keys;
            for (Key key : keyArr) {
                Assert.assertNotNull("prediction key is not stored in DKV", DKV.get(key));
            }
            stackedEnsembleParameters._keep_base_model_predictions = true;
            StackedEnsembleModel stackedEnsembleModel3 = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel3);
            Assert.assertNotNull(stackedEnsembleModel3._output._base_model_predictions_keys);
            Assert.assertEquals(models.length, stackedEnsembleModel3._output._base_model_predictions_keys.length);
            Assert.assertArrayEquals(keyArr, stackedEnsembleModel3._output._base_model_predictions_keys);
            stackedEnsembleParameters._keep_base_model_predictions = true;
            stackedEnsembleParameters._valid = frame3._key;
            StackedEnsembleModel stackedEnsembleModel4 = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel4);
            Assert.assertNotNull(stackedEnsembleModel4._output._base_model_predictions_keys);
            Assert.assertEquals(models.length * 2, stackedEnsembleModel4._output._base_model_predictions_keys.length);
            for (Key key2 : keyArr) {
                Assert.assertTrue(ArrayUtils.contains(stackedEnsembleModel4._output._base_model_predictions_keys, key2));
            }
            stackedEnsembleModel4.deleteBaseModelPredictions();
            Assert.assertNull(stackedEnsembleModel4._output._base_model_predictions_keys);
            for (Key key3 : keyArr) {
                Assert.assertNull(DKV.get(key3));
            }
            Scope.exit(new Key[0]);
            if (grid != null) {
                grid.delete();
            }
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Model) it.next()).delete();
            }
            Iterator it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                ((Frame) it2.next()).remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (grid != null) {
                grid.delete();
            }
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                ((Model) it3.next()).delete();
            }
            Iterator it4 = arrayList2.iterator();
            while (it4.hasNext()) {
                ((Frame) it4.next()).remove();
            }
            throw th;
        }
    }

    @Test
    public void test_SE_scoring_with_blending() {
        ArrayList<Model> arrayList = new ArrayList();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/logreg/prostate_train.csv");
            arrayList.add(parse_test_file);
            Frame parse_test_file2 = parse_test_file("./smalldata/logreg/prostate_test.csv");
            arrayList.add(parse_test_file2);
            int find = parse_test_file.find("CAPSULE");
            parse_test_file.replace(find, parse_test_file.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file);
            parse_test_file2.replace(find, parse_test_file2.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file2);
            SplitFrame splitFrame = new SplitFrame(parse_test_file, new double[]{0.7d, 0.3d}, (Key[]) null);
            splitFrame.exec().get();
            Key[] keyArr = splitFrame._destination_frames;
            Frame frame = keyArr[0].get();
            arrayList.add(frame);
            Frame frame2 = keyArr[1].get();
            arrayList.add(frame2);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._response_column = "CAPSULE";
            gBMParameters._seed = 62832L;
            Grid grid = GridSearch.startGridSearch((Key) null, gBMParameters, new HashMap<String, Object[]>() { // from class: hex.ensemble.StackedEnsembleTest.33
                {
                    put("_ntrees", new Integer[]{3, 5});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.2d)});
                }
            }).get();
            arrayList.add(grid);
            arrayList.addAll(Arrays.asList(grid.getModels()));
            Assert.assertEquals(4L, r0.length);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = frame._key;
            stackedEnsembleParameters._blending = frame2._key;
            stackedEnsembleParameters._response_column = "CAPSULE";
            stackedEnsembleParameters._base_models = grid.getModelKeys();
            stackedEnsembleParameters._seed = 62832L;
            StackedEnsembleModel stackedEnsembleModel = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel);
            Frame score = stackedEnsembleModel.score(parse_test_file2);
            arrayList.add(score);
            Assert.assertEquals(3L, score.numCols());
            Assert.assertEquals(parse_test_file2.numRows(), score.numRows());
            for (Model model : arrayList) {
                if (model instanceof Model) {
                    model.deleteCrossValidationPreds();
                }
                model.delete();
            }
        } catch (Throwable th) {
            for (Model model2 : arrayList) {
                if (model2 instanceof Model) {
                    model2.deleteCrossValidationPreds();
                }
                model2.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_SE_with_GLM_can_do_predictions_on_frames_with_unseen_categorical_values() {
        ArrayList<Model> arrayList = new ArrayList();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/testng/cars_train.csv");
            arrayList.add(parse_test_file);
            Frame parse_test_file2 = parse_test_file("./smalldata/testng/cars_test.csv");
            arrayList.add(parse_test_file2);
            int find = parse_test_file2.find("cylinders");
            Assert.assertTrue(parse_test_file2.vec(find).isInt());
            Vec vec = parse_test_file2.vec(find);
            vec.set(vec.length() - 1, 7L);
            parse_test_file2.replace(find, vec.toCategoricalVec()).remove();
            DKV.put(parse_test_file2);
            Assert.assertTrue(parse_test_file2.vec(find).isCategorical());
            parse_test_file.replace(find, parse_test_file.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "economy (mpg)";
            gBMParameters._seed = 62832L;
            gBMParameters._keep_cross_validation_models = false;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gBMParameters._nfolds = 5;
            Grid grid = GridSearch.startGridSearch((Key) null, gBMParameters, new HashMap<String, Object[]>() { // from class: hex.ensemble.StackedEnsembleTest.34
                {
                    put("_ntrees", new Integer[]{3, 5});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.2d)});
                }
            }).get();
            arrayList.add(grid);
            arrayList.addAll(Arrays.asList(grid.getModels()));
            Assert.assertEquals(4L, r0.length);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._train = parse_test_file._key;
            gLMParameters._response_column = "economy (mpg)";
            gLMParameters._seed = 62832L;
            gLMParameters._keep_cross_validation_models = false;
            gLMParameters._keep_cross_validation_predictions = true;
            gLMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gLMParameters._nfolds = 5;
            gLMParameters._alpha = new double[]{0.1d, 0.2d, 0.4d};
            gLMParameters._lambda_search = true;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            arrayList.add(gLMModel);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._response_column = "economy (mpg)";
            stackedEnsembleParameters._base_models = (Key[]) ArrayUtils.append(grid.getModelKeys(), new Key[]{gLMModel._key});
            stackedEnsembleParameters._seed = 62832L;
            StackedEnsembleModel stackedEnsembleModel = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel);
            Frame score = stackedEnsembleModel.score(parse_test_file2);
            arrayList.add(score);
            Assert.assertTrue(score.vec(0).at(vec.length() - 1) > 0.0d);
            for (Model model : arrayList) {
                if (model instanceof Model) {
                    model.deleteCrossValidationPreds();
                }
                model.delete();
            }
        } catch (Throwable th) {
            for (Model model2 : arrayList) {
                if (model2 instanceof Model) {
                    model2.deleteCrossValidationPreds();
                }
                model2.delete();
            }
            throw th;
        }
    }

    @Test
    public void testKeepLevelOneFrameCVMode() {
        testSEModelCanBeSafelyRemoved(true, false);
    }

    @Test
    public void testDoNotKeepLevelOneFrameCVMode() {
        testSEModelCanBeSafelyRemoved(false, false);
    }

    @Test
    public void testKeepLevelOneFrameBlendingMode() {
        testSEModelCanBeSafelyRemoved(true, true);
    }

    @Test
    public void testDoNotKeepLevelOneFrameBlendingMode() {
        testSEModelCanBeSafelyRemoved(false, true);
    }

    private void testSEModelCanBeSafelyRemoved(boolean z, boolean z2) {
        ArrayList<Model> arrayList = new ArrayList();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/logreg/prostate_train.csv");
            arrayList.add(parse_test_file);
            Frame parse_test_file2 = parse_test_file("./smalldata/logreg/prostate_test.csv");
            arrayList.add(parse_test_file2);
            int find = parse_test_file.find("CAPSULE");
            parse_test_file.replace(find, parse_test_file.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file);
            parse_test_file2.replace(find, parse_test_file2.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file2);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "CAPSULE";
            gBMParameters._seed = 1L;
            if (!z2) {
                gBMParameters._nfolds = 3;
                gBMParameters._keep_cross_validation_models = false;
                gBMParameters._keep_cross_validation_predictions = true;
                gBMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            }
            Grid grid = GridSearch.startGridSearch((Key) null, gBMParameters, new HashMap<String, Object[]>() { // from class: hex.ensemble.StackedEnsembleTest.35
                {
                    put("_ntrees", new Integer[]{3, 5});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.2d)});
                }
            }).get();
            arrayList.add(grid);
            arrayList.addAll(Arrays.asList(grid.getModels()));
            Assert.assertEquals(4L, r0.length);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._response_column = "CAPSULE";
            stackedEnsembleParameters._base_models = grid.getModelKeys();
            stackedEnsembleParameters._seed = 1L;
            stackedEnsembleParameters._keep_levelone_frame = z;
            if (z2) {
                stackedEnsembleParameters._blending = parse_test_file2._key;
            }
            StackedEnsembleModel stackedEnsembleModel = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel);
            arrayList.add(stackedEnsembleModel.score(parse_test_file2));
            if (z) {
                Assert.assertEquals(r0.length + 1, stackedEnsembleModel._output._levelone_frame_id.numCols());
                if (z2) {
                    Assert.assertEquals(parse_test_file2.numRows(), stackedEnsembleModel._output._levelone_frame_id.numRows());
                    TestUtil.assertBitIdentical(new Frame(new Vec[]{parse_test_file2.vec("CAPSULE")}), new Frame(new Vec[]{stackedEnsembleModel._output._levelone_frame_id.vec("CAPSULE")}));
                } else {
                    Assert.assertEquals(parse_test_file.numRows(), stackedEnsembleModel._output._levelone_frame_id.numRows());
                    TestUtil.assertBitIdentical(new Frame(new Vec[]{parse_test_file.vec("CAPSULE")}), new Frame(new Vec[]{stackedEnsembleModel._output._levelone_frame_id.vec("CAPSULE")}));
                }
            } else {
                Assert.assertNull(stackedEnsembleModel._output._levelone_frame_id);
            }
            stackedEnsembleModel.delete();
            arrayList.add(new GBM(gBMParameters).trainModel().get());
            arrayList.add(new StackedEnsemble(stackedEnsembleParameters).trainModel().get());
            for (Model model : arrayList) {
                if (model instanceof Model) {
                    model.deleteCrossValidationPreds();
                }
                model.delete();
            }
        } catch (Throwable th) {
            for (Model model2 : arrayList) {
                if (model2 instanceof Model) {
                    model2.deleteCrossValidationPreds();
                }
                model2.delete();
            }
            throw th;
        }
    }

    public StackedEnsembleModel.StackedEnsembleOutput basicEnsemble(String str, String str2, PrepData prepData, boolean z, DistributionFamily distributionFamily, Metalearner.Algorithm algorithm, boolean z2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(Arrays.asList(Frame.fetchAll()));
        GBMModel gBMModel = null;
        DRFModel dRFModel = null;
        StackedEnsembleModel stackedEnsembleModel = null;
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        try {
            Scope.enter();
            frame = parse_test_file(str);
            r21 = null != str2 ? parse_test_file(str2) : null;
            int prep = prepData.prep(frame);
            if (null != r21) {
                prepData.prep(r21);
            }
            if (z2) {
                Frame[] shuffleSplitFrame = ShuffleSplitFrame.shuffleSplitFrame(frame, new Key[]{Key.make(frame._key + "_train"), Key.make(frame._key + "_blending")}, new double[]{0.6d, 0.4d}, 1764L);
                frame.remove();
                frame = shuffleSplitFrame[0];
                frame2 = shuffleSplitFrame[1];
            }
            if ((distributionFamily == DistributionFamily.bernoulli || distributionFamily == DistributionFamily.multinomial || distributionFamily == DistributionFamily.modified_huber) && !frame.vecs()[prep].isCategorical()) {
                Scope.track(frame.replace(prep, frame.vecs()[prep].toCategoricalVec()));
                if (null != r21) {
                    Scope.track(r21.replace(prep, r21.vecs()[prep].toCategoricalVec()));
                }
                if (null != frame2) {
                    Scope.track(frame2.replace(prep, frame2.vecs()[prep].toCategoricalVec()));
                }
            }
            DKV.put(frame);
            if (null != r21) {
                DKV.put(r21);
            }
            if (null != frame2) {
                DKV.put(frame2);
            }
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            if (prep < 0) {
                prep ^= -1;
            }
            gBMParameters._train = frame._key;
            gBMParameters._valid = r21 == null ? null : ((Frame) r21)._key;
            gBMParameters._response_column = frame._names[prep];
            gBMParameters._ntrees = 5;
            gBMParameters._distribution = distributionFamily;
            gBMParameters._max_depth = 4;
            gBMParameters._min_rows = 1.0d;
            gBMParameters._nbins = 50;
            gBMParameters._learn_rate = 0.20000000298023224d;
            gBMParameters._score_each_iteration = true;
            gBMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMParameters._nfolds = 5;
            gBMParameters._seed = 1764L;
            if (z) {
                r21 = new Frame(frame);
                DKV.put(r21);
                gBMParameters._valid = ((Frame) r21)._key;
            }
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
            dRFParameters._train = frame._key;
            dRFParameters._valid = r21 == null ? null : ((Frame) r21)._key;
            dRFParameters._response_column = frame._names[prep];
            dRFParameters._distribution = distributionFamily;
            dRFParameters._ntrees = 5;
            dRFParameters._max_depth = 4;
            dRFParameters._min_rows = 1.0d;
            dRFParameters._nbins = 50;
            dRFParameters._score_each_iteration = true;
            dRFParameters._seed = 1764L;
            if (!z2) {
                dRFParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
                dRFParameters._keep_cross_validation_predictions = true;
                dRFParameters._nfolds = 5;
            }
            DRF drf = new DRF(dRFParameters);
            dRFModel = (DRFModel) drf.trainModel().get();
            Assert.assertTrue(drf.isStopped());
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = frame._key;
            stackedEnsembleParameters._valid = r21 == null ? null : ((Frame) r21)._key;
            stackedEnsembleParameters._blending = frame2 == null ? null : frame2._key;
            stackedEnsembleParameters._response_column = frame._names[prep];
            stackedEnsembleParameters._metalearner_algorithm = algorithm;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key, dRFModel._key};
            stackedEnsembleParameters._seed = 1764L;
            StackedEnsemble stackedEnsemble = new StackedEnsemble(stackedEnsembleParameters);
            stackedEnsembleModel = (StackedEnsembleModel) stackedEnsemble.trainModel().get();
            Frame frame4 = new Frame(frame);
            DKV.put(frame4);
            frame3 = stackedEnsembleModel.score(frame4);
            Assert.assertTrue(stackedEnsembleModel.testJavaScoring(frame4, frame3, 1.0E-15d, 0.01d));
            Assert.assertTrue(stackedEnsemble.isStopped());
            Assert.assertEquals(stackedEnsembleModel._output._training_metrics.mse(), ModelMetrics.getFromDKV(stackedEnsembleModel, frame4).mse(), 1.0E-15d);
            frame4.remove();
            if (r21 != null) {
                ModelMetrics modelMetrics = stackedEnsembleModel._output._validation_metrics;
                Frame frame5 = new Frame(r21);
                DKV.put(frame5);
                stackedEnsembleModel.score(frame5).remove();
                Assert.assertEquals(modelMetrics.mse(), ModelMetrics.getFromDKV(stackedEnsembleModel, frame5).mse(), 1.0E-15d);
                frame5.remove();
            }
            StackedEnsembleModel.StackedEnsembleOutput stackedEnsembleOutput = stackedEnsembleModel._output;
            if (frame != null) {
                frame.remove();
            }
            if (r21 != null) {
                r21.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                for (Key key : gBMModel._output._cross_validation_predictions) {
                    key.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                if (!z2) {
                    for (Key key2 : dRFModel._output._cross_validation_predictions) {
                        key2.remove();
                    }
                    dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                    dRFModel.deleteCrossValidationModels();
                }
            }
            if (frame3 != null) {
                frame3.delete();
            }
            HashSet hashSet2 = new HashSet(hashSet);
            hashSet2.removeAll(Arrays.asList(Frame.fetchAll()));
            Assert.assertEquals("finish with the same number of Frames as we started: " + hashSet2, 0L, hashSet2.size());
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel._output._metalearner.delete();
            }
            Scope.exit(new Key[0]);
            return stackedEnsembleOutput;
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (r21 != null) {
                r21.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                for (Key key3 : gBMModel._output._cross_validation_predictions) {
                    key3.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                if (!z2) {
                    for (Key key4 : dRFModel._output._cross_validation_predictions) {
                        key4.remove();
                    }
                    dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                    dRFModel.deleteCrossValidationModels();
                }
            }
            if (frame3 != null) {
                frame3.delete();
            }
            HashSet hashSet3 = new HashSet(hashSet);
            hashSet3.removeAll(Arrays.asList(Frame.fetchAll()));
            Assert.assertEquals("finish with the same number of Frames as we started: " + hashSet3, 0L, hashSet3.size());
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel._output._metalearner.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void test_SE_scoring_with_missing_response_column() {
        for (Metalearner.Algorithm algorithm : Metalearner.Algorithm.values()) {
            if (algorithm != Metalearner.Algorithm.xgboost) {
                try {
                    test_SE_scoring_with_missing_response_column(algorithm);
                } catch (Exception e) {
                    Log.err(e);
                    Assert.fail("StackedEnsemble scoring failed with algo " + algorithm + ": " + e.getMessage());
                }
            }
        }
    }

    private void test_SE_scoring_with_missing_response_column(Metalearner.Algorithm algorithm) {
        ArrayList<Model> arrayList = new ArrayList();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/testng/prostate_train.csv");
            arrayList.add(parse_test_file);
            Frame parse_test_file2 = parse_test_file("./smalldata/testng/prostate_test.csv");
            arrayList.add(parse_test_file2);
            int find = parse_test_file.find("CAPSULE");
            parse_test_file.replace(find, parse_test_file.vec(find).toCategoricalVec()).remove();
            DKV.put(parse_test_file);
            parse_test_file2.remove(find).remove();
            DKV.put(parse_test_file2);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "CAPSULE";
            gBMParameters._seed = 1L;
            gBMParameters._keep_cross_validation_models = false;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gBMParameters._nfolds = 5;
            Grid grid = GridSearch.startGridSearch((Key) null, gBMParameters, new HashMap<String, Object[]>() { // from class: hex.ensemble.StackedEnsembleTest.36
                {
                    put("_ntrees", new Integer[]{3, 5});
                    put("_learn_rate", new Double[]{Double.valueOf(0.1d), Double.valueOf(0.2d)});
                }
            }).get();
            arrayList.add(grid);
            arrayList.addAll(Arrays.asList(grid.getModels()));
            Assert.assertEquals(4L, r0.length);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._train = parse_test_file._key;
            gLMParameters._response_column = "CAPSULE";
            gLMParameters._family = GLMModel.GLMParameters.Family.binomial;
            gLMParameters._seed = 1L;
            gLMParameters._keep_cross_validation_models = false;
            gLMParameters._keep_cross_validation_predictions = true;
            gLMParameters._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
            gLMParameters._nfolds = 5;
            gLMParameters._alpha = new double[]{0.1d, 0.2d, 0.4d};
            gLMParameters._lambda_search = true;
            GLMModel gLMModel = new GLM(gLMParameters).trainModel().get();
            arrayList.add(gLMModel);
            arrayList.add(gLMModel.score(parse_test_file2));
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._response_column = "CAPSULE";
            stackedEnsembleParameters._base_models = (Key[]) ArrayUtils.append(grid.getModelKeys(), new Key[0]);
            stackedEnsembleParameters._metalearner_algorithm = algorithm;
            stackedEnsembleParameters._seed = 1L;
            StackedEnsembleModel stackedEnsembleModel = new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            arrayList.add(stackedEnsembleModel);
            Frame score = stackedEnsembleModel.score(parse_test_file2);
            arrayList.add(score);
            Assert.assertTrue(Arrays.asList(Double.valueOf(0.0d), Double.valueOf(1.0d)).indexOf(Double.valueOf(score.vec(0).at(parse_test_file2.vec(0).length() - 1))) >= 0);
            for (Model model : arrayList) {
                if (model instanceof Model) {
                    model.deleteCrossValidationPreds();
                }
                model.delete();
            }
        } catch (Throwable th) {
            for (Model model2 : arrayList) {
                if (model2 instanceof Model) {
                    model2.deleteCrossValidationPreds();
                }
                model2.delete();
            }
            throw th;
        }
    }

    @Test
    public void testMissingFoldColumn_trainingFrame() {
        GBMModel gBMModel = null;
        try {
            Scope.enter();
            Frame parse_test_file = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv");
            Scope.track(new Frame[]{parse_test_file});
            Frame parse_test_file2 = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv", new int[]{4});
            Scope.track(new Frame[]{parse_test_file2});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._fold_column = "class";
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "petal_len";
            gBMParameters._ntrees = 1;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            Assert.assertNotNull(gBMModel);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file2._key;
            stackedEnsembleParameters._response_column = "petal_len";
            stackedEnsembleParameters._metalearner_algorithm = Metalearner.Algorithm.AUTO;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key};
            stackedEnsembleParameters._seed = 65261L;
            stackedEnsembleParameters._metalearner_fold_column = "class";
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Specified fold column 'class' not found in one of the supplied data frames. Available column names are: [sepal_len, sepal_wid, petal_wid, petal_len]");
            new StackedEnsemble(stackedEnsembleParameters);
            Assert.fail("Expected the Stack Ensemble Model never to be initialized successfully.");
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testMissingFoldColumn_validationFrame() {
        GBMModel gBMModel = null;
        try {
            Scope.enter();
            Frame parse_test_file = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv");
            Scope.track(new Frame[]{parse_test_file});
            Frame parse_test_file2 = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv", new int[]{4});
            Scope.track(new Frame[]{parse_test_file2});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._valid = parse_test_file._key;
            gBMParameters._fold_column = "class";
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "petal_len";
            gBMParameters._ntrees = 1;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            Assert.assertNotNull(gBMModel);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._valid = parse_test_file2._key;
            stackedEnsembleParameters._response_column = "petal_len";
            stackedEnsembleParameters._metalearner_algorithm = Metalearner.Algorithm.AUTO;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key};
            stackedEnsembleParameters._seed = 65261L;
            stackedEnsembleParameters._metalearner_fold_column = "class";
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Specified fold column 'class' not found in one of the supplied data frames. Available column names are: [sepal_len, sepal_wid, petal_len, petal_wid]");
            new StackedEnsemble(stackedEnsembleParameters);
            Assert.fail("Expected the Stack Ensemble Model never to be initialized successfully.");
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testMissingFoldColumn_blendingFrame() {
        GBMModel gBMModel = null;
        try {
            Scope.enter();
            Frame parse_test_file = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv");
            Scope.track(new Frame[]{parse_test_file});
            Frame parse_test_file2 = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv", new int[]{4});
            Scope.track(new Frame[]{parse_test_file2});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._fold_column = "class";
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "petal_len";
            gBMParameters._ntrees = 1;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            Assert.assertNotNull(gBMModel);
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._blending = parse_test_file2._key;
            stackedEnsembleParameters._response_column = "petal_len";
            stackedEnsembleParameters._metalearner_algorithm = Metalearner.Algorithm.AUTO;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key};
            stackedEnsembleParameters._seed = 65261L;
            stackedEnsembleParameters._metalearner_fold_column = "class";
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Specified fold column 'class' not found in one of the supplied data frames. Available column names are: [sepal_len, sepal_wid, petal_len, petal_wid]");
            new StackedEnsemble(stackedEnsembleParameters);
            Assert.fail("Expected the Stack Ensemble Model never to be initialized successfully.");
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
            throw th;
        }
    }

    @Test
    public void testInvalidFoldColumn_trainingFrame() {
        GBMModel gBMModel = null;
        try {
            Scope.enter();
            Frame parse_test_file = TestUtil.parse_test_file("./smalldata/iris/iris_wheader.csv");
            Scope.track(new Frame[]{parse_test_file});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._fold_column = "class";
            gBMParameters._seed = 65261L;
            gBMParameters._response_column = "petal_len";
            gBMParameters._ntrees = 1;
            gBMParameters._keep_cross_validation_predictions = true;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            Assert.assertNotNull(gBMModel);
            Frame frame = new Frame(Key.make(), parse_test_file.names(), parse_test_file.vecs());
            frame.add("class", frame.remove("class").toStringVec());
            DKV.put(frame);
            Scope.track(new Frame[]{frame});
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = frame._key;
            stackedEnsembleParameters._response_column = "petal_len";
            stackedEnsembleParameters._metalearner_algorithm = Metalearner.Algorithm.AUTO;
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key};
            stackedEnsembleParameters._seed = 24301L;
            stackedEnsembleParameters._metalearner_fold_column = "class";
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Specified fold column 'class' not found in one of the supplied data frames. Available column names are: [sepal_len, sepal_wid, petal_wid, petal_len]");
            new StackedEnsemble(stackedEnsembleParameters);
            Assert.fail("Expected the Stack Ensemble Model never to be initialized successfully.");
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.deleteCrossValidationPreds();
                gBMModel.remove();
            }
            throw th;
        }
    }
}
