package org.apache.flink.ml.common.lossfunc;

import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.class */
public class BinaryLogisticLossTest {
    private static final LabeledPointWithWeight dataPoint = new LabeledPointWithWeight(Vectors.dense(new double[]{1.0d, 2.0d, 3.0d}), 1.0d, 2.0d);
    private static final DenseVector coefficient = Vectors.dense(new double[]{1.0d, 1.0d, 1.0d});
    private static final DenseVector cumGradient = Vectors.dense(new double[]{0.0d, 0.0d, 0.0d});
    private static final double TOLERANCE = 1.0E-7d;

    @Test
    public void computeLoss() {
        Assert.assertEquals(0.0049513d, BinaryLogisticLoss.INSTANCE.computeLoss(dataPoint, coefficient), TOLERANCE);
    }

    @Test
    public void computeGradient() {
        BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{-0.0049452d, -0.0098904d, -0.0148357d}, cumGradient.values, TOLERANCE);
        BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient);
        Assert.assertArrayEquals(new double[]{-0.0098904d, -0.0197809d, -0.0296714d}, cumGradient.values, TOLERANCE);
    }
}
