package ai.h2o.automl.leaderboard;

import ai.h2o.automl.AutoML;
import ai.h2o.automl.events.EventLog;
import hex.Model;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import java.util.Iterator;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:ai/h2o/automl/leaderboard/LeaderboardTest.class */
public class LeaderboardTest extends TestUtil {
    private static Key<AutoML> dummy = Key.make();

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void test_toTwoDimTable_with_empty_models_and_without_sort_metric() {
        Leaderboard leaderboard = null;
        EventLog orMake = EventLog.getOrMake(dummy);
        try {
            leaderboard = Leaderboard.getOrMake("dummy_lb_no_sort_metric", orMake, new Frame(new Vec[0]), (String) null);
            TwoDimTable twoDimTable = leaderboard.toTwoDimTable(new String[0]);
            Assert.assertNotNull("empty leaderboard should also produce a TwoDimTable", twoDimTable);
            Assert.assertEquals("no models in this leaderboard", twoDimTable.getTableDescription());
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
        } catch (Throwable th) {
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
            throw th;
        }
    }

    @Test
    public void test_toTwoDimTable_with_empty_models_and_with_sort_metric() {
        Leaderboard leaderboard = null;
        EventLog orMake = EventLog.getOrMake(dummy);
        try {
            leaderboard = Leaderboard.getOrMake("dummy_lb_logloss_sort_metric", orMake, new Frame(new Vec[0]), "logloss");
            TwoDimTable twoDimTable = leaderboard.toTwoDimTable(new String[0]);
            Assert.assertNotNull("empty leaderboard should also produce a TwoDimTable", twoDimTable);
            Assert.assertEquals("no models in this leaderboard", twoDimTable.getTableDescription());
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
        } catch (Throwable th) {
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
            throw th;
        }
    }

    @Test
    public void test_rank_tsv() {
        Leaderboard leaderboard = null;
        EventLog orMake = EventLog.getOrMake(dummy);
        GBMModel gBMModel = null;
        Frame frame = null;
        try {
            frame = parse_test_file("./smalldata/logreg/prostate_train.csv");
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._nfolds = 2;
            gBMParameters._seed = 1L;
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            leaderboard = Leaderboard.getOrMake("dummy_rank_tsv", orMake, (Frame) null, "mae");
            leaderboard.addModel(gBMModel._key);
            Log.info(new Object[]{leaderboard.rankTsv()});
            Assert.assertEquals("Error\n[0.3448260574357465, 0.19959320678410908, 0.44675855535636816, 0.19959320678410908, 0.31468498072970547]\n", leaderboard.rankTsv());
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (leaderboard != null) {
                leaderboard.remove();
            }
            orMake.remove();
            if (gBMModel != null) {
                gBMModel.deleteCrossValidationModels();
                gBMModel.delete();
            }
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    @Test
    public void test_leaderboard_table_with_extensions() {
        ArrayList arrayList = new ArrayList();
        try {
            final Frame categoricalCol = parse_test_file("./smalldata/logreg/prostate_train.csv").toCategoricalCol("CAPSULE");
            arrayList.add(categoricalCol);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = categoricalCol._key;
            gBMParameters._nfolds = 3;
            gBMParameters._seed = 1L;
            gBMParameters._response_column = "CAPSULE";
            Model model = new GBM(gBMParameters).trainModel().get();
            arrayList.add(model);
            EventLog orMake = EventLog.getOrMake(dummy);
            arrayList.add(orMake);
            Leaderboard orMake2 = Leaderboard.getOrMake("leaderboard_with_ext", orMake, (Frame) null, (String) null);
            arrayList.add(orMake2);
            orMake2.setExtensionsProvider(new LeaderboardExtensionsProvider() { // from class: ai.h2o.automl.leaderboard.LeaderboardTest.1
                public LeaderboardCell[] createExtensions(Model model2) {
                    return new LeaderboardCell[]{new TrainingTime(model2), new ScoringTimePerRow(model2, (Frame) null, categoricalCol)};
                }
            });
            orMake2.addModel(model._key);
            TwoDimTable twoDimTable = orMake2.toTwoDimTable(new String[0]);
            Assert.assertEquals(1L, twoDimTable.getRowDim());
            Assert.assertEquals(7L, twoDimTable.getColDim());
            Assert.assertEquals("Leaderboard for project leaderboard_with_ext", twoDimTable.getTableHeader());
            Assert.assertEquals("models sorted in order of auc, best first", twoDimTable.getTableDescription());
            Assert.assertArrayEquals(new String[]{"model_id", "auc", "logloss", "aucpr", "mean_per_class_error", "rmse", "mse"}, twoDimTable.getColHeaders());
            Assert.assertArrayEquals(new String[]{"string", "double", "double", "double", "double", "double", "double"}, twoDimTable.getColTypes());
            Assert.assertArrayEquals(new String[]{"%s", "%.6f", "%.6f", "%.6f", "%.6f", "%.6f", "%.6f"}, twoDimTable.getColFormats());
            TwoDimTable twoDimTable2 = orMake2.toTwoDimTable(new String[]{"ALL"});
            Assert.assertEquals(9L, twoDimTable2.getColDim());
            Assert.assertArrayEquals(new String[]{"model_id", "auc", "logloss", "aucpr", "mean_per_class_error", "rmse", "mse", "training_time_ms", "predict_time_per_row_ms"}, twoDimTable2.getColHeaders());
            Assert.assertArrayEquals(new String[]{"string", "double", "double", "double", "double", "double", "double", "long", "double"}, twoDimTable2.getColTypes());
            Assert.assertArrayEquals(new String[]{"%s", "%.6f", "%.6f", "%.6f", "%.6f", "%.6f", "%.6f", "%s", "%.6f"}, twoDimTable2.getColFormats());
            Assert.assertTrue(twoDimTable2.get(0, 7) instanceof Long);
            Assert.assertTrue(twoDimTable2.get(0, 8) instanceof Double);
            Assert.assertTrue(((Long) twoDimTable2.get(0, 7)).longValue() > 0);
            Assert.assertTrue(((Double) twoDimTable2.get(0, 8)).doubleValue() > 0.0d);
            TwoDimTable twoDimTable3 = orMake2.toTwoDimTable(new String[]{"training_time_ms"});
            Assert.assertEquals(8L, twoDimTable3.getColDim());
            Assert.assertArrayEquals(new String[]{"model_id", "auc", "logloss", "aucpr", "mean_per_class_error", "rmse", "mse", "training_time_ms"}, twoDimTable3.getColHeaders());
            TwoDimTable twoDimTable4 = orMake2.toTwoDimTable(new String[]{"unknown", "training_time_ms"});
            Assert.assertEquals(8L, twoDimTable4.getColDim());
            Assert.assertArrayEquals(new String[]{"model_id", "auc", "logloss", "aucpr", "mean_per_class_error", "rmse", "mse", "training_time_ms"}, twoDimTable4.getColHeaders());
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                ((Keyed) it.next()).remove(true);
            }
        } catch (Throwable th) {
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                ((Keyed) it2.next()).remove(true);
            }
            throw th;
        }
    }
}
