/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import Jama.QRDecomposition;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.solvers.GlmAlgorithm;
import java.util.Random;

public class GlmAlgorithmIrlsQrNewton
extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-20;
    private static Random rand = new Random();
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override
    public void copy(GlmAlgorithm rhs) {
        super.copy(rhs);
        GlmAlgorithmIrlsQrNewton rhs2 = (GlmAlgorithmIrlsQrNewton)rhs;
        this.A = rhs2.A == null ? null : (Matrix)rhs2.A.clone();
        this.b = rhs2.b == null ? null : (Matrix)rhs2.b.clone();
        this.At = rhs2.At == null ? null : (Matrix)rhs2.At.clone();
    }

    @Override
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrlsQrNewton clone = new GlmAlgorithmIrlsQrNewton();
        clone.copy(this);
        return clone;
    }

    public GlmAlgorithmIrlsQrNewton() {
    }

    public GlmAlgorithmIrlsQrNewton(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b) {
        super(distribution, linkFunc, (double[][])null, null, null);
        this.A = GlmAlgorithmIrlsQrNewton.toMatrix(A);
        this.b = GlmAlgorithmIrlsQrNewton.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithmIrlsQrNewton(GlmDistributionFamily distribution, double[][] A, double[] b) {
        super(distribution);
        this.A = GlmAlgorithmIrlsQrNewton.toMatrix(A);
        this.b = GlmAlgorithmIrlsQrNewton.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    private static Matrix toMatrix(double[][] A) {
        int m = A.length;
        int n = A[0].length;
        Matrix Am = new Matrix(m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                Am.set(i, j, A[i][j]);
            }
        }
        return Am;
    }

    private static Matrix columnVector(double[] b) {
        int m = b.length;
        Matrix B = new Matrix(m, 1);
        for (int i = 0; i < m; ++i) {
            B.set(i, 0, b[i]);
        }
        return B;
    }

    private static Matrix columnVector(int n) {
        return new Matrix(n, 1);
    }

    private static Matrix identity(int m) {
        Matrix A = new Matrix(m, m);
        for (int i = 0; i < m; ++i) {
            A.set(i, i, 1.0);
        }
        return A;
    }

    @Override
    public double[] solve() {
        int m = this.A.getRowDimension();
        int n = this.A.getColumnDimension();
        Matrix s = GlmAlgorithmIrlsQrNewton.columnVector(n);
        Matrix sy = GlmAlgorithmIrlsQrNewton.columnVector(n);
        for (int i = 0; i < n; ++i) {
            s.set(i, 0, 0.0);
        }
        Matrix t = GlmAlgorithmIrlsQrNewton.columnVector(m);
        for (int i = 0; i < m; ++i) {
            t.set(i, 0, 0.0);
        }
        double[] g = new double[m];
        double[] gprime = new double[m];
        QRDecomposition qr = this.A.qr();
        Matrix Q = qr.getQ();
        Matrix R = qr.getR();
        Matrix Qt = Q.transpose();
        double[] W = null;
        for (int j = 0; j < this.maxIters; ++j) {
            int k;
            double cross_prod;
            int i;
            Matrix z = GlmAlgorithmIrlsQrNewton.columnVector(m);
            for (int k2 = 0; k2 < m; ++k2) {
                g[k2] = this.linkFunc.GetInvLink(t.get(k2, 0));
                gprime[k2] = this.linkFunc.GetInvLinkDerivative(t.get(k2, 0));
                z.set(k2, 0, t.get(k2, 0) + (this.b.get(k2, 0) - g[k2]) / gprime[k2]);
            }
            W = new double[m];
            double w_kk_min = Double.MAX_VALUE;
            for (int k3 = 0; k3 < m; ++k3) {
                double w_kk;
                double g_variance = this.getVariance(g[k3]);
                W[k3] = w_kk = gprime[k3] * gprime[k3] / g_variance;
                w_kk_min = Math.min(w_kk, w_kk_min);
            }
            if (w_kk_min < Math.sqrt(1.0E-20)) {
                System.out.println("Warning: Tiny weights encountered, min(diag(W)) is too small");
            }
            Matrix s_old = s;
            Matrix WQ = new Matrix(m, n);
            Matrix Wz = GlmAlgorithmIrlsQrNewton.columnVector(m);
            for (int k4 = 0; k4 < m; ++k4) {
                Wz.set(k4, 0, z.get(k4, 0) * W[k4]);
                for (int k2 = 0; k2 < n; ++k2) {
                    WQ.set(k4, k2, Q.get(k4, k2) * W[k4]);
                }
            }
            Matrix QtWQ = Qt.times(WQ);
            Matrix QtWz = Qt.times(Wz);
            CholeskyDecomposition cholesky = QtWQ.chol();
            Matrix L = cholesky.getL();
            Matrix Lt = L.transpose();
            s = GlmAlgorithmIrlsQrNewton.columnVector(n);
            for (i = 0; i < n; ++i) {
                s.set(i, 0, 0.0);
                sy.set(i, 0, 0.0);
            }
            for (i = 0; i < n; ++i) {
                cross_prod = 0.0;
                for (k = 0; k < i; ++k) {
                    cross_prod += L.get(i, k) * sy.get(k, 0);
                }
                sy.set(i, 0, (QtWz.get(i, 0) - cross_prod) / L.get(i, i));
            }
            for (i = n - 1; i >= 0; --i) {
                cross_prod = 0.0;
                for (k = i + 1; k < n; ++k) {
                    cross_prod += Lt.get(i, k) * s.get(k, 0);
                }
                s.set(i, 0, (sy.get(i, 0) - cross_prod) / Lt.get(i, i));
            }
            t = Q.times(s);
            if (s_old.minus(s).norm2() < this.mTol) break;
        }
        this.glmCoefficients = new double[n];
        Matrix c = Qt.times(t);
        for (int i = n - 1; i >= 0; --i) {
            double cross_prod = 0.0;
            for (int j = i + 1; j < n; ++j) {
                cross_prod += R.get(i, j) * this.glmCoefficients[j];
            }
            this.glmCoefficients[i] = (c.get(i, 0) - cross_prod) / R.get(i, i);
        }
        this.updateStatistics(W);
        return this.getCoefficients();
    }

    private Matrix scalarMultiply(Matrix A, double[] v) {
        Matrix C;
        block5: {
            int n2;
            int m2;
            int m;
            block4: {
                m = v.length;
                m2 = A.getRowDimension();
                n2 = A.getColumnDimension();
                C = new Matrix(m2, n2);
                if (m != m2) break block4;
                for (int i = 0; i < m2; ++i) {
                    for (int j = 0; j < n2; ++j) {
                        C.set(i, j, A.get(i, j) * v[i]);
                    }
                }
                break block5;
            }
            if (m != n2) break block5;
            for (int i = 0; i < n2; ++i) {
                for (int j = 0; j < m2; ++j) {
                    C.set(j, i, A.get(j, i) * v[i]);
                }
            }
        }
        return C;
    }

    protected void updateStatistics(double[] W) {
        Matrix AtWA = this.scalarMultiply(this.At, W).times(this.A);
        Matrix AtWAInv = AtWA.inverse();
        int n = AtWAInv.getRowDimension();
        int m = this.b.getRowDimension();
        double[] stdErrors = this.mStats.getStandardErrors();
        double[][] VCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < n; ++i) {
            stdErrors[i] = Math.sqrt(AtWAInv.get(i, i));
            for (int j = 0; j < n; ++j) {
                VCovMatrix[i][j] = AtWAInv.get(i, j);
            }
        }
        double[] outcomes = new double[m];
        for (int i = 0; i < m; ++i) {
            double cross_prod = 0.0;
            for (int j = 0; j < n; ++j) {
                cross_prod += this.A.get(i, j) * this.glmCoefficients[j];
            }
            residuals[i] = this.b.get(i, 0) - this.linkFunc.GetInvLink(cross_prod);
            outcomes[i] = this.b.get(i, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(residuals, 0.0));
        this.mStats.setResponseMean(Mean.apply(outcomes));
        this.mStats.setResponseVariance(Variance.apply(outcomes, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance());
        this.mStats.setAdjustedR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance() * (double)(n - 1) / (double)(n - this.glmCoefficients.length - 1));
    }
}

