package hex.pdp;

import hex.PartialDependence;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/pdp/PartialDependenceTest.class */
public class PartialDependenceTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void prostateBinary() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateBinaryPickCols() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._cols = new String[]{"DPROS", "GLEASON"};
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            Assert.assertTrue(partialDependence._partial_dependence_data.length == 2);
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateRegression() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "AGE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void weatherBinary() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/junit/weather.csv");
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"Date", "RISK_MM", "EvapMM"};
            gBMParameters._response_column = "RainTomorrow";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 33;
            partialDependence._cols = new String[]{"Sunshine", "MaxWindPeriod", "WindSpeed9am"};
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }
}
