package org.apache.mahout.math.hadoop;

import java.io.IOException;
import java.net.URI;
import java.util.Iterator;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob.class */
public class TimesSquaredJob {
    private static final Logger log = LoggerFactory.getLogger(TimesSquaredJob.class);
    public static final String INPUT_VECTOR = "DistributedMatrix.times.inputVector";
    public static final String IS_SPARSE_OUTPUT = "DistributedMatrix.times.outputVector.sparse";
    public static final String OUTPUT_VECTOR_DIMENSION = "DistributedMatrix.times.output.dimension";
    public static final String OUTPUT_VECTOR_FILENAME = "DistributedMatrix.times.outputVector";

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$TimesMapper.class */
    public static class TimesMapper extends TimesSquaredMapper<IntWritable> {
        /* renamed from: map, reason: avoid collision after fix types in other method */
        public void map2(IntWritable intWritable, VectorWritable vectorWritable, OutputCollector<NullWritable, VectorWritable> outputCollector, Reporter reporter) {
            this.out = outputCollector;
            double scale = scale(vectorWritable);
            if (scale != 0.0d) {
                this.outputVector.setQuick(intWritable.get(), scale);
            }
        }

        @Override // org.apache.mahout.math.hadoop.TimesSquaredJob.TimesSquaredMapper
        public /* bridge */ /* synthetic */ void map(IntWritable intWritable, VectorWritable vectorWritable, OutputCollector outputCollector, Reporter reporter) throws IOException {
            map2(intWritable, vectorWritable, (OutputCollector<NullWritable, VectorWritable>) outputCollector, reporter);
        }

        @Override // org.apache.mahout.math.hadoop.TimesSquaredJob.TimesSquaredMapper
        public /* bridge */ /* synthetic */ void map(Object obj, Object obj2, OutputCollector outputCollector, Reporter reporter) throws IOException {
            map2((IntWritable) obj, (VectorWritable) obj2, (OutputCollector<NullWritable, VectorWritable>) outputCollector, reporter);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$TimesSquaredMapper.class */
    public static class TimesSquaredMapper<T extends WritableComparable> extends MapReduceBase implements Mapper<T, VectorWritable, NullWritable, VectorWritable> {
        private Vector inputVector;
        protected Vector outputVector;
        protected OutputCollector<NullWritable, VectorWritable> out;

        public void configure(JobConf jobConf) {
            try {
                URI[] cacheFiles = DistributedCache.getCacheFiles(jobConf);
                if (cacheFiles == null || cacheFiles.length < 1) {
                    throw new IllegalArgumentException("missing paths from the DistributedCache");
                }
                Path path = new Path(cacheFiles[0].getPath());
                SequenceFile.Reader reader = new SequenceFile.Reader(path.getFileSystem(jobConf), path, jobConf);
                VectorWritable vectorWritable = new VectorWritable();
                reader.next(NullWritable.get(), vectorWritable);
                reader.close();
                this.inputVector = vectorWritable.get();
                if (!(this.inputVector instanceof SequentialAccessSparseVector) && !(this.inputVector instanceof DenseVector)) {
                    this.inputVector = new SequentialAccessSparseVector(this.inputVector);
                }
                int i = jobConf.getInt(TimesSquaredJob.OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
                this.outputVector = jobConf.getBoolean(TimesSquaredJob.IS_SPARSE_OUTPUT, false) ? new RandomAccessSparseVector(i, 10) : new DenseVector(i);
            } catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }

        public void map(T t, VectorWritable vectorWritable, OutputCollector<NullWritable, VectorWritable> outputCollector, Reporter reporter) throws IOException {
            this.out = outputCollector;
            double scale = scale(vectorWritable);
            if (scale == 1.0d) {
                this.outputVector.assign(vectorWritable.get(), Functions.plus);
            } else if (scale != 0.0d) {
                this.outputVector.assign(vectorWritable.get(), Functions.plusMult(scale));
            }
        }

        protected double scale(VectorWritable vectorWritable) {
            return vectorWritable.get().dot(this.inputVector);
        }

        public void close() throws IOException {
            this.out.collect(NullWritable.get(), new VectorWritable(this.outputVector));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public /* bridge */ /* synthetic */ void map(Object obj, Object obj2, OutputCollector outputCollector, Reporter reporter) throws IOException {
            map((TimesSquaredMapper<T>) obj, (VectorWritable) obj2, (OutputCollector<NullWritable, VectorWritable>) outputCollector, reporter);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$VectorSummingReducer.class */
    public static class VectorSummingReducer extends MapReduceBase implements Reducer<NullWritable, VectorWritable, NullWritable, VectorWritable> {
        private Vector outputVector;

        public void configure(JobConf jobConf) {
            int i = jobConf.getInt(TimesSquaredJob.OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
            this.outputVector = jobConf.getBoolean(TimesSquaredJob.IS_SPARSE_OUTPUT, false) ? new RandomAccessSparseVector(i, 10) : new DenseVector(i);
        }

        public void reduce(NullWritable nullWritable, Iterator<VectorWritable> it, OutputCollector<NullWritable, VectorWritable> outputCollector, Reporter reporter) throws IOException {
            while (it.hasNext()) {
                VectorWritable next = it.next();
                if (next != null) {
                    next.get().addTo(this.outputVector);
                }
            }
            outputCollector.collect(NullWritable.get(), new VectorWritable(this.outputVector));
        }

        public /* bridge */ /* synthetic */ void reduce(Object obj, Iterator it, OutputCollector outputCollector, Reporter reporter) throws IOException {
            reduce((NullWritable) obj, (Iterator<VectorWritable>) it, (OutputCollector<NullWritable, VectorWritable>) outputCollector, reporter);
        }
    }

    private TimesSquaredJob() {
    }

    public static JobConf createTimesSquaredJobConf(Vector vector, Path path, Path path2) throws IOException {
        return createTimesSquaredJobConf(vector, path, path2, TimesSquaredMapper.class, VectorSummingReducer.class);
    }

    public static JobConf createTimesJobConf(Vector vector, int i, Path path, Path path2) throws IOException {
        return createTimesSquaredJobConf(vector, i, path, path2, TimesMapper.class, VectorSummingReducer.class);
    }

    public static JobConf createTimesSquaredJobConf(Vector vector, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        return createTimesSquaredJobConf(vector, vector.size(), path, path2, cls, cls2);
    }

    public static JobConf createTimesSquaredJobConf(Vector vector, int i, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        JobConf jobConf = new JobConf(TimesSquaredJob.class);
        jobConf.setJobName("TimesSquaredJob: " + path + " timesSquared(" + vector.getName() + ')');
        FileSystem fileSystem = FileSystem.get(jobConf);
        Path makeQualified = fileSystem.makeQualified(path);
        Path makeQualified2 = fileSystem.makeQualified(path2);
        Path path3 = new Path(makeQualified2, "DistributedMatrix.times.inputVector/" + System.nanoTime());
        SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, jobConf, path3, NullWritable.class, VectorWritable.class);
        writer.append(NullWritable.get(), new VectorWritable(vector));
        writer.close();
        URI uri = path3.toUri();
        DistributedCache.setCacheFiles(new URI[]{uri}, jobConf);
        fileSystem.deleteOnExit(path3);
        jobConf.set(INPUT_VECTOR, uri.toString());
        jobConf.setBoolean(IS_SPARSE_OUTPUT, !(vector instanceof DenseVector));
        jobConf.setInt(OUTPUT_VECTOR_DIMENSION, i);
        FileInputFormat.addInputPath(jobConf, makeQualified);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        FileOutputFormat.setOutputPath(jobConf, new Path(makeQualified2, OUTPUT_VECTOR_FILENAME));
        jobConf.setMapperClass(cls);
        jobConf.setMapOutputKeyClass(NullWritable.class);
        jobConf.setMapOutputValueClass(VectorWritable.class);
        jobConf.setReducerClass(cls2);
        jobConf.setCombinerClass(cls2);
        jobConf.setOutputFormat(SequenceFileOutputFormat.class);
        jobConf.setOutputKeyClass(NullWritable.class);
        jobConf.setOutputValueClass(VectorWritable.class);
        return jobConf;
    }

    public static Vector retrieveTimesSquaredOutputVector(JobConf jobConf) throws IOException {
        Path outputPath = FileOutputFormat.getOutputPath(jobConf);
        FileSystem fileSystem = FileSystem.get(jobConf);
        Path path = new Path(outputPath, "part-00000");
        SequenceFile.Reader reader = new SequenceFile.Reader(fileSystem, path, jobConf);
        NullWritable nullWritable = NullWritable.get();
        VectorWritable vectorWritable = new VectorWritable();
        reader.next(nullWritable, vectorWritable);
        Vector vector = vectorWritable.get();
        reader.close();
        fileSystem.deleteOnExit(path);
        return vector;
    }
}
