package org.apache.mahout.clustering.classify;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.iterator.ClusterWritable;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:org/apache/mahout/clustering/classify/ClusterClassificationMapper.class */
public class ClusterClassificationMapper extends Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable> {
    private static double threshold;
    private List<Cluster> clusterModels;
    private ClusterClassifier clusterClassifier;
    private IntWritable clusterId;
    private WeightedVectorWritable weightedVW;
    private boolean emitMostLikely;

    protected void setup(Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context) throws IOException, InterruptedException {
        super.setup(context);
        Configuration configuration = context.getConfiguration();
        String str = configuration.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
        threshold = configuration.getFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD, 0.0f);
        this.emitMostLikely = configuration.getBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, false);
        this.clusterModels = new ArrayList();
        if (str != null && !str.isEmpty()) {
            Path path = new Path(str);
            this.clusterModels = populateClusterModels(path, configuration);
            this.clusterClassifier = new ClusterClassifier(this.clusterModels, ClusterClassifier.readPolicy(finalClustersPath(path)));
        }
        this.clusterId = new IntWritable();
        this.weightedVW = new WeightedVectorWritable(1.0d, null);
    }

    protected void map(WritableComparable<?> writableComparable, VectorWritable vectorWritable, Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context) throws IOException, InterruptedException {
        if (this.clusterModels.isEmpty()) {
            return;
        }
        Vector classify = this.clusterClassifier.classify(vectorWritable.get());
        if (shouldClassify(classify)) {
            if (this.emitMostLikely) {
                write(vectorWritable, context, classify.maxValueIndex());
            } else {
                writeAllAboveThreshold(vectorWritable, context, classify);
            }
        }
    }

    private void writeAllAboveThreshold(VectorWritable vectorWritable, Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context, Vector vector) throws IOException, InterruptedException {
        Iterator<Vector.Element> iterateNonZero = vector.iterateNonZero();
        while (iterateNonZero.hasNext()) {
            Vector.Element next = iterateNonZero.next();
            if (next.get() >= threshold) {
                write(vectorWritable, context, next.index());
            }
        }
    }

    private void write(VectorWritable vectorWritable, Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context, int i) throws IOException, InterruptedException {
        this.clusterId.set(this.clusterModels.get(i).getId());
        this.weightedVW.setVector(vectorWritable.get());
        context.write(this.clusterId, this.weightedVW);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static List<Cluster> populateClusterModels(Path path, Configuration configuration) throws IOException {
        ArrayList arrayList = new ArrayList();
        SequenceFileDirValueIterator sequenceFileDirValueIterator = new SequenceFileDirValueIterator(path.getFileSystem(configuration).listStatus(path, PathFilters.finalPartFilter())[0].getPath(), PathType.LIST, PathFilters.partFilter(), null, false, configuration);
        while (sequenceFileDirValueIterator.hasNext()) {
            Cluster value = ((ClusterWritable) sequenceFileDirValueIterator.next()).getValue();
            value.configure(configuration);
            arrayList.add(value);
        }
        return arrayList;
    }

    private static boolean shouldClassify(Vector vector) {
        return vector.maxValue() >= threshold;
    }

    private static Path finalClustersPath(Path path) throws IOException {
        return path.getFileSystem(new Configuration()).listStatus(path, PathFilters.finalPartFilter())[0].getPath();
    }

    protected /* bridge */ /* synthetic */ void map(Object obj, Object obj2, Mapper.Context context) throws IOException, InterruptedException {
        map((WritableComparable<?>) obj, (VectorWritable) obj2, (Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context) context);
    }
}
