package net.digital_alexandria.lvm4j.decomposition;

import net.digital_alexandria.lvm4j.Decomposition;
import net.digital_alexandria.lvm4j.util.Math;
import net.digital_alexandria.lvm4j.util.Matrix;
import org.ejml.simple.SimpleMatrix;
import org.ejml.simple.SimpleSVD;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.inverse.InvertMatrix;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:net/digital_alexandria/lvm4j/decomposition/FactorAnalysis.class */
public final class FactorAnalysis implements Decomposition {
    private static final double _THRESHOLD = 1.0E-4d;
    private static final int _MAXIT = 10000;
    private static final double PSEUDO_COUNT = 1.0E-12d;
    private final INDArray _X;
    private final int _N;
    private final int _P;
    private INDArray _f;
    private INDArray _psi;

    /* JADX INFO: Access modifiers changed from: package-private */
    public FactorAnalysis(double[][] dArr) {
        this(Nd4j.create(dArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public FactorAnalysis(INDArray iNDArray) {
        this._X = Matrix.scale(iNDArray, true, false);
        this._N = this._X.rows();
        this._P = this._X.columns();
    }

    @Override // net.digital_alexandria.lvm4j.Decomposition
    public final INDArray run(int i) {
        fit(i);
        return decomp(i);
    }

    private void fit(int i) {
        double d;
        INDArray factorUpdate;
        INDArray var = this._X.var(new int[]{0});
        INDArray eye = Nd4j.eye(this._P);
        double d2 = Double.MIN_VALUE;
        int i2 = 0;
        do {
            d = d2;
            INDArray sqrtPsis = sqrtPsis(eye);
            SimpleSVD svd = svd(this._X.dup(), sqrtPsis.data().asDouble());
            INDArray singularValues = getSingularValues(svd.getW(), i);
            INDArray rightSingularVectors = getRightSingularVectors(svd.getV(), i);
            d2 = proploglik(singularValues, unexplainedVariance(svd.getW(), i), eye);
            factorUpdate = factorUpdate(singularValues, rightSingularVectors, sqrtPsis);
            eye = vcovUpdate(var, factorUpdate);
            int i3 = i2;
            i2++;
            if (i3 >= _MAXIT) {
                break;
            }
        } while (Math.abs(d2 - d) > _THRESHOLD);
        this._f = factorUpdate;
        this._psi = eye;
    }

    private INDArray sqrtPsis(INDArray iNDArray) {
        return Transforms.sqrt(Nd4j.diag(iNDArray)).add(Double.valueOf(PSEUDO_COUNT));
    }

    private SimpleSVD svd(INDArray iNDArray, double[] dArr) {
        double sqrt = Math.sqrt(this._N);
        for (int i = 0; i < dArr.length; i++) {
            iNDArray.getColumn(i).assign(iNDArray.getColumn(i).div(Double.valueOf(dArr[i] * sqrt)));
        }
        return new SimpleMatrix(this._N, this._P, true, iNDArray.data().asDouble()).svd(true);
    }

    private INDArray getSingularValues(SimpleMatrix simpleMatrix, int i) {
        double[] dArr = simpleMatrix.extractDiag().getMatrix().data;
        INDArray create = Nd4j.create(i);
        for (int i2 = 0; i2 < i; i2++) {
            create.getColumn(i2).assign(Double.valueOf(Math.pow(dArr[i2], 2.0d)));
        }
        return create;
    }

    private INDArray getRightSingularVectors(SimpleMatrix simpleMatrix, int i) {
        return Nd4j.create(simpleMatrix.transpose().extractMatrix(0, i, 0, simpleMatrix.numCols()).getMatrix().data, new int[]{i, simpleMatrix.numCols()}, 'r');
    }

    private double unexplainedVariance(SimpleMatrix simpleMatrix, int i) {
        double d = 0.0d;
        for (int i2 = i; i2 < simpleMatrix.numCols(); i2++) {
            double d2 = simpleMatrix.get(i2, i2);
            d += d2 * d2;
        }
        return d;
    }

    private INDArray factorUpdate(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int columns = iNDArray.columns();
        INDArray create = Nd4j.create(columns);
        for (int i = 0; i < columns; i++) {
            create.getColumn(i).assign(Double.valueOf(Math.sqrt(Math.max(iNDArray.getDouble(i) - 1.0d, 0.0d))));
        }
        INDArray transpose = iNDArray2.transpose().mmul(Nd4j.diag(create)).transpose();
        for (int i2 = 0; i2 < transpose.columns(); i2++) {
            transpose.getColumn(i2).assign(transpose.getColumn(i2).mul(Double.valueOf(iNDArray3.getDouble(i2))));
        }
        return transpose;
    }

    private INDArray vcovUpdate(INDArray iNDArray, INDArray iNDArray2) {
        INDArray sum = Transforms.pow(iNDArray2, 2).sum(new int[]{0});
        for (int i = 0; i < sum.columns(); i++) {
            sum.getColumn(i).assign(Double.valueOf(Double.max(iNDArray.getDouble(i) - sum.getDouble(i), PSEUDO_COUNT)));
        }
        return Nd4j.diag(sum);
    }

    private double proploglik(INDArray iNDArray, double d, INDArray iNDArray2) {
        return Math.sum(Math.log(iNDArray.data().asDouble())) + d + Math.sum(Math.log(Nd4j.diag(iNDArray2).data().asDouble()));
    }

    private INDArray decomp(int i) {
        INDArray eye = Nd4j.eye(i);
        INDArray wpsi = wpsi();
        return this._X.mmul(wpsi.transpose()).mmul(coz(i, eye, wpsi.mmul(this._f.transpose())));
    }

    private INDArray wpsi() {
        INDArray dup = this._f.dup();
        INDArray diag = Nd4j.diag(this._psi);
        for (int i = 0; i < dup.rows(); i++) {
            dup.getRow(i).assign(dup.getRow(i).div(diag));
        }
        return dup;
    }

    private INDArray coz(int i, INDArray iNDArray, INDArray iNDArray2) {
        return i == 1 ? Nd4j.ones(1).div(iNDArray.add(iNDArray2)) : InvertMatrix.invert(iNDArray.add(iNDArray2), false);
    }
}
