package org.apache.flink.ml.common.statistics.basicstatistic;

import com.github.fommil.netlib.BLAS;
import com.github.fommil.netlib.F2jBLAS;
import com.github.fommil.netlib.LAPACK;
import org.apache.flink.ml.common.linalg.DenseMatrix;
import org.apache.flink.ml.common.linalg.DenseVector;
import org.apache.flink.ml.common.linalg.SparseVector;
import org.apache.flink.ml.common.linalg.Vector;
import org.apache.flink.shaded.guava18.com.google.common.primitives.Doubles;
import org.netlib.util.intW;

/* loaded from: input_file:org/apache/flink/ml/common/statistics/basicstatistic/MultivariateGaussian.class */
public class MultivariateGaussian {
    private static final LAPACK LAPACK_INST = LAPACK.getInstance();
    private static final BLAS F2J_BLAS_INST = F2jBLAS.getInstance();
    private static final double EPSILON;
    private final DenseVector mean;
    private final DenseMatrix cov;
    private DenseMatrix rootSigmaInv;
    private double u;
    private DenseVector delta;
    private DenseVector v;

    public MultivariateGaussian(DenseVector denseVector, DenseMatrix denseMatrix) {
        this.mean = denseVector;
        this.cov = denseMatrix;
        this.delta = DenseVector.zeros(denseVector.size());
        this.v = DenseVector.zeros(denseVector.size());
        calculateCovarianceConstants();
    }

    public double pdf(Vector vector) {
        return Math.exp(logpdf(vector));
    }

    public double logpdf(Vector vector) {
        System.arraycopy(this.mean.getData(), 0, this.delta.getData(), 0, this.mean.size());
        org.apache.flink.ml.common.linalg.BLAS.scal(-1.0d, this.delta);
        if (vector instanceof DenseVector) {
            org.apache.flink.ml.common.linalg.BLAS.axpy(1.0d, (DenseVector) vector, this.delta);
        } else if (vector instanceof SparseVector) {
            org.apache.flink.ml.common.linalg.BLAS.axpy(1.0d, (SparseVector) vector, this.delta);
        }
        org.apache.flink.ml.common.linalg.BLAS.gemv(1.0d, this.rootSigmaInv, true, this.delta, 0.0d, this.v);
        return this.u - (0.5d * org.apache.flink.ml.common.linalg.BLAS.dot(this.v, this.v));
    }

    private void calculateCovarianceConstants() {
        int size = this.mean.size();
        int i = (3 * size) - 1;
        double[] dArr = new double[size * size];
        double[] dArr2 = new double[i];
        double[] dArr3 = new double[size];
        intW intw = new intW(0);
        System.arraycopy(this.cov.getData(), 0, dArr, 0, size * size);
        LAPACK_INST.dsyev("V", "U", size, dArr, size, dArr3, dArr2, i, intw);
        double max = EPSILON * size * Doubles.max(dArr3);
        double d = 0.0d;
        for (double d2 : dArr3) {
            if (d2 > max) {
                d += Math.log(d2);
            }
        }
        for (int i2 = 0; i2 < size; i2++) {
            F2J_BLAS_INST.dscal(size, dArr3[i2] > max ? Math.sqrt(1.0d / dArr3[i2]) : 0.0d, dArr, i2 * size, 1);
        }
        this.rootSigmaInv = new DenseMatrix(size, size, dArr);
        this.u = (-0.5d) * ((size * Math.log(6.283185307179586d)) + d);
    }

    static {
        double d = 1.0d;
        while (true) {
            double d2 = d;
            if (1.0d + (d2 / 2.0d) == 1.0d) {
                EPSILON = d2;
                return;
            }
            d = d2 / 2.0d;
        }
    }
}
