package org.apache.mahout.classifier.sgd;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Ordering;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.lucene.analysis.wikipedia.WikipediaTokenizer;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.vectorizer.EncodingMapper;
import org.apache.mahout.vectorizer.encoders.Dictionary;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/TrainASFEmail.class */
public final class TrainASFEmail extends AbstractJob {
    private TrainASFEmail() {
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        long j;
        addInputOption();
        addOutputOption();
        addOption("categories", "nc", "The number of categories to train on", true);
        addOption(EncodingMapper.CARDINALITY, WikipediaTokenizer.CATEGORY, "The size of the vectors to use", "100000");
        addOption("threads", "t", "The number of threads to use in the learner", "20");
        addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. Higher values require more memory.", "5");
        if (parseArguments(strArr) == null) {
            return -1;
        }
        File file = new File(getInputPath().toString());
        HashMultiset create = HashMultiset.create();
        File file2 = new File(getOutputPath().toString());
        file2.mkdirs();
        int parseInt = Integer.parseInt(getOption("categories"));
        int parseInt2 = Integer.parseInt(getOption(EncodingMapper.CARDINALITY, "100000"));
        int parseInt3 = Integer.parseInt(getOption("threads", "20"));
        int parseInt4 = Integer.parseInt(getOption("poolSize", "5"));
        Dictionary dictionary = new Dictionary();
        AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(parseInt, parseInt2, new L1(), parseInt3, parseInt4);
        adaptiveLogisticRegression.setInterval(800);
        adaptiveLogisticRegression.setAveragingWindow(500);
        Configuration configuration = new Configuration();
        PathFilter pathFilter = new PathFilter() { // from class: org.apache.mahout.classifier.sgd.TrainASFEmail.1
            @Override // org.apache.hadoop.fs.PathFilter
            public boolean accept(Path path) {
                return path.getName().contains("training");
            }
        };
        SequenceFileDirIterator sequenceFileDirIterator = new SequenceFileDirIterator(new Path(file.toString()), PathType.LIST, pathFilter, null, true, configuration);
        long j2 = 0;
        while (true) {
            j = j2;
            if (!sequenceFileDirIterator.hasNext()) {
                break;
            }
            dictionary.intern(((Text) sequenceFileDirIterator.next().getFirst()).toString());
            j2 = j + 1;
        }
        System.out.println(j + " training files");
        SGDInfo sGDInfo = new SGDInfo();
        SequenceFileDirIterator sequenceFileDirIterator2 = new SequenceFileDirIterator(new Path(file.toString()), PathType.LIST, pathFilter, null, true, configuration);
        int i = 0;
        while (sequenceFileDirIterator2.hasNext()) {
            Pair next = sequenceFileDirIterator2.next();
            adaptiveLogisticRegression.train(dictionary.intern(((Text) next.getFirst()).toString()), ((VectorWritable) next.getSecond()).get());
            i++;
            SGDHelper.analyzeState(sGDInfo, 0, i, adaptiveLogisticRegression.getBest());
        }
        adaptiveLogisticRegression.close();
        System.out.println("exiting main, writing model to " + file2);
        ModelSerializer.writeBinary(file2 + "/asf.model", adaptiveLogisticRegression.getBest().getPayload().getLearner().getModels().get(0));
        ArrayList arrayList = new ArrayList();
        System.out.println("Word counts");
        Iterator it = create.elementSet().iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(create.count((String) it.next())));
        }
        Collections.sort(arrayList, Ordering.natural().reverse());
        int i2 = 0;
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            System.out.println(i2 + "\t" + ((Integer) it2.next()));
            i2++;
            if (i2 > 1000) {
                return 0;
            }
        }
        return 0;
    }

    public static void main(String[] strArr) throws Exception {
        new TrainASFEmail().run(strArr);
    }
}
