package org.apache.flink.ml.common.linalg;

import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

/* loaded from: input_file:org/apache/flink/ml/common/linalg/BLASTest.class */
public class BLASTest {
    private static final double TOL = 1.0E-8d;
    private DenseMatrix mat = new DenseMatrix(2, 3, new double[]{1.0d, 4.0d, 2.0d, 5.0d, 3.0d, 6.0d});
    private DenseVector dv1 = new DenseVector(new double[]{1.0d, 2.0d});
    private DenseVector dv2 = new DenseVector(new double[]{1.0d, 2.0d, 3.0d});
    private SparseVector spv1 = new SparseVector(2, new int[]{0, 1}, new double[]{1.0d, 2.0d});
    private SparseVector spv2 = new SparseVector(3, new int[]{0, 2}, new double[]{1.0d, 3.0d});

    @Rule
    public ExpectedException thrown = ExpectedException.none();

    @Test
    public void testAsum() throws Exception {
        Assert.assertEquals(BLAS.asum(this.dv1), 3.0d, TOL);
        Assert.assertEquals(BLAS.asum(this.spv1), 3.0d, TOL);
    }

    @Test
    public void testScal() throws Exception {
        DenseVector clone = this.dv1.clone();
        BLAS.scal(0.5d, clone);
        Assert.assertArrayEquals(clone.getData(), new double[]{0.5d, 1.0d}, TOL);
        SparseVector clone2 = this.spv1.clone();
        BLAS.scal(0.5d, clone2);
        Assert.assertArrayEquals(clone2.getIndices(), this.spv1.getIndices());
        Assert.assertArrayEquals(clone2.getValues(), new double[]{0.5d, 1.0d}, TOL);
    }

    @Test
    public void testDot() throws Exception {
        Assert.assertEquals(BLAS.dot(this.dv1, DenseVector.ones(2)), 3.0d, TOL);
    }

    @Test
    public void testAxpy() throws Exception {
        DenseVector ones = DenseVector.ones(2);
        BLAS.axpy(1.0d, this.dv1, ones);
        Assert.assertArrayEquals(ones.getData(), new double[]{2.0d, 3.0d}, TOL);
        BLAS.axpy(1.0d, this.spv1, ones);
        Assert.assertArrayEquals(ones.getData(), new double[]{3.0d, 5.0d}, TOL);
        BLAS.axpy(1, 1.0d, new double[]{1.0d}, 0, ones.getData(), 1);
        Assert.assertArrayEquals(ones.getData(), new double[]{3.0d, 6.0d}, TOL);
    }

    private DenseMatrix simpleMM(DenseMatrix denseMatrix, DenseMatrix denseMatrix2) {
        DenseMatrix denseMatrix3 = new DenseMatrix(denseMatrix.numRows(), denseMatrix2.numCols());
        for (int i = 0; i < denseMatrix.numRows(); i++) {
            for (int i2 = 0; i2 < denseMatrix2.numCols(); i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < denseMatrix.numCols(); i3++) {
                    d += denseMatrix.get(i, i3) * denseMatrix2.get(i3, i2);
                }
                denseMatrix3.set(i, i2, d);
            }
        }
        return denseMatrix3;
    }

    @Test
    public void testGemm() throws Exception {
        DenseMatrix rand = DenseMatrix.rand(3, 2);
        DenseMatrix rand2 = DenseMatrix.rand(2, 4);
        DenseMatrix rand3 = DenseMatrix.rand(3, 4);
        DenseMatrix rand4 = DenseMatrix.rand(4, 2);
        DenseMatrix rand5 = DenseMatrix.rand(4, 3);
        DenseMatrix zeros = DenseMatrix.zeros(3, 4);
        BLAS.gemm(1.0d, rand, false, rand2, false, 0.0d, zeros);
        Assert.assertArrayEquals(zeros.getData(), simpleMM(rand, rand2).getData(), TOL);
        BLAS.gemm(1.0d, rand, false, rand4, true, 0.0d, zeros);
        Assert.assertArrayEquals(zeros.getData(), simpleMM(rand, rand4.transpose()).getData(), TOL);
        DenseMatrix zeros2 = DenseMatrix.zeros(2, 4);
        BLAS.gemm(1.0d, rand, true, rand3, false, 0.0d, zeros2);
        Assert.assertArrayEquals(zeros2.getData(), simpleMM(rand.transpose(), rand3).getData(), TOL);
        BLAS.gemm(1.0d, rand, true, rand5, true, 0.0d, zeros2);
        Assert.assertArrayEquals(zeros2.getData(), simpleMM(rand.transpose(), rand5.transpose()).getData(), TOL);
    }

    @Test
    public void testGemmSizeCheck() throws Exception {
        this.thrown.expect(IllegalArgumentException.class);
        BLAS.gemm(1.0d, DenseMatrix.rand(3, 2), false, DenseMatrix.rand(4, 2), false, 0.0d, DenseMatrix.zeros(3, 4));
    }

    @Test
    public void testGemmTransposeSizeCheck() throws Exception {
        this.thrown.expect(IllegalArgumentException.class);
        BLAS.gemm(1.0d, DenseMatrix.rand(3, 2), true, DenseMatrix.rand(4, 2), true, 0.0d, DenseMatrix.zeros(3, 4));
    }

    @Test
    public void testGemvDense() throws Exception {
        DenseVector ones = DenseVector.ones(2);
        BLAS.gemv(2.0d, this.mat, false, this.dv2, 0.0d, ones);
        Assert.assertArrayEquals(new double[]{28.0d, 64.0d}, ones.data, TOL);
        DenseVector ones2 = DenseVector.ones(2);
        BLAS.gemv(2.0d, this.mat, false, this.dv2, 1.0d, ones2);
        Assert.assertArrayEquals(new double[]{29.0d, 65.0d}, ones2.data, TOL);
    }

    @Test
    public void testGemvDenseTranspose() throws Exception {
        DenseVector ones = DenseVector.ones(3);
        BLAS.gemv(1.0d, this.mat, true, this.dv1, 0.0d, ones);
        Assert.assertArrayEquals(new double[]{9.0d, 12.0d, 15.0d}, ones.data, TOL);
        DenseVector ones2 = DenseVector.ones(3);
        BLAS.gemv(1.0d, this.mat, true, this.dv1, 1.0d, ones2);
        Assert.assertArrayEquals(new double[]{10.0d, 13.0d, 16.0d}, ones2.data, TOL);
    }

    @Test
    public void testGemvSparse() throws Exception {
        DenseVector ones = DenseVector.ones(2);
        BLAS.gemv(2.0d, this.mat, false, this.spv2, 0.0d, ones);
        Assert.assertArrayEquals(new double[]{20.0d, 44.0d}, ones.data, TOL);
        DenseVector ones2 = DenseVector.ones(2);
        BLAS.gemv(2.0d, this.mat, false, this.spv2, 1.0d, ones2);
        Assert.assertArrayEquals(new double[]{21.0d, 45.0d}, ones2.data, TOL);
    }

    @Test
    public void testGemvSparseTranspose() throws Exception {
        DenseVector ones = DenseVector.ones(3);
        BLAS.gemv(2.0d, this.mat, true, this.spv1, 0.0d, ones);
        Assert.assertArrayEquals(new double[]{18.0d, 24.0d, 30.0d}, ones.data, TOL);
        DenseVector ones2 = DenseVector.ones(3);
        BLAS.gemv(2.0d, this.mat, true, this.spv1, 1.0d, ones2);
        Assert.assertArrayEquals(new double[]{19.0d, 25.0d, 31.0d}, ones2.data, TOL);
    }

    @Test
    public void testGemvSizeCheck() throws Exception {
        this.thrown.expect(IllegalArgumentException.class);
        BLAS.gemv(2.0d, this.mat, false, this.dv1, 0.0d, DenseVector.ones(2));
    }

    @Test
    public void testGemvTransposeSizeCheck() throws Exception {
        this.thrown.expect(IllegalArgumentException.class);
        BLAS.gemv(2.0d, this.mat, true, this.dv1, 0.0d, DenseVector.ones(2));
    }
}
