/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.ModelSerializationTest;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.VecUtils;

public class GBMCheckpointTest
extends TestUtil {
    @BeforeClass
    public static void stall() {
        GBMCheckpointTest.stall_till_cloudsize((int)1);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial() {
        this.testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial() {
        this.testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 2, 2);
    }

    @Test(expected=H2OIllegalArgumentException.class)
    @Ignore
    public void testCheckpointWrongParams() {
        this.testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f);
    }

    @Test
    public void testCheckpointReconstruction4Regression() {
        this.testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Regression2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 5, 3);
    }

    private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel) {
        this.testCheckPointReconstruction(dataset, responseIdx, classification, ntreesInPriorModel, ntreesInNewModel, 0.632f, 0.632f);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel, float sampleRateInPriorModel, float sampleRateInNewModel) {
        Frame f = GBMCheckpointTest.parse_test_file((String)dataset);
        Vec v = f.remove("economy");
        if (v != null) {
            v.remove();
        }
        DKV.put((Keyed)f);
        if (classification) {
            Vec respVec = f.vec(responseIdx);
            f.replace(responseIdx, VecUtils.toCategoricalVec((Vec)respVec)).remove();
            DKV.put((Key)f._key, (Iced)f);
        }
        GBMModel model = null;
        GBMModel modelFromCheckpoint = null;
        GBMModel modelFinal = null;
        try {
            GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
            gbmParams._train = f._key;
            gbmParams._response_column = f.name(responseIdx);
            gbmParams._ntrees = ntreesInPriorModel;
            gbmParams._seed = 42L;
            gbmParams._max_depth = 5;
            gbmParams._learn_rate_annealing = 0.9;
            gbmParams._score_each_iteration = true;
            model = (GBMModel)new GBM(gbmParams, Key.make((String)"Initial model")).trainModel().get();
            GBMModel.GBMParameters gbmFromCheckpointParams = new GBMModel.GBMParameters();
            gbmFromCheckpointParams._train = f._key;
            gbmFromCheckpointParams._response_column = f.name(responseIdx);
            gbmFromCheckpointParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
            gbmFromCheckpointParams._seed = 42L;
            gbmFromCheckpointParams._checkpoint = model._key;
            gbmFromCheckpointParams._score_each_iteration = true;
            gbmFromCheckpointParams._max_depth = 5;
            gbmFromCheckpointParams._learn_rate_annealing = 0.9;
            modelFromCheckpoint = (GBMModel)new GBM(gbmFromCheckpointParams, Key.make((String)"Model from checkpoint")).trainModel().get();
            GBMModel.GBMParameters gbmFinalParams = new GBMModel.GBMParameters();
            gbmFinalParams._train = f._key;
            gbmFinalParams._response_column = f.name(responseIdx);
            gbmFinalParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
            gbmFinalParams._seed = 42L;
            gbmFinalParams._score_each_iteration = true;
            gbmFinalParams._max_depth = 5;
            gbmFinalParams._learn_rate_annealing = 0.9;
            modelFinal = (GBMModel)new GBM(gbmFinalParams, Key.make((String)"Validation model")).trainModel().get();
            CompressedTree[][] treesFromCheckpoint = ModelSerializationTest.getTrees((SharedTreeModel)modelFromCheckpoint);
            CompressedTree[][] treesFromFinalModel = ModelSerializationTest.getTrees((SharedTreeModel)modelFinal);
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", treesFromCheckpoint, treesFromFinalModel, true);
            for (int tree = 0; tree < treesFromCheckpoint.length; ++tree) {
                for (int clazz = 0; clazz < treesFromCheckpoint[tree].length; ++clazz) {
                    if (treesFromCheckpoint[tree][clazz] == null) continue;
                    CompressedTree a = treesFromCheckpoint[tree][clazz];
                    CompressedTree b = treesFromFinalModel[tree][clazz];
                    Assert.assertNotEquals((Object)a._key, (Object)b._key);
                }
            }
        }
        finally {
            if (f != null) {
                f.delete();
            }
            if (model != null) {
                model.delete();
            }
            if (modelFromCheckpoint != null) {
                modelFromCheckpoint.delete();
            }
            if (modelFinal != null) {
                modelFinal.delete();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Ignore(value="PUBDEV-1829")
    public void testCheckpointReconstruction4BinomialPUBDEV1829() {
        Frame tr = GBMCheckpointTest.parse_test_file((String)"smalldata/jira/gbm_checkpoint_train.csv");
        Frame val = GBMCheckpointTest.parse_test_file((String)"smalldata/jira/gbm_checkpoint_valid.csv");
        Vec old = null;
        tr.remove("name").remove();
        tr.remove("economy").remove();
        val.remove("name").remove();
        val.remove("economy").remove();
        old = tr.remove("economy_20mpg");
        tr.add("economy_20mpg", old);
        DKV.put((Keyed)tr);
        old = val.remove("economy_20mpg");
        val.add("economy_20mpg", old);
        DKV.put((Keyed)val);
        GBMModel model = null;
        GBMModel modelFromCheckpoint = null;
        GBMModel modelFinal = null;
        try {
            GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
            gbmParams._train = tr._key;
            gbmParams._valid = val._key;
            gbmParams._response_column = "economy_20mpg";
            gbmParams._ntrees = 5;
            gbmParams._max_depth = 5;
            gbmParams._min_rows = 10.0;
            gbmParams._score_each_iteration = true;
            gbmParams._seed = 42L;
            model = (GBMModel)new GBM(gbmParams, Key.make((String)"Initial model")).trainModel().get();
            GBMModel.GBMParameters gbmFromCheckpointParams = new GBMModel.GBMParameters();
            gbmFromCheckpointParams._train = tr._key;
            gbmFromCheckpointParams._valid = val._key;
            gbmFromCheckpointParams._response_column = "economy_20mpg";
            gbmFromCheckpointParams._ntrees = 10;
            gbmFromCheckpointParams._checkpoint = model._key;
            gbmFromCheckpointParams._score_each_iteration = true;
            gbmFromCheckpointParams._max_depth = 5;
            gbmFromCheckpointParams._min_rows = 10.0;
            gbmFromCheckpointParams._seed = 42L;
            modelFromCheckpoint = (GBMModel)new GBM(gbmFromCheckpointParams, Key.make((String)"Model from checkpoint")).trainModel().get();
            GBMModel.GBMParameters gbmFinalParams = new GBMModel.GBMParameters();
            gbmFinalParams._train = tr._key;
            gbmFinalParams._valid = val._key;
            gbmFinalParams._response_column = "economy_20mpg";
            gbmFinalParams._ntrees = 10;
            gbmFinalParams._score_each_iteration = true;
            gbmFinalParams._max_depth = 5;
            gbmFinalParams._min_rows = 10.0;
            gbmFinalParams._seed = 42L;
            modelFinal = (GBMModel)new GBM(gbmFinalParams, Key.make((String)"Validation model")).trainModel().get();
            CompressedTree[][] treesFromCheckpoint = ModelSerializationTest.getTrees((SharedTreeModel)modelFromCheckpoint);
            CompressedTree[][] treesFromFinalModel = ModelSerializationTest.getTrees((SharedTreeModel)modelFinal);
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", treesFromCheckpoint, treesFromFinalModel, true);
        }
        finally {
            if (tr != null) {
                tr.delete();
            }
            if (val != null) {
                val.delete();
            }
            if (old != null) {
                old.remove();
            }
            if (model != null) {
                model.delete();
            }
            if (modelFromCheckpoint != null) {
                modelFromCheckpoint.delete();
            }
            if (modelFinal != null) {
                modelFinal.delete();
            }
        }
    }
}

