package hex.pdp;

import hex.PartialDependence;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.IcedWrapper;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;

/* loaded from: input_file:hex/pdp/PartialDependenceTest.class */
public class PartialDependenceTest extends TestUtil {
    static double _tot;
    static final /* synthetic */ boolean $assertionsDisabled;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void prostateBinary() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateBinaryWeights() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = parse_test_file.remove(str);
                parse_test_file.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            Scope.track(new Frame[]{parse_test_file});
            Vec anyVec = parse_test_file.anyVec();
            Vec[] vecArr = {anyVec.makeCon(2.0d)};
            parse_test_file.add(new String[]{"weights"}, vecArr);
            Scope.track(anyVec);
            Scope.track(vecArr[0]);
            DKV.put(parse_test_file);
            Scope.track(anyVec);
            Scope.track(vecArr[0]);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
            PartialDependence partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._cols = new String[]{"AGE", "RACE"};
            partialDependence._frame_id = parse_test_file._key;
            partialDependence.execImpl().get();
            PartialDependence partialDependence2 = new PartialDependence(Key.make());
            partialDependence2._nbins = 10;
            partialDependence2._model_id = gBMModel._key;
            partialDependence2._cols = new String[]{"AGE", "RACE"};
            partialDependence2._weight_column_index = parse_test_file.numCols() - 1;
            partialDependence2._frame_id = parse_test_file._key;
            partialDependence2.execImpl().get();
            Scope.track_generic(gBMModel);
            Scope.track_generic(partialDependence);
            Scope.track_generic(partialDependence2);
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence._partial_dependence_data[0], partialDependence2._partial_dependence_data[0], 1.0E-10d)) {
                throw new AssertionError("pdp with constant weight and without weight generated different answers for column AGE.");
            }
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence._partial_dependence_data[1], partialDependence2._partial_dependence_data[1], 1.0E-10d)) {
                throw new AssertionError("pdp with constant weight and without weight generated different answers for column RACE.");
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r1v29, types: [java.lang.String[], java.lang.String[][]] */
    @Test
    public void prostate2Dpdp() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = parse_test_file.remove(str);
                parse_test_file.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(parse_test_file);
            Scope.track(new Frame[]{parse_test_file});
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            GBMModel gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            PartialDependence partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._cols = new String[]{"RACE", "VOL"};
            partialDependence._frame_id = parse_test_file._key;
            partialDependence.execImpl().get();
            PartialDependence partialDependence2 = new PartialDependence(Key.make());
            partialDependence2._nbins = 10;
            partialDependence2._model_id = gBMModel._key;
            partialDependence2._cols = new String[]{"RACE", "VOL"};
            partialDependence2._col_pairs_2dpdp = new String[2];
            String[][] strArr = partialDependence2._col_pairs_2dpdp;
            String[] strArr2 = new String[2];
            strArr2[0] = "AGE";
            strArr2[1] = "RACE";
            strArr[0] = strArr2;
            String[][] strArr3 = partialDependence2._col_pairs_2dpdp;
            String[] strArr4 = new String[2];
            strArr4[0] = "AGE";
            strArr4[1] = "PSA";
            strArr3[1] = strArr4;
            partialDependence2._user_cols = new String[]{"AGE", "PSA"};
            partialDependence2._num_user_splits = new int[]{3, 3};
            partialDependence2._user_splits_present = true;
            partialDependence2._user_splits = new double[]{65.0d, 61.0d, 72.0d, 1.4d, 6.7d, 20.0d};
            partialDependence2._frame_id = parse_test_file._key;
            partialDependence2.execImpl().get();
            Scope.track_generic(gBMModel);
            Scope.track_generic(partialDependence2);
            Scope.track_generic(partialDependence);
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence2._partial_dependence_data[0], partialDependence._partial_dependence_data[0], 1.0E-10d)) {
                throw new AssertionError("pdp from 1d pdp only and pdp from 2d pdp differ for col RACE.");
            }
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence2._partial_dependence_data[1], partialDependence._partial_dependence_data[1], 1.0E-10d)) {
                throw new AssertionError("pdp from 1d pdp only and pdp from 2d pdp differ for col VOL.");
            }
            double[] dArr = {65.0d, 61.0d, 72.0d};
            double[] dArr2 = new double[3];
            assertCorrect2Dpdp(parse_test_file, partialDependence2._partial_dependence_data[2].getCellValues(), "AGE", "RACE", false, true, dArr, new double[]{0.0d, 1.0d, 2.0d}, gBMModel, _tot, dArr2);
            assertCorrect2Dpdp(parse_test_file, partialDependence2._partial_dependence_data[3].getCellValues(), "AGE", "PSA", false, false, dArr, new double[]{1.4d, 6.7d, 20.0d}, gBMModel, _tot, dArr2);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    public void assertCorrect2Dpdp(Frame frame, IcedWrapper[][] icedWrapperArr, String str, String str2, boolean z, boolean z2, double[] dArr, double[] dArr2, GBMModel gBMModel, double d, double[] dArr3) {
        for (int i = 0; i < icedWrapperArr.length; i++) {
            int length = i / dArr2.length;
            int length2 = i % dArr2.length;
            if (!$assertionsDisabled && dArr[length] != Double.valueOf(icedWrapperArr[i][0].toString()).doubleValue()) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && dArr2[length2] != Double.valueOf(icedWrapperArr[i][1].toString()).doubleValue()) {
                throw new AssertionError();
            }
            grab2DStats(dArr3, frame, dArr[length], dArr2[length2], str, str2, z, z2, gBMModel);
            Assert.assertTrue(Math.abs(dArr3[0] - Double.valueOf(icedWrapperArr[i][2].toString()).doubleValue()) < d);
            Assert.assertTrue(Math.abs(dArr3[1] - Double.valueOf(icedWrapperArr[i][3].toString()).doubleValue()) < d);
            Assert.assertTrue(Math.abs(dArr3[2] - Double.valueOf(icedWrapperArr[i][4].toString()).doubleValue()) < d);
        }
    }

    public void grab2DStats(double[] dArr, Frame frame, double d, double d2, String str, String str2, boolean z, boolean z2, GBMModel gBMModel) {
        Scope.enter();
        try {
            Frame deepCopy = frame._key.get().deepCopy(Key.make().toString());
            Scope.track(new Frame[]{deepCopy});
            Frame frame2 = new Frame(deepCopy.names(), deepCopy.vecs());
            Vec remove = frame2.remove(str);
            Vec makeCon = remove.makeCon(d);
            if (z) {
                makeCon.setDomain(deepCopy.vec(str).domain());
            }
            frame2.add(str, makeCon);
            Vec remove2 = frame2.remove(str2);
            Vec makeCon2 = remove2.makeCon(d2);
            if (z2) {
                makeCon2.setDomain(deepCopy.vec(str2).domain());
            }
            frame2.add(str2, makeCon2);
            Scope.track(new Frame[]{frame2});
            Scope.track(makeCon);
            Scope.track(remove);
            Scope.track(makeCon2);
            Scope.track(remove2);
            Frame score = gBMModel.score(frame2);
            Scope.track(new Frame[]{score});
            dArr[0] = score.vec(2).mean();
            dArr[1] = score.vec(2).sigma();
            dArr[2] = dArr[1] / Math.sqrt(score.numRows());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void prostateBinaryRow() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._row_index = 1L;
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateBinaryPickCols() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "CAPSULE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._cols = new String[]{"DPROS", "GLEASON"};
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            Assert.assertTrue(partialDependence._partial_dependence_data.length == 2);
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void prostateRegressionWeighted() {
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = parse_test_file.remove(str);
                parse_test_file.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            Scope.track(new Frame[]{parse_test_file});
            Vec anyVec = parse_test_file.anyVec();
            Vec[] vecArr = {anyVec.makeCon(2.0d)};
            parse_test_file.add(new String[]{"weights"}, vecArr);
            Scope.track(anyVec);
            Scope.track(vecArr[0]);
            DKV.put(parse_test_file);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = parse_test_file._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "AGE";
            GBMModel gBMModel = new GBM(gBMParameters).trainModel().get();
            Scope.track_generic(gBMModel);
            PartialDependence partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = parse_test_file._key;
            partialDependence._cols = new String[]{"AGE", "RACE"};
            partialDependence.execImpl().get();
            Scope.track_generic(partialDependence);
            PartialDependence partialDependence2 = new PartialDependence(Key.make());
            partialDependence2._nbins = 10;
            partialDependence2._model_id = gBMModel._key;
            partialDependence2._frame_id = parse_test_file._key;
            partialDependence2._weight_column_index = parse_test_file.numCols() - 1;
            partialDependence2._cols = new String[]{"AGE", "RACE"};
            partialDependence2.execImpl().get();
            Scope.track_generic(partialDependence2);
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence._partial_dependence_data[0], partialDependence2._partial_dependence_data[0], 1.0E-10d)) {
                throw new AssertionError("pdp with constant weight and without weight generated different answers for column AGE.");
            }
            if (!$assertionsDisabled && !equalTwoDimTables(partialDependence._partial_dependence_data[1], partialDependence2._partial_dependence_data[1], 1.0E-10d)) {
                throw new AssertionError("pdp with constant weight and without weight generated different answers for column RACE.");
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void prostateRegression() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/prostate/prostate.csv");
            for (String str : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec remove = frame.remove(str);
                frame.add(str, remove.toCategoricalVec());
                remove.remove();
            }
            DKV.put(frame);
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"ID"};
            gBMParameters._response_column = "AGE";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    @Test
    public void weatherBinary() {
        Frame frame = null;
        GBMModel gBMModel = null;
        PartialDependence partialDependence = null;
        try {
            frame = parse_test_file("smalldata/junit/weather.csv");
            GBMModel.GBMParameters gBMParameters = new GBMModel.GBMParameters();
            gBMParameters._train = frame._key;
            gBMParameters._ignored_columns = new String[]{"Date", "RISK_MM", "EvapMM"};
            gBMParameters._response_column = "RainTomorrow";
            gBMModel = (GBMModel) new GBM(gBMParameters).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 33;
            partialDependence._cols = new String[]{"Sunshine", "MaxWindPeriod", "WindSpeed9am"};
            partialDependence._model_id = gBMModel._key;
            partialDependence._frame_id = frame._key;
            partialDependence.execImpl().get();
            for (TwoDimTable twoDimTable : partialDependence._partial_dependence_data) {
                Log.info(new Object[]{twoDimTable});
            }
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.remove();
            }
            if (gBMModel != null) {
                gBMModel.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
            throw th;
        }
    }

    static {
        $assertionsDisabled = !PartialDependenceTest.class.desiredAssertionStatus();
        _tot = 1.0E-10d;
    }
}
