package org.apache.flink.ml.common.linalg;

import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/linalg/MatVecOpTest.class */
public class MatVecOpTest {
    private static final double TOL = 1.0E-6d;
    private DenseVector dv;
    private SparseVector sv;

    @Before
    public void setUp() throws Exception {
        this.dv = new DenseVector(new double[]{1.0d, 2.0d, 3.0d, 4.0d});
        this.sv = new SparseVector(4, new int[]{0, 2}, new double[]{1.0d, 1.0d});
    }

    @Test
    public void testPlus() throws Exception {
        DenseVector plus = MatVecOp.plus(this.dv, this.sv);
        DenseVector plus2 = MatVecOp.plus(this.sv, this.dv);
        SparseVector plus3 = MatVecOp.plus(this.sv, this.sv);
        DenseVector plus4 = MatVecOp.plus(this.dv, this.dv);
        Assert.assertTrue(plus instanceof DenseVector);
        Assert.assertTrue(plus2 instanceof DenseVector);
        Assert.assertTrue(plus3 instanceof SparseVector);
        Assert.assertTrue(plus4 instanceof DenseVector);
        Assert.assertArrayEquals(plus.getData(), new double[]{2.0d, 2.0d, 4.0d, 4.0d}, TOL);
        Assert.assertArrayEquals(plus2.getData(), new double[]{2.0d, 2.0d, 4.0d, 4.0d}, TOL);
        Assert.assertArrayEquals(plus3.getIndices(), new int[]{0, 2});
        Assert.assertArrayEquals(plus3.getValues(), new double[]{2.0d, 2.0d}, TOL);
        Assert.assertArrayEquals(plus4.getData(), new double[]{2.0d, 4.0d, 6.0d, 8.0d}, TOL);
    }

    @Test
    public void testMinus() throws Exception {
        DenseVector minus = MatVecOp.minus(this.dv, this.sv);
        DenseVector minus2 = MatVecOp.minus(this.sv, this.dv);
        SparseVector minus3 = MatVecOp.minus(this.sv, this.sv);
        DenseVector minus4 = MatVecOp.minus(this.dv, this.dv);
        Assert.assertTrue(minus instanceof DenseVector);
        Assert.assertTrue(minus2 instanceof DenseVector);
        Assert.assertTrue(minus3 instanceof SparseVector);
        Assert.assertTrue(minus4 instanceof DenseVector);
        Assert.assertArrayEquals(minus.getData(), new double[]{0.0d, 2.0d, 2.0d, 4.0d}, TOL);
        Assert.assertArrayEquals(minus2.getData(), new double[]{0.0d, -2.0d, -2.0d, -4.0d}, TOL);
        Assert.assertArrayEquals(minus3.getIndices(), new int[]{0, 2});
        Assert.assertArrayEquals(minus3.getValues(), new double[]{0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(minus4.getData(), new double[]{0.0d, 0.0d, 0.0d, 0.0d}, TOL);
    }

    @Test
    public void testDot() throws Exception {
        Assert.assertEquals(MatVecOp.dot(this.dv, this.sv), 4.0d, TOL);
        Assert.assertEquals(MatVecOp.dot(this.sv, this.dv), 4.0d, TOL);
        Assert.assertEquals(MatVecOp.dot(this.sv, this.sv), 2.0d, TOL);
        Assert.assertEquals(MatVecOp.dot(this.dv, this.dv), 30.0d, TOL);
    }

    @Test
    public void testSumAbsDiff() throws Exception {
        Assert.assertEquals(MatVecOp.sumAbsDiff(this.dv, this.sv), 8.0d, TOL);
        Assert.assertEquals(MatVecOp.sumAbsDiff(this.sv, this.dv), 8.0d, TOL);
        Assert.assertEquals(MatVecOp.sumAbsDiff(this.sv, this.sv), 0.0d, TOL);
        Assert.assertEquals(MatVecOp.sumAbsDiff(this.dv, this.dv), 0.0d, TOL);
    }

    @Test
    public void testSumSquaredDiff() throws Exception {
        Assert.assertEquals(MatVecOp.sumSquaredDiff(this.dv, this.sv), 24.0d, TOL);
        Assert.assertEquals(MatVecOp.sumSquaredDiff(this.sv, this.dv), 24.0d, TOL);
        Assert.assertEquals(MatVecOp.sumSquaredDiff(this.sv, this.sv), 0.0d, TOL);
        Assert.assertEquals(MatVecOp.sumSquaredDiff(this.dv, this.dv), 0.0d, TOL);
    }
}
