package org.apache.mahout.math.hadoop.solver;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.DistributedRowMatrix;
import org.apache.mahout.math.solver.ConjugateGradientSolver;
import org.apache.mahout.math.solver.Preconditioner;

/* loaded from: input_file:org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.class */
public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool {
    private Configuration conf;
    private Map<String, List<String>> parsedArgs;

    /* loaded from: input_file:org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver$DistributedConjugateGradientSolverJob.class */
    public class DistributedConjugateGradientSolverJob extends AbstractJob {
        public DistributedConjugateGradientSolverJob() {
        }

        @Override // org.apache.mahout.common.AbstractJob
        public void setConf(Configuration configuration) {
            DistributedConjugateGradientSolver.this.setConf(configuration);
        }

        public Configuration getConf() {
            return DistributedConjugateGradientSolver.this.getConf();
        }

        public int run(String[] strArr) throws Exception {
            addInputOption();
            addOutputOption();
            addOption("numRows", "nr", "Number of rows in the input matrix", true);
            addOption("numCols", "nc", "Number of columns in the input matrix", true);
            addOption("vector", WikipediaTokenizer.BOLD, "Vector to solve against", true);
            addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0");
            addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true");
            addOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION, "x", "Maximum number of iterations to run");
            addOption("maxError", "err", "Maximum residual error to allow before stopping");
            DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(strArr);
            if (DistributedConjugateGradientSolver.this.parsedArgs == null) {
                return -1;
            }
            DistributedConjugateGradientSolver.this.setConf(new Configuration());
            return DistributedConjugateGradientSolver.this.run(strArr);
        }
    }

    public Vector runJob(Path path, Path path2, int i, int i2, Vector vector, Preconditioner preconditioner, int i3, double d) {
        DistributedRowMatrix distributedRowMatrix = new DistributedRowMatrix(path, path2, i, i2);
        distributedRowMatrix.setConf(this.conf);
        return solve(distributedRowMatrix, vector, preconditioner, i3, d);
    }

    public Configuration getConf() {
        return this.conf;
    }

    public void setConf(Configuration configuration) {
        this.conf = configuration;
    }

    public int run(String[] strArr) throws Exception {
        Path path = new Path(AbstractJob.getOption(this.parsedArgs, "--input"));
        Path path2 = new Path(AbstractJob.getOption(this.parsedArgs, "--output"));
        Path path3 = new Path(AbstractJob.getOption(this.parsedArgs, "--tempDir"));
        Path path4 = new Path(AbstractJob.getOption(this.parsedArgs, "--vector"));
        int parseInt = Integer.parseInt(AbstractJob.getOption(this.parsedArgs, "--numRows"));
        int parseInt2 = Integer.parseInt(AbstractJob.getOption(this.parsedArgs, "--numCols"));
        saveOutputVector(path2, runJob(path, path3, parseInt, parseInt2, loadInputVector(path4), null, this.parsedArgs.containsKey("--maxIter") ? Integer.parseInt(AbstractJob.getOption(this.parsedArgs, "--maxIter")) : parseInt2, this.parsedArgs.containsKey("--maxError") ? Double.parseDouble(AbstractJob.getOption(this.parsedArgs, "--maxError")) : 1.0E-9d));
        path3.getFileSystem(this.conf).delete(path3, true);
        return 0;
    }

    public DistributedConjugateGradientSolverJob job() {
        return new DistributedConjugateGradientSolverJob();
    }

    private Vector loadInputVector(Path path) throws IOException {
        SequenceFile.Reader reader = new SequenceFile.Reader(path.getFileSystem(this.conf), path, this.conf);
        VectorWritable vectorWritable = new VectorWritable();
        try {
            if (!reader.next(new IntWritable(), vectorWritable)) {
                throw new IOException("Input vector file is empty.");
            }
            Vector vector = vectorWritable.get();
            reader.close();
            return vector;
        } catch (Throwable th) {
            reader.close();
            throw th;
        }
    }

    private void saveOutputVector(Path path, Vector vector) throws IOException {
        SequenceFile.Writer writer = new SequenceFile.Writer(path.getFileSystem(this.conf), this.conf, path, IntWritable.class, VectorWritable.class);
        try {
            writer.append(new IntWritable(0), new VectorWritable(vector));
            writer.close();
        } catch (Throwable th) {
            writer.close();
            throw th;
        }
    }

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new DistributedConjugateGradientSolver().job(), strArr);
    }
}
