/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.io.IOException;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.Scope;
import water.TestUtil;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.udf.JFuncUtils;
import water.udf.TestBernoulliCustomDistribution;
import water.util.FrameUtils;

public class GBMCustomDistributionTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        GBMCustomDistributionTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCustomDistribution() throws Exception {
        Frame fr = null;
        Frame fr2 = null;
        GBMModel gbm_default = null;
        GBMModel gbm_custom = null;
        CFuncRef func = this.bernoulliCustomDistribution();
        try {
            Scope.enter();
            fr = GBMCustomDistributionTest.parse_test_file((String)"./smalldata/gbm_test/alphabet_cattest.csv");
            int idx = fr.find("y");
            if (!fr.vecs()[idx].isCategorical()) {
                Scope.track((Vec)fr.replace(idx, fr.vecs()[idx].toCategoricalVec()));
            }
            System.out.println("Creating default model GBM...");
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._response_column = "y";
            parms._distribution = DistributionFamily.bernoulli;
            gbm_default = (GBMModel)Scope.track_generic((Keyed)new GBM(parms).trainModel().get());
            fr2 = GBMCustomDistributionTest.parse_test_file((String)"./smalldata/gbm_test/alphabet_cattest.csv");
            int idx2 = fr2.find("y");
            if (!fr2.vecs()[idx2].isCategorical()) {
                Scope.track((Vec)fr2.replace(idx2, fr2.vecs()[idx2].toCategoricalVec()));
            }
            System.out.println("Creating custom distribution model GBM...");
            parms = new GBMModel.GBMParameters();
            parms._train = fr2._key;
            parms._distribution = DistributionFamily.custom;
            parms._custom_distribution_func = func.toRef();
            parms._response_column = "y";
            gbm_custom = (GBMModel)Scope.track_generic((Keyed)new GBM(parms).trainModel().get());
            System.out.println("Test MSE is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.mse() + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.mse() + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.mse(), (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.mse(), (double)1.0E-4);
            System.out.println("Test RMSE is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.rmse() + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.rmse() + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.rmse(), (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.rmse(), (double)1.0E-4);
            System.out.println("Test AUC is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.auc_obj()._auc + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.auc_obj()._auc + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.auc_obj()._auc, (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.auc_obj()._auc, (double)1.0E-4);
            System.out.println("Test accuracy is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().accuracy() + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().accuracy() + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().accuracy(), (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().accuracy(), (double)1.0E-4);
            System.out.println("Test precision is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().precision() + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().precision() + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().precision(), (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().precision(), (double)1.0E-4);
            System.out.println("Test recall is the same for default (" + ((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().recall() + ") and custom (" + ((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().recall() + ")");
            Assert.assertEquals((double)((GBMModel.GBMOutput)gbm_default._output)._training_metrics.cm().recall(), (double)((GBMModel.GBMOutput)gbm_custom._output)._training_metrics.cm().recall(), (double)1.0E-4);
            try {
                System.out.println("Creating custom distribution model GBM wrong setting...");
                parms = new GBMModel.GBMParameters();
                parms._train = fr._key;
                parms._response_column = "y";
                parms._distribution = DistributionFamily.custom;
                parms._custom_distribution_func = null;
                gbm_custom = (GBMModel)Scope.track_generic((Keyed)new GBM(parms).trainModel().get());
            }
            catch (H2OModelBuilderIllegalArgumentException ex) {
                System.out.println("Catch illegal argument exception.");
            }
        }
        catch (Throwable throwable) {
            FrameUtils.delete((Lockable[])new Lockable[]{fr, fr2, gbm_default, gbm_custom});
            DKV.remove((Key)func.getKey());
            Scope.exit((Key[])new Key[0]);
            throw throwable;
        }
        FrameUtils.delete((Lockable[])new Lockable[]{fr, fr2, gbm_default, gbm_custom});
        DKV.remove((Key)func.getKey());
        Scope.exit((Key[])new Key[0]);
    }

    private CFuncRef bernoulliCustomDistribution() throws IOException {
        return JFuncUtils.loadTestFunc((String)"customDistribution.key", TestBernoulliCustomDistribution.class);
    }
}

