package org.apache.mahout.math.hadoop;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.TestCanopyCreation;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.decomposer.SolverTest;

/* loaded from: input_file:org/apache/mahout/math/hadoop/TestDistributedRowMatrix.class */
public class TestDistributedRowMatrix extends MahoutTestCase {
    private static final String TESTDATA = "testdata";

    @Override // org.apache.mahout.common.MahoutTestCase
    public void setUp() throws Exception {
        super.setUp();
        File file = new File(TESTDATA);
        if (file.exists()) {
            TestCanopyCreation.rmr(TESTDATA);
        }
        file.mkdir();
    }

    public void tearDown() throws Exception {
        TestCanopyCreation.rmr(TESTDATA);
        super.tearDown();
    }

    public static void assertEquals(double d, double d2, double d3) {
        assertTrue(Math.abs(d - d2) < d3);
    }

    public static void assertEquals(VectorIterable vectorIterable, VectorIterable vectorIterable2, double d) {
        Iterator iterateAll = vectorIterable.iterateAll();
        Iterator iterateAll2 = vectorIterable2.iterateAll();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        while (iterateAll.hasNext() && iterateAll2.hasNext()) {
            MatrixSlice matrixSlice = (MatrixSlice) iterateAll.next();
            hashMap.put(Integer.valueOf(matrixSlice.index()), matrixSlice.vector());
            MatrixSlice matrixSlice2 = (MatrixSlice) iterateAll2.next();
            hashMap2.put(Integer.valueOf(matrixSlice2.index()), matrixSlice2.vector());
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            Integer num = (Integer) entry.getKey();
            Vector vector = (Vector) entry.getValue();
            if (vector == null || hashMap2.get(num) == null) {
                assertTrue(vector == null || vector.norm(2.0d) == 0.0d);
                assertTrue(hashMap2.get(num) == null || ((Vector) hashMap2.get(num)).norm(2.0d) == 0.0d);
            } else {
                assertTrue(vector.getDistanceSquared((Vector) hashMap2.get(num)) < d);
            }
        }
    }

    public void testTranspose() throws Exception {
        DistributedRowMatrix randomDistributedMatrix = randomDistributedMatrix(10, 9, 5, 4, 1.0d, false);
        DistributedRowMatrix transpose = randomDistributedMatrix.transpose();
        transpose.setOutputTempPathString(new Path(randomDistributedMatrix.getOutputTempPath().getParent(), "/tmpOutTranspose").toString());
        assertEquals((VectorIterable) randomDistributedMatrix, (VectorIterable) transpose.transpose(), 1.0E-9d);
    }

    public void testMatrixTimesVector() throws Exception {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        assertEquals(0.0d, SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0d).times(randomAccessSparseVector).getDistanceSquared(randomDistributedMatrix(100, 90, 50, 20, 1.0d, false).times(randomAccessSparseVector)), 1.0E-9d);
    }

    public void testMatrixTimesSquaredVector() throws Exception {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(50);
        randomAccessSparseVector.assign(1.0d);
        assertEquals(0.0d, SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0d).timesSquared(randomAccessSparseVector).getDistanceSquared(randomDistributedMatrix(100, 90, 50, 20, 1.0d, false).timesSquared(randomAccessSparseVector)), 1.0E-9d);
    }

    public void testMatrixTimesMatrix() throws Exception {
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(20, 19, 15, 5, 10.0d);
        assertEquals((VectorIterable) randomSequentialAccessSparseMatrix.transpose().times(SolverTest.randomSequentialAccessSparseMatrix(20, 13, 25, 10, 5.0d)), (VectorIterable) randomDistributedMatrix(20, 19, 15, 5, 10.0d, false, "/distA").times(randomDistributedMatrix(20, 13, 25, 10, 5.0d, false, "/distB")), 1.0E-9d);
    }

    public static DistributedRowMatrix randomDistributedMatrix(int i, int i2, int i3, int i4, double d, boolean z) throws Exception {
        return randomDistributedMatrix(i, i2, i3, i4, d, z, "");
    }

    public static DistributedRowMatrix randomDistributedMatrix(int i, int i2, int i3, int i4, double d, boolean z, String str) throws IOException {
        String str2 = TESTDATA + str;
        Matrix randomSequentialAccessSparseMatrix = SolverTest.randomSequentialAccessSparseMatrix(i, i2, i3, i4, d);
        if (z) {
            randomSequentialAccessSparseMatrix = randomSequentialAccessSparseMatrix.times(randomSequentialAccessSparseMatrix.transpose());
        }
        final Matrix matrix = randomSequentialAccessSparseMatrix;
        Configuration configuration = new Configuration();
        ClusteringTestUtils.writePointsToFile(new Iterable<VectorWritable>() { // from class: org.apache.mahout.math.hadoop.TestDistributedRowMatrix.1
            @Override // java.lang.Iterable
            public Iterator<VectorWritable> iterator() {
                final Iterator it = matrix.iterator();
                final VectorWritable vectorWritable = new VectorWritable();
                return new Iterator<VectorWritable>() { // from class: org.apache.mahout.math.hadoop.TestDistributedRowMatrix.1.1
                    @Override // java.util.Iterator
                    public boolean hasNext() {
                        return it.hasNext();
                    }

                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.Iterator
                    public VectorWritable next() {
                        vectorWritable.set(((MatrixSlice) it.next()).vector());
                        return vectorWritable;
                    }

                    @Override // java.util.Iterator
                    public void remove() {
                        it.remove();
                    }
                };
            }
        }, true, str2 + "/distMatrix/part-00000", FileSystem.get(configuration), configuration);
        DistributedRowMatrix distributedRowMatrix = new DistributedRowMatrix(str2 + "/distMatrix", str2 + "/tmpOut", matrix.numRows(), matrix.numCols());
        distributedRowMatrix.configure(new JobConf(configuration));
        return distributedRowMatrix;
    }
}
