package hex.tree.xgboost;

import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackendFactory;
import hex.genmodel.algos.xgboost.XGBoostJavaMojoModel;
import hex.tree.xgboost.XGBoostModel;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.util.Log;

/* loaded from: input_file:hex/tree/xgboost/DefaultMojoImplTest.class */
public class DefaultMojoImplTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testGBTree() throws IOException {
        Scope.enter();
        try {
            Assert.assertEquals(XGBoostJavaMojoModel.class.getName(), trainModel(XGBoostModel.XGBoostParameters.Booster.gbtree).toMojo().getClass().getName());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testDART() throws IOException {
        Scope.enter();
        try {
            Assert.assertEquals(XGBoostJavaMojoModel.class.getName(), trainModel(XGBoostModel.XGBoostParameters.Booster.dart).toMojo().getClass().getName());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testGBLinear() throws IOException {
        Scope.enter();
        try {
            Assert.assertEquals(XGBoostJavaMojoModel.class.getName(), trainModel(XGBoostModel.XGBoostParameters.Booster.gblinear).toMojo().getClass().getName());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testOldDARTMojoUsesNativeScoring() throws IOException {
        Scope.enter();
        try {
            Assert.assertEquals(XGBoostJavaMojoModel.class.getName(), MojoModel.load(MojoReaderBackendFactory.createReaderBackend(getClass().getResourceAsStream("oldDart.mojo"), MojoReaderBackendFactory.CachingStrategy.MEMORY)).getClass().getName());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private static XGBoostModel trainModel(XGBoostModel.XGBoostParameters.Booster booster) {
        Frame parse_test_file = parse_test_file("./smalldata/prostate/prostate.csv");
        Scope.track(new Frame[]{parse_test_file});
        Scope.track(parse_test_file.replace(1, parse_test_file.vecs()[1].toCategoricalVec()));
        Scope.track(parse_test_file.replace(3, parse_test_file.vecs()[3].toCategoricalVec()));
        DKV.put(parse_test_file);
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        xGBoostParameters._train = parse_test_file._key;
        xGBoostParameters._response_column = "AGE";
        xGBoostParameters._ignored_columns = new String[]{"ID"};
        xGBoostParameters._booster = booster;
        if (!XGBoostModel.XGBoostParameters.Booster.gblinear.equals(booster)) {
            xGBoostParameters._ntrees = 5;
        }
        XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
        Scope.track_generic(xGBoostModel);
        Log.info(new Object[]{xGBoostModel});
        return xGBoostModel;
    }
}
