package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.maxent.quasinewton.LineSearch;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/LineSearchTest.class */
public class LineSearchTest {
    private static final double TOLERANCE = 0.01d;

    /* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/LineSearchTest$QuadraticFunction1.class */
    public class QuadraticFunction1 implements Function {
        public QuadraticFunction1() {
        }

        public double valueAt(double[] dArr) {
            return StrictMath.pow(dArr[0] - 2.0d, 2.0d) + 4.0d;
        }

        public double[] gradientAt(double[] dArr) {
            return new double[]{2.0d * (dArr[0] - 2.0d)};
        }

        public int getDimension() {
            return 1;
        }
    }

    /* loaded from: input_file:opennlp/tools/ml/maxent/quasinewton/LineSearchTest$QuadraticFunction2.class */
    public class QuadraticFunction2 implements Function {
        public QuadraticFunction2() {
        }

        public double valueAt(double[] dArr) {
            return StrictMath.pow(dArr[0], 2.0d);
        }

        public double[] gradientAt(double[] dArr) {
            return new double[]{2.0d * dArr[0]};
        }

        public int getDimension() {
            return 1;
        }
    }

    @Test
    void testLineSearchDeterminesSaneStepLength1() {
        QuadraticFunction1 quadraticFunction1 = new QuadraticFunction1();
        double[] dArr = {0.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction1.valueAt(dArr), quadraticFunction1.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction1, new double[]{1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertTrue(TOLERANCE < stepSize && stepSize <= 1.0d);
    }

    @Test
    void testLineSearchDeterminesSaneStepLength2() {
        QuadraticFunction2 quadraticFunction2 = new QuadraticFunction2();
        double[] dArr = {-2.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction2.valueAt(dArr), quadraticFunction2.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction2, new double[]{1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertTrue(TOLERANCE < stepSize && stepSize <= 1.0d);
    }

    @Test
    void testLineSearchFailsWithWrongDirection1() {
        QuadraticFunction1 quadraticFunction1 = new QuadraticFunction1();
        double[] dArr = {0.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction1.valueAt(dArr), quadraticFunction1.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction1, new double[]{-1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }

    @Test
    void testLineSearchFailsWithWrongDirection2() {
        QuadraticFunction2 quadraticFunction2 = new QuadraticFunction2();
        double[] dArr = {-2.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction2.valueAt(dArr), quadraticFunction2.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction2, new double[]{-1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }

    @Test
    void testLineSearchFailsWithWrongDirection3() {
        QuadraticFunction1 quadraticFunction1 = new QuadraticFunction1();
        double[] dArr = {4.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction1.valueAt(dArr), quadraticFunction1.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction1, new double[]{1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }

    @Test
    void testLineSearchFailsWithWrongDirection4() {
        QuadraticFunction2 quadraticFunction2 = new QuadraticFunction2();
        double[] dArr = {2.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction2.valueAt(dArr), quadraticFunction2.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction2, new double[]{1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }

    @Test
    void testLineSearchFailsAtMinimum1() {
        QuadraticFunction2 quadraticFunction2 = new QuadraticFunction2();
        double[] dArr = {0.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction2.valueAt(dArr), quadraticFunction2.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction2, new double[]{-1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }

    @Test
    void testLineSearchFailsAtMinimum2() {
        QuadraticFunction2 quadraticFunction2 = new QuadraticFunction2();
        double[] dArr = {0.0d};
        LineSearch.LineSearchResult initialObject = LineSearch.LineSearchResult.getInitialObject(quadraticFunction2.valueAt(dArr), quadraticFunction2.gradientAt(dArr), dArr);
        LineSearch.doLineSearch(quadraticFunction2, new double[]{1.0d}, initialObject, 1.0d);
        double stepSize = initialObject.getStepSize();
        Assertions.assertFalse(TOLERANCE < stepSize && stepSize <= 1.0d);
        Assertions.assertEquals(0.0d, stepSize, TOLERANCE);
    }
}
