package hex.tree.xgboost;

import hex.SplitFrame;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.xgboost.XGBoostModel;
import java.util.Arrays;
import java.util.Collection;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.DKV;
import water.H2O;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

@RunWith(Parameterized.class)
/* loaded from: input_file:hex/tree/xgboost/XGBoostPredictImplComparisonTest.class */
public class XGBoostPredictImplComparisonTest extends TestUtil {

    @Parameterized.Parameter
    public String booster;

    @Parameterized.Parameter(1)
    public String distribution;

    @Parameterized.Parameter(2)
    public String response;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters(name = "XGBoost(booster={0},distribution={1},response={2}")
    public static Collection<Object[]> data() {
        return Arrays.asList(new Object[]{"gbtree", "AUTO", "AGE"}, new Object[]{"gbtree", "bernoulli", "CAPSULE"}, new Object[]{"gbtree", "multinomial", "CAPSULE"}, new Object[]{"gbtree", "gaussian", "AGE"}, new Object[]{"gbtree", "gamma", "AGE"}, new Object[]{"gbtree", "poisson", "AGE"}, new Object[]{"gbtree", "tweedie", "AGE"}, new Object[]{"gbtree", "gamma", "AGE"}, new Object[]{"dart", "AUTO", "AGE"}, new Object[]{"dart", "bernoulli", "CAPSULE"}, new Object[]{"dart", "multinomial", "CAPSULE"}, new Object[]{"dart", "gaussian", "AGE"}, new Object[]{"dart", "gamma", "AGE"}, new Object[]{"dart", "poisson", "AGE"}, new Object[]{"dart", "tweedie", "AGE"}, new Object[]{"dart", "gamma", "AGE"}, new Object[]{"gblinear", "AUTO", "AGE"}, new Object[]{"gblinear", "bernoulli", "CAPSULE"}, new Object[]{"gblinear", "multinomial", "CAPSULE"}, new Object[]{"gblinear", "gaussian", "AGE"}, new Object[]{"gblinear", "gamma", "AGE"}, new Object[]{"gblinear", "poisson", "AGE"}, new Object[]{"gblinear", "tweedie", "AGE"}, new Object[]{"gblinear", "gamma", "AGE"});
    }

    @Test
    public void testPredictionsAreSame() {
        Scope.enter();
        try {
            Frame track = Scope.track(new Frame[]{parse_test_file("./smalldata/prostate/prostate.csv")});
            Scope.track(track.replace(1, track.vecs()[1].toCategoricalVec()));
            Scope.track(track.replace(3, track.vecs()[3].toCategoricalVec()));
            DKV.put(track);
            SplitFrame splitFrame = new SplitFrame(track, new double[]{0.7d, 0.3d}, (Key[]) null);
            splitFrame.exec().get();
            Key[] keyArr = splitFrame._destination_frames;
            Frame track2 = Scope.track(new Frame[]{(Frame) keyArr[0].get()});
            Frame track3 = Scope.track(new Frame[]{(Frame) keyArr[1].get()});
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._booster = XGBoostModel.XGBoostParameters.Booster.valueOf(this.booster);
            xGBoostParameters._distribution = DistributionFamily.valueOf(this.distribution);
            xGBoostParameters._ntrees = 10;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._train = track2._key;
            xGBoostParameters._valid = track3._key;
            xGBoostParameters._response_column = this.response;
            XGBoostModel xGBoostModel = new XGBoost(xGBoostParameters).trainModel().get();
            Scope.track_generic(xGBoostModel);
            System.setProperty("sys.ai.h2o.xgboost.predict.native.enable", "true");
            Frame track4 = Scope.track(new Frame[]{xGBoostModel.score(track3)});
            System.setProperty("sys.ai.h2o.xgboost.predict.native.enable", "false");
            assertFrameEquals(track4, Scope.track(new Frame[]{xGBoostModel.score(track3)}), Double.valueOf(1.0E-10d), getRelDelta(xGBoostParameters));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private Double getRelDelta(XGBoostModel.XGBoostParameters xGBoostParameters) {
        if (usesGpu(xGBoostParameters)) {
            return Double.valueOf(0.001d);
        }
        if ("gblinear".equals(this.booster)) {
            return Double.valueOf(1.0E-6d);
        }
        return null;
    }

    private boolean usesGpu(XGBoostModel.XGBoostParameters xGBoostParameters) {
        return xGBoostParameters._backend == XGBoostModel.XGBoostParameters.Backend.gpu || (xGBoostParameters._backend == XGBoostModel.XGBoostParameters.Backend.auto && XGBoost.hasGPU(H2O.CLOUD.members()[0], 0));
    }
}
