package org.apache.mahout.math;

import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.jet.random.Gamma;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/TestSparseRowMatrix.class */
public final class TestSparseRowMatrix extends MatrixTest {
    @Override // org.apache.mahout.math.MatrixTest
    public Matrix matrixFactory(double[][] dArr) {
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(dArr.length, dArr[0].length);
        for (int i = 0; i < sparseRowMatrix.rowSize(); i++) {
            for (int i2 = 0; i2 < sparseRowMatrix.columnSize(); i2++) {
                sparseRowMatrix.setQuick(i, i2, dArr[i][i2]);
            }
        }
        return sparseRowMatrix;
    }

    @Test(timeout = 50000)
    public void testTimesSparseEfficiency() {
        RandomWrapper random = RandomUtils.getRandom();
        Gamma gamma = new Gamma(0.1d, 0.1d, random);
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(1000, 2000, false);
        for (int i = 0; i < 1000; i++) {
            int[] iArr = new int[1000];
            for (int i2 = 0; i2 < 1000; i2++) {
                int min = (int) Math.min(1000.0d, gamma.nextDouble());
                iArr[min] = iArr[min] + 1;
            }
            for (int i3 = 0; i3 < 1000; i3++) {
                if (iArr[i3] > 0) {
                    sparseRowMatrix.set(i, i3, iArr[i3]);
                }
            }
        }
        SparseRowMatrix sparseRowMatrix2 = new SparseRowMatrix(2000, 1000, false);
        for (int i4 = 0; i4 < 2000; i4++) {
            int[] iArr2 = new int[1000];
            for (int i5 = 0; i5 < 1000; i5++) {
                int min2 = (int) Math.min(1000.0d, gamma.nextDouble());
                iArr2[min2] = iArr2[min2] + 1;
            }
            for (int i6 = 0; i6 < 1000; i6++) {
                if (iArr2[i6] > 0) {
                    sparseRowMatrix2.set(i4, i6, iArr2[i6]);
                }
            }
        }
        long nanoTime = System.nanoTime();
        Matrix times = sparseRowMatrix.times(sparseRowMatrix2);
        System.out.printf("done in %.1f ms\n", Double.valueOf((System.nanoTime() - nanoTime) * 1.0E-6d));
        for (int i7 = 0; i7 < 1000; i7++) {
            int log = (int) ((-10.0d) * Math.log(random.nextDouble()));
            int log2 = (int) ((-10.0d) * Math.log(random.nextDouble()));
            Assert.assertEquals(sparseRowMatrix.viewRow(log).dot(sparseRowMatrix2.viewColumn(log2)), times.get(log, log2), 1.0E-12d);
        }
    }

    @Test(timeout = 50000)
    public void testTimesDenseEfficiency() {
        RandomWrapper random = RandomUtils.getRandom();
        Gamma gamma = new Gamma(0.1d, 0.1d, random);
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(1000, 2000, false);
        for (int i = 0; i < 1000; i++) {
            int[] iArr = new int[1000];
            for (int i2 = 0; i2 < 1000; i2++) {
                int min = (int) Math.min(1000.0d, gamma.nextDouble());
                iArr[min] = iArr[min] + 1;
            }
            for (int i3 = 0; i3 < 1000; i3++) {
                if (iArr[i3] > 0) {
                    sparseRowMatrix.set(i, i3, iArr[i3]);
                }
            }
        }
        DenseMatrix denseMatrix = new DenseMatrix(2000, 20);
        for (int i4 = 0; i4 < 2000; i4++) {
            for (int i5 = 0; i5 < 20; i5++) {
                denseMatrix.set(i4, i5, random.nextDouble());
            }
        }
        long nanoTime = System.nanoTime();
        Matrix times = sparseRowMatrix.times(denseMatrix);
        System.out.printf("done in %.1f ms\n", Double.valueOf((System.nanoTime() - nanoTime) * 1.0E-6d));
        for (int i6 = 0; i6 < 1000; i6++) {
            for (int i7 = 0; i7 < 20; i7++) {
                Assert.assertEquals(sparseRowMatrix.viewRow(i6).dot(denseMatrix.viewColumn(i7)), times.get(i6, i7), 1.0E-12d);
            }
        }
    }

    @Test(timeout = 50000)
    public void testTimesOtherSparseEfficiency() {
        Gamma gamma = new Gamma(0.1d, 0.1d, RandomUtils.getRandom());
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(1000, 2000, false);
        for (int i = 0; i < 1000; i++) {
            int[] iArr = new int[1000];
            for (int i2 = 0; i2 < 1000; i2++) {
                int min = (int) Math.min(1000.0d, gamma.nextDouble());
                iArr[min] = iArr[min] + 1;
            }
            for (int i3 = 0; i3 < 1000; i3++) {
                if (iArr[i3] > 0) {
                    sparseRowMatrix.set(i, i3, iArr[i3]);
                }
            }
        }
        Vector assign = new DenseVector(2000).assign(Functions.random());
        DiagonalMatrix diagonalMatrix = new DiagonalMatrix(assign);
        long nanoTime = System.nanoTime();
        Matrix<MatrixSlice> times = sparseRowMatrix.times(diagonalMatrix);
        System.out.printf("done in %.1f ms\n", Double.valueOf((System.nanoTime() - nanoTime) * 1.0E-6d));
        for (MatrixSlice matrixSlice : times) {
            for (Vector.Element element : matrixSlice.nonZeroes()) {
                assertEquals(sparseRowMatrix.get(matrixSlice.index(), element.index()) * assign.get(element.index()), element.get(), 1.0E-12d);
            }
        }
    }

    @Test(timeout = 50000)
    public void testTimesCorrect() {
        RandomUtils.getRandom();
        Matrix assign = new SparseRowMatrix(100, 2000, false).assign(Functions.random());
        Matrix assign2 = new SparseRowMatrix(2000, 100, false).assign(Functions.random());
        Matrix assign3 = new DenseMatrix(100, 2000).assign(assign);
        Matrix assign4 = new DenseMatrix(2000, 100).assign(assign2);
        assertEquals(0.0d, assign3.times(assign4).minus(assign.times(assign2)).aggregate(Functions.PLUS, Functions.ABS), 1.0E-15d);
        assertEquals(0.0d, assign.times(assign4).minus(assign.times(assign2)).aggregate(Functions.PLUS, Functions.ABS), 1.0E-15d);
        assertEquals(0.0d, assign3.times(assign2).minus(assign.times(assign2)).aggregate(Functions.PLUS, Functions.ABS), 1.0E-15d);
    }
}
