package org.apache.flink.ml.common.statistics.basicstatistic;

import org.apache.flink.ml.common.linalg.DenseMatrix;
import org.apache.flink.ml.common.linalg.DenseVector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussianTest.class */
public class MultivariateGaussianTest {
    private static final double TOL = 1.0E-5d;

    @Test
    public void testUnivariate() throws Exception {
        DenseVector denseVector = new DenseVector(new double[]{0.0d});
        DenseVector denseVector2 = new DenseVector(new double[]{1.5d});
        DenseVector zeros = DenseVector.zeros(1);
        MultivariateGaussian multivariateGaussian = new MultivariateGaussian(zeros, DenseMatrix.ones(1, 1));
        Assert.assertEquals(multivariateGaussian.pdf(denseVector), 0.39894d, TOL);
        Assert.assertEquals(multivariateGaussian.pdf(denseVector2), 0.12952d, TOL);
        DenseMatrix ones = DenseMatrix.ones(1, 1);
        ones.scaleEqual(4.0d);
        MultivariateGaussian multivariateGaussian2 = new MultivariateGaussian(zeros, ones);
        Assert.assertEquals(multivariateGaussian2.pdf(denseVector), 0.19947d, TOL);
        Assert.assertEquals(multivariateGaussian2.pdf(denseVector2), 0.15057d, TOL);
    }

    @Test
    public void testMultivariate() throws Exception {
        DenseVector zeros = DenseVector.zeros(2);
        MultivariateGaussian multivariateGaussian = new MultivariateGaussian(zeros, DenseMatrix.eye(2));
        Assert.assertEquals(multivariateGaussian.pdf(DenseVector.zeros(2)), 0.15915d, TOL);
        Assert.assertEquals(multivariateGaussian.pdf(DenseVector.ones(2)), 0.05855d, TOL);
        MultivariateGaussian multivariateGaussian2 = new MultivariateGaussian(zeros, new DenseMatrix(2, 2, new double[]{4.0d, -1.0d, -1.0d, 2.0d}));
        Assert.assertEquals(multivariateGaussian2.pdf(DenseVector.zeros(2)), 0.060155d, TOL);
        Assert.assertEquals(multivariateGaussian2.pdf(DenseVector.ones(2)), 0.033971d, TOL);
    }

    @Test
    public void testMultivariateDegenerate() throws Exception {
        MultivariateGaussian multivariateGaussian = new MultivariateGaussian(DenseVector.zeros(2), new DenseMatrix(2, 2, new double[]{1.0d, 1.0d, 1.0d, 1.0d}));
        Assert.assertEquals(multivariateGaussian.pdf(DenseVector.zeros(2)), 0.11254d, TOL);
        Assert.assertEquals(multivariateGaussian.pdf(DenseVector.ones(2)), 0.068259d, TOL);
    }
}
