/*
 * Decompiled with CFR 0.152.
 */
package hex.pdp;

import hex.PartialDependence;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.IcedWrapper;
import water.Key;
import water.Keyed;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.TwoDimTable;

public class PartialDependenceTest
extends TestUtil {
    static double _tot = 1.0E-10;

    @BeforeClass
    public static void setup() {
        PartialDependenceTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateBinary() {
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            DKV.put((Keyed)fr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "CAPSULE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            for (TwoDimTable t : partialDependence._partial_dependence_data) {
                Log.info((Object[])new Object[]{t});
            }
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateBinaryWeights() {
        Scope.enter();
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        PartialDependence partialDependenceW = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            Scope.track((Frame[])new Frame[]{fr});
            Vec orig = fr.anyVec();
            Vec[] weights = new Vec[]{orig.makeCon(2.0)};
            fr.add(new String[]{"weights"}, weights);
            Scope.track((Vec)orig);
            Scope.track((Vec)weights[0]);
            DKV.put((Keyed)fr);
            Scope.track((Vec)orig);
            Scope.track((Vec)weights[0]);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "CAPSULE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._cols = new String[]{"AGE", "RACE"};
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            partialDependenceW = new PartialDependence(Key.make());
            partialDependenceW._nbins = 10;
            partialDependenceW._model_id = model._key;
            partialDependenceW._cols = new String[]{"AGE", "RACE"};
            partialDependenceW._weight_column_index = fr.numCols() - 1;
            partialDependenceW._frame_id = fr._key;
            partialDependenceW.execImpl().get();
            Scope.track_generic((Keyed)model);
            Scope.track_generic((Keyed)partialDependence);
            Scope.track_generic((Keyed)partialDependenceW);
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[0], (TwoDimTable)partialDependenceW._partial_dependence_data[0], (double)1.0E-10)) : "pdp with constant weight and without weight generated different answers for column AGE.";
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[1], (TwoDimTable)partialDependenceW._partial_dependence_data[1], (double)1.0E-10)) : "pdp with constant weight and without weight generated different answers for column RACE.";
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostate2Dpdp() {
        Scope.enter();
        try {
            Frame fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            DKV.put((Keyed)fr);
            Scope.track((Frame[])new Frame[]{fr});
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "CAPSULE";
            GBMModel model = (GBMModel)new GBM(parms).trainModel().get();
            PartialDependence partialDependence1 = new PartialDependence(Key.make());
            partialDependence1._nbins = 10;
            partialDependence1._model_id = model._key;
            partialDependence1._cols = new String[]{"RACE", "VOL"};
            partialDependence1._frame_id = fr._key;
            partialDependence1.execImpl().get();
            PartialDependence partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._cols = new String[]{"RACE", "VOL"};
            partialDependence._col_pairs_2dpdp = new String[2][];
            partialDependence._col_pairs_2dpdp[0] = new String[]{"AGE", "RACE"};
            partialDependence._col_pairs_2dpdp[1] = new String[]{"AGE", "PSA"};
            partialDependence._user_cols = new String[]{"AGE", "PSA"};
            partialDependence._num_user_splits = new int[]{3, 3};
            partialDependence._user_splits_present = true;
            partialDependence._user_splits = new double[]{65.0, 61.0, 72.0, 1.4, 6.7, 20.0};
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            Scope.track_generic((Keyed)model);
            Scope.track_generic((Keyed)partialDependence);
            Scope.track_generic((Keyed)partialDependence1);
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[0], (TwoDimTable)partialDependence1._partial_dependence_data[0], (double)1.0E-10)) : "pdp from 1d pdp only and pdp from 2d pdp differ for col RACE.";
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[1], (TwoDimTable)partialDependence1._partial_dependence_data[1], (double)1.0E-10)) : "pdp from 1d pdp only and pdp from 2d pdp differ for col VOL.";
            double[] ageSplit = new double[]{65.0, 61.0, 72.0};
            double[] psaSplit = new double[]{1.4, 6.7, 20.0};
            double[] raceSplit = new double[]{0.0, 1.0, 2.0};
            double[] tstats = new double[3];
            this.assertCorrect2Dpdp(fr, partialDependence._partial_dependence_data[2].getCellValues(), "AGE", "RACE", false, true, ageSplit, raceSplit, model, _tot, tstats);
            this.assertCorrect2Dpdp(fr, partialDependence._partial_dependence_data[3].getCellValues(), "AGE", "PSA", false, false, ageSplit, psaSplit, model, _tot, tstats);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    public void assertCorrect2Dpdp(Frame fr, IcedWrapper[][] cellVs, String col, String col2, boolean cat, boolean cat2, double[] colVals, double[] col2Vals, GBMModel model, double tot, double[] tstats) {
        for (int index = 0; index < cellVs.length; ++index) {
            int counter1 = index / col2Vals.length;
            int counter2 = index % col2Vals.length;
            assert (colVals[counter1] == Double.valueOf(cellVs[index][0].toString()));
            assert (col2Vals[counter2] == Double.valueOf(cellVs[index][1].toString()));
            this.grab2DStats(tstats, fr, colVals[counter1], col2Vals[counter2], col, col2, cat, cat2, model);
            Assert.assertTrue((Math.abs(tstats[0] - Double.valueOf(cellVs[index][2].toString())) < tot ? 1 : 0) != 0);
            Assert.assertTrue((Math.abs(tstats[1] - Double.valueOf(cellVs[index][3].toString())) < tot ? 1 : 0) != 0);
            Assert.assertTrue((Math.abs(tstats[2] - Double.valueOf(cellVs[index][4].toString())) < tot ? 1 : 0) != 0);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void grab2DStats(double[] tstats, Frame fr, double value, double value2, String col, String col2, boolean cat, boolean cat2, GBMModel model) {
        Scope.enter();
        try {
            Frame tfr = ((Frame)fr._key.get()).deepCopy(Key.make().toString());
            Scope.track((Frame[])new Frame[]{tfr});
            Frame test = new Frame(tfr.names(), tfr.vecs());
            Vec orig = test.remove(col);
            Vec cons = orig.makeCon(value);
            if (cat) {
                cons.setDomain(tfr.vec(col).domain());
            }
            test.add(col, cons);
            Vec cons2 = null;
            Vec orig2 = test.remove(col2);
            cons2 = orig2.makeCon(value2);
            if (cat2) {
                cons2.setDomain(tfr.vec(col2).domain());
            }
            test.add(col2, cons2);
            Scope.track((Frame[])new Frame[]{test});
            Scope.track((Vec)cons);
            Scope.track((Vec)orig);
            Scope.track((Vec)cons2);
            Scope.track((Vec)orig2);
            Frame preds = model.score(test);
            Scope.track((Frame[])new Frame[]{preds});
            tstats[0] = preds.vec(2).mean();
            tstats[1] = preds.vec(2).sigma();
            tstats[2] = tstats[1] / Math.sqrt(preds.numRows());
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateBinaryRow() {
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            DKV.put((Keyed)fr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "CAPSULE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._row_index = 1L;
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            for (TwoDimTable t : partialDependence._partial_dependence_data) {
                Log.info((Object[])new Object[]{t});
            }
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateBinaryPickCols() {
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            DKV.put((Keyed)fr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "CAPSULE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._cols = new String[]{"DPROS", "GLEASON"};
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            for (TwoDimTable t : partialDependence._partial_dependence_data) {
                Log.info((Object[])new Object[]{t});
            }
            Assert.assertTrue((partialDependence._partial_dependence_data.length == 2 ? 1 : 0) != 0);
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateRegressionWeighted() {
        Scope.enter();
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        PartialDependence partialDependenceW = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            Scope.track((Frame[])new Frame[]{fr});
            Vec orig = fr.anyVec();
            Vec[] weights = new Vec[]{orig.makeCon(2.0)};
            fr.add(new String[]{"weights"}, weights);
            Scope.track((Vec)orig);
            Scope.track((Vec)weights[0]);
            DKV.put((Keyed)fr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "AGE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            Scope.track_generic((Keyed)model);
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence._cols = new String[]{"AGE", "RACE"};
            partialDependence.execImpl().get();
            Scope.track_generic((Keyed)partialDependence);
            partialDependenceW = new PartialDependence(Key.make());
            partialDependenceW._nbins = 10;
            partialDependenceW._model_id = model._key;
            partialDependenceW._frame_id = fr._key;
            partialDependenceW._weight_column_index = fr.numCols() - 1;
            partialDependenceW._cols = new String[]{"AGE", "RACE"};
            partialDependenceW.execImpl().get();
            Scope.track_generic((Keyed)partialDependenceW);
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[0], (TwoDimTable)partialDependenceW._partial_dependence_data[0], (double)1.0E-10)) : "pdp with constant weight and without weight generated different answers for column AGE.";
            assert (PartialDependenceTest.equalTwoDimTables((TwoDimTable)partialDependence._partial_dependence_data[1], (TwoDimTable)partialDependenceW._partial_dependence_data[1], (double)1.0E-10)) : "pdp with constant weight and without weight generated different answers for column RACE.";
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void prostateRegression() {
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/prostate/prostate.csv");
            for (String s : new String[]{"RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE"}) {
                Vec v = fr.remove(s);
                fr.add(s, v.toCategoricalVec());
                v.remove();
            }
            DKV.put((Keyed)fr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"ID"};
            parms._response_column = "AGE";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 10;
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            for (TwoDimTable t : partialDependence._partial_dependence_data) {
                Log.info((Object[])new Object[]{t});
            }
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void weatherBinary() {
        Frame fr = null;
        GBMModel model = null;
        PartialDependence partialDependence = null;
        try {
            fr = PartialDependenceTest.parse_test_file((String)"smalldata/junit/weather.csv");
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = fr._key;
            parms._ignored_columns = new String[]{"Date", "RISK_MM", "EvapMM"};
            parms._response_column = "RainTomorrow";
            model = (GBMModel)new GBM(parms).trainModel().get();
            partialDependence = new PartialDependence(Key.make());
            partialDependence._nbins = 33;
            partialDependence._cols = new String[]{"Sunshine", "MaxWindPeriod", "WindSpeed9am"};
            partialDependence._model_id = model._key;
            partialDependence._frame_id = fr._key;
            partialDependence.execImpl().get();
            for (TwoDimTable t : partialDependence._partial_dependence_data) {
                Log.info((Object[])new Object[]{t});
            }
        }
        finally {
            if (fr != null) {
                fr.remove();
            }
            if (model != null) {
                model.remove();
            }
            if (partialDependence != null) {
                partialDependence.remove();
            }
        }
    }
}

