package org.apache.mahout.clustering.spectral.common;

import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.spectral.common.VectorMatrixMultiplicationJob;
import org.apache.mahout.common.DummyRecordWriter;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/clustering/spectral/common/TestVectorMatrixMultiplicationJob.class */
public class TestVectorMatrixMultiplicationJob extends MahoutTestCase {
    private static final double[][] MATRIX = {new double[]{1.0d, 1.0d}, new double[]{2.0d, 3.0d}};
    private static final double[] VECTOR = {9.0d, 16.0d};

    @Test
    public void testVectorMatrixMultiplicationMapper() throws Exception {
        VectorMatrixMultiplicationJob.VectorMatrixMultiplicationMapper vectorMatrixMultiplicationMapper = new VectorMatrixMultiplicationJob.VectorMatrixMultiplicationMapper();
        Configuration configuration = new Configuration();
        DenseVector denseVector = new DenseVector(VECTOR);
        DummyRecordWriter dummyRecordWriter = new DummyRecordWriter();
        Mapper.Context build = DummyRecordWriter.build(vectorMatrixMultiplicationMapper, configuration, dummyRecordWriter);
        vectorMatrixMultiplicationMapper.setup(denseVector);
        for (int i = 0; i < MATRIX.length; i++) {
            RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(MATRIX[i].length);
            randomAccessSparseVector.assign(MATRIX[i]);
            vectorMatrixMultiplicationMapper.map(new IntWritable(i), new VectorWritable(randomAccessSparseVector), build);
        }
        assertEquals("Number of map results", MATRIX.length, dummyRecordWriter.getData().size());
        for (int i2 = 0; i2 < MATRIX.length; i2++) {
            List value = dummyRecordWriter.getValue(new IntWritable(i2));
            assertEquals("Only one vector per key", 1L, value.size());
            Vector vector = ((VectorWritable) value.get(0)).get();
            for (int i3 = 0; i3 < MATRIX[i2].length; i3++) {
                assertEquals("Product matrix elements", Math.sqrt(VECTOR[i2]) * Math.sqrt(VECTOR[i3]) * MATRIX[i2][i3], vector.get(i3), 1.0E-6d);
            }
        }
    }
}
