package ai.h2o.automl.targetencoding;

import hex.ModelMetricsBinomial;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.Iterator;
import java.util.Map;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:ai/h2o/automl/targetencoding/TargetEncodingTitanicBenchmark.class */
public class TargetEncodingTitanicBenchmark extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void KFoldHoldoutTypeTest() {
        Scope.enter();
        GBMModel gBMModel = null;
        try {
            String[] strArr = {"cabin", "home.dest", "embarked", "fold"};
            TargetEncoder targetEncoder = new TargetEncoder(new String[]{"cabin", "home.dest", "embarked"}, new BlendingParams(3.0d, 1.0d));
            Frame parse_test_file = parse_test_file(Key.make("titanic_train_parsed"), "smalldata/gbm_test/titanic_train.csv");
            Frame parse_test_file2 = parse_test_file(Key.make("titanic_valid_parsed"), "smalldata/gbm_test/titanic_valid.csv");
            Frame parse_test_file3 = parse_test_file(Key.make("titanic_test_parsed"), "smalldata/gbm_test/titanic_test.csv");
            asFactor(parse_test_file, "survived");
            asFactor(parse_test_file2, "survived");
            asFactor(parse_test_file3, "survived");
            printOutColumnsMetadata(parse_test_file3);
            Scope.track(new Frame[]{parse_test_file, parse_test_file2, parse_test_file3});
            TargetEncoderFrameHelper.addKFoldColumn(parse_test_file, "fold", 5, 1234L);
            parse_test_file.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file2.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file3.remove(new String[]{"name", "ticket", "boat", "body"});
            Map<String, Frame> prepareEncodingMap = targetEncoder.prepareEncodingMap(parse_test_file, "survived", "fold", true);
            Frame applyTargetEncoding = 1 != 0 ? targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 1, "fold", true, true, 1234L) : targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 1, "fold", true, 0.0d, true, 1234L);
            Frame applyTargetEncoding2 = targetEncoder.applyTargetEncoding(parse_test_file2, "survived", prepareEncodingMap, (byte) 2, "fold", true, 0.0d, true, 1234L);
            Frame applyTargetEncoding3 = targetEncoder.applyTargetEncoding(parse_test_file3, "survived", prepareEncodingMap, (byte) 2, "fold", true, 0.0d, true, 1234L);
            Scope.track(new Frame[]{applyTargetEncoding, applyTargetEncoding2, applyTargetEncoding3});
            printOutColumnsMetadata(applyTargetEncoding);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = applyTargetEncoding._key;
            gBMParameters._response_column = "survived";
            gBMParameters._score_tree_interval = 10;
            gBMParameters._ntrees = 1000;
            gBMParameters._max_depth = 5;
            gBMParameters._distribution = DistributionFamily.multinomial;
            gBMParameters._valid = applyTargetEncoding2._key;
            gBMParameters._stopping_tolerance = 0.001d;
            gBMParameters._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            gBMParameters._stopping_rounds = 5;
            gBMParameters._ignored_columns = strArr;
            gBMParameters._seed = 1234L;
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            System.out.println(gBMModel._output._variable_importances.toString(2, true));
            Assert.assertTrue(gbm.isStopped());
            Frame score = gBMModel.score(applyTargetEncoding3);
            Scope.track(new Frame[]{score});
            double d = ModelMetricsBinomial.make(score.vec(2), applyTargetEncoding3.vec(gBMParameters._response_column))._auc._auc;
            double trainDefaultGBM = trainDefaultGBM("survived");
            System.out.println("AUC with encoding:" + d);
            System.out.println("AUC without encoding:" + trainDefaultGBM);
            Assert.assertTrue(trainDefaultGBM < d);
            encodingMapCleanUp(prepareEncodingMap);
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void leaveOneOutHoldoutTypeTest() {
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            String[] strArr = {"cabin", "embarked", "home.dest"};
            TargetEncoder targetEncoder = new TargetEncoder(strArr, new BlendingParams(3.0d, 1.0d));
            Frame parse_test_file = parse_test_file(Key.make("titanic_train_parsed"), "smalldata/gbm_test/titanic_train.csv");
            Frame parse_test_file2 = parse_test_file(Key.make("titanic_valid_parsed"), "smalldata/gbm_test/titanic_valid.csv");
            Frame parse_test_file3 = parse_test_file(Key.make("titanic_test_parsed"), "smalldata/gbm_test/titanic_test.csv");
            asFactor(parse_test_file, "survived");
            asFactor(parse_test_file2, "survived");
            asFactor(parse_test_file3, "survived");
            Scope.track(new Frame[]{parse_test_file, parse_test_file2, parse_test_file3});
            parse_test_file.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file2.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file3.remove(new String[]{"name", "ticket", "boat", "body"});
            Map<String, Frame> prepareEncodingMap = targetEncoder.prepareEncodingMap(parse_test_file, "survived", (String) null, true);
            Frame applyTargetEncoding = 1 != 0 ? targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 0, true, true, 1234L) : targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 0, true, 0.0d, true, 1234L);
            Frame applyTargetEncoding2 = targetEncoder.applyTargetEncoding(parse_test_file2, "survived", prepareEncodingMap, (byte) 2, 1 != 0 && 0 == 0, 0.0d, true, 1234L);
            Frame applyTargetEncoding3 = targetEncoder.applyTargetEncoding(parse_test_file3, "survived", prepareEncodingMap, (byte) 2, 1 != 0 && 0 == 0, 0.0d, true, 1234L);
            Scope.track(new Frame[]{applyTargetEncoding, applyTargetEncoding2, applyTargetEncoding3});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = applyTargetEncoding._key;
            gBMParameters._response_column = "survived";
            gBMParameters._score_tree_interval = 10;
            gBMParameters._ntrees = 1000;
            gBMParameters._max_depth = 5;
            gBMParameters._distribution = DistributionFamily.AUTO;
            gBMParameters._valid = applyTargetEncoding2._key;
            gBMParameters._stopping_tolerance = 0.001d;
            gBMParameters._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            gBMParameters._stopping_rounds = 5;
            gBMParameters._ignored_columns = strArr;
            gBMParameters._seed = 1234L;
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            System.out.println(gBMModel._output._variable_importances.toString(2, true));
            Frame score = gBMModel.score(applyTargetEncoding3);
            Scope.track(new Frame[]{score});
            double d = ModelMetricsBinomial.make(score.vec(2), applyTargetEncoding3.vec(gBMParameters._response_column))._auc._auc;
            double trainDefaultGBM = trainDefaultGBM("survived");
            System.out.println("AUC with encoding:" + d);
            System.out.println("AUC without encoding:" + trainDefaultGBM);
            Assert.assertTrue(trainDefaultGBM < d);
            encodingMapCleanUp(prepareEncodingMap);
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void noneHoldoutTypeTest() {
        Scope.enter();
        try {
            String[] strArr = {"cabin", "embarked", "home.dest"};
            TargetEncoder targetEncoder = new TargetEncoder(strArr, new BlendingParams(3.0d, 1.0d));
            Frame parse_test_file = parse_test_file(Key.make("titanic_train_parsed"), "smalldata/gbm_test/titanic_train_wteh.csv");
            Frame parse_test_file2 = parse_test_file(Key.make("titanic_te_holdout_parsed"), "smalldata/gbm_test/titanic_te_holdout.csv");
            Frame parse_test_file3 = parse_test_file(Key.make("titanic_valid_parsed"), "smalldata/gbm_test/titanic_valid.csv");
            Frame parse_test_file4 = parse_test_file(Key.make("titanic_test_parsed"), "smalldata/gbm_test/titanic_test.csv");
            asFactor(parse_test_file, "survived");
            asFactor(parse_test_file2, "survived");
            asFactor(parse_test_file3, "survived");
            asFactor(parse_test_file4, "survived");
            Scope.track(new Frame[]{parse_test_file, parse_test_file2, parse_test_file3, parse_test_file4});
            parse_test_file.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file2.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file3.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file4.remove(new String[]{"name", "ticket", "boat", "body"});
            Frame asFactor = asFactor(parse_test_file2, "cabin");
            Scope.track(new Frame[]{asFactor});
            Map<String, Frame> prepareEncodingMap = targetEncoder.prepareEncodingMap(asFactor, "survived", (String) null, true);
            Frame applyTargetEncoding = 1 != 0 ? targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 2, true, true, 1234L) : targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, (byte) 2, true, 0.0d, true, 1234L);
            Frame applyTargetEncoding2 = targetEncoder.applyTargetEncoding(parse_test_file3, "survived", prepareEncodingMap, (byte) 2, true, 0.0d, true, 1234L);
            Frame applyTargetEncoding3 = targetEncoder.applyTargetEncoding(parse_test_file4, "survived", prepareEncodingMap, (byte) 2, true, 0.0d, true, 1234L);
            Scope.track(new Frame[]{applyTargetEncoding, applyTargetEncoding2, applyTargetEncoding3});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = applyTargetEncoding._key;
            gBMParameters._response_column = "survived";
            gBMParameters._score_tree_interval = 10;
            gBMParameters._ntrees = 1000;
            gBMParameters._max_depth = 5;
            gBMParameters._distribution = DistributionFamily.AUTO;
            gBMParameters._valid = applyTargetEncoding2._key;
            gBMParameters._stopping_tolerance = 0.001d;
            gBMParameters._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            gBMParameters._stopping_rounds = 5;
            gBMParameters._ignored_columns = strArr;
            gBMParameters._seed = 1234L;
            GBM gbm = new GBM(gBMParameters);
            GBMModel gBMModel = gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            System.out.println(gBMModel._output._variable_importances.toString(2, true));
            Frame score = gBMModel.score(applyTargetEncoding3);
            Scope.track(new Frame[]{score});
            double d = ModelMetricsBinomial.make(score.vec(2), applyTargetEncoding3.vec(gBMParameters._response_column))._auc._auc;
            double trainDefaultGBM = trainDefaultGBM("survived");
            System.out.println("AUC with encoding:" + d);
            System.out.println("AUC without encoding:" + trainDefaultGBM);
            Assert.assertTrue(trainDefaultGBM < d);
            encodingMapCleanUp(prepareEncodingMap);
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private double trainDefaultGBM(String str) {
        GBMModel gBMModel = null;
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file(Key.make("titanic_train_parsed"), "smalldata/gbm_test/titanic_train.csv");
            Frame parse_test_file2 = parse_test_file(Key.make("titanic_valid_parsed2"), "smalldata/gbm_test/titanic_valid.csv");
            Frame parse_test_file3 = parse_test_file(Key.make("titanic_test_parsed2"), "smalldata/gbm_test/titanic_test.csv");
            Scope.track(new Frame[]{parse_test_file, parse_test_file3, parse_test_file2});
            parse_test_file.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file2.remove(new String[]{"name", "ticket", "boat", "body"});
            parse_test_file3.remove(new String[]{"name", "ticket", "boat", "body"});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._response_column = str;
            gBMParameters._score_tree_interval = 10;
            gBMParameters._ntrees = 1000;
            gBMParameters._max_depth = 5;
            gBMParameters._distribution = DistributionFamily.quasibinomial;
            gBMParameters._valid = parse_test_file2._key;
            gBMParameters._stopping_tolerance = 0.001d;
            gBMParameters._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            gBMParameters._stopping_rounds = 5;
            gBMParameters._seed = 1234L;
            GBM gbm = new GBM(gBMParameters);
            gBMModel = (GBMModel) gbm.trainModel().get();
            Assert.assertTrue(gbm.isStopped());
            Frame score = gBMModel.score(parse_test_file3);
            Scope.track(new Frame[]{score});
            printOutFrameAsTable(score, false, score.numRows());
            double d = ModelMetricsBinomial.make(score.vec(2), parse_test_file3.vec(gBMParameters._response_column))._auc._auc;
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
            return d;
        } catch (Throwable th) {
            if (gBMModel != null) {
                gBMModel.delete();
                gBMModel.deleteCrossValidationModels();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @After
    public void afterEach() {
    }

    private void encodingMapCleanUp(Map<String, Frame> map) {
        Iterator<Map.Entry<String, Frame>> it = map.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().delete();
        }
    }
}
