package mikera.vectorz;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.List;
import mikera.transformz.TestTransformz;
import mikera.util.Maths;
import mikera.util.Rand;
import mikera.vectorz.ops.Clamp;
import mikera.vectorz.ops.Composed;
import mikera.vectorz.ops.Constant;
import mikera.vectorz.ops.GaussianNoise;
import mikera.vectorz.ops.Identity;
import mikera.vectorz.ops.Linear;
import mikera.vectorz.ops.Logistic;
import mikera.vectorz.ops.Offset;
import mikera.vectorz.ops.Power;
import mikera.vectorz.ops.Quadratic;
import mikera.vectorz.ops.StochasticBinary;
import mikera.vectorz.util.VectorzException;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:mikera/vectorz/TestOps.class */
public class TestOps {
    public static final List<Op> ALL_OPS = new ArrayList();

    @Test
    public void testComposedOp() {
        Op compose = Composed.compose(Linear.create(2.0d, 1.0d), Linear.create(100.0d, 10.0d));
        Vector of = Vector.of(new double[]{1.0d, 2.0d});
        of.applyOp(compose);
        Assert.assertEquals(221.0d, of.get(0), 0.0d);
    }

    @Test
    public void testLogistic() {
        Op op = Ops.LOGISTIC;
        Assert.assertEquals(0.0d, op.apply(-1000.0d), 1.0E-4d);
        Assert.assertEquals(0.5d, op.apply(0.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, op.apply(1000.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.derivative(-1000.0d), 1.0E-4d);
        Assert.assertEquals(0.25d, op.derivative(0.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.derivative(1000.0d), 1.0E-4d);
    }

    @Test
    public void testTanh() {
        Op op = Ops.TANH;
        Assert.assertEquals(-1.0d, op.apply(-1000.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.apply(0.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, op.apply(1000.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.derivative(-1000.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, op.derivative(0.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.derivative(1000.0d), 1.0E-4d);
    }

    @Test
    public void testLog10() {
        Op op = Ops.LOG10;
        Assert.assertEquals(1.0d, op.apply(10.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.apply(1.0d), 1.0E-4d);
        Assert.assertEquals(-1.0d, op.apply(0.1d), 1.0E-4d);
        Assert.assertEquals(10.0d, op.applyInverse(1.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, op.applyInverse(0.0d), 1.0E-4d);
        Assert.assertEquals(0.1d, op.applyInverse(-1.0d), 1.0E-4d);
    }

    @Test
    public void testAbs() {
        Op op = Ops.ABS;
        Vector of = Vector.of(new double[]{-1.0d, 2.0d, -3.0d});
        op.applyTo(of);
        Assert.assertEquals(Vector.of(new double[]{1.0d, 2.0d, 3.0d}), of);
    }

    @Test
    public void testFunctions() {
        Assert.assertEquals(1.0d, Ops.ABS.apply(-1.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.SIN.apply(1.5707963267948966d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.COS.apply(0.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.TAN.apply(0.7853981633974483d), 1.0E-4d);
        Assert.assertEquals(9.0d, Ops.SQUARE.apply(3.0d), 1.0E-4d);
        Assert.assertEquals(3.0d, Ops.SQRT.apply(9.0d), 1.0E-4d);
        Assert.assertEquals(3.0d, Ops.CEIL.apply(2.1d), 1.0E-4d);
        Assert.assertEquals(-3.0d, Ops.FLOOR.apply(-2.1d), 1.0E-4d);
        Assert.assertEquals(-3.0d, Ops.RINT.apply(-3.4d), 1.0E-4d);
        Assert.assertEquals(2.0d, Ops.CBRT.apply(8.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, Ops.SIGNUM.apply(801.0d), 1.0E-4d);
        Assert.assertEquals(0.25d, Ops.RECIPROCAL.apply(4.0d), 1.0E-4d);
        Assert.assertEquals(1.25d, Ops.LINEAR.apply(1.25d), 1.0E-4d);
        Assert.assertEquals(1.25d, Ops.IDENTITY.apply(1.25d), 1.0E-4d);
        Assert.assertEquals(2.718281828459045d, Ops.EXP.apply(1.0d), 1.0E-4d);
        Assert.assertEquals(2.0d, Ops.LOG.apply(7.3890560989306495d), 1.0E-4d);
        Assert.assertEquals(3.0d, Ops.LOG10.apply(1000.0d), 1.0E-4d);
        Assert.assertEquals(0.5d, Ops.LOGISTIC.apply(0.0d), 1.0E-4d);
    }

    @Test
    public void testSoftplus() {
        Op op = Ops.SOFTPLUS;
        Assert.assertEquals(0.0d, op.apply(-1000.0d), 1.0E-4d);
        Assert.assertEquals(Math.log(2.0d), op.apply(0.0d), 1.0E-4d);
        Assert.assertEquals(1000.0d, op.apply(1000.0d), 1.0E-4d);
        Assert.assertEquals(0.0d, op.derivative(-1000.0d), 1.0E-4d);
        Assert.assertEquals(0.5d, op.derivative(0.0d), 1.0E-4d);
        Assert.assertEquals(1.0d, op.derivative(1000.0d), 1.0E-4d);
    }

    @Test
    public void testLinear() {
        Assert.assertNotNull(Ops.LINEAR);
    }

    private void testApply(Op op) {
        double apply = op.apply(Rand.nextGaussian() * 1000.0d);
        if (Double.isNaN(apply)) {
            return;
        }
        Assert.assertTrue(apply <= op.maxValue());
        Assert.assertTrue(apply >= op.minValue());
    }

    private void testVectorApply(Op op) {
        if (op.isStochastic() || op.isDomainBounded()) {
            return;
        }
        Vector createLength = Vector.createLength(10);
        Vectorz.fillGaussian(createLength);
        Vector clone = createLength.clone();
        Vector clone2 = createLength.clone();
        op.applyTo(clone);
        clone2.applyOp(op);
        Assert.assertEquals(clone, clone2);
        Vector createLength2 = Vector.createLength(10);
        op.getTransform(10).transform(createLength, createLength2);
        Assert.assertEquals(clone, createLength2);
        double[] dArr = new double[10];
        double[] dArr2 = new double[10];
        createLength.getElements(dArr, 0);
        createLength.getElements(dArr2, 0);
        op.applyTo(dArr);
        op.applyTo(dArr2, 0, dArr2.length);
        Assert.assertTrue(clone2.equalsArray(dArr2));
        Assert.assertTrue(clone.equalsArray(dArr));
    }

    private void testTransforms(Op op) {
        if (op.isStochastic()) {
            return;
        }
        TestTransformz.doTransformTests(op.getTransform(1));
        TestTransformz.doTransformTests(op.getTransform(10));
    }

    private void testBounds(Op op) {
        if (op.isBounded()) {
            double minValue = op.minValue();
            double maxValue = op.maxValue();
            double averageValue = op.averageValue();
            Assert.assertTrue(minValue <= averageValue);
            Assert.assertTrue(averageValue <= maxValue);
            for (int i = 0; i < 100; i++) {
                double nextGaussian = Rand.nextGaussian() * 1000.0d;
                if (op.isDomainBounded()) {
                    nextGaussian = Maths.bound(nextGaussian, op.minDomain(), op.maxDomain());
                }
                double apply = op.apply(nextGaussian);
                Assert.assertTrue(apply <= maxValue);
                Assert.assertTrue(apply >= minValue);
            }
        }
    }

    private void testDerivative(Op op) {
        double nextGaussian = Rand.nextGaussian() * 100.0d;
        if (nextGaussian < op.minDomain()) {
            nextGaussian = -nextGaussian;
        }
        double apply = op.apply(nextGaussian);
        if (!op.hasDerivative()) {
            try {
                op.derivative(nextGaussian);
                Assert.fail("Derivative did not throw exception!");
            } catch (Throwable th) {
            }
            try {
                op.derivativeForOutput(nextGaussian);
                Assert.fail("Derivative for output did not throw exception!");
                return;
            } catch (Throwable th2) {
                return;
            }
        }
        op.derivative(nextGaussian);
        if (op.hasDerivativeForOutput()) {
            op.derivativeForOutput(apply);
        }
        Op derivativeOp = op.getDerivativeOp();
        if (Double.isNaN(nextGaussian) || Double.isNaN(apply) || op.isStochastic()) {
            return;
        }
        Assert.assertEquals(op.derivative(nextGaussian), derivativeOp.apply(nextGaussian), 1.0E-5d);
    }

    private void testInverse(Op op) {
        if (op.hasInverse()) {
            Op inverse = op.getInverse();
            double apply = inverse.apply(op.apply(Rand.nextGaussian()));
            Assert.assertEquals(apply, inverse.apply(op.apply(apply)), 1.0E-4d);
        } else {
            try {
                op.getInverse();
                Assert.fail("getInverse did not throw exception!");
            } catch (Throwable th) {
            }
        }
    }

    private void testStochastic(Op op) {
        if (op.isStochastic()) {
            return;
        }
        for (int i = 0; i < 30; i++) {
            double nextGaussian = Rand.nextGaussian() * 20.0d;
            double apply = op.apply(nextGaussian);
            for (int i2 = 0; i2 < 30; i2++) {
                Assert.assertEquals(apply, op.apply(nextGaussian), 0.0d);
            }
        }
    }

    private void testComposedDerivative(Op op, Op op2) {
        Op compose = op.compose(op2);
        if (compose.hasDerivative()) {
            double nextGaussian = Rand.nextGaussian();
            double apply = op2.apply(nextGaussian);
            double derivative = op2.derivative(nextGaussian);
            double derivative2 = op.derivative(apply);
            Assert.assertEquals(derivative2 == 0.0d ? 0.0d : derivative2 * derivative, compose.derivative(nextGaussian), 0.001d);
        }
    }

    private void doOpTest(Op op) {
        testApply(op);
        testInverse(op);
        testStochastic(op);
        testVectorApply(op);
        testTransforms(op);
        testBounds(op);
        testDerivative(op);
        TestTransformz.doITransformTests(op.getTransform(3));
    }

    private void doComposeTest(Op op, Op op2) {
        Op compose = op.compose(op2);
        doOpTest(compose);
        Op compose2 = Composed.compose(op, op2);
        doOpTest(compose2);
        if (compose2.isStochastic()) {
            return;
        }
        AVector createUniformRandomVector = Vectorz.createUniformRandomVector(10);
        AVector clone = createUniformRandomVector.clone();
        AVector clone2 = createUniformRandomVector.clone();
        createUniformRandomVector.applyOp(op2);
        createUniformRandomVector.applyOp(op);
        clone.applyOp(compose);
        clone2.applyOp(compose2);
        Assert.assertTrue(createUniformRandomVector.epsilonEquals(clone, 1.0E-5d));
        Assert.assertTrue(createUniformRandomVector.epsilonEquals(clone2, 1.0E-5d));
        testComposedDerivative(op, op2);
    }

    @Test
    public void genericTests() {
        doOpTest(Constant.create(5.0d));
        doOpTest(Linear.create(0.5d, 3.0d));
        doOpTest(Identity.INSTANCE);
        doOpTest(Offset.create(1.3d));
        doOpTest(Clamp.ZERO_TO_ONE);
        doOpTest(Ops.ABS);
        doOpTest(Ops.SIGNUM);
        doOpTest(Ops.CEIL);
        doOpTest(Ops.FLOOR);
        doOpTest(Ops.RINT);
        doOpTest(Ops.LINEAR);
        doOpTest(Ops.LOGISTIC);
        doOpTest(Ops.SCALED_LOGISTIC);
        doOpTest(Ops.STOCHASTIC_BINARY);
        doOpTest(Ops.STOCHASTIC_LOGISTIC);
        doOpTest(Ops.SOFTPLUS);
        doOpTest(Ops.RECTIFIER);
        doOpTest(Ops.RBF_NORMAL);
        doOpTest(Ops.SQRT);
        doOpTest(Ops.CBRT);
        doOpTest(Ops.SQUARE);
        doOpTest(Ops.TO_DEGREES);
        doOpTest(Ops.TO_RADIANS);
        doOpTest(Ops.EXP);
        doOpTest(Ops.LOG);
        doOpTest(Ops.LOG10);
        doOpTest(Ops.TANH);
        doOpTest(Ops.COSH);
        doOpTest(Ops.SINH);
        doOpTest(Ops.SIN);
        doOpTest(Ops.COS);
        doOpTest(Ops.TAN);
        doOpTest(Ops.ACOS);
        doOpTest(Ops.ASIN);
        doOpTest(Ops.ATAN);
        doOpTest(Power.create(0.5d));
        doOpTest(Power.create(1.0d));
        doOpTest(Power.create(2.0d));
        doOpTest(Power.create(3.2d));
        doOpTest(Power.create(-0.5d));
        doOpTest(Power.create(0.0d));
        doOpTest(Quadratic.create(2.0d, 3.0d, 4.0d));
        doOpTest(Quadratic.create(0.0d, 3.0d, 4.0d));
        doOpTest(Ops.LINEAR.product(Quadratic.create(0.0d, 3.0d, 4.0d)));
        doOpTest(Ops.LINEAR.divide(Quadratic.create(0.0d, 3.0d, 4.0d)));
        doComposeTest(Linear.create(0.31d, 0.12d), Linear.create(-100.0d, 11.0d));
        doComposeTest(StochasticBinary.INSTANCE, GaussianNoise.create(2.0d));
        doComposeTest(Logistic.INSTANCE, Linear.create(10.0d, -0.2d));
    }

    static {
        for (Field field : Ops.class.getDeclaredFields()) {
            if (Modifier.isStatic(field.getModifiers()) && Op.class.isAssignableFrom(field.getType())) {
                try {
                    ALL_OPS.add((Op) field.get(null));
                } catch (Throwable th) {
                    throw new VectorzException("Problem analysing Ops", th);
                }
            }
        }
        ALL_OPS.add(Constant.ONE);
    }
}
