/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.maxent.quasinewton.Function;
import opennlp.tools.ml.maxent.quasinewton.QNMinimizer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class QNMinimizerTest {
    @Test
    void testQuadraticFunction() {
        QNMinimizer minimizer = new QNMinimizer();
        QuadraticFunction f = new QuadraticFunction();
        double[] x = minimizer.minimize((Function)f);
        double minValue = f.valueAt(x);
        Assertions.assertEquals((double)1.0, (double)x[0], (double)1.0E-5);
        Assertions.assertEquals((double)5.0, (double)x[1], (double)1.0E-5);
        Assertions.assertEquals((double)10.0, (double)minValue, (double)1.0E-10);
    }

    @Test
    void testRosenbrockFunction() {
        QNMinimizer minimizer = new QNMinimizer();
        Rosenbrock f = new Rosenbrock();
        double[] x = minimizer.minimize((Function)f);
        double minValue = f.valueAt(x);
        Assertions.assertEquals((double)1.0, (double)x[0], (double)1.0E-5);
        Assertions.assertEquals((double)1.0, (double)x[1], (double)1.0E-5);
        Assertions.assertEquals((double)0.0, (double)minValue, (double)1.0E-10);
    }

    public static class QuadraticFunction
    implements Function {
        public int getDimension() {
            return 2;
        }

        public double valueAt(double[] x) {
            return StrictMath.pow(x[0] - 1.0, 2.0) + StrictMath.pow(x[1] - 5.0, 2.0) + 10.0;
        }

        public double[] gradientAt(double[] x) {
            return new double[]{2.0 * (x[0] - 1.0), 2.0 * (x[1] - 5.0)};
        }
    }

    public static class Rosenbrock
    implements Function {
        public int getDimension() {
            return 2;
        }

        public double valueAt(double[] x) {
            return StrictMath.pow(1.0 - x[0], 2.0) + 100.0 * StrictMath.pow(x[1] - StrictMath.pow(x[0], 2.0), 2.0);
        }

        public double[] gradientAt(double[] x) {
            double[] g = new double[]{-2.0 * (1.0 - x[0]) - 400.0 * (x[1] - StrictMath.pow(x[0], 2.0)) * x[0], 200.0 * (x[1] - StrictMath.pow(x[0], 2.0))};
            return g;
        }
    }
}

