package org.apache.flink.ml.common.linalg;

import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/linalg/DenseMatrixTest.class */
public class DenseMatrixTest {
    private static final double TOL = 1.0E-6d;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static void assertEqual2D(double[][] dArr, double[][] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr[0].length != dArr2[0].length) {
            throw new AssertionError();
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                Assert.assertEquals(dArr[i][i2], dArr2[i][i2], TOL);
            }
        }
    }

    private static double[][] simpleMM(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr2[0].length;
        int length3 = dArr[0].length;
        double[][] dArr3 = new double[length][length2];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                dArr3[i][i2] = 0.0d;
                for (int i3 = 0; i3 < length3; i3++) {
                    double[] dArr4 = dArr3[i];
                    int i4 = i2;
                    dArr4[i4] = dArr4[i4] + (dArr[i][i3] * dArr2[i3][i2]);
                }
            }
        }
        return dArr3;
    }

    private static double[] simpleMV(double[][] dArr, double[] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        if (!$assertionsDisabled && length2 != dArr2.length) {
            throw new AssertionError();
        }
        double[] dArr3 = new double[length];
        for (int i = 0; i < length; i++) {
            dArr3[i] = 0.0d;
            for (int i2 = 0; i2 < length2; i2++) {
                int i3 = i;
                dArr3[i3] = dArr3[i3] + (dArr[i][i2] * dArr2[i2]);
            }
        }
        return dArr3;
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test
    public void testPlusEquals() throws Exception {
        DenseMatrix denseMatrix = new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 3.0d, 5.0d}, new double[]{2.0d, 4.0d, 6.0d}});
        denseMatrix.plusEquals(DenseMatrix.ones(2, 3));
        Assert.assertArrayEquals(denseMatrix.getData(), new double[]{2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d}, TOL);
        denseMatrix.plusEquals(1.0d);
        Assert.assertArrayEquals(denseMatrix.getData(), new double[]{3.0d, 4.0d, 5.0d, 6.0d, 7.0d, 8.0d}, TOL);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test
    public void testMinusEquals() throws Exception {
        DenseMatrix denseMatrix = new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 3.0d, 5.0d}, new double[]{2.0d, 4.0d, 6.0d}});
        denseMatrix.minusEquals(DenseMatrix.ones(2, 3));
        Assert.assertArrayEquals(denseMatrix.getData(), new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, TOL);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test
    public void testPlus() throws Exception {
        DenseMatrix denseMatrix = new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 3.0d, 5.0d}, new double[]{2.0d, 4.0d, 6.0d}});
        Assert.assertArrayEquals(denseMatrix.plus(DenseMatrix.ones(2, 3)).getData(), new double[]{2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d}, TOL);
        Assert.assertArrayEquals(denseMatrix.plus(1.0d).getData(), new double[]{2.0d, 3.0d, 4.0d, 5.0d, 6.0d, 7.0d}, TOL);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test
    public void testMinus() throws Exception {
        Assert.assertArrayEquals(new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 3.0d, 5.0d}, new double[]{2.0d, 4.0d, 6.0d}}).minus(DenseMatrix.ones(2, 3)).getData(), new double[]{0.0d, 1.0d, 2.0d, 3.0d, 4.0d, 5.0d}, TOL);
    }

    @Test
    public void testMM() throws Exception {
        DenseMatrix rand = DenseMatrix.rand(4, 3);
        DenseMatrix rand2 = DenseMatrix.rand(3, 5);
        DenseMatrix multiplies = rand.multiplies(rand2);
        assertEqual2D(multiplies.getArrayCopy2D(), simpleMM(rand.getArrayCopy2D(), rand2.getArrayCopy2D()));
        DenseMatrix denseMatrix = new DenseMatrix(5, 4);
        BLAS.gemm(1.0d, rand2, true, rand, true, 0.0d, denseMatrix);
        Assert.assertArrayEquals(denseMatrix.transpose().getData(), multiplies.data, TOL);
    }

    @Test
    public void testMV() throws Exception {
        DenseMatrix rand = DenseMatrix.rand(4, 3);
        DenseVector ones = DenseVector.ones(3);
        DenseVector multiplies = rand.multiplies(ones);
        Assert.assertArrayEquals(multiplies.getData(), simpleMV(rand.getArrayCopy2D(), ones.getData()), TOL);
        Assert.assertArrayEquals(rand.multiplies(new SparseVector(3, new int[]{0, 1, 2}, new double[]{1.0d, 1.0d, 1.0d})).getData(), multiplies.getData(), TOL);
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    @Test
    public void testDataSelection() throws Exception {
        DenseMatrix denseMatrix = new DenseMatrix((double[][]) new double[]{new double[]{1.0d, 2.0d, 3.0d}, new double[]{4.0d, 5.0d, 6.0d}, new double[]{7.0d, 8.0d, 9.0d}});
        DenseMatrix selectRows = denseMatrix.selectRows(new int[]{1});
        DenseMatrix subMatrix = denseMatrix.getSubMatrix(1, 2, 1, 2);
        Assert.assertEquals(selectRows.numRows(), 1L);
        Assert.assertEquals(selectRows.numCols(), 3L);
        Assert.assertEquals(subMatrix.numRows(), 1L);
        Assert.assertEquals(subMatrix.numCols(), 1L);
        Assert.assertArrayEquals(selectRows.getData(), new double[]{4.0d, 5.0d, 6.0d}, TOL);
        Assert.assertArrayEquals(subMatrix.getData(), new double[]{5.0d}, TOL);
        double[] row = denseMatrix.getRow(1);
        double[] column = denseMatrix.getColumn(1);
        Assert.assertArrayEquals(row, new double[]{4.0d, 5.0d, 6.0d}, 0.0d);
        Assert.assertArrayEquals(column, new double[]{2.0d, 5.0d, 8.0d}, 0.0d);
    }

    @Test
    public void testSum() throws Exception {
        Assert.assertEquals(DenseMatrix.ones(3, 2).sum(), 6.0d, TOL);
    }

    @Test
    public void testRowMajorFormat() throws Exception {
        double[] dArr = {1.0d, 2.0d, 3.0d, 4.0d, 5.0d, 6.0d};
        DenseMatrix denseMatrix = new DenseMatrix(2, 3, dArr, true);
        Assert.assertArrayEquals(dArr, new double[]{1.0d, 4.0d, 2.0d, 5.0d, 3.0d, 6.0d}, 0.0d);
        Assert.assertArrayEquals(denseMatrix.getData(), new double[]{1.0d, 4.0d, 2.0d, 5.0d, 3.0d, 6.0d}, 0.0d);
        double[] dArr2 = {1.0d, 2.0d, 3.0d, 4.0d};
        DenseMatrix denseMatrix2 = new DenseMatrix(2, 2, dArr2, true);
        Assert.assertArrayEquals(dArr2, new double[]{1.0d, 3.0d, 2.0d, 4.0d}, 0.0d);
        Assert.assertArrayEquals(denseMatrix2.getData(), new double[]{1.0d, 3.0d, 2.0d, 4.0d}, 0.0d);
    }

    static {
        $assertionsDisabled = !DenseMatrixTest.class.desiredAssertionStatus();
    }
}
