package org.apache.mahout.math.solver;

import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/solver/TestConjugateGradientSolver.class */
public class TestConjugateGradientSolver extends MahoutTestCase {
    @Test
    public void testConjugateGradientSolver() {
        Matrix a = getA();
        Vector b = getB();
        ConjugateGradientSolver conjugateGradientSolver = new ConjugateGradientSolver();
        assertEquals(0.0d, Math.sqrt(a.times(conjugateGradientSolver.solve(a, b)).getDistanceSquared(b)), 1.0E-6d);
        assertEquals(0.0d, conjugateGradientSolver.getResidualNorm(), 1.0E-9d);
        assertEquals(10L, conjugateGradientSolver.getIterations());
    }

    @Test
    public void testConditionedConjugateGradientSolver() {
        Matrix illConditionedMatrix = getIllConditionedMatrix();
        Vector b = getB();
        JacobiConditioner jacobiConditioner = new JacobiConditioner(illConditionedMatrix);
        ConjugateGradientSolver conjugateGradientSolver = new ConjugateGradientSolver();
        assertEquals(0.0d, Math.sqrt(illConditionedMatrix.times(conjugateGradientSolver.solve(illConditionedMatrix, b, (Preconditioner) null, 100, 1.0E-9d)).getDistanceSquared(b)), 1.0E-6d);
        assertEquals(0.0d, conjugateGradientSolver.getResidualNorm(), 1.0E-9d);
        assertEquals(16L, conjugateGradientSolver.getIterations());
        assertEquals(0.0d, Math.sqrt(illConditionedMatrix.times(conjugateGradientSolver.solve(illConditionedMatrix, b, jacobiConditioner, 100, 1.0E-9d)).getDistanceSquared(b)), 1.0E-6d);
        assertEquals(0.0d, conjugateGradientSolver.getResidualNorm(), 1.0E-9d);
        assertEquals(15L, conjugateGradientSolver.getIterations());
    }

    @Test
    public void testEarlyStop() {
        Matrix a = getA();
        Vector b = getB();
        ConjugateGradientSolver conjugateGradientSolver = new ConjugateGradientSolver();
        double sqrt = Math.sqrt(a.times(conjugateGradientSolver.solve(a, b, (Preconditioner) null, 10, 0.1d)).getDistanceSquared(b));
        assertTrue(sqrt > 1.0E-6d);
        assertEquals(0.0d, sqrt, 0.1d);
        assertEquals(7L, conjugateGradientSolver.getIterations());
        double sqrt2 = Math.sqrt(a.times(conjugateGradientSolver.solve(a, b, (Preconditioner) null, 7, 1.0E-9d)).getDistanceSquared(b));
        assertTrue(sqrt2 > 1.0E-6d);
        assertEquals(0.0d, sqrt2, 0.1d);
        assertEquals(7L, conjugateGradientSolver.getIterations());
    }

    private static Matrix getA() {
        return reshape(new double[]{11.7155649822794d, -0.7125253363083646d, 4.647361396186018d, 1.6020939468348456d, -4.6789817799137134d, -0.814041676343497d, -4.5995617505618345d, -1.174907004277534d, -1.6747995811678336d, 3.1922255171058342d, -0.7125253363083646d, 12.340057968399487d, -2.6498099427000645d, 0.5264507222630669d, 0.3783428369189767d, -2.117018615918881d, 2.369513425219053d, 3.8182131490333013d, 6.528594229827035d, 2.8564814419366353d, 4.647361396186018d, -2.6498099427000645d, 16.13179339216685d, -0.0409475448061225d, 1.4805687075608227d, -2.995807648462895d, -2.5288893025027264d, -0.9614557539842487d, -2.2974738351519077d, -1.5516184284572598d, 1.6020939468348456d, 0.5264507222630669d, -0.0409475448061225d, 4.194680212269448d, -2.5210038046912198d, 0.6634899962909317d, 0.4036187419205338d, -0.2829211393003727d, -0.2283091172980954d, 1.1253516563552464d, -4.6789817799137134d, 0.3783428369189767d, 1.4805687075608227d, -2.5210038046912198d, 19.430736186273343d, -2.5200132222091787d, 2.374851197144451d, 11.642659844330552d, -0.1508136510863874d, 4.347134388806351d, -0.814041676343497d, -2.117018615918881d, -2.995807648462895d, 0.6634899962909317d, -2.5200132222091787d, 7.671233441970075d, -3.868777362950285d, -3.045341871159153d, -0.1155580876143619d, -2.402545946742212d, -4.5995617505618345d, 2.369513425219053d, -2.5288893025027264d, 0.4036187419205338d, 2.374851197144451d, -3.868777362950285d, 10.468166605747008d, 1.652718086617123d, 2.9341795819365384d, -2.17081763727631d, -1.174907004277534d, 3.8182131490333013d, -0.9614557539842487d, -0.2829211393003727d, 11.642659844330552d, -3.045341871159153d, 1.652718086617123d, 16.005061693417623d, 1.1689747208793086d, 1.666509094595487d, -1.6747995811678336d, 6.528594229827035d, -2.2974738351519077d, -0.2283091172980954d, -0.1508136510863874d, -0.1155580876143619d, 2.9341795819365384d, 1.1689747208793086d, 6.479432975163748d, -1.9197339981871877d, 3.1922255171058342d, 2.8564814419366353d, -1.5516184284572598d, 1.1253516563552464d, 4.347134388806351d, -2.402545946742212d, -2.17081763727631d, 1.666509094595487d, -1.9197339981871877d, 18.91490213563446d}, 10, 10);
    }

    private static Vector getB() {
        return new DenseVector(new double[]{-0.552252d, 0.03843d, 0.058392d, -1.234496d, 1.240369d, 0.373649d, 0.505113d, 0.503723d, 1.21534d, -0.391908d});
    }

    private static Matrix getIllConditionedMatrix() {
        return reshape(new double[]{0.00695278043678842d, 0.09911830022078683d, 0.01309584636255063d, 0.00652917453032394d, 0.04337631487735064d, 0.14232165273321387d, 0.05808722912361313d, -0.06591965049732287d, 0.06055771542862332d, 0.00577423310349649d, 0.09911830022078683d, 1.5007140241806143d, 0.14988743575884242d, 0.07195514527480981d, 0.6374736234175272d, 1.3071181902041469d, 0.8215160938511595d, -0.7261612552458794d, 1.0349013600202295d, 0.12800239664439328d, 0.01309584636255063d, 0.14988743575884242d, 0.04068462583124965d, 0.02147022047006482d, 0.0738811358014665d, 0.58070223915076d, 0.11280336266257514d, -0.21690068430020618d, 0.04065087561300068d, -0.00876895259593769d, 0.00652917453032394d, 0.07195514527480981d, 0.02147022047006482d, 0.01140105250542524d, 0.03624164348693958d, 0.31291554581393255d, 0.05648457235205666d, -0.1150758301607778d, 0.01475756130709823d, -0.00584453679519805d, 0.04337631487735064d, 0.6374736234175272d, 0.07388113580146649d, 0.03624164348693959d, 0.2749154320076057d, 0.7341054316874812d, 0.36120630002843257d, -0.36583546331208316d, 0.41472509341940017d, 0.0458145875825548d, 0.14232165273321387d, 1.3071181902041467d, 0.58070223915076d, 0.31291554581393255d, 0.7341054316874812d, 9.02536073121807d, 1.254263855828831d, -3.1618633512559464d, -0.19740140818905436d, -0.26613760880058035d, 0.05808722912361314d, 0.8215160938511595d, 0.11280336266257514d, 0.05648457235205667d, 0.36120630002843257d, 1.2542638558288313d, 0.4866105845160682d, -0.570305113365622d, 0.491512804648181d, 0.04428280690189127d, -0.06591965049732286d, -0.7261612552458794d, -0.21690068430020618d, -0.11507583016077781d, -0.36583546331208316d, -3.1618633512559464d, -0.570305113365622d, 1.1627081503807895d, -0.14837898963724327d, 0.05917203395002889d, 0.06055771542862331d, 1.0349013600202293d, 0.04065087561300068d, 0.01475756130709823d, 0.4147250934194002d, -0.19740140818905436d, 0.49151280464818103d, -0.14837898963724327d, 0.8669382068204972d, 0.1408968875257034d, 0.00577423310349649d, 0.12800239664439328d, -0.00876895259593769d, -0.00584453679519805d, 0.0458145875825548d, -0.26613760880058035d, 0.04428280690189126d, 0.05917203395002889d, 0.1408968875257034d, 0.02901858439788401d}, 10, 10);
    }

    private static Matrix reshape(double[] dArr, int i, int i2) {
        DenseMatrix denseMatrix = new DenseMatrix(i, i2);
        int i3 = 0;
        for (double d : dArr) {
            denseMatrix.set(i3 % i, i3 / i, d);
            i3++;
        }
        return denseMatrix;
    }
}
