/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.ast.prims.advmath.AstKFold;
import water.util.ArrayUtils;

public class CrossValidFoldAssignmentsTest
extends TestUtil {
    @BeforeClass
    public static void setup() {
        CrossValidFoldAssignmentsTest.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void checkFoldAssignmentsAreKeptWithoutMakeCopy() {
        int nfolds = 3;
        Frame tfr = null;
        Frame cvFoldAssignmentFrame = null;
        Frame foldId = null;
        GBMModel gbm = null;
        try {
            tfr = CrossValidFoldAssignmentsTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn((Vec)tfr.vec("class").makeZero(), (int)3, (long)543216789L)});
            tfr.add(foldId);
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "class";
            parms._ntrees = 1;
            parms._max_depth = 1;
            parms._fold_column = "foldId";
            parms._distribution = DistributionFamily.multinomial;
            parms._keep_cross_validation_predictions = false;
            parms._keep_cross_validation_fold_assignment = true;
            GBM job = new GBM(parms);
            gbm = (GBMModel)job.trainModel().get();
            TestCase.assertNotNull((Object)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            cvFoldAssignmentFrame = (Frame)DKV.getGet((Key)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            Assert.assertEquals((long)tfr.numRows(), (long)cvFoldAssignmentFrame.numRows());
            CrossValidFoldAssignmentsTest.assertBitIdentical((Frame)foldId, (Frame)cvFoldAssignmentFrame);
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (gbm != null) {
                gbm.delete();
                gbm.deleteCrossValidationModels();
            }
            if (cvFoldAssignmentFrame != null) {
                cvFoldAssignmentFrame.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void checkFoldAssignmentsAreBeingRemovedAsSideEffectOfRemovingTrainingFrame() {
        int nfolds = 3;
        Frame tfr = null;
        Frame cvFoldAssignmentFrame = null;
        Frame foldId = null;
        GBMModel gbm = null;
        try {
            tfr = CrossValidFoldAssignmentsTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            foldId = new Frame(new String[]{"foldId"}, new Vec[]{AstKFold.kfoldColumn((Vec)tfr.vec("class").makeZero(), (int)3, (long)543216789L)});
            tfr.add(foldId);
            DKV.put((Keyed)tfr);
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "class";
            parms._ntrees = 1;
            parms._max_depth = 1;
            parms._fold_column = "foldId";
            parms._distribution = DistributionFamily.multinomial;
            parms._keep_cross_validation_predictions = false;
            parms._keep_cross_validation_fold_assignment = true;
            GBM job = new GBM(parms);
            gbm = (GBMModel)job.trainModel().get();
            tfr.delete();
            TestCase.assertNotNull((Object)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            cvFoldAssignmentFrame = (Frame)DKV.getGet((Key)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            Assert.assertNull((Object)DKV.get((Key)cvFoldAssignmentFrame.vec((String)"fold_assignment")._key));
        }
        finally {
            if (gbm != null) {
                gbm.delete();
                gbm.deleteCrossValidationModels();
            }
            if (cvFoldAssignmentFrame != null) {
                cvFoldAssignmentFrame.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void checkImplicitFoldAssignmentsAreKeptWithoutMakeCopy() {
        int nfolds = 3;
        Frame tfr = null;
        Frame cvFoldAssignmentFrame = null;
        GBMModel gbm = null;
        try {
            tfr = CrossValidFoldAssignmentsTest.parse_test_file((String)"smalldata/iris/iris_wheader.csv");
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = tfr._key;
            parms._response_column = "class";
            parms._ntrees = 1;
            parms._max_depth = 1;
            parms._nfolds = 3;
            parms._distribution = DistributionFamily.multinomial;
            parms._keep_cross_validation_predictions = false;
            parms._keep_cross_validation_fold_assignment = true;
            GBM job = new GBM(parms);
            gbm = (GBMModel)job.trainModel().get();
            TestCase.assertNotNull((Object)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            cvFoldAssignmentFrame = (Frame)DKV.getGet((Key)((GBMModel.GBMOutput)gbm._output)._cross_validation_fold_assignment_frame_id);
            TestCase.assertNotNull((Object)cvFoldAssignmentFrame);
            Assert.assertEquals((long)tfr.numRows(), (long)cvFoldAssignmentFrame.numRows());
            Assert.assertEquals((long)tfr.numRows(), (long)ArrayUtils.sum((long[])((CheckFoldTask)new CheckFoldTask(3).doAll(cvFoldAssignmentFrame))._foldCnt));
        }
        finally {
            if (tfr != null) {
                tfr.remove();
            }
            if (gbm != null) {
                gbm.delete();
                gbm.deleteCrossValidationModels();
            }
            if (cvFoldAssignmentFrame != null) {
                cvFoldAssignmentFrame.delete();
            }
        }
    }

    private static class CheckFoldTask
    extends MRTask<CheckFoldTask> {
        private long[] _foldCnt;

        private CheckFoldTask(int nfolds) {
            this._foldCnt = new long[nfolds];
        }

        public void map(Chunk c) {
            for (int i = 0; i < c._len; ++i) {
                double val = c.atd(i);
                if ((double)((int)val) != val || val < 0.0 || val > (double)this._foldCnt.length) {
                    throw new IllegalStateException("Unexpected value: " + val);
                }
                int n = (int)val;
                this._foldCnt[n] = this._foldCnt[n] + 1L;
            }
        }

        public void reduce(CheckFoldTask mrt) {
            this._foldCnt = ArrayUtils.add((long[])this._foldCnt, (long[])mrt._foldCnt);
        }
    }
}

