package hex.tree.xgboost;

import hex.tree.xgboost.XGBoostModel;
import java.util.Map;
import org.junit.Assert;
import org.junit.Test;
import water.H2O;

/* loaded from: input_file:hex/tree/xgboost/XGBoostModelTest.class */
public class XGBoostModelTest {
    @Test
    public void testCreateParamsNThreads() throws Exception {
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        xGBoostParameters._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
        Assert.assertEquals(Integer.valueOf(H2O.ARGS.nthreads), XGBoostModel.createParams(xGBoostParameters, 2, (String[]) null).get().get("nthread"));
        XGBoostModel.XGBoostParameters xGBoostParameters2 = new XGBoostModel.XGBoostParameters();
        xGBoostParameters2._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
        xGBoostParameters2._nthread = H2O.ARGS.nthreads - 1;
        Assert.assertEquals(Integer.valueOf(H2O.ARGS.nthreads - 1), XGBoostModel.createParams(xGBoostParameters2, 2, (String[]) null).get().get("nthread"));
        XGBoostModel.XGBoostParameters xGBoostParameters3 = new XGBoostModel.XGBoostParameters();
        xGBoostParameters3._backend = XGBoostModel.XGBoostParameters.Backend.cpu;
        xGBoostParameters3._nthread = H2O.ARGS.nthreads + 1;
        Assert.assertEquals(Integer.valueOf(H2O.ARGS.nthreads), XGBoostModel.createParams(xGBoostParameters3, 2, (String[]) null).get().get("nthread"));
    }

    @Test
    public void gpuIncompatibleParametersMaxDepth() {
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        xGBoostParameters._max_depth = 16;
        Map gpuIncompatibleParams = xGBoostParameters.gpuIncompatibleParams();
        Assert.assertEquals(gpuIncompatibleParams.size(), 1L);
        Assert.assertEquals(gpuIncompatibleParams.get("max_depth"), "16 . Max depth must be greater than 0 and lower than 16 for GPU backend.");
        xGBoostParameters._max_depth = 0;
        Map gpuIncompatibleParams2 = xGBoostParameters.gpuIncompatibleParams();
        Assert.assertEquals(gpuIncompatibleParams2.size(), 1L);
        Assert.assertEquals(gpuIncompatibleParams2.get("max_depth"), "0 . Max depth must be greater than 0 and lower than 16 for GPU backend.");
    }

    @Test
    public void gpuIncompatibleParametersGrowPolicy() {
        XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
        xGBoostParameters._grow_policy = XGBoostModel.XGBoostParameters.GrowPolicy.lossguide;
        Map gpuIncompatibleParams = xGBoostParameters.gpuIncompatibleParams();
        Assert.assertEquals(gpuIncompatibleParams.size(), 1L);
        Assert.assertEquals(gpuIncompatibleParams.get("grow_policy"), XGBoostModel.XGBoostParameters.GrowPolicy.lossguide);
    }
}
