package hex.deeplearning;

import hex.deeplearning.Storage;
import java.util.Arrays;
import java.util.Random;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.TestUtil;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/deeplearning/DropoutTest.class */
public class DropoutTest extends TestUtil {
    static final /* synthetic */ boolean $assertionsDisabled;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void test() throws Exception {
        Storage.DenseVector denseVector = new Storage.DenseVector(1000);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < 10000; i++) {
            long nextLong = new Random().nextLong();
            Dropout dropout = new Dropout(1000, 0.3d);
            Arrays.fill(denseVector.raw(), 1.0d);
            dropout.randomlySparsifyActivation(denseVector, nextLong);
            d += ArrayUtils.sum(denseVector.raw());
            Dropout dropout2 = new Dropout(1000, 0.0d);
            Arrays.fill(denseVector.raw(), 1.0d);
            dropout2.randomlySparsifyActivation(denseVector, nextLong + 1);
            d2 += ArrayUtils.sum(denseVector.raw());
            Dropout dropout3 = new Dropout(1000, 1.0d);
            Arrays.fill(denseVector.raw(), 1.0d);
            dropout3.randomlySparsifyActivation(denseVector, nextLong + 2);
            d3 += ArrayUtils.sum(denseVector.raw());
            Dropout dropout4 = new Dropout(1000, 0.314d);
            dropout4.fillBytes(nextLong + 3);
            for (int i2 = 0; i2 < 1000; i2++) {
                if (dropout4.unit_active(i2)) {
                    d4 += 1.0d;
                    if (!$assertionsDisabled && !dropout4.unit_active(i2)) {
                        throw new AssertionError();
                    }
                } else if (!$assertionsDisabled && dropout4.unit_active(i2)) {
                    throw new AssertionError();
                }
            }
        }
        double d5 = d2 / 10000.0d;
        double d6 = d3 / 10000.0d;
        double d7 = d4 / 10000.0d;
        Assert.assertTrue(Math.abs((d / 10000.0d) - 700.0d) < 1.0d);
        Assert.assertTrue(d5 == 1000.0d);
        Assert.assertTrue(d6 == 0.0d);
        Assert.assertTrue(Math.abs(d7 - 686.0d) < 1.0d);
    }

    static {
        $assertionsDisabled = !DropoutTest.class.desiredAssertionStatus();
    }
}
