package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.TreeMap;
import org.apache.mahout.classifier.NewsgroupHelper;
import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression;
import org.apache.mahout.classifier.sgd.ModelDissector;
import org.apache.mahout.ep.State;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.vectorizer.encoders.Dictionary;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/SGDHelper.class */
public final class SGDHelper {
    private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"};

    private SGDHelper() {
    }

    public static void dissect(int i, Dictionary dictionary, AdaptiveLogisticRegression adaptiveLogisticRegression, Iterable<File> iterable, Multiset<String> multiset) throws IOException {
        CrossFoldLearner learner = adaptiveLogisticRegression.getBest().getPayload().getLearner();
        learner.close();
        TreeMap newTreeMap = Maps.newTreeMap();
        ModelDissector modelDissector = new ModelDissector();
        NewsgroupHelper newsgroupHelper = new NewsgroupHelper();
        newsgroupHelper.getEncoder().setTraceDictionary(newTreeMap);
        newsgroupHelper.getBias().setTraceDictionary(newTreeMap);
        for (File file : permute(iterable, newsgroupHelper.getRandom()).subList(0, 500)) {
            int intern = dictionary.intern(file.getParentFile().getName());
            newTreeMap.clear();
            modelDissector.update(newsgroupHelper.encodeFeatureVector(file, intern, i, multiset), newTreeMap, learner);
        }
        ArrayList newArrayList = Lists.newArrayList(dictionary.values());
        List<ModelDissector.Weight> summary = modelDissector.summary(100);
        System.out.println("============");
        System.out.println("Model Dissection");
        for (ModelDissector.Weight weight : summary) {
            System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", weight.getFeature(), Double.valueOf(weight.getWeight()), newArrayList.get(weight.getMaxImpact() + 1), Double.valueOf(weight.getCategory(1)), Double.valueOf(weight.getWeight(1)), Double.valueOf(weight.getCategory(2)), Double.valueOf(weight.getWeight(2)));
        }
    }

    public static List<File> permute(Iterable<File> iterable, Random random) {
        ArrayList newArrayList = Lists.newArrayList();
        for (File file : iterable) {
            int nextInt = random.nextInt(newArrayList.size() + 1);
            if (nextInt == newArrayList.size()) {
                newArrayList.add(file);
            } else {
                newArrayList.add(newArrayList.get(nextInt));
                newArrayList.set(nextInt, file);
            }
        }
        return newArrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void analyzeState(SGDInfo sGDInfo, int i, int i2, State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> state) throws IOException {
        double d;
        double d2;
        double d3;
        double d4;
        int i3 = sGDInfo.getBumps()[((int) Math.floor(sGDInfo.getStep())) % sGDInfo.getBumps().length];
        int pow = (int) Math.pow(10.0d, Math.floor(sGDInfo.getStep() / sGDInfo.getBumps().length));
        double d5 = 0.0d;
        double d6 = 0.0d;
        if (state != null) {
            CrossFoldLearner learner = state.getPayload().getLearner();
            sGDInfo.setAverageCorrect(learner.percentCorrect());
            sGDInfo.setAverageLL(learner.logLikelihood());
            OnlineLogisticRegression onlineLogisticRegression = learner.getModels().get(0);
            onlineLogisticRegression.close();
            Matrix beta = onlineLogisticRegression.getBeta();
            d = beta.aggregate(Functions.MAX, Functions.ABS);
            d2 = beta.aggregate(Functions.PLUS, new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.SGDHelper.1
                @Override // org.apache.mahout.math.function.DoubleFunction
                public double apply(double d7) {
                    return Math.abs(d7) > 1.0E-6d ? 1.0d : 0.0d;
                }
            });
            d3 = beta.aggregate(Functions.PLUS, new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.SGDHelper.2
                @Override // org.apache.mahout.math.function.DoubleFunction
                public double apply(double d7) {
                    return d7 > 0.0d ? 1.0d : 0.0d;
                }
            });
            d4 = beta.aggregate(Functions.PLUS, Functions.ABS);
            d5 = state.getMappedParams()[0];
            d6 = state.getMappedParams()[1];
        } else {
            d = 0.0d;
            d2 = 0.0d;
            d3 = 0.0d;
            d4 = 0.0d;
        }
        if (i2 % (i3 * pow) == 0) {
            if (state != null) {
                ModelSerializer.writeBinary("/tmp/news-group-" + i2 + ".model", state.getPayload().getLearner().getModels().get(0));
            }
            sGDInfo.setStep(sGDInfo.getStep() + 0.25d);
            System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4), Double.valueOf(d5), Double.valueOf(d6));
            System.out.printf("%d\t%.3f\t%.2f\t%s\n", Integer.valueOf(i2), Double.valueOf(sGDInfo.getAverageLL()), Double.valueOf(sGDInfo.getAverageCorrect() * 100.0d), LEAK_LABELS[i % 3]);
        }
    }
}
