package org.apache.mahout.classifier.sgd;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Random;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/ModelSerializerTest.class */
public final class ModelSerializerTest extends MahoutTestCase {
    private static <T extends Writable> T roundTrip(T t, Class<T> cls) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1000);
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        PolymorphicWritable.write(dataOutputStream, t);
        dataOutputStream.close();
        return (T) PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray())), cls);
    }

    @Test
    public void onlineAucRoundtrip() throws IOException {
        RandomUtils.useTestSeed();
        GlobalOnlineAuc globalOnlineAuc = new GlobalOnlineAuc();
        Random random = RandomUtils.getRandom();
        for (int i = 0; i < 10000; i++) {
            globalOnlineAuc.addSample(0, random.nextGaussian());
            globalOnlineAuc.addSample(1, random.nextGaussian() + 1.0d);
        }
        assertEquals(0.76d, globalOnlineAuc.auc(), 0.01d);
        OnlineAuc roundTrip = roundTrip(globalOnlineAuc, OnlineAuc.class);
        assertEquals(globalOnlineAuc.auc(), roundTrip.auc(), 0.0d);
        for (int i2 = 0; i2 < 1000; i2++) {
            globalOnlineAuc.addSample(0, random.nextGaussian());
            globalOnlineAuc.addSample(1, random.nextGaussian() + 1.0d);
            roundTrip.addSample(0, random.nextGaussian());
            roundTrip.addSample(1, random.nextGaussian() + 1.0d);
        }
        assertEquals(globalOnlineAuc.auc(), roundTrip.auc(), 0.01d);
    }

    @Test
    public void onlineLogisticRegressionRoundTrip() throws IOException {
        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(2, 5, new L1());
        train(onlineLogisticRegression, 100);
        OnlineLogisticRegression roundTrip = roundTrip(onlineLogisticRegression, OnlineLogisticRegression.class);
        assertEquals(0.0d, onlineLogisticRegression.getBeta().minus(roundTrip.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0E-6d);
        train(onlineLogisticRegression, 100);
        train(roundTrip, 100);
        assertEquals(0.0d, onlineLogisticRegression.getBeta().minus(roundTrip.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0E-6d);
    }

    @Test
    public void crossFoldLearnerRoundTrip() throws IOException {
        CrossFoldLearner crossFoldLearner = new CrossFoldLearner(5, 2, 5, new L1());
        train(crossFoldLearner, 100);
        CrossFoldLearner roundTrip = roundTrip(crossFoldLearner, CrossFoldLearner.class);
        double auc = crossFoldLearner.auc();
        assertTrue(auc > 0.85d);
        assertEquals(auc, crossFoldLearner.auc(), 1.0E-6d);
        assertEquals(auc, roundTrip.auc(), 1.0E-6d);
        train(crossFoldLearner, 100);
        train(crossFoldLearner, 100);
        train(roundTrip, 100);
        assertEquals(crossFoldLearner.auc(), crossFoldLearner.auc(), 0.02d);
        assertEquals(crossFoldLearner.auc(), roundTrip.auc(), 0.02d);
        assertTrue(crossFoldLearner.auc() > auc);
    }

    @Test
    public void adaptiveLogisticRegressionRoundTrip() throws IOException {
        AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 5, new L1());
        adaptiveLogisticRegression.setInterval(200);
        train(adaptiveLogisticRegression, 400);
        AdaptiveLogisticRegression roundTrip = roundTrip(adaptiveLogisticRegression, AdaptiveLogisticRegression.class);
        double auc = adaptiveLogisticRegression.auc();
        assertTrue(auc > 0.85d);
        assertEquals(auc, adaptiveLogisticRegression.auc(), 1.0E-6d);
        assertEquals(auc, roundTrip.auc(), 1.0E-6d);
        train(adaptiveLogisticRegression, 1000);
        train(adaptiveLogisticRegression, 1000);
        train(roundTrip, 1000);
        assertEquals(adaptiveLogisticRegression.auc(), adaptiveLogisticRegression.auc(), 0.005d);
        assertEquals(adaptiveLogisticRegression.auc(), roundTrip.auc(), 0.005d);
        double auc2 = adaptiveLogisticRegression.auc();
        assertTrue(String.format("%.3f > %.3f", Double.valueOf(auc2), Double.valueOf(auc)), auc2 > auc);
    }

    private static void train(OnlineLearner onlineLearner, int i) {
        DenseVector denseVector = new DenseVector(new double[]{1.0d, -1.0d, 0.0d, 0.5d, -0.5d});
        Random random = RandomUtils.getRandom();
        for (int i2 = 0; i2 < i; i2++) {
            Vector randomVector = randomVector(random, 5);
            onlineLearner.train(random.nextDouble() < denseVector.dot(randomVector) ? 1 : 0, randomVector);
        }
    }

    private static Vector randomVector(final Random random, int i) {
        DenseVector denseVector = new DenseVector(i);
        denseVector.assign(new DoubleFunction() { // from class: org.apache.mahout.classifier.sgd.ModelSerializerTest.1
            public double apply(double d) {
                return random.nextGaussian();
            }
        });
        return denseVector;
    }
}
