/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Model;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.util.ArrayUtils;

public class XValPredictionsCheck
extends TestUtil {
    @BeforeClass
    public static void setup() {
        XValPredictionsCheck.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testXValPredictions() {
        int nfolds = 3;
        Frame tfr = null;
        try {
            tfr = XValPredictionsCheck.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            Frame foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn((Vec)tfr.vec("class").makeZero(), (int)3, (long)543216789L)});
            tfr.add(foldId);
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "class";
            parms._ntrees = 1;
            parms._max_depth = 1;
            parms._fold_column = "foldId";
            parms._distribution = DistributionFamily.multinomial;
            parms._keep_cross_validation_predictions = true;
            GBM job = new GBM(parms);
            GBMModel gbm = (GBMModel)job.trainModel().get();
            this.checkModel((Model)gbm, foldId.anyVec(), 3);
            DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
            parmsDRF._train = tfr._key;
            parmsDRF._response_column = "class";
            parmsDRF._ntrees = 1;
            parmsDRF._max_depth = 1;
            parmsDRF._fold_column = "foldId";
            parmsDRF._distribution = DistributionFamily.multinomial;
            parmsDRF._keep_cross_validation_predictions = true;
            DRF drfJob = new DRF(parmsDRF);
            DRFModel drf = (DRFModel)drfJob.trainModel().get();
            this.checkModel((Model)drf, foldId.anyVec(), 3);
            GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
            parmsGLM._train = tfr._key;
            parmsGLM._response_column = "sepal_len";
            parmsGLM._fold_column = "foldId";
            parmsGLM._keep_cross_validation_predictions = true;
            GLM glmJob = new GLM(parmsGLM);
            GLMModel glm = (GLMModel)glmJob.trainModel().get();
            this.checkModel((Model)glm, foldId.anyVec(), 1);
            DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
            parmsDL._train = tfr._key;
            parmsDL._response_column = "class";
            parmsDL._hidden = new int[]{1};
            parmsDL._epochs = 1.0;
            parmsDL._fold_column = "foldId";
            parmsDL._keep_cross_validation_predictions = true;
            DeepLearning dlJob = new DeepLearning(parmsDL);
            DeepLearningModel dl = (DeepLearningModel)dlJob.trainModel().get();
            this.checkModel((Model)dl, foldId.anyVec(), 3);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void checkModel(Model m, Vec foldId, int nclass) {
        if (!(m instanceof DRFModel)) {
            Assert.assertEquals((long)m._output._training_metrics._nobs, (long)m._output._cross_validation_metrics._nobs);
        }
        try {
            Key[] xvalKeys = m._output._cross_validation_predictions;
            Key xvalKey = m._output._cross_validation_holdout_predictions_frame_id;
            final int[] id = new int[1];
            for (Key k : xvalKeys) {
                Frame preds = (Frame)DKV.getGet((Key)k);
                assert (preds.numRows() == foldId.length());
                Vec[] vecs = new Vec[nclass + 1];
                vecs[0] = foldId;
                if (nclass == 1) {
                    vecs[1] = preds.anyVec();
                } else {
                    System.arraycopy(preds.vecs(ArrayUtils.range((int)1, (int)nclass)), 0, vecs, 1, nclass);
                }
                new MRTask(){

                    public void map(Chunk[] cs) {
                        Chunk foldId = cs[0];
                        for (int r = 0; r < cs[0]._len; ++r) {
                            if (foldId.at8(r) == (long)id[0]) continue;
                            for (int i = 1; i < cs.length; ++i) {
                                assert (cs[i].atd(r) == 0.0);
                            }
                        }
                    }
                }.doAll(vecs);
                id[0] = id[0] + 1;
                preds.delete();
            }
            Keyed.remove((Key)xvalKey);
        }
        finally {
            m.delete(true);
        }
    }
}

