package org.apache.mahout.classifier.naivebayes.test;

import com.google.common.base.Preconditions;
import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hdfs.DFSConfigKeys;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.class */
public class TestNaiveBayesDriver extends AbstractJob {
    public static final String COMPLEMENTARY = "class";
    private static final Logger log = LoggerFactory.getLogger((Class<?>) TestNaiveBayesDriver.class);
    private static final Pattern SLASH = Pattern.compile("/");

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), strArr);
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
        addOption("model", FuzzyKMeansDriver.M_OPTION, "The path to the model built during training", true);
        addOption(buildOption("testComplementary", WikipediaTokenizer.CATEGORY, "test complementary?", false, false, String.valueOf(false)));
        addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false)));
        addOption("labelIndex", "l", "The path to the location of the label index", true);
        if (parseArguments(strArr) == null) {
            return -1;
        }
        if (hasOption("overwrite")) {
            HadoopUtil.delete(getConf(), getOutputPath());
        }
        if (hasOption("runSequential")) {
            runSequential();
        } else if (!runMapReduce()) {
            return -1;
        }
        Map<Integer, String> readLabelIndex = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
        SequenceFileDirIterable sequenceFileDirIterable = new SequenceFileDirIterable(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
        ResultAnalyzer resultAnalyzer = new ResultAnalyzer(readLabelIndex.values(), DFSConfigKeys.DFS_CLIENT_WRITE_REPLACE_DATANODE_ON_FAILURE_POLICY_DEFAULT);
        analyzeResults(readLabelIndex, sequenceFileDirIterable, resultAnalyzer);
        log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", resultAnalyzer);
        return 0;
    }

    private void runSequential() throws IOException {
        boolean hasOption = hasOption("testComplementary");
        FileSystem fileSystem = FileSystem.get(getConf());
        NaiveBayesModel materialize = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
        if (hasOption) {
            Preconditions.checkArgument(materialize.isComplemtary(), "Complementary mode in model is different from test mode");
        }
        AbstractVectorClassifier complementaryNaiveBayesClassifier = hasOption ? new ComplementaryNaiveBayesClassifier(materialize) : new StandardNaiveBayesClassifier(materialize);
        SequenceFile.Writer createWriter = SequenceFile.createWriter(fileSystem, getConf(), new Path(getOutputPath(), "part-r-00000"), Text.class, VectorWritable.class);
        Throwable th = null;
        try {
            try {
                Iterator it2 = new SequenceFileDirIterable(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf()).iterator();
                while (it2.hasNext()) {
                    Pair pair = (Pair) it2.next();
                    createWriter.append((Writable) new Text(SLASH.split(((Text) pair.getFirst()).toString())[1]), (Writable) new VectorWritable(complementaryNaiveBayesClassifier.classifyFull(((VectorWritable) pair.getSecond()).get())));
                }
                if (createWriter != null) {
                    if (0 == 0) {
                        createWriter.close();
                        return;
                    }
                    try {
                        createWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (createWriter != null) {
                if (th != null) {
                    try {
                        createWriter.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    createWriter.close();
                }
            }
            throw th4;
        }
    }

    private boolean runMapReduce() throws IOException, InterruptedException, ClassNotFoundException {
        HadoopUtil.cacheFiles(new Path(getOption("model")), getConf());
        Job prepareJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(hasOption("testComplementary")));
        return prepareJob.waitForCompletion(true);
    }

    private static void analyzeResults(Map<Integer, String> map, SequenceFileDirIterable<Text, VectorWritable> sequenceFileDirIterable, ResultAnalyzer resultAnalyzer) {
        Iterator<Pair<Text, VectorWritable>> it2 = sequenceFileDirIterable.iterator();
        while (it2.hasNext()) {
            Pair<Text, VectorWritable> next = it2.next();
            int i = Integer.MIN_VALUE;
            double d = -9.223372036854776E18d;
            for (Vector.Element element : next.getSecond().get().all()) {
                if (element.get() > d) {
                    d = element.get();
                    i = element.index();
                }
            }
            if (i != Integer.MIN_VALUE) {
                resultAnalyzer.addInstance(next.getFirst().toString(), new ClassifierResult(map.get(Integer.valueOf(i)), d));
            }
        }
    }
}
