package org.apache.mahout.classifier.sgd;

import java.io.IOException;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.class */
public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
    @Test
    public void crossValidation() throws IOException {
        Vector readStandardData = readStandardData();
        CrossFoldLearner learningRate = new CrossFoldLearner(5, 2, 8, new L1()).lambda(0.001d).learningRate(50.0d);
        train(getInput(), readStandardData, learningRate);
        System.out.printf("%.2f %.5f\n", Double.valueOf(learningRate.auc()), Double.valueOf(learningRate.logLikelihood()));
        test(getInput(), readStandardData, learningRate, 0.05d, 0.3d);
    }

    @Test
    public void crossValidatedAuc() throws IOException {
        RandomUtils.useTestSeed();
        Random random = RandomUtils.getRandom();
        Matrix readCsv = readCsv("cancer.csv");
        CrossFoldLearner learningRate = new CrossFoldLearner(5, 2, 10, new L1()).stepOffset(10).decayExponent(0.7d).lambda(0.001d).learningRate(5.0d);
        int i = 0;
        int[] permute = permute(random, readCsv.numRows());
        for (int i2 = 0; i2 < 100; i2++) {
            for (int i3 : permute) {
                learningRate.train(i3, (int) readCsv.get(i3, 9), readCsv.viewRow(i3));
                int i4 = i;
                i++;
                System.out.printf("%d,%d,%.3f\n", Integer.valueOf(i2), Integer.valueOf(i4), Double.valueOf(learningRate.auc()));
            }
            assertEquals(1.0d, learningRate.auc(), 0.2d);
        }
        assertEquals(1.0d, learningRate.auc(), 0.1d);
    }

    @Test
    public void testClassify() {
        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(3, 2, new L2(1.0d));
        onlineLogisticRegression.setBeta(0, 0, -1.0d);
        onlineLogisticRegression.setBeta(1, 0, -2.0d);
        Vector classify = onlineLogisticRegression.classify(new DenseVector(new double[]{0.0d, 0.0d}));
        assertEquals(0.3333333333333333d, classify.get(0), 1.0E-8d);
        assertEquals(0.3333333333333333d, classify.get(1), 1.0E-8d);
        Vector classifyFull = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{0.0d, 0.0d}));
        assertEquals(1.0d, classifyFull.zSum(), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(0), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(1), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull.get(2), 1.0E-8d);
        Vector classify2 = onlineLogisticRegression.classify(new DenseVector(new double[]{0.0d, 1.0d}));
        assertEquals(0.3333333333333333d, classify2.get(0), 0.001d);
        assertEquals(0.3333333333333333d, classify2.get(1), 0.001d);
        Vector classifyFull2 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{0.0d, 1.0d}));
        assertEquals(1.0d, classifyFull2.zSum(), 1.0E-8d);
        assertEquals(0.3333333333333333d, classifyFull2.get(0), 0.001d);
        assertEquals(0.3333333333333333d, classifyFull2.get(1), 0.001d);
        assertEquals(0.3333333333333333d, classifyFull2.get(2), 0.001d);
        Vector classify3 = onlineLogisticRegression.classify(new DenseVector(new double[]{1.0d, 0.0d}));
        assertEquals(Math.exp(-1.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classify3.get(0), 1.0E-8d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classify3.get(1), 1.0E-8d);
        Vector classifyFull3 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 0.0d}));
        assertEquals(1.0d, classifyFull3.zSum(), 1.0E-8d);
        assertEquals(1.0d / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(0), 1.0E-8d);
        assertEquals(Math.exp(-1.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(1), 1.0E-8d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(-1.0d)) + Math.exp(-2.0d)), classifyFull3.get(2), 1.0E-8d);
        onlineLogisticRegression.setBeta(0, 1, 1.0d);
        Vector classifyFull4 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 1.0d}));
        assertEquals(1.0d, classifyFull4.zSum(), 1.0E-8d);
        assertEquals(Math.exp(0.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(1), 0.001d);
        assertEquals(Math.exp(-2.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(2), 0.001d);
        assertEquals(1.0d / ((1.0d + Math.exp(0.0d)) + Math.exp(-2.0d)), classifyFull4.get(0), 0.001d);
        onlineLogisticRegression.setBeta(1, 1, 3.0d);
        Vector classifyFull5 = onlineLogisticRegression.classifyFull(new DenseVector(new double[]{1.0d, 1.0d}));
        assertEquals(1.0d, classifyFull5.zSum(), 1.0E-8d);
        assertEquals(Math.exp(0.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(1), 1.0E-8d);
        assertEquals(Math.exp(1.0d) / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(2), 1.0E-8d);
        assertEquals(1.0d / ((1.0d + Math.exp(0.0d)) + Math.exp(1.0d)), classifyFull5.get(0), 1.0E-8d);
    }

    @Test
    public void testTrain() throws Exception {
        Vector readStandardData = readStandardData();
        OnlineLogisticRegression learningRate = new OnlineLogisticRegression(2, 8, new L1()).lambda(0.001d).learningRate(50.0d);
        train(getInput(), readStandardData, learningRate);
        test(getInput(), readStandardData, learningRate, 0.05d, 0.3d);
    }
}
