/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.links.AbstractLinkFunction;
import com.github.chen0040.glm.links.IdentityLinkFunction;
import com.github.chen0040.glm.links.InverseLinkFunction;
import com.github.chen0040.glm.links.InverseSquaredLinkFunction;
import com.github.chen0040.glm.links.LinkFunction;
import com.github.chen0040.glm.links.LogLinkFunction;
import com.github.chen0040.glm.links.LogitLinkFunction;
import com.github.chen0040.glm.maths.MatrixOp;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.search.CostEvaluationMethod;
import com.github.chen0040.glm.search.GradientEvaluationMethod;
import com.github.chen0040.glm.search.LocalSearch;
import com.github.chen0040.glm.search.TerminationEvaluationMethod;
import com.github.chen0040.glm.search.methods.cgs.NonlinearCGSearch;
import com.github.chen0040.glm.search.solutions.NumericSolution;
import java.util.Random;
import sun.reflect.generics.reflectiveObjects.NotImplementedException;

public class GlmAlgorithm
implements Cloneable {
    private static Random random = new Random();
    protected LinkFunction linkFunc;
    protected int maxIters = 25;
    protected double mTol = 1.0E-6;
    protected double mRegularizationLambda = 0.0;
    protected GlmDistributionFamily mDistributionFamily;
    protected GlmStatistics mStats = new GlmStatistics();
    protected TerminationEvaluationMethod shouldTerminate = (state, iteration) -> {
        if (!state.improved() || state.improvement() < this.mTol) {
            return false;
        }
        return iteration >= this.maxIters;
    };
    protected double[] glmCoefficients;
    private LocalSearch solver;
    private double[][] A;
    private double[] b;
    protected CostEvaluationMethod evaluateCost = new CostEvaluationMethod(){

        @Override
        public double apply(double[] x, double[] lowerBounds, double[] upperBounds, Object constraint) {
            int m = GlmAlgorithm.this.b.length;
            int n = x.length;
            double[] c = MatrixOp.Multiply(GlmAlgorithm.this.A, x);
            double crossprod = 0.0;
            for (int i = 0; i < m; ++i) {
                double g = GlmAlgorithm.this.linkFunc.GetInvLink(c[i]);
                double gprime = GlmAlgorithm.this.linkFunc.GetInvLinkDerivative(c[i]);
                double d = g - GlmAlgorithm.this.b[i];
                crossprod += d * d;
            }
            double J = crossprod / (double)(2 * m);
            for (int j = 1; j < n; ++j) {
                J += GlmAlgorithm.this.mRegularizationLambda * x[j] * x[j] / (double)(2 * m);
            }
            return J;
        }
    };
    protected GradientEvaluationMethod evaluateGradient = new GradientEvaluationMethod(){

        @Override
        public void apply(double[] x, double[] gradx, double[] lowerBounds, double[] upperBounds, Object constraint) {
            int m = GlmAlgorithm.this.b.length;
            int n = GlmAlgorithm.this.A[0].length;
            double[] c = MatrixOp.Multiply(GlmAlgorithm.this.A, x);
            double[] g = new double[m];
            double[] gprime = new double[m];
            for (int j = 0; j < m; ++j) {
                g[j] = GlmAlgorithm.this.linkFunc.GetInvLink(c[j]);
                gprime[j] = GlmAlgorithm.this.linkFunc.GetInvLinkDerivative(c[j]);
            }
            for (int i = 0; i < n; ++i) {
                double crossprod = 0.0;
                for (int j = 0; j < m; ++j) {
                    double cb = g[j] - GlmAlgorithm.this.b[j];
                    crossprod += cb * gprime[j] * GlmAlgorithm.this.A[j][i];
                }
                gradx[i] = crossprod / (double)m;
                if (i == 0) continue;
                int n2 = i;
                gradx[n2] = gradx[n2] + GlmAlgorithm.this.mRegularizationLambda * x[i] / (double)m;
            }
        }
    };

    public GlmAlgorithm() {
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, LinkFunction linkFunc, double[][] A, double[] b, LocalSearch solver) {
        this.mDistributionFamily = distribution;
        this.solver = solver;
        this.linkFunc = linkFunc;
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b, LocalSearch solver) {
        this.solver = solver;
        this.mDistributionFamily = distribution;
        this.linkFunc = GlmAlgorithm.getLinkFunction(distribution);
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b) {
        this.solver = new NonlinearCGSearch();
        this.mDistributionFamily = distribution;
        this.linkFunc = GlmAlgorithm.getLinkFunction(distribution);
        this.A = A;
        this.b = b;
        this.mStats = new GlmStatistics(A[0].length, b.length);
    }

    public GlmAlgorithm(GlmDistributionFamily distribution) {
        this.linkFunc = GlmAlgorithm.getLinkFunction(distribution);
        this.mDistributionFamily = distribution;
    }

    public GlmAlgorithm(GlmDistributionFamily distribution, double[][] A, double[] b, LocalSearch solver, int maxIters) {
        this.solver = solver;
        this.mDistributionFamily = distribution;
        this.linkFunc = GlmAlgorithm.getLinkFunction(distribution);
        int m = A.length;
        int n = A[0].length;
        this.A = new double[m][];
        for (int i = 0; i < m; ++i) {
            this.A[i] = new double[n];
            for (int j = 0; j < n; ++j) {
                this.A[i][j] = A[i][j];
            }
        }
        this.b = b;
        if (maxIters > 0) {
            this.maxIters = maxIters;
        }
        this.mStats = new GlmStatistics(m, b.length);
    }

    public static LinkFunction getLinkFunction(GlmDistributionFamily distribution) {
        switch (distribution) {
            case Bernouli: 
            case Binomial: 
            case Categorical: 
            case Multinomial: {
                return new LogitLinkFunction();
            }
            case Exponential: 
            case Gamma: {
                return new InverseLinkFunction();
            }
            case InverseGaussian: {
                return new InverseSquaredLinkFunction();
            }
            case Normal: {
                return new IdentityLinkFunction();
            }
            case Poisson: {
                return new LogLinkFunction();
            }
        }
        throw new NotImplementedException();
    }

    private LinkFunction clone(LinkFunction rhs) {
        if (rhs == null) {
            return null;
        }
        AbstractLinkFunction rhs2 = (AbstractLinkFunction)rhs;
        return rhs2.makeCopy();
    }

    public void copy(GlmAlgorithm rhs) {
        this.linkFunc = rhs.linkFunc;
        this.maxIters = rhs.maxIters;
        this.mTol = rhs.mTol;
        this.mRegularizationLambda = rhs.mRegularizationLambda;
        this.mDistributionFamily = rhs.mDistributionFamily;
        this.mStats = rhs.mStats == null ? null : (GlmStatistics)rhs.mStats.clone();
        this.mDistributionFamily = rhs.mDistributionFamily;
        this.shouldTerminate = rhs.shouldTerminate;
        this.glmCoefficients = rhs.glmCoefficients == null ? null : (double[])rhs.glmCoefficients.clone();
        this.solver = rhs.solver == null ? null : rhs.solver.makeCopy();
        this.A = rhs.A == null ? (double[][])null : (double[][])rhs.A.clone();
        this.b = rhs.b == null ? null : (double[])rhs.b.clone();
        this.evaluateCost = rhs.evaluateCost;
        this.evaluateGradient = rhs.evaluateGradient;
    }

    public GlmAlgorithm makeCopy() {
        GlmAlgorithm clone = new GlmAlgorithm();
        clone.copy(this);
        return clone;
    }

    public double getTol() {
        return this.mTol;
    }

    public void setTol(double value) {
        this.mTol = value;
    }

    public GlmDistributionFamily getDistributionFamily() {
        return this.mDistributionFamily;
    }

    public double predict(double[] input_0) {
        if (this.glmCoefficients == null) {
            return Double.NaN;
        }
        int n = input_0.length;
        double linear_predictor = 0.0;
        for (int i = 0; i < n; ++i) {
            linear_predictor += this.glmCoefficients[i] * input_0[i];
        }
        return this.linkFunc.GetInvLink(linear_predictor);
    }

    protected double getVariance(double g) {
        switch (this.mDistributionFamily) {
            case Bernouli: 
            case Binomial: 
            case Categorical: 
            case Multinomial: {
                return g * (1.0 - g);
            }
            case Exponential: 
            case Gamma: {
                return g * g;
            }
            case InverseGaussian: {
                return g * g * g;
            }
            case Normal: {
                return 1.0;
            }
            case Poisson: {
                return g;
            }
        }
        throw new NotImplementedException();
    }

    public int getMaxIters() {
        return this.maxIters;
    }

    public void setMaxIters(int value) {
        this.maxIters = value;
    }

    public double[] getCoefficients() {
        return this.glmCoefficients;
    }

    public GlmStatistics getStatistics() {
        return this.mStats;
    }

    public double[] solve() {
        int n = this.A[0].length;
        double[] x_0 = new double[n];
        for (int i = 0; i < n; ++i) {
            x_0[i] = random.nextDouble();
        }
        NumericSolution s = this.solver.minimize(x_0, this.evaluateCost, this.evaluateGradient, this.shouldTerminate, null);
        this.glmCoefficients = s.values();
        this.updateStatistics();
        return this.getCoefficients();
    }

    private void updateStatistics() {
        this.mStats = new GlmStatistics(this.A, this.b, this.glmCoefficients);
    }
}

