package org.apache.mahout.classifier.sgd;

import com.google.common.base.Charsets;
import com.google.common.collect.Lists;
import com.google.common.io.Resources;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Locale;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
import org.apache.commons.cli2.builder.ArgumentBuilder;
import org.apache.commons.cli2.builder.DefaultOptionBuilder;
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.commons.cli2.option.DefaultOption;
import org.apache.commons.cli2.util.HelpFormatter;
import org.apache.lucene.analysis.shingle.ShingleFilter;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.jets3t.service.security.EncryptionUtil;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/TrainLogistic.class */
public final class TrainLogistic {
    private static String inputFile;
    private static String outputFile;
    private static LogisticModelParameters lmp;
    private static int passes;
    private static boolean scores;
    private static OnlineLogisticRegression model;
    static PrintStream output = System.out;

    private TrainLogistic() {
    }

    public static void main(String[] strArr) throws IOException {
        if (parseArgs(strArr)) {
            double d = 0.0d;
            int i = 0;
            CsvRecordFactory csvRecordFactory = lmp.getCsvRecordFactory();
            OnlineLogisticRegression createRegression = lmp.createRegression();
            for (int i2 = 0; i2 < passes; i2++) {
                BufferedReader open = open(inputFile);
                csvRecordFactory.firstLine(open.readLine());
                String readLine = open.readLine();
                while (true) {
                    String str = readLine;
                    if (str != null) {
                        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(lmp.getNumFeatures());
                        int processLine = csvRecordFactory.processLine(str, randomAccessSparseVector);
                        double logLikelihood = createRegression.logLikelihood(processLine, randomAccessSparseVector);
                        if (!Double.isInfinite(logLikelihood)) {
                            d = i < 20 ? ((i * d) + logLikelihood) / (i + 1) : (0.95d * d) + (0.05d * logLikelihood);
                            i++;
                        }
                        double classifyScalar = createRegression.classifyScalar(randomAccessSparseVector);
                        if (scores) {
                            output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f\n", Integer.valueOf(i), Integer.valueOf(processLine), Double.valueOf(createRegression.currentLearningRate()), Double.valueOf(classifyScalar), Double.valueOf(logLikelihood), Double.valueOf(d));
                        }
                        createRegression.train(processLine, randomAccessSparseVector);
                        readLine = open.readLine();
                    }
                }
                open.close();
            }
            FileOutputStream fileOutputStream = new FileOutputStream(outputFile);
            try {
                lmp.saveTo(fileOutputStream);
                fileOutputStream.close();
                output.printf(Locale.ENGLISH, "%d\n", Integer.valueOf(lmp.getNumFeatures()));
                output.printf(Locale.ENGLISH, "%s ~ ", lmp.getTargetVariable());
                Object obj = "";
                for (String str2 : csvRecordFactory.getTraceDictionary().keySet()) {
                    double predictorWeight = predictorWeight(createRegression, 0, csvRecordFactory, str2);
                    if (predictorWeight != 0.0d) {
                        output.printf(Locale.ENGLISH, "%s%.3f*%s", obj, Double.valueOf(predictorWeight), str2);
                        obj = " + ";
                    }
                }
                output.printf("\n", new Object[0]);
                model = createRegression;
                for (int i3 = 0; i3 < createRegression.getBeta().numRows(); i3++) {
                    for (String str3 : csvRecordFactory.getTraceDictionary().keySet()) {
                        double predictorWeight2 = predictorWeight(createRegression, i3, csvRecordFactory, str3);
                        if (predictorWeight2 != 0.0d) {
                            output.printf(Locale.ENGLISH, "%20s %.5f\n", str3, Double.valueOf(predictorWeight2));
                        }
                    }
                    for (int i4 = 0; i4 < createRegression.getBeta().numCols(); i4++) {
                        output.printf(Locale.ENGLISH, "%15.9f ", Double.valueOf(createRegression.getBeta().get(i3, i4)));
                    }
                    output.println();
                }
            } catch (Throwable th) {
                fileOutputStream.close();
                throw th;
            }
        }
    }

    private static double predictorWeight(OnlineLogisticRegression onlineLogisticRegression, int i, RecordFactory recordFactory, String str) {
        double d = 0.0d;
        Iterator<Integer> it = recordFactory.getTraceDictionary().get(str).iterator();
        while (it.hasNext()) {
            d += onlineLogisticRegression.getBeta().get(i, it.next().intValue());
        }
        return d;
    }

    private static boolean parseArgs(String[] strArr) {
        DefaultOptionBuilder defaultOptionBuilder = new DefaultOptionBuilder();
        DefaultOption create = defaultOptionBuilder.withLongName("help").withDescription("print this list").create();
        DefaultOption create2 = defaultOptionBuilder.withLongName("quiet").withDescription("be extra quiet").create();
        DefaultOption create3 = defaultOptionBuilder.withLongName("scores").withDescription("output score diagnostics during training").create();
        ArgumentBuilder argumentBuilder = new ArgumentBuilder();
        DefaultOption create4 = defaultOptionBuilder.withLongName("input").withRequired(true).withArgument(argumentBuilder.withName("input").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption create5 = defaultOptionBuilder.withLongName("output").withRequired(true).withArgument(argumentBuilder.withName("output").withMaximum(1).create()).withDescription("where to get training data").create();
        DefaultOption create6 = defaultOptionBuilder.withLongName("predictors").withRequired(true).withArgument(argumentBuilder.withName("p").create()).withDescription("a list of predictor variables").create();
        DefaultOption create7 = defaultOptionBuilder.withLongName("types").withRequired(true).withArgument(argumentBuilder.withName("t").create()).withDescription("a list of predictor variable types (numeric, word, or text)").create();
        DefaultOption create8 = defaultOptionBuilder.withLongName("target").withRequired(true).withArgument(argumentBuilder.withName("target").withMaximum(1).create()).withDescription("the name of the target variable").create();
        DefaultOption create9 = defaultOptionBuilder.withLongName("features").withArgument(argumentBuilder.withName("numFeatures").withDefault("1000").withMaximum(1).create()).withDescription("the number of internal hashed features to use").create();
        DefaultOption create10 = defaultOptionBuilder.withLongName("passes").withArgument(argumentBuilder.withName("passes").withDefault(EncryptionUtil.DEFAULT_VERSION).withMaximum(1).create()).withDescription("the number of times to pass over the input data").create();
        DefaultOption create11 = defaultOptionBuilder.withLongName("lambda").withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create()).withDescription("the amount of coefficient decay to use").create();
        DefaultOption create12 = defaultOptionBuilder.withLongName("rate").withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create()).withDescription("the learning rate").create();
        DefaultOption create13 = defaultOptionBuilder.withLongName("noBias").withDescription("don't include a bias term").create();
        DefaultOption create14 = defaultOptionBuilder.withLongName("categories").withRequired(true).withArgument(argumentBuilder.withName("number").withMaximum(1).create()).withDescription("the number of target categories to be considered").create();
        Group create15 = new GroupBuilder().withOption(create).withOption(create2).withOption(create4).withOption(create5).withOption(create8).withOption(create14).withOption(create6).withOption(create7).withOption(create10).withOption(create11).withOption(create12).withOption(create13).withOption(create9).create();
        Parser parser = new Parser();
        parser.setHelpOption(create);
        parser.setHelpTrigger("--help");
        parser.setGroup(create15);
        parser.setHelpFormatter(new HelpFormatter(ShingleFilter.TOKEN_SEPARATOR, "", ShingleFilter.TOKEN_SEPARATOR, 130));
        CommandLine parseAndHelp = parser.parseAndHelp(strArr);
        if (parseAndHelp == null) {
            return false;
        }
        inputFile = getStringArgument(parseAndHelp, create4);
        outputFile = getStringArgument(parseAndHelp, create5);
        ArrayList newArrayList = Lists.newArrayList();
        Iterator it = parseAndHelp.getValues(create7).iterator();
        while (it.hasNext()) {
            newArrayList.add(it.next().toString());
        }
        ArrayList newArrayList2 = Lists.newArrayList();
        Iterator it2 = parseAndHelp.getValues(create6).iterator();
        while (it2.hasNext()) {
            newArrayList2.add(it2.next().toString());
        }
        lmp = new LogisticModelParameters();
        lmp.setTargetVariable(getStringArgument(parseAndHelp, create8));
        lmp.setMaxTargetCategories(getIntegerArgument(parseAndHelp, create14));
        lmp.setNumFeatures(getIntegerArgument(parseAndHelp, create9));
        lmp.setUseBias(!getBooleanArgument(parseAndHelp, create13));
        lmp.setTypeMap(newArrayList2, newArrayList);
        lmp.setLambda(getDoubleArgument(parseAndHelp, create11));
        lmp.setLearningRate(getDoubleArgument(parseAndHelp, create12));
        scores = getBooleanArgument(parseAndHelp, create3);
        passes = getIntegerArgument(parseAndHelp, create10);
        return true;
    }

    private static String getStringArgument(CommandLine commandLine, Option option) {
        return (String) commandLine.getValue(option);
    }

    private static boolean getBooleanArgument(CommandLine commandLine, Option option) {
        return commandLine.hasOption(option);
    }

    private static int getIntegerArgument(CommandLine commandLine, Option option) {
        return Integer.parseInt((String) commandLine.getValue(option));
    }

    private static double getDoubleArgument(CommandLine commandLine, Option option) {
        return Double.parseDouble((String) commandLine.getValue(option));
    }

    public static OnlineLogisticRegression getModel() {
        return model;
    }

    public static LogisticModelParameters getParameters() {
        return lmp;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v4, types: [java.io.InputStream] */
    public static BufferedReader open(String str) throws IOException {
        FileInputStream fileInputStream;
        try {
            fileInputStream = Resources.getResource(str).openStream();
        } catch (IllegalArgumentException e) {
            fileInputStream = new FileInputStream(new File(str));
        }
        return new BufferedReader(new InputStreamReader(fileInputStream, Charsets.UTF_8));
    }
}
