package org.apache.flink.ml.common.lossfunc;

import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/lossfunc/HingeLossTest.class */
public class HingeLossTest {
    private static final LabeledPointWithWeight dataPoint1 = new LabeledPointWithWeight(Vectors.dense(new double[]{1.0d, -1.0d, -1.0d}), 1.0d, 2.0d);
    private static final LabeledPointWithWeight dataPoint2 = new LabeledPointWithWeight(Vectors.dense(new double[]{1.0d, -1.0d, 1.0d}), 1.0d, 2.0d);
    private static final DenseVector coefficient = Vectors.dense(new double[]{1.0d, 1.0d, 1.0d});
    private static final DenseVector cumGradient = Vectors.dense(new double[]{0.0d, 0.0d, 0.0d});
    private static final double TOLERANCE = 1.0E-7d;

    @Test
    public void computeLoss() {
        Assert.assertEquals(4.0d, HingeLoss.INSTANCE.computeLoss(dataPoint1, coefficient), TOLERANCE);
        Assert.assertEquals(0.0d, HingeLoss.INSTANCE.computeLoss(dataPoint2, coefficient), TOLERANCE);
    }

    @Test
    public void computeGradient() {
        HingeLoss.INSTANCE.computeGradient(dataPoint1, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{-2.0d, 2.0d, 2.0d}, cumGradient.values, TOLERANCE);
        HingeLoss.INSTANCE.computeGradient(dataPoint2, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{-2.0d, 2.0d, 2.0d}, cumGradient.values, TOLERANCE);
    }
}
