package mikera.matrixx.algo;

import mikera.matrixx.AMatrix;
import mikera.matrixx.Matrix;
import mikera.matrixx.impl.DiagonalMatrix;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.impl.AStridedVector;

/* loaded from: input_file:mikera/matrixx/algo/PLS.class */
public class PLS implements IPLSResult {
    private final AMatrix origX;
    private final Matrix X;
    private final Matrix Y;
    private final Matrix P;
    private final Matrix Q;
    private final Matrix T;
    private final Matrix U;
    private final Matrix W;
    private final Vector b;
    private final DiagonalMatrix B;
    private final Matrix coefficients;
    private final Vector constant;
    private final int l;
    private final int n;
    private final int m;
    private final int p;

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getX() {
        return this.origX;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getY() {
        return this.Y;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getT() {
        return this.T;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getP() {
        return this.P;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getQ() {
        return this.Q;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getW() {
        return this.W;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getB() {
        return this.B;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AMatrix getCoefficients() {
        return this.coefficients;
    }

    @Override // mikera.matrixx.algo.IPLSResult
    public AVector getConstant() {
        return this.constant;
    }

    private PLS(AMatrix aMatrix, AMatrix aMatrix2, int i) {
        this.origX = aMatrix;
        this.Y = Matrix.create(aMatrix2);
        this.X = Matrix.create(this.origX);
        this.n = aMatrix.rowCount();
        this.m = aMatrix.columnCount();
        this.l = i;
        this.p = aMatrix2.columnCount();
        if (aMatrix2.rowCount() != this.n) {
            throw new IllegalArgumentException("PLS regression requires equal number of rows in X annd Y matrices");
        }
        this.T = Matrix.create(this.n, this.l);
        this.U = Matrix.create(this.n, this.l);
        this.P = Matrix.create(this.m, this.l);
        this.Q = Matrix.create(this.p, this.l);
        this.W = Matrix.create(this.m, this.l);
        this.b = Vector.createLength(this.l);
        this.B = DiagonalMatrix.createDimensions(this.l);
        this.coefficients = Matrix.create(this.m, this.p);
        this.constant = Vector.createLength(this.p);
    }

    public static IPLSResult calculate(AMatrix aMatrix, AMatrix aMatrix2, int i) {
        PLS pls = new PLS(aMatrix, aMatrix2, i);
        pls.calcResult();
        return pls;
    }

    private int selectMaxSSColumn(AMatrix aMatrix) {
        int i = 0;
        double d = 0.0d;
        for (int i2 = 0; i2 < this.m; i2++) {
            double elementSquaredSum = aMatrix.getColumn(i2).elementSquaredSum();
            if (elementSquaredSum > d) {
                i = i2;
                d = elementSquaredSum;
            }
        }
        return i;
    }

    private void calcResult() {
        ADenseArrayVector createLength = Vector.createLength(this.n);
        AVector createLength2 = Vector.createLength(this.m);
        ADenseArrayVector createLength3 = Vector.createLength(this.n);
        Vector createLength4 = Vector.createLength(this.n);
        AVector createLength5 = Vector.createLength(this.p);
        AVector createLength6 = Vector.createLength(this.m);
        for (int i = 0; i < this.m; i++) {
            AStridedVector columnView = this.X.getColumnView(i);
            columnView.add(-(columnView.elementSum() / this.n));
        }
        for (int i2 = 0; i2 < this.l; i2++) {
            createLength.set(this.X.getColumn(selectMaxSSColumn(this.X)));
            int i3 = 0;
            while (true) {
                int i4 = i3;
                i3++;
                if (i4 <= 10) {
                    createLength2.setInnerProduct((AVector) createLength, this.X);
                    createLength2.normalise();
                    createLength3.setInnerProduct(this.X, createLength2);
                    createLength3.normalise();
                    createLength5.setInnerProduct((AVector) createLength3, this.Y);
                    if (createLength5.normalise() == 0.0d) {
                        break;
                    }
                    createLength.setInnerProduct(this.Y, createLength5);
                    if (createLength3.distance(createLength4) < 1.0E-11d) {
                        break;
                    } else {
                        createLength4.set(createLength3);
                    }
                }
            }
            this.U.setColumn(i2, createLength);
            this.W.setColumn(i2, createLength2);
            this.T.setColumn(i2, createLength3);
            this.Q.setColumn(i2, createLength5);
            this.b.set(i2, createLength3.dotProduct(createLength));
            createLength6.setInnerProduct((AVector) createLength3, this.X);
            this.P.setColumn(i2, createLength6);
            createLength6.negate();
            this.X.addOuterProduct((AVector) createLength3, createLength6);
        }
        this.B.getLeadingDiagonal().set((ADenseArrayVector) this.b);
        this.coefficients.setInnerProduct(PseudoInverse.calculate(this.P.getTranspose()), this.B.innerProduct((AMatrix) this.Q.getTranspose()));
        this.constant.set((AVector) this.Q.getColumn(0));
        this.constant.addInnerProduct(this.P.getColumn(0), this.coefficients, -1.0d);
    }
}
