package org.apache.mahout.classifier.sgd;

import java.io.IOException;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.Vector;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/sgd/GradientMachineTest.class */
public final class GradientMachineTest extends OnlineBaseTest {
    @Test
    public void testGradientmachine() throws IOException {
        Vector readStandardData = readStandardData();
        GradientMachine regularization = new GradientMachine(8, 4, 2).learningRate(0.1d).regularization(0.01d);
        RandomUtils.useTestSeed();
        regularization.initWeights(RandomUtils.getRandom());
        train(getInput(), readStandardData, regularization);
        test(getInput(), readStandardData, regularization, 0.05d, 1.0d);
    }
}
