/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeSHAP;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.util.NaiveTreeSHAP;
import java.io.IOException;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

public class DRFPredictContribsTest
extends TestUtil {
    @BeforeClass
    public static void stall() {
        DRFPredictContribsTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testPredictContribsGaussian() {
        try {
            Scope.enter();
            Frame fr = Scope.track((Frame[])new Frame[]{DRFPredictContribsTest.parse_test_file((String)"smalldata/junit/titanic_alt.csv")});
            DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
            parms._train = fr._key;
            parms._distribution = DistributionFamily.gaussian;
            parms._response_column = "age";
            parms._ntrees = 5;
            parms._max_depth = 4;
            parms._min_rows = 1.0;
            parms._nbins = 50;
            parms._score_each_iteration = true;
            parms._seed = 42L;
            DRF job = new DRF(parms);
            DRFModel drf = (DRFModel)job.trainModel().get();
            Scope.track_generic((Keyed)drf);
            Frame adapted = new Frame(fr);
            drf.adaptTestForTrain(adapted, true, false);
            for (int i = 0; i < parms._ntrees; ++i) {
                new CheckTreeSHAPTask(drf, i).doAll(adapted);
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testScoreContributionsGaussian() throws IOException, PredictException {
        try {
            Scope.enter();
            Frame fr = Scope.track((Frame[])new Frame[]{DRFPredictContribsTest.parse_test_file((String)"smalldata/junit/titanic_alt.csv")});
            DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
            parms._train = fr._key;
            parms._distribution = DistributionFamily.gaussian;
            parms._response_column = "age";
            parms._ntrees = 5;
            parms._max_depth = 4;
            parms._min_rows = 1.0;
            parms._nbins = 50;
            parms._score_each_iteration = true;
            parms._seed = 42L;
            DRF job = new DRF(parms);
            DRFModel drf = (DRFModel)job.trainModel().get();
            Scope.track_generic((Keyed)drf);
            Frame contributions = drf.scoreContributions(fr, Key.make((String)"contributions_regression_titanic"));
            Scope.track((Frame[])new Frame[]{contributions});
            Frame contribsAggregated = ((RowSumTask)new RowSumTask().doAll((byte)3, contributions)).outputFrame();
            Scope.track((Frame[])new Frame[]{contribsAggregated});
            TestCase.assertTrue((boolean)drf.testJavaScoring(fr, contribsAggregated, 1.0E-5));
            EasyPredictModelWrapper.Config cfg = new EasyPredictModelWrapper.Config().setModel((GenModel)drf.toMojo()).setEnableContributions(true);
            EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(cfg);
            for (long row = 0L; row < fr.numRows(); ++row) {
                RowData rd = DRFPredictContribsTest.toRowData((Frame)fr, (String[])((DRFModel.DRFOutput)drf._output)._names, (long)row);
                RegressionModelPrediction pr = wrapper.predictRegression(rd);
                for (int c = 0; c < contributions.numCols(); ++c) {
                    Assert.assertArrayEquals((String)("Contributions should match, row=" + row), (double[])DRFPredictContribsTest.toNumericRow((Frame)contributions, (long)row), (double[])ArrayUtils.toDouble((float[])pr.contributions), (double)0.0);
                }
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testPredictContribsBinomial() {
        try {
            Scope.enter();
            Frame fr = Scope.track((Frame[])new Frame[]{DRFPredictContribsTest.parse_test_file((String)"smalldata/junit/titanic_alt.csv")});
            int ci = fr.find("survived");
            fr.toCategoricalCol(ci);
            DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
            parms._train = fr._key;
            parms._distribution = DistributionFamily.bernoulli;
            parms._response_column = "survived";
            parms._ntrees = 5;
            parms._max_depth = 4;
            parms._min_rows = 1.0;
            parms._nbins = 50;
            parms._score_each_iteration = true;
            parms._seed = 42L;
            DRF job = new DRF(parms);
            DRFModel drf = (DRFModel)job.trainModel().get();
            Scope.track_generic((Keyed)drf);
            Frame adapted = new Frame(fr);
            drf.adaptTestForTrain(adapted, true, false);
            for (int i = 0; i < parms._ntrees; ++i) {
                new CheckTreeSHAPTask(drf, i).doAll(adapted);
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testScoreContributionsBinomial() throws IOException, PredictException {
        try {
            Scope.enter();
            Frame fr = Scope.track((Frame[])new Frame[]{DRFPredictContribsTest.parse_test_file((String)"smalldata/junit/titanic_alt.csv")});
            int ci = fr.find("survived");
            fr.toCategoricalCol(ci);
            DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
            parms._train = fr._key;
            parms._distribution = DistributionFamily.bernoulli;
            parms._response_column = "survived";
            parms._ntrees = 5;
            parms._max_depth = 4;
            parms._min_rows = 1.0;
            parms._nbins = 50;
            parms._score_each_iteration = true;
            parms._seed = 42L;
            DRF job = new DRF(parms);
            DRFModel drf = (DRFModel)job.trainModel().get();
            Scope.track_generic((Keyed)drf);
            Frame contributions = drf.scoreContributions(fr, Key.make((String)"contributions_binomial_titanic"));
            Scope.track((Frame[])new Frame[]{contributions});
            Frame contribsAggregated = ((RowSumTask)new RowSumTask().doAll((byte)3, contributions)).outputFrame();
            Scope.track((Frame[])new Frame[]{contribsAggregated});
            TestCase.assertTrue((boolean)drf.testJavaScoring(fr, contribsAggregated, 1.0E-5));
            EasyPredictModelWrapper.Config cfg = new EasyPredictModelWrapper.Config().setModel((GenModel)drf.toMojo()).setEnableContributions(true);
            EasyPredictModelWrapper wrapper = new EasyPredictModelWrapper(cfg);
            for (long row = 0L; row < fr.numRows(); ++row) {
                RowData rd = DRFPredictContribsTest.toRowData((Frame)fr, (String[])((DRFModel.DRFOutput)drf._output)._names, (long)row);
                BinomialModelPrediction pr = wrapper.predictBinomial(rd);
                for (int c = 0; c < contributions.numCols(); ++c) {
                    Assert.assertArrayEquals((String)("Contributions should match, row=" + row), (double[])DRFPredictContribsTest.toNumericRow((Frame)contributions, (long)row), (double[])ArrayUtils.toDouble((float[])pr.contributions), (double)0.0);
                }
            }
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    private static class RowSumTask
    extends MRTask<RowSumTask> {
        private RowSumTask() {
        }

        public void map(Chunk[] cs, NewChunk nc) {
            for (int i = 0; i < cs[0]._len; ++i) {
                double sum = 0.0;
                for (Chunk c : cs) {
                    sum += c.atd(i);
                }
                nc.addNum(sum);
            }
        }
    }

    private static class CheckTreeSHAPTask
    extends MRTask<CheckTreeSHAPTask> {
        final DRFModel _model;
        final int _tree;
        transient SharedTreeNode[] _nodes;

        private CheckTreeSHAPTask(DRFModel model, int tree) {
            this._model = model;
            this._tree = tree;
        }

        protected void setupLocal() {
            SharedTreeSubgraph tree = this._model.getSharedTreeSubgraph(this._tree, 0);
            this._nodes = tree.nodesArray.toArray(new SharedTreeNode[0]);
        }

        public void map(Chunk[] cs) {
            TreeSHAP treeSHAP = new TreeSHAP((INode[])this._nodes, (INodeStat[])this._nodes, 0);
            NaiveTreeSHAP naiveTreeSHAP = new NaiveTreeSHAP((INode[])this._nodes, (INodeStat[])this._nodes, 0, 0.0);
            double[] row = MemoryManager.malloc8d((int)cs.length);
            float[] contribs = MemoryManager.malloc4f((int)cs.length);
            double[] naiveContribs = MemoryManager.malloc8d((int)cs.length);
            for (int i = 0; i < cs[0]._len; ++i) {
                for (int j = 0; j < cs.length; ++j) {
                    row[j] = cs[j].atd(i);
                    contribs[j] = 0.0f;
                    naiveContribs[j] = 0.0;
                }
                treeSHAP.calculateContributions((Object)row, contribs);
                double expValPred = naiveTreeSHAP.calculateContributions(row, naiveContribs);
                double contribPred = ArrayUtils.sum((double[])naiveContribs);
                Assert.assertEquals((double)expValPred, (double)contribPred, (double)1.0E-5);
                Assert.assertArrayEquals((double[])naiveContribs, (double[])ArrayUtils.toDouble((float[])contribs), (double)1.0E-5);
            }
        }
    }
}

