package org.apache.flink.ml.common.linalg;

import java.util.TreeMap;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/linalg/SparseVectorTest.class */
public class SparseVectorTest {
    private static final double TOL = 1.0E-6d;
    private SparseVector v1 = new SparseVector(8, new int[]{1, 3, 5, 7}, new double[]{2.0d, 2.0d, 2.0d, 2.0d});
    private SparseVector v2 = new SparseVector(8, new int[]{3, 4, 5}, new double[]{1.0d, 1.0d, 1.0d});

    @Test
    public void testConstructor() throws Exception {
        int[] iArr = {3, 7, 2, 1};
        double[] dArr = {3.0d, 7.0d, 2.0d, 1.0d};
        TreeMap treeMap = new TreeMap();
        for (int i = 0; i < iArr.length; i++) {
            treeMap.put(Integer.valueOf(iArr[i]), Double.valueOf(dArr[i]));
        }
        SparseVector sparseVector = new SparseVector(8, treeMap);
        Assert.assertArrayEquals(sparseVector.getIndices(), new int[]{1, 2, 3, 7});
        Assert.assertArrayEquals(sparseVector.getValues(), new double[]{1.0d, 2.0d, 3.0d, 7.0d}, TOL);
    }

    @Test
    public void testSize() throws Exception {
        Assert.assertEquals(this.v1.size(), 8L);
    }

    @Test
    public void testSet() throws Exception {
        SparseVector clone = this.v1.clone();
        clone.set(2, 2.0d);
        clone.set(3, 3.0d);
        Assert.assertEquals(clone.get(2), 2.0d, TOL);
        Assert.assertEquals(clone.get(3), 3.0d, TOL);
    }

    @Test
    public void testAdd() throws Exception {
        SparseVector clone = this.v1.clone();
        clone.add(2, 2.0d);
        clone.add(3, 3.0d);
        Assert.assertEquals(clone.get(2), 2.0d, TOL);
        Assert.assertEquals(clone.get(3), 5.0d, TOL);
    }

    @Test
    public void testPrefix() throws Exception {
        SparseVector prefix = this.v1.prefix(0.2d);
        Assert.assertArrayEquals(prefix.getIndices(), new int[]{0, 2, 4, 6, 8});
        Assert.assertArrayEquals(prefix.getValues(), new double[]{0.2d, 2.0d, 2.0d, 2.0d, 2.0d}, 0.0d);
    }

    @Test
    public void testAppend() throws Exception {
        SparseVector append = this.v1.append(0.2d);
        Assert.assertArrayEquals(append.getIndices(), new int[]{1, 3, 5, 7, 8});
        Assert.assertArrayEquals(append.getValues(), new double[]{2.0d, 2.0d, 2.0d, 2.0d, 0.2d}, 0.0d);
    }

    @Test
    public void testSortIndices() throws Exception {
        int[] iArr = {7, 5, 3, 1};
        double[] dArr = {7.0d, 5.0d, 3.0d, 1.0d};
        this.v1 = new SparseVector(8, iArr, dArr);
        Assert.assertArrayEquals(dArr, new double[]{1.0d, 3.0d, 5.0d, 7.0d}, 0.0d);
        Assert.assertArrayEquals(this.v1.getValues(), new double[]{1.0d, 3.0d, 5.0d, 7.0d}, 0.0d);
        Assert.assertArrayEquals(iArr, new int[]{1, 3, 5, 7});
        Assert.assertArrayEquals(this.v1.getIndices(), new int[]{1, 3, 5, 7});
    }

    @Test
    public void testNormL2Square() throws Exception {
        Assert.assertEquals(this.v2.normL2Square(), 3.0d, TOL);
    }

    @Test
    public void testMinus() throws Exception {
        Vector minus = this.v2.minus(this.v1);
        Assert.assertEquals(minus.get(0), 0.0d, TOL);
        Assert.assertEquals(minus.get(1), -2.0d, TOL);
        Assert.assertEquals(minus.get(2), 0.0d, TOL);
        Assert.assertEquals(minus.get(3), -1.0d, TOL);
        Assert.assertEquals(minus.get(4), 1.0d, TOL);
    }

    @Test
    public void testPlus() throws Exception {
        Vector plus = this.v1.plus(this.v2);
        Assert.assertEquals(plus.get(0), 0.0d, TOL);
        Assert.assertEquals(plus.get(1), 2.0d, TOL);
        Assert.assertEquals(plus.get(2), 0.0d, TOL);
        Assert.assertEquals(plus.get(3), 3.0d, TOL);
        Assert.assertArrayEquals(DenseVector.ones(8).plus(this.v2).getData(), new double[]{1.0d, 1.0d, 1.0d, 2.0d, 2.0d, 2.0d, 1.0d, 1.0d}, TOL);
    }

    @Test
    public void testDot() throws Exception {
        Assert.assertEquals(this.v1.dot(this.v2), 4.0d, TOL);
    }

    @Test
    public void testGet() throws Exception {
        Assert.assertEquals(this.v1.get(5), 2.0d, TOL);
        Assert.assertEquals(this.v1.get(6), 0.0d, TOL);
    }

    @Test
    public void testSlice() throws Exception {
        SparseVector sparseVector = new SparseVector(8, new int[]{1, 3, 5, 7}, new double[]{2.0d, 3.0d, 4.0d, 5.0d});
        SparseVector slice = sparseVector.slice(new int[]{5, 4, 3});
        Assert.assertEquals(slice.size(), 3L);
        Assert.assertArrayEquals(slice.getIndices(), new int[]{0, 2});
        Assert.assertArrayEquals(slice.getValues(), new double[]{4.0d, 3.0d}, 0.0d);
        SparseVector slice2 = sparseVector.slice(new int[]{3, 5});
        Assert.assertArrayEquals(slice2.getIndices(), new int[]{0, 1});
        Assert.assertArrayEquals(slice2.getValues(), new double[]{3.0d, 4.0d}, 0.0d);
        SparseVector slice3 = sparseVector.slice(new int[]{2, 4});
        Assert.assertEquals(slice3.size(), 2L);
        Assert.assertArrayEquals(slice3.getIndices(), new int[0]);
        Assert.assertArrayEquals(slice3.getValues(), new double[0], 0.0d);
        SparseVector slice4 = sparseVector.slice(new int[]{2, 2, 4, 4});
        Assert.assertEquals(slice4.size(), 4L);
        Assert.assertArrayEquals(slice4.getIndices(), new int[0]);
        Assert.assertArrayEquals(slice4.getValues(), new double[0], 0.0d);
    }

    @Test
    public void testToDenseVector() throws Exception {
        DenseVector denseVector = new SparseVector(-1, new int[]{1, 3, 5}, new double[]{1.0d, 3.0d, 5.0d}).toDenseVector();
        Assert.assertEquals(denseVector.size(), 6L);
        Assert.assertArrayEquals(denseVector.getData(), new double[]{0.0d, 1.0d, 0.0d, 3.0d, 0.0d, 5.0d}, TOL);
    }

    @Test
    public void testRemoveZeroValues() throws Exception {
        SparseVector sparseVector = new SparseVector(6, new int[]{1, 3, 5}, new double[]{0.0d, 3.0d, 0.0d});
        sparseVector.removeZeroValues();
        Assert.assertArrayEquals(sparseVector.getIndices(), new int[]{3});
        Assert.assertArrayEquals(sparseVector.getValues(), new double[]{3.0d}, TOL);
    }

    @Test
    public void testOuter() throws Exception {
        DenseMatrix outer = this.v1.outer(this.v2);
        Assert.assertEquals(outer.numRows(), 8L);
        Assert.assertEquals(outer.numCols(), 8L);
        Assert.assertArrayEquals(outer.getRow(0), new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(1), new double[]{0.0d, 0.0d, 0.0d, 2.0d, 2.0d, 2.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(2), new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(3), new double[]{0.0d, 0.0d, 0.0d, 2.0d, 2.0d, 2.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(4), new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(5), new double[]{0.0d, 0.0d, 0.0d, 2.0d, 2.0d, 2.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(6), new double[]{0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d, 0.0d}, TOL);
        Assert.assertArrayEquals(outer.getRow(7), new double[]{0.0d, 0.0d, 0.0d, 2.0d, 2.0d, 2.0d, 0.0d, 0.0d}, TOL);
    }

    @Test
    public void testIterator() throws Exception {
        VectorIterator it = this.v1.iterator();
        Assert.assertTrue(it.hasNext());
        Assert.assertEquals(it.getIndex(), 1L);
        Assert.assertEquals(it.getValue(), 2.0d, 0.0d);
        it.next();
        Assert.assertTrue(it.hasNext());
        Assert.assertEquals(it.getIndex(), 3L);
        Assert.assertEquals(it.getValue(), 2.0d, 0.0d);
        it.next();
        Assert.assertTrue(it.hasNext());
        Assert.assertEquals(it.getIndex(), 5L);
        Assert.assertEquals(it.getValue(), 2.0d, 0.0d);
        it.next();
        Assert.assertTrue(it.hasNext());
        Assert.assertEquals(it.getIndex(), 7L);
        Assert.assertEquals(it.getValue(), 2.0d, 0.0d);
        it.next();
        Assert.assertFalse(it.hasNext());
    }
}
