/*
 * Decompiled with CFR 0.152.
 */
package hex.deeplearning;

import hex.Model;
import hex.ModelMetrics;
import hex.ScoringInfo;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningMLPReference;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningModelInfo;
import hex.deeplearning.DeepLearningScoringInfo;
import hex.deeplearning.DeepLearningTask;
import hex.deeplearning.Neurons;
import hex.genmodel.GenModel;
import java.util.Random;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Keyed;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.MathUtils;
import water.util.RandomUtils;

public class DeepLearningIrisTest
extends TestUtil {
    static final String PATH = "smalldata/iris/iris.csv";
    Frame _train;
    Frame _test;

    @BeforeClass
    public static void setup() {
        DeepLearningIrisTest.stall_till_cloudsize((int)1);
    }

    @Test
    public void run() throws Exception {
        this.runFraction(0.05f);
    }

    private void compareVal(double a, double b, double abseps, double releps) {
        if (!MathUtils.compare((double)a, (double)b, (double)abseps, (double)releps)) {
            Assert.assertEquals((String)"Not equal: ", (double)a, (double)b, (double)0.0);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void runFraction(float fraction) {
        long seed0 = 912559L;
        int num_runs = 0;
        Frame frame = null;
        try {
            frame = DeepLearningIrisTest.parse_test_file((Key)Key.make((String)"iris.hex"), (String)PATH);
            for (int repeat = 0; repeat < 5; ++repeat) {
                DeepLearningModel.DeepLearningParameters.Activation[] activations = new DeepLearningModel.DeepLearningParameters.Activation[]{DeepLearningModel.DeepLearningParameters.Activation.Tanh, DeepLearningModel.DeepLearningParameters.Activation.Rectifier};
                DeepLearningModel.DeepLearningParameters.Loss[] losses = new DeepLearningModel.DeepLearningParameters.Loss[]{DeepLearningModel.DeepLearningParameters.Loss.Quadratic, DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy};
                DeepLearningModel.DeepLearningParameters.InitialWeightDistribution[] dists = new DeepLearningModel.DeepLearningParameters.InitialWeightDistribution[]{DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.Normal, DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.Uniform, DeepLearningModel.DeepLearningParameters.InitialWeightDistribution.UniformAdaptive};
                long seed = seed0 + (long)repeat;
                Random rng = new Random(seed);
                double[] initial_weight_scales = new double[]{1.0E-4 + rng.nextDouble()};
                double[] holdout_ratios = new double[]{0.1 + rng.nextDouble() * 0.8};
                double[] momenta = new double[]{rng.nextDouble() * 0.99};
                int[] hiddens = new int[]{1, 2 + rng.nextInt(50)};
                int[] epochs = new int[]{1, 2 + rng.nextInt(50)};
                double[] rates = new double[]{0.01, 1.0E-5 + rng.nextDouble() * 0.1};
                for (DeepLearningModel.DeepLearningParameters.Activation activation : activations) {
                    for (DeepLearningModel.DeepLearningParameters.Loss loss : losses) {
                        for (DeepLearningModel.DeepLearningParameters.InitialWeightDistribution dist : dists) {
                            for (double scale : initial_weight_scales) {
                                for (double holdout_ratio : holdout_ratios) {
                                    for (double momentum : momenta) {
                                        for (int hidden : hiddens) {
                                            for (int epoch : epochs) {
                                                for (double rate : rates) {
                                                    DeepLearningModel mymodel = null;
                                                    Frame trainPredict = null;
                                                    Frame testPredict = null;
                                                    try {
                                                        double bb;
                                                        double ba;
                                                        double b;
                                                        double a;
                                                        int o;
                                                        int o2;
                                                        ++num_runs;
                                                        if (fraction < rng.nextFloat()) continue;
                                                        Log.info((Object[])new Object[]{""});
                                                        Log.info((Object[])new Object[]{"STARTING."});
                                                        Log.info((Object[])new Object[]{"Running with " + activation.name() + " activation function and " + loss.name() + " loss function."});
                                                        Log.info((Object[])new Object[]{"Initialization with " + dist.name() + " distribution and " + scale + " scale, holdout ratio " + holdout_ratio});
                                                        Log.info((Object[])new Object[]{"Using " + hidden + " hidden layer neurons and momentum: " + momentum});
                                                        Log.info((Object[])new Object[]{"Using seed " + seed});
                                                        int trial = 0;
                                                        do {
                                                            Log.info((Object[])new Object[]{"Trial #" + ++trial});
                                                            if (this._train != null) {
                                                                this._train.delete();
                                                            }
                                                            if (this._test != null) {
                                                                this._test.delete();
                                                            }
                                                            Random rand = RandomUtils.getRNG((long[])new long[]{seed});
                                                            double[][] rows = new double[(int)frame.numRows()][frame.numCols()];
                                                            String[] names = new String[frame.numCols()];
                                                            for (int c = 0; c < frame.numCols(); ++c) {
                                                                names[c] = "ColumnName" + c;
                                                                int r = 0;
                                                                while ((long)r < frame.numRows()) {
                                                                    rows[r][c] = frame.vecs()[c].at((long)r);
                                                                    ++r;
                                                                }
                                                            }
                                                            for (int i = rows.length - 1; i >= 0; --i) {
                                                                int shuffle = rand.nextInt(i + 1);
                                                                double[] row = rows[shuffle];
                                                                rows[shuffle] = rows[i];
                                                                rows[i] = row;
                                                            }
                                                            int limit = (int)((double)frame.numRows() * holdout_ratio);
                                                            this._train = ArrayUtils.frame((String[])names, (double[][])((double[][])ArrayUtils.subarray((Object[])rows, (int)0, (int)limit)));
                                                            this._test = ArrayUtils.frame((String[])names, (double[][])((double[][])ArrayUtils.subarray((Object[])rows, (int)limit, (int)((int)frame.numRows() - limit))));
                                                            String respname = this._train.lastVecName();
                                                            Vec resp = this._train.lastVec().toCategoricalVec();
                                                            this._train.remove(respname).remove();
                                                            this._train.add(respname, resp);
                                                            DKV.put((Keyed)this._train);
                                                            Vec vresp = this._test.lastVec().toCategoricalVec();
                                                            this._test.remove(respname).remove();
                                                            this._test.add(respname, vresp);
                                                            DKV.put((Keyed)this._test);
                                                        } while (this._train.lastVec().cardinality() < 3);
                                                        DeepLearningMLPReference ref = new DeepLearningMLPReference();
                                                        ref.init(activation, RandomUtils.getRNG((long[])new long[]{seed}), holdout_ratio, hidden);
                                                        DeepLearningModel.DeepLearningParameters p = new DeepLearningModel.DeepLearningParameters();
                                                        p._train = this._train._key;
                                                        p._response_column = this._train.lastVecName();
                                                        assert (this._train.lastVec().isCategorical());
                                                        p._ignored_columns = null;
                                                        p._seed = seed;
                                                        p._hidden = new int[]{hidden};
                                                        p._adaptive_rate = false;
                                                        p._rho = 0.0;
                                                        p._epsilon = 0.0;
                                                        p._rate = rate / (1.0 - momentum);
                                                        p._activation = activation;
                                                        p._max_w2 = Float.POSITIVE_INFINITY;
                                                        p._input_dropout_ratio = 0.0;
                                                        p._rate_annealing = 0.0;
                                                        p._l1 = 0.0;
                                                        p._loss = loss;
                                                        p._l2 = 0.0;
                                                        p._momentum_start = p._momentum_stable = momentum;
                                                        p._momentum_ramp = 0.0;
                                                        p._initial_weight_distribution = dist;
                                                        p._initial_weight_scale = scale;
                                                        p._valid = null;
                                                        p._quiet_mode = true;
                                                        p._fast_mode = false;
                                                        p._nesterov_accelerated_gradient = false;
                                                        p._train_samples_per_iteration = 0L;
                                                        p._ignore_const_cols = false;
                                                        p._shuffle_training_data = false;
                                                        p._classification_stop = -1.0;
                                                        p._force_load_balance = false;
                                                        p._overwrite_with_best_model = false;
                                                        p._replicate_training_data = false;
                                                        p._mini_batch_size = 1;
                                                        p._single_node_mode = true;
                                                        p._epochs = 0.0;
                                                        p._elastic_averaging = false;
                                                        mymodel = (DeepLearningModel)new DeepLearning(p).trainModel().get();
                                                        p._epochs = epoch;
                                                        Neurons[] neurons = DeepLearningTask.makeNeuronsForTraining((DeepLearningModelInfo)mymodel.model_info());
                                                        Neurons l = neurons[1];
                                                        for (o2 = 0; o2 < l._a[0].size(); ++o2) {
                                                            for (int i = 0; i < l._previous._a[0].size(); ++i) {
                                                                ref._nn.ihWeights[i][o2] = l._w.get(o2, i);
                                                            }
                                                            ref._nn.hBiases[o2] = l._b.get(o2);
                                                        }
                                                        l = neurons[2];
                                                        for (o2 = 0; o2 < l._a[0].size(); ++o2) {
                                                            for (int i = 0; i < l._previous._a[0].size(); ++i) {
                                                                ref._nn.hoWeights[i][o2] = l._w.get(o2, i);
                                                            }
                                                            ref._nn.oBiases[o2] = l._b.get(o2);
                                                        }
                                                        ref.train((int)p._epochs, rate, p._momentum_stable, loss, seed);
                                                        mymodel.delete();
                                                        DeepLearning dl = new DeepLearning(p);
                                                        mymodel = (DeepLearningModel)dl.trainModel().get();
                                                        Assert.assertTrue((mymodel.model_info().get_processed_total() == (long)epoch * dl.train().numRows() ? 1 : 0) != 0);
                                                        double abseps = 1.0E-6;
                                                        double releps = 1.0E-6;
                                                        neurons = DeepLearningTask.makeNeuronsForTesting((DeepLearningModelInfo)mymodel.model_info());
                                                        l = neurons[1];
                                                        for (o = 0; o < l._a[0].size(); ++o) {
                                                            for (int i = 0; i < l._previous._a[0].size(); ++i) {
                                                                a = ref._nn.ihWeights[i][o];
                                                                b = l._w.get(o, i);
                                                                this.compareVal(a, b, 1.0E-6, 1.0E-6);
                                                            }
                                                            ba = ref._nn.hBiases[o];
                                                            bb = l._b.get(o);
                                                            this.compareVal(ba, bb, 1.0E-6, 1.0E-6);
                                                        }
                                                        Log.info((Object[])new Object[]{"Weights and biases for hidden layer: PASS"});
                                                        l = neurons[2];
                                                        for (o = 0; o < l._a[0].size(); ++o) {
                                                            for (int i = 0; i < l._previous._a[0].size(); ++i) {
                                                                a = ref._nn.hoWeights[i][o];
                                                                b = l._w.get(o, i);
                                                                this.compareVal(a, b, 1.0E-6, 1.0E-6);
                                                            }
                                                            ba = ref._nn.oBiases[o];
                                                            bb = l._b.get(o);
                                                            this.compareVal(ba, bb, 1.0E-6, 1.0E-6);
                                                        }
                                                        Log.info((Object[])new Object[]{"Weights and biases for output layer: PASS"});
                                                        Frame fpreds = mymodel.score(this._test);
                                                        try {
                                                            int i = 0;
                                                            while ((long)i < this._test.numRows()) {
                                                                double[] xValues = new double[neurons[0]._a[0].size()];
                                                                System.arraycopy(ref._testData[i], 0, xValues, 0, xValues.length);
                                                                double[] ref_preds = ref._nn.ComputeOutputs(xValues);
                                                                double[] preds = new double[ref_preds.length + 1];
                                                                for (int j = 0; j < ref_preds.length; ++j) {
                                                                    preds[j + 1] = ref_preds[j];
                                                                }
                                                                preds[0] = GenModel.getPrediction((double[])preds, null, (double[])xValues, (double)0.5);
                                                                Assert.assertTrue((preds[0] == (double)((int)fpreds.vecs()[0].at((long)i)) ? 1 : 0) != 0);
                                                                ++i;
                                                            }
                                                        }
                                                        finally {
                                                            if (fpreds != null) {
                                                                fpreds.delete();
                                                            }
                                                        }
                                                        Log.info((Object[])new Object[]{"Predicted values: PASS"});
                                                        double trainErr = ref._nn.Accuracy(ref._trainData);
                                                        double testErr = ref._nn.Accuracy(ref._testData);
                                                        trainPredict = mymodel.score(this._train);
                                                        testPredict = mymodel.score(this._test);
                                                        ModelMetrics mmtrain = ModelMetrics.getFromDKV((Model)mymodel, (Frame)this._train);
                                                        ModelMetrics mmtest = ModelMetrics.getFromDKV((Model)mymodel, (Frame)this._test);
                                                        double myTrainErr = mmtrain.cm().err();
                                                        double myTestErr = mmtest.cm().err();
                                                        Log.info((Object[])new Object[]{"H2O  training error : " + myTrainErr * 100.0 + "%, test error: " + myTestErr * 100.0 + "%"});
                                                        Log.info((Object[])new Object[]{"REF  training error : " + trainErr * 100.0 + "%, test error: " + testErr * 100.0 + "%"});
                                                        this.compareVal(trainErr, myTrainErr, 1.0E-6, 1.0E-6);
                                                        this.compareVal(testErr, myTestErr, 1.0E-6, 1.0E-6);
                                                        Log.info((Object[])new Object[]{"Scoring: PASS"});
                                                        float best_err = Float.MAX_VALUE;
                                                        for (ScoringInfo e : mymodel.scoring_history()) {
                                                            DeepLearningScoringInfo err = (DeepLearningScoringInfo)e;
                                                            best_err = Math.min(best_err, (float)(Double.isNaN(err.scored_train._classError) ? (double)best_err : err.scored_train._classError));
                                                        }
                                                        Log.info((Object[])new Object[]{"Actual best error : " + best_err * 100.0f + "%."});
                                                        if (p._overwrite_with_best_model) {
                                                            Frame bestPredict = null;
                                                            try {
                                                                bestPredict = mymodel.score(this._train);
                                                                ModelMetrics mmbest = ModelMetrics.getFromDKV((Model)mymodel, (Frame)this._train);
                                                                double bestErr = mmbest.cm().err();
                                                                Log.info((Object[])new Object[]{"Best_model's error : " + bestErr * 100.0 + "%."});
                                                                this.compareVal(bestErr, best_err, 1.0E-6, 1.0E-6);
                                                            }
                                                            finally {
                                                                if (bestPredict != null) {
                                                                    bestPredict.delete();
                                                                }
                                                            }
                                                        }
                                                        Log.info((Object[])new Object[]{"Parameters combination " + num_runs + ": PASS"});
                                                    }
                                                    finally {
                                                        if (mymodel != null) {
                                                            mymodel.delete();
                                                        }
                                                        if (this._train != null) {
                                                            this._train.delete();
                                                        }
                                                        if (this._test != null) {
                                                            this._test.delete();
                                                        }
                                                        if (trainPredict != null) {
                                                            trainPredict.delete();
                                                        }
                                                        if (testPredict != null) {
                                                            testPredict.delete();
                                                        }
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        catch (Throwable t) {
            t.printStackTrace();
            throw new RuntimeException(t);
        }
        finally {
            if (frame != null) {
                frame.delete();
            }
        }
    }

    public static class Short
    extends DeepLearningIrisTest {
        @Override
        @Test
        @Ignore
        public void run() throws Exception {
            this.runFraction(0.05f);
        }
    }

    public static class Long
    extends DeepLearningIrisTest {
        @Override
        @Test
        @Ignore
        public void run() throws Exception {
            this.runFraction(0.1f);
        }
    }
}

