package hex.ensemble;

import hex.StackedEnsembleModel;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:hex/ensemble/CheckSumTest.class */
public class CheckSumTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
    }

    @Test
    public void checkSumTest() {
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        GBMModel gBMModel = null;
        GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
        DRFModel dRFModel = null;
        DRFModel.DRFParameters dRFParameters = new DRFModel.DRFParameters();
        StackedEnsembleModel stackedEnsembleModel = null;
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/stackedensembles/stacking_fold.csv");
            DKV.put(parse_test_file);
            int find = parse_test_file.find("response");
            Scope.track(parse_test_file.replace(find, parse_test_file.vecs()[find].toCategoricalVec()));
            DKV.put(parse_test_file._key, parse_test_file);
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = "response";
            gBMParameters._ntrees = 10;
            gBMParameters._max_depth = 3;
            gBMParameters._min_rows = 2.0d;
            gBMParameters._learn_rate = 0.20000000298023224d;
            gBMParameters._distribution = DistributionFamily.bernoulli;
            gBMParameters._fold_column = "fold_column";
            gBMParameters._keep_cross_validation_predictions = true;
            gBMParameters._seed = 1L;
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            frame2 = gBMModel._parms.train();
            Assert.assertEquals(parse_test_file.checksum(), frame2.checksum());
            dRFParameters._train = parse_test_file._key;
            dRFParameters._response_column = "response";
            dRFParameters._distribution = DistributionFamily.bernoulli;
            dRFParameters._fold_column = "fold_column";
            dRFParameters._keep_cross_validation_predictions = true;
            dRFParameters._seed = 1L;
            dRFModel = (DRFModel) new DRF(dRFParameters).trainModel().get();
            frame3 = dRFModel._parms.train();
            Assert.assertEquals(parse_test_file.checksum(), frame3.checksum());
            StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters = new StackedEnsembleModel.StackedEnsembleParameters();
            stackedEnsembleParameters._train = parse_test_file._key;
            stackedEnsembleParameters._response_column = "response";
            stackedEnsembleParameters._base_models = new Key[]{gBMModel._key, dRFModel._key};
            stackedEnsembleModel = (StackedEnsembleModel) new StackedEnsemble(stackedEnsembleParameters).trainModel().get();
            if (0 != 0) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (frame3 != null) {
                frame3.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                gBMParameters._train.remove();
                for (Key key : gBMModel._output._cross_validation_predictions) {
                    key.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                dRFParameters._train.remove();
                for (Key key2 : dRFModel._output._cross_validation_predictions) {
                    key2.remove();
                }
                dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                dRFModel.deleteCrossValidationModels();
            }
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel.remove();
                stackedEnsembleModel._output._metalearner._output._training_metrics.remove();
                stackedEnsembleModel._output._metalearner.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (0 != 0) {
                frame.remove();
            }
            if (frame2 != null) {
                frame2.remove();
            }
            if (frame3 != null) {
                frame3.remove();
            }
            if (gBMModel != null) {
                gBMModel.delete();
                gBMParameters._train.remove();
                for (Key key3 : gBMModel._output._cross_validation_predictions) {
                    key3.remove();
                }
                gBMModel._output._cross_validation_holdout_predictions_frame_id.remove();
                gBMModel.deleteCrossValidationModels();
            }
            if (dRFModel != null) {
                dRFModel.delete();
                dRFParameters._train.remove();
                for (Key key4 : dRFModel._output._cross_validation_predictions) {
                    key4.remove();
                }
                dRFModel._output._cross_validation_holdout_predictions_frame_id.remove();
                dRFModel.deleteCrossValidationModels();
            }
            if (stackedEnsembleModel != null) {
                stackedEnsembleModel.delete();
                stackedEnsembleModel.remove();
                stackedEnsembleModel._output._metalearner._output._training_metrics.remove();
                stackedEnsembleModel._output._metalearner.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
