package org.apache.flink.ml.common.lossfunc;

import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.class */
public class LeastSquareLossTest {
    private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(new double[]{1.0d, 2.0d, 3.0d}), 1.0d, 2.0d);
    private static final DenseVector coefficient = Vectors.dense(new double[]{1.0d, 1.0d, 1.0d});
    private static final DenseVector cumGradient = Vectors.dense(new double[]{0.0d, 0.0d, 0.0d});
    private static final double TOLERANCE = 1.0E-7d;

    @Test
    public void computeLoss() {
        Assert.assertEquals(25.0d, LeastSquareLoss.INSTANCE.computeLoss(dataPoint, coefficient), TOLERANCE);
    }

    @Test
    public void computeGradient() {
        LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{10.0d, 20.0d, 30.0d}, cumGradient.values, TOLERANCE);
        LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{20.0d, 40.0d, 60.0d}, cumGradient.values, TOLERANCE);
    }
}
