package org.apache.mahout.math.hadoop.stochasticsvd;

import java.util.LinkedList;
import java.util.Random;
import junit.framework.Assert;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/hadoop/stochasticsvd/LocalSSVDSolverDenseTest.class */
public class LocalSSVDSolverDenseTest extends MahoutTestCase {
    private static final double s_epsilon = 1.0E-10d;

    @Test
    public void testSSVDSolver() throws Exception {
        Configuration configuration = new Configuration();
        configuration.set("mapred.job.tracker", "local");
        configuration.set("fs.default.name", "file:///");
        LinkedList linkedList = new LinkedList();
        Random random = RandomUtils.getRandom();
        configuration.set("hadoop.tmp.dir", getTestTempDir("svdtmp").getAbsolutePath());
        Path path = new Path(getTestTempDirPath("svdtmp/A"), "A.seq");
        SequenceFile.Writer createWriter = SequenceFile.createWriter(FileSystem.getLocal(configuration), configuration, path, IntWritable.class, VectorWritable.class, SequenceFile.CompressionType.BLOCK, new DefaultCodec());
        linkedList.addFirst(createWriter);
        double[] dArr = new double[100];
        VectorWritable vectorWritable = new VectorWritable(new DenseVector(dArr, true));
        IntWritable intWritable = new IntWritable();
        for (int i = 0; i < 1000; i++) {
            for (int i2 = 0; i2 < 100; i2++) {
                dArr[i2] = 50.0d * (random.nextDouble() - 0.5d);
            }
            intWritable.set(i);
            createWriter.append(intWritable, vectorWritable);
        }
        linkedList.remove(createWriter);
        createWriter.close();
        FileSystem fileSystem = FileSystem.get(configuration);
        Path testTempDirPath = getTestTempDirPath("svd-proc");
        Path path2 = new Path(testTempDirPath, "A/A.seq");
        fileSystem.copyFromLocalFile(path, path2);
        Path path3 = new Path(testTempDirPath, "SSVD-out");
        fileSystem.delete(path3, true);
        SSVDSolver sSVDSolver = new SSVDSolver(configuration, new Path[]{path2}, path3, 251, 40, 60, 3);
        sSVDSolver.setOverwrite(true);
        sSVDSolver.run();
        double[] singularValues = sSVDSolver.getSingularValues();
        System.out.println("--SSVD solver singular values:");
        dumpSv(singularValues);
        System.out.println("--Colt SVD solver singular values:");
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(new DenseMatrix(SSVDSolver.loadDistributedRowMatrix(fileSystem, path2, configuration)));
        double[] singularValues2 = singularValueDecomposition.getSingularValues();
        dumpSv(singularValues2);
        for (int i3 = 0; i3 < 40 + 60; i3++) {
            Assert.assertTrue(Math.abs(singularValues2[i3] - singularValues[i3]) <= s_epsilon);
        }
        SSVDPrototypeTest.assertOrthonormality(new DenseMatrix(SSVDSolver.loadDistributedRowMatrix(fileSystem, new Path(path3, "Bt-job/Q-*"), configuration)), false, s_epsilon);
        SSVDPrototypeTest.assertOrthonormality(new DenseMatrix(SSVDSolver.loadDistributedRowMatrix(fileSystem, new Path(path3, "U/[^_]*"), configuration)), false, s_epsilon);
        SSVDPrototypeTest.assertOrthonormality(new DenseMatrix(SSVDSolver.loadDistributedRowMatrix(fileSystem, new Path(path3, "V/[^_]*"), configuration)), false, s_epsilon);
    }

    static void dumpSv(double[] dArr) {
        System.out.printf("svs: ", new Object[0]);
        for (double d : dArr) {
            System.out.printf("%f  ", Double.valueOf(d));
        }
        System.out.println();
    }

    static void dump(double[][] dArr) {
        for (double[] dArr2 : dArr) {
            for (double d : dArr2) {
                System.out.printf("%f  ", Double.valueOf(d));
            }
            System.out.println();
        }
    }
}
