package org.apache.mahout.classifier.discriminative;

import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/discriminative/LinearModelTest.class */
public final class LinearModelTest extends MahoutTestCase {
    private LinearModel model;

    @Override // org.apache.mahout.common.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.model = new LinearModel(new DenseVector(new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d}), 0.1d, 0.5d);
    }

    @Test
    public void testClassify() {
        assertFalse(this.model.classify(new DenseVector(new double[]{1.0d, 0.0d, 1.0d, 0.0d, 1.0d})));
        assertTrue(this.model.classify(new DenseVector(new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d})));
    }

    @Test
    public void testAddDelta() {
        this.model.addDelta(new DenseVector(new double[]{1.0d, -1.0d, 1.0d, -1.0d, 1.0d}));
        assertTrue(this.model.classify(new DenseVector(new double[]{1.0d, 0.0d, 1.0d, 0.0d, 1.0d})));
        assertFalse(this.model.classify(new DenseVector(new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d})));
    }

    @Test
    public void testTimesDelta() {
        this.model.addDelta(new DenseVector(new double[]{-1.0d, -1.0d, -1.0d, -1.0d, -1.0d}));
        double[] dArr = {-1.0d, -1.0d, -1.0d, -1.0d, -1.0d};
        for (int i = 0; i < dArr.length; i++) {
            this.model.timesDelta(i, dArr[i]);
        }
        assertTrue(this.model.classify(new DenseVector(new double[]{1.0d, 0.0d, 1.0d, 0.0d, 1.0d})));
        assertFalse(this.model.classify(new DenseVector(new double[]{0.0d, 1.0d, 0.0d, 1.0d, 0.0d})));
    }
}
