package mikera.vectorz;

import mikera.util.Rand;
import mikera.vectorz.ops.Constant;
import mikera.vectorz.ops.Linear;
import mikera.vectorz.ops.Power;
import mikera.vectorz.ops.Quadratic;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:mikera/vectorz/TestSpecialOps.class */
public class TestSpecialOps {
    private void testDerivativesAt(Op op, double... dArr) {
        for (double d : dArr) {
            testDerivativeAt(op, d);
        }
    }

    @Test
    public void testSinh() {
        Assert.assertEquals(0.0d, Ops.SINH.apply(0.0d), 0.0d);
    }

    private void testDerivativeAt(Op op, double d) {
        double derivative = op.derivative(d);
        double apply = (op.apply(d + 1.0E-5d) - op.apply(d - 1.0E-5d)) / (2.0d * 1.0E-5d);
        Assert.assertEquals(1.0d, derivative == 0.0d ? apply + 1.0d : apply / derivative, 0.01d);
    }

    @Test
    public void testOp() {
        double[] dArr = new double[10];
        dArr[0] = 1000.0d;
        Ops.LOGISTIC.applyTo(dArr);
        Assert.assertEquals(1.0d, dArr[0], 0.0010000000474974513d);
        new Op[]{Ops.LINEAR}[0].applyTo(dArr);
        Assert.assertEquals(1.0d, dArr[0], 0.0010000000474974513d);
    }

    @Test
    public void testDerivatives() {
        Assert.assertEquals(0.0d, Ops.LOGISTIC.derivativeForOutput(1.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.LOGISTIC.derivativeForOutput(0.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.LOGISTIC.derivative(-100.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.LOGISTIC.derivative(100.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SCALED_LOGISTIC.derivativeForOutput(1.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SCALED_LOGISTIC.derivativeForOutput(0.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SCALED_LOGISTIC.derivative(-100.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SCALED_LOGISTIC.derivative(100.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.SOFTPLUS.derivativeForOutput(100.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SOFTPLUS.derivativeForOutput(0.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.SOFTPLUS.derivative(100.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, Ops.SOFTPLUS.derivative(-100.0d), 1.0E-4d);
        for (int i = 0; i < 10; i++) {
            double nextDouble = Rand.nextDouble();
            Assert.assertEquals(1.0d, Ops.LINEAR.derivativeForOutput(nextDouble), 1.0E-4d);
            Assert.assertEquals(Ops.STOCHASTIC_LOGISTIC.derivativeForOutput(nextDouble), Ops.LOGISTIC.derivativeForOutput(nextDouble), 1.0E-4d);
        }
        testDerivativesAt(Ops.LINEAR, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d, 100.0d, -100.0d);
        testDerivativesAt(Ops.LOGISTIC, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d, 100.0d, -100.0d);
        testDerivativesAt(Ops.EXP, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d, 100.0d, -100.0d);
        testDerivativesAt(Ops.LOG, 0.1d, 1.0d, 10.0d, 100.0d, 1000.0d);
        testDerivativesAt(Ops.TANH, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d, 100.0d, -100.0d);
        testDerivativesAt(Ops.SOFTPLUS, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Quadratic.create(1.0d, 2.0d, 3.0d), 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Linear.create(-11.0d, 2.0d), 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.RECIPROCAL, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.SQRT, 0.001d, 0.1d, 1.0d, 10.0d, 100.0d, 45654.0d);
        testDerivativesAt(Ops.SIN, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.COS, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.TAN, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.NEGATE, 0.0d, 0.1d, -0.1d, 1.0d, -1.0d, 10.0d, -10.0d);
        testDerivativesAt(Ops.SIN.compose(Ops.EXP), 0.1d, -0.1d, 1.0d, -1.0d, 2.0d, -2.0d, 3.0d, -3.0d);
        testDerivativesAt(Ops.COS.product(Ops.SOFTPLUS), 0.1d, -0.1d, 1.0d, -1.0d, 2.0d, -2.0d, 3.0d, -3.0d);
        testDerivativesAt(Ops.TANH.sum(Ops.SQUARE), 0.1d, -0.1d, 1.0d, -1.0d, 2.0d, -2.0d, 3.0d, -3.0d);
        testDerivativesAt(Ops.ACOS, 0.0d, 0.1d, -0.1d, 0.99d, -0.99d);
        testDerivativesAt(Ops.ASIN, 0.0d, 0.1d, -0.1d, 0.99d, -0.99d);
        testDerivativesAt(Ops.ATAN, 0.0d, 0.1d, -0.1d, 0.99d, -0.99d);
        testDerivativesAt(Power.create(0.2d), 0.1d, 1.0d, 2.0d, 3.0d, 10.0d);
        testDerivativesAt(Power.create(1.4d), 0.1d, 1.0d, 2.0d, 3.0d, 10.0d);
        testDerivativesAt(Power.create(-1.4d), 0.1d, 1.0d, 2.0d, 3.0d, 10.0d);
    }

    @Test
    public void testCompositions() {
        Assert.assertEquals(Ops.SIN, Ops.compose(Linear.create(1.0d, 0.0d), Ops.SIN));
        Assert.assertEquals(Ops.SIN, Ops.compose(Linear.create(0.5d, 0.0d), Ops.compose(Linear.create(2.0d, 0.0d), Ops.SIN)));
    }

    @Test
    public void testDerivativeChains() {
        Op op = Ops.SIN;
        Assert.assertTrue(op.getDerivativeOp().getDerivativeOp().getDerivativeOp().getDerivativeOp() == op);
        Op op2 = Ops.COS;
        Assert.assertTrue(op2.getDerivativeOp().getDerivativeOp().getDerivativeOp().getDerivativeOp() == op2);
        Assert.assertTrue(Ops.EXP.getDerivativeOp() == Ops.EXP);
        Op derivativeOp = Quadratic.create(Math.random(), Math.random(), Math.random()).getDerivativeOp().getDerivativeOp();
        Op derivativeOp2 = derivativeOp.getDerivativeOp();
        Assert.assertEquals(Constant.class, derivativeOp.getClass());
        Assert.assertEquals(0.0d, derivativeOp2.apply(Math.random()), 1.0E-5d);
        Assert.assertTrue(op2 == Constant.create(10.0d).sum(op).getDerivativeOp());
    }

    @Test
    public void testRange() {
        Assert.assertEquals(0.0d, Ops.LOGISTIC.minValue(), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.LOGISTIC.maxValue(), 1.0E-4d);
        Assert.assertEquals(-1.0d, Ops.TANH.minValue(), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.TANH.maxValue(), 1.0E-4d);
        Assert.assertEquals(0.0d, Power.create(0.3d).minDomain(), 0.0d);
    }

    @Test
    public void testAllOps() {
        testOp(Ops.LOGISTIC);
        testOp(Ops.LINEAR);
        testOp(Ops.STOCHASTIC_BINARY);
        testOp(Ops.STOCHASTIC_LOGISTIC);
        testOp(Ops.TANH);
    }

    public void testOp(Op op) {
        double[] dArr = new double[100];
        double[] dArr2 = new double[100];
        for (int i = 0; i < 100; i++) {
            dArr[i] = Rand.n(0.0d, 10.0d);
        }
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        op.applyTo(dArr2);
        for (int i2 = 0; i2 < 100; i2++) {
            Assert.assertTrue(dArr2[i2] <= op.maxValue());
            Assert.assertTrue(dArr2[i2] >= op.minValue());
        }
    }
}
