package org.apache.mahout.clustering.dirichlet;

import java.io.IOException;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.OptionException;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.matrix.SparseVector;
import org.apache.mahout.matrix.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/clustering/dirichlet/DirichletDriver.class */
public class DirichletDriver {
    public static final String STATE_IN_KEY = "org.apache.mahout.clustering.dirichlet.stateIn";
    public static final String MODEL_FACTORY_KEY = "org.apache.mahout.clustering.dirichlet.modelFactory";
    public static final String NUM_CLUSTERS_KEY = "org.apache.mahout.clustering.dirichlet.numClusters";
    public static final String ALPHA_0_KEY = "org.apache.mahout.clustering.dirichlet.alpha_0";
    private static final Logger log = LoggerFactory.getLogger(DirichletDriver.class);

    private DirichletDriver() {
    }

    public static void main(String[] strArr) throws InstantiationException, IllegalAccessException, ClassNotFoundException, IOException {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        GroupBuilder groupBuilder = new GroupBuilder();
        DefaultOption create = DefaultOptionCreator.inputOption(defaultOptionBuilder, argumentBuilder).create();
        DefaultOption create2 = DefaultOptionCreator.outputOption(defaultOptionBuilder, argumentBuilder).create();
        DefaultOption create3 = DefaultOptionCreator.maxIterOption(defaultOptionBuilder, argumentBuilder).create();
        DefaultOption create4 = DefaultOptionCreator.kOption(defaultOptionBuilder, argumentBuilder).create();
        Option helpOption = DefaultOptionCreator.helpOption(defaultOptionBuilder);
        DefaultOption create5 = defaultOptionBuilder.withLongName("alpha").withRequired(true).withShortName("m").withArgument(argumentBuilder.withName("alpha").withMinimum(1).withMaximum(1).create()).withDescription("The alpha0 value for the DirichletDistribution.").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument(argumentBuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()).withDescription("The ModelDistribution class name.").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("maxRed").withRequired(true).withShortName("r").withArgument(argumentBuilder.withName("maxRed").withMinimum(1).withMaximum(1).create()).withDescription("The number of reduce tasks.").create();
        Group create8 = groupBuilder.withName("Options").withOption(create).withOption(create2).withOption(create6).withOption(create3).withOption(create5).withOption(create4).withOption(helpOption).withOption(create7).create();
        try {
            Parser parser = new Parser();
            parser.setGroup(create8);
            CommandLine parse = parser.parse(strArr);
            if (parse.hasOption(helpOption)) {
                CommandLineUtil.printHelp(create8);
                return;
            }
            runJob(parse.getValue(create).toString(), parse.getValue(create2).toString(), parse.getValue(create6).toString(), Integer.parseInt(parse.getValue(create4).toString()), Integer.parseInt(parse.getValue(create3).toString()), Double.parseDouble(parse.getValue(create5).toString()), Integer.parseInt(parse.getValue(create7).toString()));
        } catch (OptionException e) {
            log.error("Exception parsing command line: ", e);
            CommandLineUtil.printHelp(create8);
        }
    }

    public static void runJob(String str, String str2, String str3, int i, int i2, double d, int i3) throws ClassNotFoundException, InstantiationException, IllegalAccessException, IOException {
        String str4 = str2 + "/state-0";
        writeInitialState(str2, str4, str3, i, d);
        for (int i4 = 0; i4 < i2; i4++) {
            log.info("Iteration {}", Integer.valueOf(i4));
            String str5 = str2 + "/state-" + (i4 + 1);
            runIteration(str, str4, str5, str3, i, d, i3);
            str4 = str5;
        }
    }

    private static void writeInitialState(String str, String str2, String str3, int i, double d) throws ClassNotFoundException, InstantiationException, IllegalAccessException, IOException {
        DirichletState<Vector> createState = createState(str3, i, d);
        JobConf jobConf = new JobConf(KMeansDriver.class);
        Path path = new Path(str);
        FileSystem fileSystem = FileSystem.get(path.toUri(), jobConf);
        fileSystem.delete(path, true);
        for (int i2 = 0; i2 < i; i2++) {
            SequenceFile.Writer writer = new SequenceFile.Writer(fileSystem, jobConf, new Path(str2 + "/part-" + i2), Text.class, DirichletCluster.class);
            writer.append(new Text(Integer.toString(i2)), createState.getClusters().get(i2));
            writer.close();
        }
    }

    public static DirichletState<Vector> createState(String str, int i, double d) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        return new DirichletState<>((ModelDistribution) Thread.currentThread().getContextClassLoader().loadClass(str).newInstance(), i, d, 1, 1);
    }

    public static void runIteration(String str, String str2, String str3, String str4, int i, double d, int i2) {
        JobClient jobClient = new JobClient();
        JobConf jobConf = new JobConf(DirichletDriver.class);
        jobConf.setOutputKeyClass(Text.class);
        jobConf.setOutputValueClass(DirichletCluster.class);
        jobConf.setMapOutputKeyClass(Text.class);
        jobConf.setMapOutputValueClass(SparseVector.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{new Path(str)});
        FileOutputFormat.setOutputPath(jobConf, new Path(str3));
        jobConf.setMapperClass(DirichletMapper.class);
        jobConf.setReducerClass(DirichletReducer.class);
        jobConf.setNumReduceTasks(i2);
        jobConf.setInputFormat(SequenceFileInputFormat.class);
        jobConf.setOutputFormat(SequenceFileOutputFormat.class);
        jobConf.set(STATE_IN_KEY, str2);
        jobConf.set(MODEL_FACTORY_KEY, str4);
        jobConf.set(NUM_CLUSTERS_KEY, Integer.toString(i));
        jobConf.set(ALPHA_0_KEY, Double.toString(d));
        jobClient.setConf(jobConf);
        try {
            JobClient.runJob(jobConf);
        } catch (IOException e) {
            log.warn(e.toString(), (Throwable) e);
        }
    }

    public static void runClustering(String str, String str2, String str3) {
        JobClient jobClient = new JobClient();
        JobConf jobConf = new JobConf(DirichletDriver.class);
        jobConf.setOutputKeyClass(Text.class);
        jobConf.setOutputValueClass(Text.class);
        FileInputFormat.setInputPaths(jobConf, new Path[]{new Path(str)});
        FileOutputFormat.setOutputPath(jobConf, new Path(str3));
        jobConf.setMapperClass(DirichletMapper.class);
        jobConf.setNumReduceTasks(0);
        jobClient.setConf(jobConf);
        try {
            JobClient.runJob(jobConf);
        } catch (IOException e) {
            log.warn(e.toString(), (Throwable) e);
        }
    }
}
