/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.BlendingParams;
import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderFrameHelper;
import hex.ModelMetricsBinomial;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.Arrays;
import java.util.Map;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMapGeneric;
import water.util.Log;

@Ignore(value="Ignoring benchmark tests")
public class TargetEncodingAirlinesBenchmark
extends TestUtil {
    @BeforeClass
    public static void setup() {
        TargetEncodingAirlinesBenchmark.stall_till_cloudsize((int)1);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Test
    public void KFoldHoldoutTypeTest() {
        block3: {
            Scope.enter();
            GBMModel gbm = null;
            IcedHashMapGeneric encodingMap = null;
            try {
                Frame airlinesTrainWithTEH = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_train"), (String)"smalldata/airlines/target_encoding/airlines_train_with_teh.csv");
                Frame airlinesValid = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_valid"), (String)"smalldata/airlines/target_encoding/airlines_valid.csv");
                Frame airlinesTestFrame = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_test"), (String)"smalldata/airlines/target_encoding/airlines_test.csv");
                Scope.track((Frame[])new Frame[]{airlinesTrainWithTEH, airlinesValid, airlinesTestFrame});
                long startTimeEncoding = System.currentTimeMillis();
                String foldColumnName = "fold";
                TargetEncoderFrameHelper.addKFoldColumn((Frame)airlinesTrainWithTEH, (String)foldColumnName, (int)5, (long)1234L);
                BlendingParams params = new BlendingParams(5.0, 1.0);
                String[] teColumns = new String[]{"Origin", "Dest"};
                TargetEncoder tec = new TargetEncoder(teColumns);
                String targetColumnName = "IsDepDelayed";
                boolean withBlendedAvg = true;
                boolean withNoiseOnlyForTraining = true;
                boolean withImputationForNAsInOriginalColumns = true;
                encodingMap = tec.prepareEncodingMap(airlinesTrainWithTEH, targetColumnName, foldColumnName, true);
                int seed = 1234;
                int seedForGBM = 1234;
                Frame trainEncoded = withNoiseOnlyForTraining ? tec.applyTargetEncoding(airlinesTrainWithTEH, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.KFold, foldColumnName, withBlendedAvg, withImputationForNAsInOriginalColumns, params, (long)seed) : tec.applyTargetEncoding(airlinesTrainWithTEH, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.KFold, foldColumnName, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, (long)seed);
                Frame validEncoded = tec.applyTargetEncoding(airlinesValid, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.None, foldColumnName, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, (long)seed);
                Frame testEncoded = tec.applyTargetEncoding(airlinesTestFrame, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.None, foldColumnName, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, (long)seed);
                this.printOutColumnsMetadata(testEncoded);
                testEncoded = tec.ensureTargetColumnIsBinaryCategorical(testEncoded, targetColumnName);
                Scope.track((Frame[])new Frame[]{trainEncoded, validEncoded, testEncoded});
                long finishTimeEncoding = System.currentTimeMillis();
                System.out.println("Calculation of encodings took: " + (finishTimeEncoding - startTimeEncoding));
                long startTime = System.currentTimeMillis();
                GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
                parms._train = trainEncoded._key;
                parms._response_column = targetColumnName;
                parms._score_tree_interval = 10;
                parms._ntrees = 1000;
                parms._max_depth = 5;
                parms._distribution = DistributionFamily.AUTO;
                parms._valid = validEncoded._key;
                parms._stopping_tolerance = 0.001;
                parms._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
                parms._stopping_rounds = 5;
                parms._ignored_columns = TargetEncodingAirlinesBenchmark.concat(new String[]{"IsDepDelayed_REC", foldColumnName}, teColumns);
                parms._seed = seedForGBM;
                GBM job = new GBM(parms);
                gbm = (GBMModel)job.trainModel().get();
                Assert.assertTrue((boolean)job.isStopped());
                long finishTime = System.currentTimeMillis();
                System.out.println("Calculation took: " + (finishTime - startTime));
                Frame preds = gbm.score(testEncoded);
                Scope.track((Frame[])new Frame[]{preds});
                ModelMetricsBinomial mm = ModelMetricsBinomial.make((Vec)preds.vec(2), (Vec)testEncoded.vec(parms._response_column));
                double auc = mm._auc._auc;
                double auc2 = this.trainDefaultGBM(targetColumnName, tec);
                System.out.println("AUC with encoding:" + auc);
                System.out.println("AUC without encoding:" + auc2);
                Assert.assertTrue((auc2 < auc ? 1 : 0) != 0);
                this.encodingMapCleanUp((Map<String, Frame>)encodingMap);
                if (gbm == null) break block3;
            }
            catch (Throwable throwable) {
                this.encodingMapCleanUp((Map<String, Frame>)encodingMap);
                if (gbm != null) {
                    gbm.delete();
                    gbm.deleteCrossValidationModels();
                }
                Scope.exit((Key[])new Key[0]);
                throw throwable;
            }
            gbm.delete();
            gbm.deleteCrossValidationModels();
        }
        Scope.exit((Key[])new Key[0]);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void noneHoldoutTypeTest() {
        Scope.enter();
        try {
            Frame airlinesTrainWithoutTEH = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_train"), (String)"smalldata/airlines/target_encoding/airlines_train_without_teh.csv");
            Frame airlinesTEHoldout = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_te_holdout"), (String)"smalldata/airlines/target_encoding/airlines_te_holdout.csv");
            Frame airlinesValid = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_valid"), (String)"smalldata/airlines/target_encoding/airlines_valid.csv");
            Frame airlinesTestFrame = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_test"), (String)"smalldata/airlines/AirlinesTest.csv.zip");
            Scope.track((Frame[])new Frame[]{airlinesTrainWithoutTEH, airlinesTEHoldout, airlinesValid, airlinesTestFrame});
            long startTimeEncoding = System.currentTimeMillis();
            BlendingParams params = new BlendingParams(3.0, 1.0);
            String[] teColumns = new String[]{"Origin", "Dest"};
            TargetEncoder tec = new TargetEncoder(teColumns);
            String targetColumnName = "IsDepDelayed";
            boolean withBlendedAvg = true;
            boolean withImputationForNAsInOriginalColumns = true;
            IcedHashMapGeneric encodingMap = tec.prepareEncodingMap(airlinesTEHoldout, targetColumnName, null);
            Frame trainEncoded = tec.applyTargetEncoding(airlinesTrainWithoutTEH, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.None, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, 1234L);
            Frame validEncoded = tec.applyTargetEncoding(airlinesValid, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.None, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, 1234L);
            Frame testEncoded = tec.applyTargetEncoding(airlinesTestFrame, targetColumnName, (Map)encodingMap, TargetEncoder.DataLeakageHandlingStrategy.None, withBlendedAvg, 0.0, withImputationForNAsInOriginalColumns, params, 1234L);
            testEncoded = tec.ensureTargetColumnIsBinaryCategorical(testEncoded, targetColumnName);
            Scope.track((Frame[])new Frame[]{trainEncoded, validEncoded, testEncoded});
            long finishTimeEncoding = System.currentTimeMillis();
            System.out.println("Calculation of encodings took: " + (finishTimeEncoding - startTimeEncoding));
            this.checkNumRows(airlinesTrainWithoutTEH, trainEncoded);
            this.checkNumRows(airlinesValid, validEncoded);
            this.checkNumRows(airlinesTestFrame, testEncoded);
            long startTime = System.currentTimeMillis();
            GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
            parms._train = trainEncoded._key;
            parms._response_column = "IsDepDelayed";
            parms._score_tree_interval = 10;
            parms._ntrees = 1000;
            parms._max_depth = 5;
            parms._distribution = DistributionFamily.AUTO;
            parms._valid = validEncoded._key;
            parms._stopping_tolerance = 0.001;
            parms._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            parms._stopping_rounds = 5;
            parms._ignored_columns = TargetEncodingAirlinesBenchmark.concat(new String[]{"IsDepDelayed_REC"}, teColumns);
            parms._seed = 1234L;
            GBM job = new GBM(parms);
            GBMModel gbm = (GBMModel)job.trainModel().get();
            Assert.assertTrue((boolean)job.isStopped());
            long finishTime = System.currentTimeMillis();
            System.out.println("Calculation took: " + (finishTime - startTime));
            Frame preds = gbm.score(testEncoded);
            Scope.track((Frame[])new Frame[]{preds});
            ModelMetricsBinomial mm = ModelMetricsBinomial.make((Vec)preds.vec(2), (Vec)testEncoded.vec(parms._response_column));
            double auc = mm._auc._auc;
            double auc2 = this.trainDefaultGBM(targetColumnName, tec);
            System.out.println("AUC with encoding:" + auc);
            System.out.println("AUC without encoding:" + auc2);
            this.encodingMapCleanUp((Map<String, Frame>)encodingMap);
            if (gbm != null) {
                gbm.delete();
                gbm.deleteCrossValidationModels();
            }
            Assert.assertTrue((auc2 < auc ? 1 : 0) != 0);
        }
        finally {
            Scope.exit((Key[])new Key[0]);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private double trainDefaultGBM(String targetColumnName, TargetEncoder tec) {
        GBMModel gbm2 = null;
        Scope.enter();
        try {
            double auc2;
            Frame airlinesTrainWithTEHDefault = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_train_d"), (String)"smalldata/airlines/target_encoding/airlines_train_with_teh.csv");
            Frame airlinesValidDefault = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_valid_d"), (String)"smalldata/airlines/target_encoding/airlines_valid.csv");
            Frame airlinesTestFrameDefault = TargetEncodingAirlinesBenchmark.parse_test_file((Key)Key.make((String)"airlines_test_d"), (String)"smalldata/airlines/AirlinesTest.csv.zip");
            Scope.track((Frame[])new Frame[]{airlinesTrainWithTEHDefault, airlinesValidDefault, airlinesTestFrameDefault});
            airlinesTrainWithTEHDefault = tec.ensureTargetColumnIsBinaryCategorical(airlinesTrainWithTEHDefault, targetColumnName);
            airlinesValidDefault = tec.ensureTargetColumnIsBinaryCategorical(airlinesValidDefault, targetColumnName);
            airlinesTestFrameDefault = tec.ensureTargetColumnIsBinaryCategorical(airlinesTestFrameDefault, targetColumnName);
            GBMModel.GBMParameters parms2 = new GBMModel.GBMParameters();
            parms2._train = airlinesTrainWithTEHDefault._key;
            parms2._response_column = targetColumnName;
            parms2._score_tree_interval = 10;
            parms2._ntrees = 1000;
            parms2._max_depth = 5;
            parms2._distribution = DistributionFamily.AUTO;
            parms2._valid = airlinesValidDefault._key;
            parms2._stopping_tolerance = 0.001;
            parms2._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
            parms2._stopping_rounds = 5;
            parms2._ignored_columns = new String[]{"IsDepDelayed_REC"};
            parms2._seed = 1234L;
            GBM job2 = new GBM(parms2);
            gbm2 = (GBMModel)job2.trainModel().get();
            Assert.assertTrue((boolean)job2.isStopped());
            Frame preds2 = gbm2.score(airlinesTestFrameDefault);
            Scope.track((Frame[])new Frame[]{preds2});
            ModelMetricsBinomial mm2 = ModelMetricsBinomial.make((Vec)preds2.vec(2), (Vec)airlinesTestFrameDefault.vec(parms2._response_column));
            double d = auc2 = mm2._auc._auc;
            return d;
        }
        finally {
            if (gbm2 != null) {
                gbm2.delete();
                gbm2.deleteCrossValidationModels();
            }
            Scope.exit((Key[])new Key[0]);
        }
    }

    public void checkNumRows(Frame before, Frame after) {
        long droppedCount = before.numRows() - after.numRows();
        if (droppedCount != 0L) {
            Log.warn((Object[])new Object[]{String.format("Number of rows has dropped by %d after manipulations with frame ( %s , %s ).", droppedCount, before._key, after._key)});
        }
    }

    private void encodingMapCleanUp(Map<String, Frame> encodingMap) {
        for (Map.Entry<String, Frame> map : encodingMap.entrySet()) {
            map.getValue().delete();
        }
    }

    public static <T> T[] concat(T[] first, T[] second) {
        T[] result = Arrays.copyOf(first, first.length + second.length);
        System.arraycopy(second, 0, result, first.length, second.length);
        return result;
    }
}

