package org.apache.mahout.math.hadoop.solver;

import java.io.IOException;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;
import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/hadoop/solver/TestDistributedConjugateGradientSolverCLI.class */
public final class TestDistributedConjugateGradientSolverCLI extends MahoutTestCase {
    private static Vector randomVector(int i, double d) {
        DenseVector denseVector = new DenseVector(i);
        Random random = RandomUtils.getRandom();
        for (int i2 = 0; i2 < i; i2++) {
            denseVector.setQuick(i2, random.nextGaussian() * d);
        }
        return denseVector;
    }

    private static Path saveVector(Configuration configuration, Path path, Vector vector) throws IOException {
        SequenceFile.Writer writer = new SequenceFile.Writer(path.getFileSystem(configuration), configuration, path, IntWritable.class, VectorWritable.class);
        try {
            writer.append(new IntWritable(0), new VectorWritable(vector));
            writer.close();
            return path;
        } catch (Throwable th) {
            writer.close();
            throw th;
        }
    }

    private static Vector loadVector(Configuration configuration, Path path) throws IOException {
        SequenceFile.Reader reader = new SequenceFile.Reader(path.getFileSystem(configuration), path, configuration);
        IntWritable intWritable = new IntWritable();
        VectorWritable vectorWritable = new VectorWritable();
        try {
            if (!reader.next(intWritable, vectorWritable)) {
                throw new IOException("Input vector file is empty.");
            }
            Vector vector = vectorWritable.get();
            reader.close();
            return vector;
        } catch (Throwable th) {
            reader.close();
            throw th;
        }
    }

    @Test
    public void testSolver() throws Exception {
        Configuration configuration = new Configuration();
        DistributedRowMatrix randomDistributedMatrix = new TestDistributedRowMatrix().randomDistributedMatrix(10, 10, 10, 10, 10.0d, true, getTestTempDirPath("testdata").toString());
        randomDistributedMatrix.setConf(configuration);
        Path testTempFilePath = getTestTempFilePath("output");
        Path testTempFilePath2 = getTestTempFilePath("vector");
        Path testTempDirPath = getTestTempDirPath("tmp");
        Vector randomVector = randomVector(randomDistributedMatrix.numCols(), 10.0d);
        saveVector(configuration, testTempFilePath2, randomVector);
        new DistributedConjugateGradientSolver().job().run(new String[]{"-i", randomDistributedMatrix.getRowPath().toString(), "-o", testTempFilePath.toString(), "--tempDir", testTempDirPath.toString(), "--vector", testTempFilePath2.toString(), "--numRows", "10", "--numCols", "10", "--symmetric", "true"});
        assertEquals(0.0d, Math.sqrt(randomVector.getDistanceSquared(randomDistributedMatrix.times(loadVector(configuration, testTempFilePath)))), 1.0E-6d);
    }
}
