package hex.tree.gbm;

import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBMModel;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Lockable;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.udf.CFuncRef;
import water.udf.JFuncUtils;
import water.udf.TestBernoulliCustomDistribution;
import water.util.FrameUtils;

/* loaded from: input_file:hex/tree/gbm/GBMCustomDistributionTest.class */
public class GBMCustomDistributionTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testCustomDistribution() throws Exception {
        Lockable lockable = null;
        Lockable lockable2 = null;
        Lockable lockable3 = null;
        GBMModel gBMModel = null;
        CFuncRef bernoulliCustomDistribution = bernoulliCustomDistribution();
        try {
            Scope.enter();
            lockable = parse_test_file("./smalldata/gbm_test/alphabet_cattest.csv");
            int find = lockable.find("y");
            if (!lockable.vecs()[find].isCategorical()) {
                Scope.track(lockable.replace(find, lockable.vecs()[find].toCategoricalVec()));
            }
            System.out.println("Creating default model GBM...");
            new GBMModel.GBMParameters();
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = ((Frame) lockable)._key;
            gBMParameters._response_column = "y";
            gBMParameters._distribution = DistributionFamily.bernoulli;
            lockable3 = (GBMModel) Scope.track_generic(new GBM(gBMParameters).trainModel().get());
            lockable2 = parse_test_file("./smalldata/gbm_test/alphabet_cattest.csv");
            int find2 = lockable2.find("y");
            if (!lockable2.vecs()[find2].isCategorical()) {
                Scope.track(lockable2.replace(find2, lockable2.vecs()[find2].toCategoricalVec()));
            }
            System.out.println("Creating custom distribution model GBM...");
            GBMModel.GBMParameters gBMParameters2 = new GBMModel.GBMParameters();
            gBMParameters2._train = ((Frame) lockable2)._key;
            gBMParameters2._distribution = DistributionFamily.custom;
            gBMParameters2._custom_distribution_func = bernoulliCustomDistribution.toRef();
            gBMParameters2._response_column = "y";
            gBMModel = Scope.track_generic(new GBM(gBMParameters2).trainModel().get());
            System.out.println("Test MSE is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.mse() + ") and custom (" + gBMModel._output._training_metrics.mse() + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.mse(), gBMModel._output._training_metrics.mse(), 1.0E-4d);
            System.out.println("Test RMSE is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.rmse() + ") and custom (" + gBMModel._output._training_metrics.rmse() + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.rmse(), gBMModel._output._training_metrics.rmse(), 1.0E-4d);
            System.out.println("Test AUC is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.auc_obj()._auc + ") and custom (" + gBMModel._output._training_metrics.auc_obj()._auc + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.auc_obj()._auc, gBMModel._output._training_metrics.auc_obj()._auc, 1.0E-4d);
            System.out.println("Test accuracy is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.cm().accuracy() + ") and custom (" + gBMModel._output._training_metrics.cm().accuracy() + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.cm().accuracy(), gBMModel._output._training_metrics.cm().accuracy(), 1.0E-4d);
            System.out.println("Test precision is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.cm().precision() + ") and custom (" + gBMModel._output._training_metrics.cm().precision() + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.cm().precision(), gBMModel._output._training_metrics.cm().precision(), 1.0E-4d);
            System.out.println("Test recall is the same for default (" + ((GBMModel) lockable3)._output._training_metrics.cm().recall() + ") and custom (" + gBMModel._output._training_metrics.cm().recall() + ")");
            Assert.assertEquals(((GBMModel) lockable3)._output._training_metrics.cm().recall(), gBMModel._output._training_metrics.cm().recall(), 1.0E-4d);
            try {
                System.out.println("Creating custom distribution model GBM wrong setting...");
                GBMModel.GBMParameters gBMParameters3 = new GBMModel.GBMParameters();
                gBMParameters3._train = ((Frame) lockable)._key;
                gBMParameters3._response_column = "y";
                gBMParameters3._distribution = DistributionFamily.custom;
                gBMParameters3._custom_distribution_func = null;
                gBMModel = Scope.track_generic(new GBM(gBMParameters3).trainModel().get());
            } catch (H2OModelBuilderIllegalArgumentException e) {
                System.out.println("Catch illegal argument exception.");
            }
            FrameUtils.delete(new Lockable[]{lockable, lockable2, lockable3, gBMModel});
            DKV.remove(bernoulliCustomDistribution.getKey());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            FrameUtils.delete(new Lockable[]{lockable, lockable2, lockable3, gBMModel});
            DKV.remove(bernoulliCustomDistribution.getKey());
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private CFuncRef bernoulliCustomDistribution() throws IOException {
        return JFuncUtils.loadTestFunc("customDistribution.key", TestBernoulliCustomDistribution.class);
    }
}
