package org.apache.mahout.math.hadoop;

import com.google.common.base.Function;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.decomposer.SolverTest;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.hadoop.TimesSquaredJob;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/hadoop/TestDistributedRowMatrix.class */
public final class TestDistributedRowMatrix extends MahoutTestCase {
    public static final String TEST_PROPERTY_KEY = "test.property.key";
    public static final String TEST_PROPERTY_VALUE = "test.property.value";

    private static void assertEquals(VectorIterable vectorIterable, VectorIterable vectorIterable2, double d) {
        Iterator iterateAll = vectorIterable.iterateAll();
        Iterator iterateAll2 = vectorIterable2.iterateAll();
        HashMap newHashMap = Maps.newHashMap();
        HashMap newHashMap2 = Maps.newHashMap();
        while (iterateAll.hasNext() && iterateAll2.hasNext()) {
            MatrixSlice matrixSlice = (MatrixSlice) iterateAll.next();
            newHashMap.put(Integer.valueOf(matrixSlice.index()), matrixSlice.vector());
            MatrixSlice matrixSlice2 = (MatrixSlice) iterateAll2.next();
            newHashMap2.put(Integer.valueOf(matrixSlice2.index()), matrixSlice2.vector());
        }
        for (Map.Entry entry : newHashMap.entrySet()) {
            Integer num = (Integer) entry.getKey();
            Vector vector = (Vector) entry.getValue();
            if (vector == null || newHashMap2.get(num) == null) {
                assertTrue(vector == null || vector.norm(2.0d) == 0.0d);
                assertTrue(newHashMap2.get(num) == null || ((Vector) newHashMap2.get(num)).norm(2.0d) == 0.0d);
            } else {
                assertTrue(vector.getDistanceSquared((Vector) newHashMap2.get(num)) < d);
            }
        }
    }

    @Test
    public void testTranspose() throws Exception {
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(10, 9, 5, 4, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        DistributedRowMatrix transpose = randomDistributedMatrix.transpose();
        transpose.setConf(getConfiguration());
        Path testTempDirPath = getTestTempDirPath();
        randomDistributedMatrix.setOutputTempPathString(testTempDirPath.toString());
        Path path = new Path(testTempDirPath, "/tmpOutTranspose");
        transpose.setOutputTempPathString(path.toString());
        HadoopUtil.delete(getConfiguration(), new Path[]{path});
        assertEquals(randomDistributedMatrix, transpose.transpose(), 1.0E-6d);
    }

    @Test
    public void testMatrixColumnMeansJob() throws Exception {
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 50, 20, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        DenseVector denseVector = new DenseVector(50);
        for (int i = 0; i < randomSequentialAccessSparseMatrix.numRows(); i++) {
            denseVector.assign(randomSequentialAccessSparseMatrix.viewRow(i), Functions.PLUS);
        }
        denseVector.assign(Functions.DIV, randomSequentialAccessSparseMatrix.numRows());
        assertEquals(0.0d, denseVector.getDistanceSquared(randomDistributedMatrix.columnMeans("DenseVector")), 1.0E-6d);
    }

    @Test
    public void testNullMatrixColumnMeansJob() throws Exception {
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 0, 0, 1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 0, 0, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        DenseVector denseVector = new DenseVector(0);
        for (int i = 0; i < randomSequentialAccessSparseMatrix.numRows(); i++) {
            denseVector.assign(randomSequentialAccessSparseMatrix.viewRow(i), Functions.PLUS);
        }
        denseVector.assign(Functions.DIV, randomSequentialAccessSparseMatrix.numRows());
        assertEquals(0.0d, denseVector.getDistanceSquared(randomDistributedMatrix.columnMeans()), 1.0E-6d);
    }

    @Test
    public void testMatrixTimesVector() throws Exception {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 50, 20, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        assertEquals(0.0d, randomSequentialAccessSparseMatrix.times(randomAccessSparseVector).getDistanceSquared(randomDistributedMatrix.times(randomAccessSparseVector)), 1.0E-6d);
    }

    @Test
    public void testMatrixTimesSquaredVector() throws Exception {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 50, 20, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        assertEquals(0.0d, randomSequentialAccessSparseMatrix.timesSquared(randomAccessSparseVector).getDistanceSquared(randomDistributedMatrix.timesSquared(randomAccessSparseVector)), 1.0E-9d);
    }

    @Test
    public void testMatrixTimesMatrix() throws Exception {
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(20, 19, 15, 5, 10.0d);
        Matrix times = randomSequentialAccessSparseMatrix.transpose().times(SolverTest.randomSequentialAccessSparseMatrix(20, 13, 25, 10, 5.0d));
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(20, 19, 15, 5, 10.0d, false, "distA");
        randomDistributedMatrix.setConf(getConfiguration());
        DistributedRowMatrix randomDistributedMatrix2 = randomDistributedMatrix(20, 13, 25, 10, 5.0d, false, "distB");
        randomDistributedMatrix2.setConf(getConfiguration());
        assertEquals(times, randomDistributedMatrix.times(randomDistributedMatrix2), 1.0E-6d);
    }

    @Test
    public void testMatrixMultiplactionJobConfBuilder() throws Exception {
        Configuration createInitialConf = createInitialConf();
        Path testTempDirPath = getTestTempDirPath("testpaths");
        Path path = new Path(testTempDirPath, "a");
        Path path2 = new Path(testTempDirPath, "b");
        Path path3 = new Path(testTempDirPath, "out");
        Configuration createMatrixMultiplyJobConf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(path, path2, path3, 10);
        Configuration createMatrixMultiplyJobConf2 = MatrixMultiplicationJob.createMatrixMultiplyJobConf(createInitialConf, path, path2, path3, 10);
        assertNull(createMatrixMultiplyJobConf.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, createMatrixMultiplyJobConf2.get(TEST_PROPERTY_KEY));
    }

    @Test
    public void testTransposeJobConfBuilder() throws Exception {
        Configuration createInitialConf = createInitialConf();
        Path testTempDirPath = getTestTempDirPath("testpaths");
        Path path = new Path(testTempDirPath, "input");
        Path path2 = new Path(testTempDirPath, "output");
        Configuration buildTransposeJobConf = TransposeJob.buildTransposeJobConf(path, path2, 10);
        Configuration buildTransposeJobConf2 = TransposeJob.buildTransposeJobConf(createInitialConf, path, path2, 10);
        assertNull(buildTransposeJobConf.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, buildTransposeJobConf2.get(TEST_PROPERTY_KEY));
    }

    @Test
    public void testTimesSquaredJobConfBuilders() throws Exception {
        Configuration createInitialConf = createInitialConf();
        Path testTempDirPath = getTestTempDirPath("testpaths");
        Path path = new Path(testTempDirPath, "input");
        Path path2 = new Path(testTempDirPath, "output");
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        Configuration createTimesSquaredJobConf = TimesSquaredJob.createTimesSquaredJobConf(randomAccessSparseVector, path, path2);
        Configuration createTimesSquaredJobConf2 = TimesSquaredJob.createTimesSquaredJobConf(createInitialConf, randomAccessSparseVector, path, path2);
        assertNull(createTimesSquaredJobConf.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, createTimesSquaredJobConf2.get(TEST_PROPERTY_KEY));
        Configuration createTimesJobConf = TimesSquaredJob.createTimesJobConf(randomAccessSparseVector, 50, path, path2);
        Configuration createTimesJobConf2 = TimesSquaredJob.createTimesJobConf(createInitialConf, randomAccessSparseVector, 50, path, path2);
        assertNull(createTimesJobConf.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, createTimesJobConf2.get(TEST_PROPERTY_KEY));
        Configuration createTimesSquaredJobConf3 = TimesSquaredJob.createTimesSquaredJobConf(randomAccessSparseVector, path, path2, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
        Configuration createTimesSquaredJobConf4 = TimesSquaredJob.createTimesSquaredJobConf(createInitialConf, randomAccessSparseVector, path, path2, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
        assertNull(createTimesSquaredJobConf3.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, createTimesSquaredJobConf4.get(TEST_PROPERTY_KEY));
        Configuration createTimesSquaredJobConf5 = TimesSquaredJob.createTimesSquaredJobConf(randomAccessSparseVector, 50, path, path2, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
        Configuration createTimesSquaredJobConf6 = TimesSquaredJob.createTimesSquaredJobConf(createInitialConf, randomAccessSparseVector, 50, path, path2, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
        assertNull(createTimesSquaredJobConf5.get(TEST_PROPERTY_KEY));
        assertEquals(TEST_PROPERTY_VALUE, createTimesSquaredJobConf6.get(TEST_PROPERTY_KEY));
    }

    @Test
    public void testTimesVectorTempDirDeletion() throws Exception {
        Configuration configuration = getConfiguration();
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 50, 20, 1.0d, false);
        randomDistributedMatrix.setConf(configuration);
        Path outputTempPath = randomDistributedMatrix.getOutputTempPath();
        FileSystem fileSystem = outputTempPath.getFileSystem(configuration);
        deleteContentsOfPath(configuration, outputTempPath);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        Vector times = randomDistributedMatrix.times(randomAccessSparseVector);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        deleteContentsOfPath(configuration, outputTempPath);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        configuration.setBoolean("DistributedMatrix.keep.temp.files", true);
        randomDistributedMatrix.setConf(configuration);
        Vector times2 = randomDistributedMatrix.times(randomAccessSparseVector);
        FileStatus[] listStatus = fileSystem.listStatus(outputTempPath);
        assertEquals(1L, listStatus.length);
        Path path = listStatus[0].getPath();
        Path path2 = new Path(path, "DistributedMatrix.times.inputVector");
        Path path3 = new Path(path, "DistributedMatrix.times.outputVector");
        assertEquals(1L, fileSystem.listStatus(path2, PathFilters.logsCRCFilter()).length);
        assertEquals(1L, fileSystem.listStatus(path3, PathFilters.logsCRCFilter()).length);
        assertEquals(0.0d, times.getDistanceSquared(times2), 1.0E-6d);
    }

    @Test
    public void testTimesSquaredVectorTempDirDeletion() throws Exception {
        Configuration configuration = getConfiguration();
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(100, 90, 50, 20, 1.0d, false);
        randomDistributedMatrix.setConf(getConfiguration());
        Path outputTempPath = randomDistributedMatrix.getOutputTempPath();
        FileSystem fileSystem = outputTempPath.getFileSystem(configuration);
        deleteContentsOfPath(configuration, outputTempPath);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        Vector timesSquared = randomDistributedMatrix.timesSquared(randomAccessSparseVector);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        deleteContentsOfPath(configuration, outputTempPath);
        assertEquals(0L, HadoopUtil.listStatus(fileSystem, outputTempPath).length);
        configuration.setBoolean("DistributedMatrix.keep.temp.files", true);
        randomDistributedMatrix.setConf(configuration);
        Vector timesSquared2 = randomDistributedMatrix.timesSquared(randomAccessSparseVector);
        FileStatus[] listStatus = fileSystem.listStatus(outputTempPath);
        assertEquals(1L, listStatus.length);
        Path path = listStatus[0].getPath();
        Path path2 = new Path(path, "DistributedMatrix.times.inputVector");
        Path path3 = new Path(path, "DistributedMatrix.times.outputVector");
        assertEquals(1L, fileSystem.listStatus(path2, PathFilters.logsCRCFilter()).length);
        assertEquals(1L, fileSystem.listStatus(path3, PathFilters.logsCRCFilter()).length);
        assertEquals(0.0d, timesSquared.getDistanceSquared(timesSquared2), 1.0E-6d);
    }

    public Configuration createInitialConf() throws IOException {
        Configuration configuration = getConfiguration();
        configuration.set(TEST_PROPERTY_KEY, TEST_PROPERTY_VALUE);
        return configuration;
    }

    private static void deleteContentsOfPath(Configuration configuration, Path path) throws Exception {
        FileSystem fileSystem = path.getFileSystem(configuration);
        for (FileStatus fileStatus : HadoopUtil.listStatus(fileSystem, path)) {
            fileSystem.delete(fileStatus.getPath(), true);
        }
    }

    public DistributedRowMatrix randomDistributedMatrix(int i, int i2, int i3, int i4, double d, boolean z) throws IOException {
        return randomDistributedMatrix(i, i2, i3, i4, d, z, "testdata");
    }

    public DistributedRowMatrix randomDenseHierarchicalDistributedMatrix(int i, int i2, boolean z, String str) throws IOException {
        return saveToFs(SolverTest.randomHierarchicalMatrix(i, i2, z), getTestTempDirPath(str));
    }

    public DistributedRowMatrix randomDistributedMatrix(int i, int i2, int i3, int i4, double d, boolean z, String str) throws IOException {
        Path testTempDirPath = getTestTempDirPath(str);
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(i, i2, i3, i4, d);
        if (z) {
            randomSequentialAccessSparseMatrix = randomSequentialAccessSparseMatrix.times(randomSequentialAccessSparseMatrix.transpose());
        }
        return saveToFs(randomSequentialAccessSparseMatrix, testTempDirPath);
    }

    private DistributedRowMatrix saveToFs(final Matrix matrix, Path path) throws IOException {
        Configuration configuration = getConfiguration();
        ClusteringTestUtils.writePointsToFile(new Iterable<VectorWritable>() { // from class: org.apache.mahout.math.hadoop.TestDistributedRowMatrix.1
            @Override // java.lang.Iterable
            public Iterator<VectorWritable> iterator() {
                return Iterators.transform(matrix.iterator(), new Function<MatrixSlice, VectorWritable>() { // from class: org.apache.mahout.math.hadoop.TestDistributedRowMatrix.1.1
                    public VectorWritable apply(MatrixSlice matrixSlice) {
                        return new VectorWritable(matrixSlice.vector());
                    }
                });
            }
        }, true, new Path(path, "distMatrix/part-00000"), FileSystem.get(path.toUri(), configuration), configuration);
        DistributedRowMatrix distributedRowMatrix = new DistributedRowMatrix(new Path(path, "distMatrix"), new Path(path, "tmpOut"), matrix.numRows(), matrix.numCols());
        distributedRowMatrix.setConf(new Configuration(configuration));
        return distributedRowMatrix;
    }
}
