/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import Jama.SingularValueDecomposition;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.solvers.GlmAlgorithm;

public class GlmAlgorithmIrlsSvdNewton
extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-34;
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override
    public void copy(GlmAlgorithm rhs) {
        super.copy(rhs);
        GlmAlgorithmIrlsSvdNewton rhs2 = (GlmAlgorithmIrlsSvdNewton)rhs;
        this.A = rhs2.A == null ? null : (Matrix)rhs2.A.clone();
        this.b = rhs2.b == null ? null : (Matrix)rhs2.b.clone();
        this.At = rhs2.At == null ? null : (Matrix)rhs2.At.clone();
    }

    @Override
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrlsSvdNewton clone = new GlmAlgorithmIrlsSvdNewton();
        clone.copy(this);
        return clone;
    }

    public GlmAlgorithmIrlsSvdNewton() {
    }

    public GlmAlgorithmIrlsSvdNewton(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b) {
        super(distribution, linkFunc, (double[][])null, null, null);
        this.A = new Matrix(A);
        this.b = GlmAlgorithmIrlsSvdNewton.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithmIrlsSvdNewton(GlmDistributionFamily distribution, double[][] A, double[] b) {
        super(distribution);
        this.A = new Matrix(A);
        this.b = GlmAlgorithmIrlsSvdNewton.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    private static Matrix columnVector(double[] b) {
        int m = b.length;
        Matrix B = new Matrix(m, 1);
        for (int i = 0; i < m; ++i) {
            B.set(i, 0, b[i]);
        }
        return B;
    }

    private static Matrix columnVector(int n) {
        return new Matrix(n, 1);
    }

    @Override
    public double[] solve() {
        int m = this.A.getRowDimension();
        int n = this.A.getColumnDimension();
        int m2 = Math.min(m, n);
        Matrix t = GlmAlgorithmIrlsSvdNewton.columnVector(m);
        Matrix s = GlmAlgorithmIrlsSvdNewton.columnVector(n);
        Matrix sy = GlmAlgorithmIrlsSvdNewton.columnVector(n);
        SingularValueDecomposition svd = this.A.svd();
        Matrix U = svd.getU();
        Matrix V = svd.getV();
        Matrix Sigma = svd.getS();
        Matrix Ut = U.transpose();
        Matrix SigmaInv = new Matrix(m2, m2);
        for (int i = 0; i < m2; ++i) {
            double sigma_i = Sigma.get(i, i);
            if (sigma_i < 1.0E-34) {
                System.out.println("Near rank-deficient model matrix");
                return null;
            }
            SigmaInv.set(i, i, 1.0 / sigma_i);
        }
        SigmaInv = SigmaInv.transpose();
        double[] W = new double[m];
        for (int j = 0; j < this.maxIters; ++j) {
            int k;
            double cross_prod;
            int i;
            Matrix z = GlmAlgorithmIrlsSvdNewton.columnVector(m);
            double[] g = new double[m];
            double[] gprime = new double[m];
            for (int k2 = 0; k2 < m; ++k2) {
                g[k2] = this.linkFunc.GetInvLink(t.get(k2, 0));
                gprime[k2] = this.linkFunc.GetInvLinkDerivative(t.get(k2, 0));
                z.set(k2, 0, t.get(k2, 0) + (this.b.get(k2, 0) - g[k2]) / gprime[k2]);
            }
            int tiny_weight_count = 0;
            for (int k3 = 0; k3 < m; ++k3) {
                double w_kk;
                W[k3] = w_kk = gprime[k3] * gprime[k3] / this.getVariance(g[k3]);
                if (!(w_kk < 2.0E-34)) continue;
                ++tiny_weight_count;
            }
            if (tiny_weight_count > 0) {
                System.out.println("Warning: tiny weights encountered, (diag(W)) is too small");
            }
            Matrix s_old = s;
            Matrix UtW = new Matrix(m2, m);
            for (int k4 = 0; k4 < m2; ++k4) {
                for (int k2 = 0; k2 < m; ++k2) {
                    UtW.set(k4, k2, Ut.get(k4, k2) * W[k2]);
                }
            }
            Matrix UtWU = UtW.times(U);
            CholeskyDecomposition cholesky = UtWU.chol();
            Matrix L = cholesky.getL();
            Matrix Lt = L.transpose();
            Matrix UtWz = UtW.times(z);
            s = GlmAlgorithmIrlsSvdNewton.columnVector(n);
            for (i = 0; i < n; ++i) {
                s.set(i, 0, 0.0);
                sy.set(i, 0, 0.0);
            }
            for (i = 0; i < n; ++i) {
                cross_prod = 0.0;
                for (k = 0; k < i; ++k) {
                    cross_prod += L.get(i, k) * sy.get(k, 0);
                }
                sy.set(i, 0, (UtWz.get(i, 0) - cross_prod) / L.get(i, i));
            }
            for (i = n - 1; i >= 0; --i) {
                cross_prod = 0.0;
                for (k = i + 1; k < n; ++k) {
                    cross_prod += Lt.get(i, k) * s.get(k, 0);
                }
                s.set(i, 0, (sy.get(i, 0) - cross_prod) / Lt.get(i, i));
            }
            t = U.times(s);
            if (s_old.minus(s).norm2() < this.mTol) break;
        }
        Matrix x = V.times(SigmaInv).times(Ut).times(t);
        this.glmCoefficients = new double[n];
        for (int i = 0; i < n; ++i) {
            this.glmCoefficients[i] = x.get(i, 0);
        }
        this.updateStatistics(W);
        return this.getCoefficients();
    }

    @Override
    public double[] getCoefficients() {
        return this.glmCoefficients;
    }

    private Matrix scalarMultiply(Matrix A, double[] v) {
        Matrix C;
        block5: {
            int n2;
            int m2;
            int m;
            block4: {
                m = v.length;
                m2 = A.getRowDimension();
                n2 = A.getColumnDimension();
                C = new Matrix(m2, n2);
                if (m != m2) break block4;
                for (int i = 0; i < m2; ++i) {
                    for (int j = 0; j < n2; ++j) {
                        C.set(i, j, A.get(i, j) * v[i]);
                    }
                }
                break block5;
            }
            if (m != n2) break block5;
            for (int i = 0; i < n2; ++i) {
                for (int j = 0; j < m2; ++j) {
                    C.set(j, i, A.get(j, i) * v[i]);
                }
            }
        }
        return C;
    }

    protected void updateStatistics(double[] W) {
        Matrix AtWA = this.scalarMultiply(this.At, W).times(this.A);
        Matrix AtWAInv = AtWA.inverse();
        int n = AtWAInv.getRowDimension();
        int m = this.b.getRowDimension();
        double[] stdErrors = this.mStats.getStandardErrors();
        double[][] VCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < n; ++i) {
            stdErrors[i] = Math.sqrt(AtWAInv.get(i, i));
            for (int j = 0; j < n; ++j) {
                VCovMatrix[i][j] = AtWAInv.get(i, j);
            }
        }
        double[] outcomes = new double[m];
        for (int i = 0; i < m; ++i) {
            double cross_prod = 0.0;
            for (int j = 0; j < n; ++j) {
                cross_prod += this.A.get(i, j) * this.glmCoefficients[j];
            }
            residuals[i] = this.b.get(i, 0) - this.linkFunc.GetInvLink(cross_prod);
            outcomes[i] = this.b.get(i, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(residuals, 0.0));
        this.mStats.setResponseMean(Mean.apply(outcomes));
        this.mStats.setResponseVariance(Variance.apply(outcomes, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance());
        this.mStats.setAdjustedR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance() * (double)(n - 1) / (double)(n - this.glmCoefficients.length - 1));
    }
}

