/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import Jama.Matrix;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.maths.Mean;
import com.github.chen0040.glm.maths.StdDev;
import com.github.chen0040.glm.maths.Variance;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.solvers.GlmAlgorithm;

public class GlmAlgorithmIrls
extends GlmAlgorithm {
    private static final double EPSILON = 1.0E-20;
    private Matrix A;
    private Matrix b;
    private Matrix At;

    @Override
    public void copy(GlmAlgorithm rhs) {
        super.copy(rhs);
        GlmAlgorithmIrls rhs2 = (GlmAlgorithmIrls)rhs;
        this.A = rhs2.A == null ? null : (Matrix)rhs2.A.clone();
        this.b = rhs2.b == null ? null : (Matrix)rhs2.b.clone();
        this.At = rhs2.At == null ? null : (Matrix)rhs2.At.clone();
    }

    @Override
    public GlmAlgorithm makeCopy() {
        GlmAlgorithmIrls clone = new GlmAlgorithmIrls();
        clone.copy(this);
        return clone;
    }

    public GlmAlgorithmIrls() {
    }

    public GlmAlgorithmIrls(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b) {
        super(distribution, linkFunc, (double[][])null, null, null);
        this.A = GlmAlgorithmIrls.toMatrix(A);
        this.b = GlmAlgorithmIrls.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithmIrls(GlmDistributionFamily distribution, double[][] A, double[] b) {
        super(distribution);
        this.A = GlmAlgorithmIrls.toMatrix(A);
        this.b = GlmAlgorithmIrls.columnVector(b);
        this.At = this.A.transpose();
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    private static Matrix toMatrix(double[][] A) {
        int m = A.length;
        int n = A[0].length;
        Matrix Am = new Matrix(m, n);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                Am.set(i, j, (double)((float)A[i][j]));
            }
        }
        return Am;
    }

    private static Matrix columnVector(double[] b) {
        int m = b.length;
        Matrix B = new Matrix(m, 1);
        for (int i = 0; i < m; ++i) {
            B.set(i, 0, b[i]);
        }
        return B;
    }

    private static Matrix columnVector(int n) {
        return new Matrix(n, 1);
    }

    private static Matrix identity(int m) {
        Matrix A = new Matrix(m, m);
        for (int i = 0; i < m; ++i) {
            A.set(i, i, 1.0);
        }
        return A;
    }

    @Override
    public double[] solve() {
        int m = this.A.getRowDimension();
        int n = this.A.getColumnDimension();
        Matrix x = GlmAlgorithmIrls.columnVector(n);
        Matrix W = null;
        Matrix AtWAInv = null;
        for (int j = 0; j < this.maxIters; ++j) {
            int k;
            Matrix eta = this.A.times(x);
            Matrix z = GlmAlgorithmIrls.columnVector(m);
            double[] g = new double[m];
            double[] gprime = new double[m];
            for (k = 0; k < m; ++k) {
                g[k] = this.linkFunc.GetInvLink(eta.get(k, 0));
                gprime[k] = this.linkFunc.GetInvLinkDerivative(eta.get(k, 0));
                z.set(k, 0, eta.get(k, 0) + (this.b.get(k, 0) - g[k]) / gprime[k]);
            }
            W = GlmAlgorithmIrls.identity(m);
            for (k = 0; k < m; ++k) {
                double g_variance = this.getVariance(g[k]);
                if (g_variance == 0.0) {
                    g_variance = 1.0E-20;
                }
                W.set(k, k, gprime[k] * gprime[k] / g_variance);
            }
            Matrix x_old = x;
            Matrix AtW = this.At.times(W);
            Matrix AtWA = AtW.times(this.A);
            AtWAInv = AtWA.inverse();
            x = AtWAInv.times(AtW).times(z);
            if (x.minus(x_old).norm2() < this.mTol) break;
        }
        this.glmCoefficients = new double[n];
        for (int i = 0; i < n; ++i) {
            this.glmCoefficients[i] = x.get(i, 0);
        }
        this.updateStatistics(AtWAInv, W);
        return this.glmCoefficients;
    }

    private void updateStatistics(Matrix vcovmat, Matrix W) {
        int n = vcovmat.getRowDimension();
        int m = this.b.getRowDimension();
        double[] stdErr = this.mStats.getStandardErrors();
        double[][] VCovMatrix = this.mStats.getVCovMatrix();
        double[] residuals = this.mStats.getResiduals();
        for (int i = 0; i < n; ++i) {
            stdErr[i] = Math.sqrt(vcovmat.get(i, i));
            for (int j = 0; j < n; ++j) {
                VCovMatrix[i][j] = vcovmat.get(i, j);
            }
        }
        double[] outcomes = new double[m];
        for (int i = 0; i < m; ++i) {
            double cross_prod = 0.0;
            for (int j = 0; j < n; ++j) {
                cross_prod += this.A.get(i, j) * this.glmCoefficients[j];
            }
            residuals[i] = this.b.get(i, 0) - this.linkFunc.GetInvLink(cross_prod);
            outcomes[i] = this.b.get(i, 0);
        }
        this.mStats.setResidualStdDev(StdDev.apply(this.mStats.getResiduals(), 0.0));
        this.mStats.setResponseMean(Mean.apply(outcomes));
        this.mStats.setResponseVariance(Variance.apply(outcomes, this.mStats.getResponseMean()));
        this.mStats.setR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance());
        this.mStats.setAdjustedR2(1.0 - this.mStats.getResidualStdDev() * this.mStats.getResidualStdDev() / this.mStats.getResponseVariance() * (double)(n - 1) / (double)(n - this.glmCoefficients.length - 1));
    }
}

