package org.apache.mahout.classifier.discriminative;

import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.class */
public final class PerceptronTrainerTest extends MahoutTestCase {
    private PerceptronTrainer trainer;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.trainer = new PerceptronTrainer(3, 0.5d, 0.1d, 1.0d, 1.0d);
    }

    @Test
    public void testUpdate() throws Exception {
        DenseVector denseVector = new DenseVector(new double[]{1.0d, 1.0d, 1.0d, 0.0d});
        double[][] dArr = new double[3][4];
        for (int i = 0; i < 3; i++) {
            dArr[i][0] = 1.0d;
            dArr[i][1] = 1.0d;
            dArr[i][2] = 1.0d;
            dArr[i][3] = 1.0d;
        }
        dArr[1][0] = 0.0d;
        dArr[2][0] = 0.0d;
        dArr[1][1] = 0.0d;
        dArr[2][2] = 0.0d;
        DenseMatrix denseMatrix = new DenseMatrix(dArr);
        this.trainer.train(denseVector, denseMatrix);
        assertFalse(this.trainer.getModel().classify(denseMatrix.viewColumn(3)));
        assertTrue(this.trainer.getModel().classify(denseMatrix.viewColumn(0)));
    }
}
