/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.tree.CompressedTree;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.ModelSerializationTest;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;

public class DRFCheckpointTest
extends TestUtil {
    @BeforeClass
    public static void stall() {
        DRFCheckpointTest.stall_till_cloudsize((int)1);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial() {
        this.testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Multinomial2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial() {
        this.testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3);
    }

    @Test
    public void testCheckpointReconstruction4Binomial2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 1, 1);
    }

    @Test(expected=H2OIllegalArgumentException.class)
    public void testCheckpointWrongParams() {
        this.testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f);
    }

    @Test
    public void testCheckpointReconstruction4Regression() {
        this.testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 4, 3);
    }

    @Test
    public void testCheckpointReconstruction4Regression2() {
        this.testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 4, 3);
    }

    private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel) {
        this.testCheckPointReconstruction(dataset, responseIdx, classification, ntreesInPriorModel, ntreesInNewModel, 0.632f, 0.632f);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel, float sampleRateInPriorModel, float sampleRateInNewModel) {
        Frame f = DRFCheckpointTest.parse_test_file((String)dataset);
        Vec v = f.remove("economy");
        if (v != null) {
            v.remove();
        }
        DKV.put((Keyed)f);
        if (classification) {
            Vec respVec = f.vec(responseIdx);
            f.replace(responseIdx, respVec.toCategoricalVec()).remove();
            DKV.put((Key)f._key, (Iced)f);
        }
        DRFModel model = null;
        DRFModel modelFromCheckpoint = null;
        DRFModel modelFinal = null;
        try {
            DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
            drfParams._train = f._key;
            drfParams._response_column = f.name(responseIdx);
            drfParams._ntrees = ntreesInPriorModel;
            drfParams._seed = 42L;
            drfParams._max_depth = 10;
            drfParams._score_each_iteration = true;
            drfParams._sample_rate = sampleRateInPriorModel;
            model = (DRFModel)new DRF(drfParams, Key.make((String)"Initial model")).trainModel().get();
            DRFModel.DRFParameters drfFromCheckpointParams = new DRFModel.DRFParameters();
            drfFromCheckpointParams._train = f._key;
            drfFromCheckpointParams._response_column = f.name(responseIdx);
            drfFromCheckpointParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
            drfFromCheckpointParams._seed = 42L;
            drfFromCheckpointParams._checkpoint = model._key;
            drfFromCheckpointParams._score_each_iteration = true;
            drfFromCheckpointParams._max_depth = 10;
            drfFromCheckpointParams._sample_rate = sampleRateInNewModel;
            modelFromCheckpoint = (DRFModel)new DRF(drfFromCheckpointParams, Key.make((String)"Model from checkpoint")).trainModel().get();
            DRFModel.DRFParameters drfFinalParams = new DRFModel.DRFParameters();
            drfFinalParams._train = f._key;
            drfFinalParams._response_column = f.name(responseIdx);
            drfFinalParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
            drfFinalParams._seed = 42L;
            drfFinalParams._score_each_iteration = true;
            drfFinalParams._max_depth = 10;
            drfFinalParams._sample_rate = sampleRateInPriorModel;
            modelFinal = (DRFModel)new DRF(drfFinalParams, Key.make((String)"Validation model")).trainModel().get();
            CompressedTree[][] treesFromCheckpoint = ModelSerializationTest.getTrees((SharedTreeModel)modelFromCheckpoint);
            CompressedTree[][] treesFromFinalModel = ModelSerializationTest.getTrees((SharedTreeModel)modelFinal);
            ModelSerializationTest.assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", treesFromCheckpoint, treesFromFinalModel, true);
            for (int tree = 0; tree < treesFromCheckpoint.length; ++tree) {
                for (int clazz = 0; clazz < treesFromCheckpoint[tree].length; ++clazz) {
                    if (treesFromCheckpoint[tree][clazz] == null) continue;
                    CompressedTree a = treesFromCheckpoint[tree][clazz];
                    CompressedTree b = treesFromFinalModel[tree][clazz];
                    Assert.assertNotEquals((Object)a._key, (Object)b._key);
                }
            }
        }
        finally {
            if (f != null) {
                f.delete();
            }
            if (model != null) {
                model.delete();
            }
            if (modelFromCheckpoint != null) {
                modelFromCheckpoint.delete();
            }
            if (modelFinal != null) {
                modelFinal.delete();
            }
        }
    }
}

