package hex.deeplearning;

import hex.deeplearning.DeepLearningModel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeMap;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.parser.ParseDataset;
import water.util.FrameUtils;
import water.util.Log;

/* loaded from: input_file:hex/deeplearning/DeepLearningReproducibilityTest.class */
public class DeepLearningReproducibilityTest extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    /* JADX WARN: Finally extract failed */
    @Test
    public void run() {
        Frame parse = ParseDataset.parse(Key.make("golden.hex"), new Key[]{TestUtil.makeNfsFileVec("smalldata/junit/weather.csv")._key});
        DeepLearningModel deepLearningModel = null;
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        TreeMap treeMap = new TreeMap();
        StringBuilder sb = new StringBuilder();
        float f = 0.0f;
        boolean[] zArr = {true, false};
        int length = zArr.length;
        for (int i = 0; i < length; i++) {
            boolean z = zArr[i];
            Scope.enter();
            Frame[] frameArr = new Frame[3];
            long[] jArr = new long[3];
            double[] dArr = new double[3];
            for (int i2 = 0; i2 < 3; i2++) {
                try {
                    frame3 = ParseDataset.parse(Key.make("data.hex"), new Key[]{TestUtil.makeNfsFileVec("smalldata/junit/weather.csv")._key});
                    TestUtil.assertBitIdentical(frame3, parse);
                    frame = frame3;
                    frame2 = frame3;
                    DeepLearningModel.DeepLearningParameters deepLearningParameters = new DeepLearningModel.DeepLearningParameters();
                    deepLearningParameters._train = frame._key;
                    deepLearningParameters._valid = frame2._key;
                    deepLearningParameters._response_column = frame.names()[frame.names().length - 1];
                    int length2 = frame.names().length - 1;
                    Scope.track(frame.replace(length2, frame.vecs()[length2].toCategoricalVec()));
                    DKV.put(frame);
                    deepLearningParameters._ignored_columns = new String[]{"EvapMM", "RISK_MM"};
                    deepLearningParameters._activation = DeepLearningModel.DeepLearningParameters.Activation.RectifierWithDropout;
                    deepLearningParameters._hidden = new int[]{32, 58};
                    deepLearningParameters._l1 = 1.0E-5d;
                    deepLearningParameters._l2 = 3.0E-5d;
                    deepLearningParameters._seed = 48830L;
                    deepLearningParameters._loss = DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy;
                    deepLearningParameters._input_dropout_ratio = 0.2d;
                    deepLearningParameters._train_samples_per_iteration = 3L;
                    deepLearningParameters._hidden_dropout_ratios = new double[]{0.4d, 0.1d};
                    deepLearningParameters._epochs = 1.32d;
                    deepLearningParameters._quiet_mode = true;
                    deepLearningParameters._reproducible = z;
                    deepLearningModel = (DeepLearningModel) new DeepLearning(deepLearningParameters).trainModel().get();
                    frameArr[i2] = deepLearningModel.score(frame2);
                    for (int i3 = 0; i3 < 5; i3++) {
                        Frame score = deepLearningModel.score(frame2);
                        assertBitIdentical(frameArr[i2], score);
                        score.delete();
                    }
                    Log.info(new Object[]{"Prediction:\n" + FrameUtils.chunkSummary(frameArr[i2]).toString()});
                    dArr[i2] = deepLearningModel.model_info().get_weights(0).get(23, 4);
                    jArr[i2] = deepLearningModel.model_info().checksum_impl();
                    treeMap.put(Integer.valueOf(i2), Float.valueOf(deepLearningModel.loss()));
                    if (deepLearningModel != null) {
                        deepLearningModel.delete();
                    }
                    if (frame != null) {
                        frame.delete();
                    }
                    if (frame2 != null) {
                        frame2.delete();
                    }
                    if (frame3 != null) {
                        frame3.delete();
                    }
                } catch (Throwable th) {
                    if (deepLearningModel != null) {
                        deepLearningModel.delete();
                    }
                    if (frame != null) {
                        frame.delete();
                    }
                    if (frame2 != null) {
                        frame2.delete();
                    }
                    if (frame3 != null) {
                        frame3.delete();
                    }
                    throw th;
                }
            }
            sb.append("Reproducibility: ").append(z ? "on" : "off").append("\n");
            sb.append("Repeat # --> Validation Loss\n");
            for (String str : Arrays.toString(treeMap.entrySet().toArray()).split(",")) {
                sb.append(str.replace("=", " --> ")).append("\n");
            }
            sb.append('\n');
            Log.info(new Object[]{sb.toString()});
            if (z) {
                try {
                    for (double d : dArr) {
                        Assert.assertTrue(Arrays.toString(dArr), d == dArr[0]);
                    }
                    Iterator it = treeMap.values().iterator();
                    while (it.hasNext()) {
                        Assert.assertTrue(((Float) it.next()).equals(treeMap.get(0)));
                    }
                    for (long j : jArr) {
                        Assert.assertTrue(j == jArr[0]);
                    }
                    for (Frame frame4 : frameArr) {
                        for (int i4 = 0; i4 < frame4.vecs().length; i4++) {
                            TestUtil.assertVecEquals(frame4.vecs()[i4], frameArr[0].vecs()[i4], 1.0E-5d);
                        }
                    }
                    f = ((Float) treeMap.get(0)).floatValue();
                } finally {
                    for (Frame frame5 : frameArr) {
                        if (frame5 != null) {
                            frame5.delete();
                        }
                    }
                }
            } else {
                double d2 = 0.0d;
                while (treeMap.values().iterator().hasNext()) {
                    d2 += ((Float) r0.next()).floatValue();
                }
                double d3 = d2 / 3;
                for (int i5 = 1; i5 < 3; i5++) {
                    Assert.assertTrue(treeMap.get(Integer.valueOf(i5)) != treeMap.get(0));
                }
                Log.info(new Object[]{"mean error: " + d3});
                double d4 = 0.0d;
                for (Float f2 : treeMap.values()) {
                    d4 += (f2.floatValue() - d3) * (f2.floatValue() - d3);
                }
                double sqrt = Math.sqrt(d4 / 3);
                Log.info(new Object[]{"standard deviation: " + sqrt});
                Log.info(new Object[]{"difference to reproducible mode: " + (Math.abs(d3 - f) / sqrt) + " standard deviations"});
            }
            Scope.exit(new Key[0]);
        }
        parse.delete();
    }
}
