package org.apache.mahout.math.hadoop;

import com.google.common.base.Preconditions;
import com.google.common.io.Closeables;
import java.io.IOException;
import java.net.URI;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.Functions;

/* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob.class */
public final class TimesSquaredJob {
    public static final String INPUT_VECTOR = "DistributedMatrix.times.inputVector";
    public static final String IS_SPARSE_OUTPUT = "DistributedMatrix.times.outputVector.sparse";
    public static final String OUTPUT_VECTOR_DIMENSION = "DistributedMatrix.times.output.dimension";
    public static final String OUTPUT_VECTOR_FILENAME = "DistributedMatrix.times.outputVector";

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$TimesMapper.class */
    public static class TimesMapper extends TimesSquaredMapper<IntWritable> {
        /* renamed from: map, reason: avoid collision after fix types in other method */
        protected void map2(IntWritable intWritable, VectorWritable vectorWritable, Mapper<IntWritable, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            double scale = scale(vectorWritable);
            if (scale != 0.0d) {
                getOutputVector().setQuick(intWritable.get(), scale);
            }
        }

        @Override // org.apache.mahout.math.hadoop.TimesSquaredJob.TimesSquaredMapper
        protected /* bridge */ /* synthetic */ void map(IntWritable intWritable, VectorWritable vectorWritable, Mapper.Context context) throws IOException, InterruptedException {
            map2(intWritable, vectorWritable, (Mapper<IntWritable, VectorWritable, NullWritable, VectorWritable>.Context) context);
        }

        @Override // org.apache.mahout.math.hadoop.TimesSquaredJob.TimesSquaredMapper, org.apache.hadoop.mapreduce.Mapper
        protected /* bridge */ /* synthetic */ void map(Object obj, VectorWritable vectorWritable, Mapper.Context context) throws IOException, InterruptedException {
            map2((IntWritable) obj, vectorWritable, (Mapper<IntWritable, VectorWritable, NullWritable, VectorWritable>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$TimesSquaredMapper.class */
    public static class TimesSquaredMapper<T extends WritableComparable> extends Mapper<T, VectorWritable, NullWritable, VectorWritable> {
        private Vector outputVector;
        private Vector inputVector;

        Vector getOutputVector() {
            return this.outputVector;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void setup(Mapper<T, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            try {
                Configuration configuration = context.getConfiguration();
                Path[] localCacheFiles = DistributedCache.getLocalCacheFiles(configuration);
                Preconditions.checkArgument(localCacheFiles != null && localCacheFiles.length >= 1, "missing paths from the DistributedCache");
                SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(HadoopUtil.getSingleCachedFile(configuration), true, configuration);
                try {
                    this.inputVector = ((VectorWritable) sequenceFileValueIterator.next()).get();
                    Closeables.close(sequenceFileValueIterator, true);
                    int i = configuration.getInt(TimesSquaredJob.OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
                    this.outputVector = configuration.getBoolean(TimesSquaredJob.IS_SPARSE_OUTPUT, false) ? new RandomAccessSparseVector(i, 10) : new DenseVector(i);
                } catch (Throwable th) {
                    Closeables.close(sequenceFileValueIterator, true);
                    throw th;
                }
            } catch (IOException e) {
                throw new IllegalStateException(e);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void map(T t, VectorWritable vectorWritable, Mapper<T, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            double scale = scale(vectorWritable);
            if (scale == 1.0d) {
                this.outputVector.assign(vectorWritable.get(), Functions.PLUS);
            } else if (scale != 0.0d) {
                this.outputVector.assign(vectorWritable.get(), Functions.plusMult(scale));
            }
        }

        protected double scale(VectorWritable vectorWritable) {
            return vectorWritable.get().dot(this.inputVector);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Mapper
        public void cleanup(Mapper<T, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            context.write(NullWritable.get(), new VectorWritable(this.outputVector));
        }
    }

    /* loaded from: input_file:org/apache/mahout/math/hadoop/TimesSquaredJob$VectorSummingReducer.class */
    public static class VectorSummingReducer extends Reducer<NullWritable, VectorWritable, NullWritable, VectorWritable> {
        private Vector outputVector;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void setup(Reducer<NullWritable, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            Configuration configuration = context.getConfiguration();
            int i = configuration.getInt(TimesSquaredJob.OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
            this.outputVector = configuration.getBoolean(TimesSquaredJob.IS_SPARSE_OUTPUT, false) ? new RandomAccessSparseVector(i, 10) : new DenseVector(i);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.apache.hadoop.mapreduce.Reducer
        public void reduce(NullWritable nullWritable, Iterable<VectorWritable> iterable, Reducer<NullWritable, VectorWritable, NullWritable, VectorWritable>.Context context) throws IOException, InterruptedException {
            for (VectorWritable vectorWritable : iterable) {
                if (vectorWritable != null) {
                    this.outputVector.assign(vectorWritable.get(), Functions.PLUS);
                }
            }
            context.write(NullWritable.get(), new VectorWritable(this.outputVector));
        }
    }

    private TimesSquaredJob() {
    }

    public static Job createTimesSquaredJob(Vector vector, Path path, Path path2) throws IOException {
        return createTimesSquaredJob(new Configuration(), vector, path, path2);
    }

    public static Job createTimesSquaredJob(Configuration configuration, Vector vector, Path path, Path path2) throws IOException {
        return createTimesSquaredJob(configuration, vector, path, path2, (Class<? extends TimesSquaredMapper>) TimesSquaredMapper.class, (Class<? extends VectorSummingReducer>) VectorSummingReducer.class);
    }

    public static Job createTimesJob(Vector vector, int i, Path path, Path path2) throws IOException {
        return createTimesJob(new Configuration(), vector, i, path, path2);
    }

    public static Job createTimesJob(Configuration configuration, Vector vector, int i, Path path, Path path2) throws IOException {
        return createTimesSquaredJob(configuration, vector, i, path, path2, TimesMapper.class, VectorSummingReducer.class);
    }

    public static Job createTimesSquaredJob(Vector vector, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        return createTimesSquaredJob(new Configuration(), vector, path, path2, cls, cls2);
    }

    public static Job createTimesSquaredJob(Configuration configuration, Vector vector, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        return createTimesSquaredJob(configuration, vector, vector.size(), path, path2, cls, cls2);
    }

    public static Job createTimesSquaredJob(Vector vector, int i, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        return createTimesSquaredJob(new Configuration(), vector, i, path, path2, cls, cls2);
    }

    public static Job createTimesSquaredJob(Configuration configuration, Vector vector, int i, Path path, Path path2, Class<? extends TimesSquaredMapper> cls, Class<? extends VectorSummingReducer> cls2) throws IOException {
        FileSystem fileSystem = FileSystem.get(path.toUri(), configuration);
        Path makeQualified = fileSystem.makeQualified(path);
        Path makeQualified2 = fileSystem.makeQualified(path2);
        Path path3 = new Path(makeQualified2, "DistributedMatrix.times.inputVector/" + System.nanoTime());
        SequenceFile.Writer writer = null;
        try {
            writer = new SequenceFile.Writer(fileSystem, configuration, path3, NullWritable.class, VectorWritable.class);
            writer.append((Writable) NullWritable.get(), (Writable) new VectorWritable(vector));
            Closeables.close(writer, false);
            URI uri = path3.toUri();
            DistributedCache.setCacheFiles(new URI[]{uri}, configuration);
            Job prepareJob = HadoopUtil.prepareJob(makeQualified, new Path(makeQualified2, OUTPUT_VECTOR_FILENAME), SequenceFileInputFormat.class, cls, NullWritable.class, VectorWritable.class, cls2, NullWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, configuration);
            prepareJob.setCombinerClass(cls2);
            prepareJob.setJobName("TimesSquaredJob: " + makeQualified);
            Configuration configuration2 = prepareJob.getConfiguration();
            configuration2.set(INPUT_VECTOR, uri.toString());
            configuration2.setBoolean(IS_SPARSE_OUTPUT, !vector.isDense());
            configuration2.setInt(OUTPUT_VECTOR_DIMENSION, i);
            return prepareJob;
        } catch (Throwable th) {
            Closeables.close(writer, false);
            throw th;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static Vector retrieveTimesSquaredOutputVector(Path path, Configuration configuration) throws IOException {
        SequenceFileValueIterator sequenceFileValueIterator = new SequenceFileValueIterator(new Path(path, "DistributedMatrix.times.outputVector/part-r-00000"), true, configuration);
        try {
            Vector vector = ((VectorWritable) sequenceFileValueIterator.next()).get();
            Closeables.close(sequenceFileValueIterator, true);
            return vector;
        } catch (Throwable th) {
            Closeables.close(sequenceFileValueIterator, true);
            throw th;
        }
    }
}
