/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import com.github.chen0040.data.frame.Coefficients;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.CollectionUtils;
import com.github.chen0040.glm.enums.GlmDistributionFamily;
import com.github.chen0040.glm.enums.GlmSolverType;
import com.github.chen0040.glm.metrics.GlmStatistics;
import com.github.chen0040.glm.solvers.GlmAlgorithm;
import com.github.chen0040.glm.solvers.GlmAlgorithmIrls;
import com.github.chen0040.glm.solvers.GlmAlgorithmIrlsQrNewton;
import com.github.chen0040.glm.solvers.GlmAlgorithmIrlsSvdNewton;
import com.github.chen0040.glm.solvers.OneVsOneGlmClassifier;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Glm {
    private static final Logger logger = LoggerFactory.getLogger(Glm.class);
    private GlmAlgorithm solver;
    private GlmDistributionFamily distributionFamily;
    private GlmSolverType solverType;
    private Coefficients coefficients;
    private String name;

    public void copy(Glm that) {
        this.solver = that.solver == null ? null : that.solver.makeCopy();
        this.distributionFamily = that.distributionFamily;
        this.solverType = that.solverType;
        this.coefficients = that.coefficients == null ? null : that.coefficients.makeCopy();
    }

    public Glm makeCopy() {
        Glm clone = new Glm();
        clone.copy(this);
        return clone;
    }

    public Glm(GlmSolverType solverType, GlmDistributionFamily distributionFamily) {
        this.solverType = solverType;
        this.distributionFamily = distributionFamily;
        this.coefficients = new Coefficients();
    }

    public Glm() {
        this(GlmSolverType.GlmIrls, GlmDistributionFamily.Normal);
    }

    public GlmDistributionFamily getDistributionFamily() {
        return this.distributionFamily;
    }

    public void setDistributionFamily(GlmDistributionFamily distributionFamily) {
        this.distributionFamily = distributionFamily;
    }

    public GlmSolverType getSolverType() {
        return this.solverType;
    }

    public void setSolverType(GlmSolverType solverType) {
        this.solverType = solverType;
    }

    public double transform(DataRow tuple) {
        double[] x0 = tuple.toArray();
        double[] x = new double[x0.length + 1];
        x[0] = 1.0;
        for (int i = 0; i < x0.length; ++i) {
            x[i + 1] = x0[i];
        }
        return this.solver.predict(x);
    }

    protected GlmAlgorithm createSolver(double[][] A, double[] b) {
        if (this.solverType == GlmSolverType.GlmNaive) {
            return new GlmAlgorithm(this.distributionFamily, A, b);
        }
        if (this.solverType == GlmSolverType.GlmIrlsQr) {
            return new GlmAlgorithmIrlsQrNewton(this.distributionFamily, A, b);
        }
        if (this.solverType == GlmSolverType.GlmIrls) {
            return new GlmAlgorithmIrls(this.distributionFamily, A, b);
        }
        if (this.solverType == GlmSolverType.GlmIrlsSvd) {
            return new GlmAlgorithmIrlsSvdNewton(this.distributionFamily, A, b);
        }
        return null;
    }

    public void fit(DataFrame dataFrame) {
        int m = dataFrame.rowCount();
        double[][] X = new double[m][];
        this.coefficients.setDescriptors(dataFrame.getInputColumns());
        double[] y = new double[m];
        for (int i = 0; i < m; ++i) {
            DataRow tuple = dataFrame.row(i);
            double[] x_i = tuple.toArray();
            double[] x_prime = new double[x_i.length + 1];
            x_prime[0] = 1.0;
            for (int j = 0; j < x_i.length; ++j) {
                x_prime[j + 1] = x_i[j];
            }
            X[i] = x_prime;
            y[i] = tuple.target();
        }
        this.solver = this.createSolver(X, y);
        double[] x_best = this.solver.solve();
        if (x_best == null) {
            throw new RuntimeException("The solver failed");
        }
        this.coefficients.setValues(CollectionUtils.toList((double[])x_best));
    }

    public GlmStatistics showStatistics() {
        return this.solver != null ? this.solver.getStatistics() : null;
    }

    public Coefficients getCoefficients() {
        return this.coefficients;
    }

    public static Glm logistic() {
        Glm glm = new Glm();
        glm.setDistributionFamily(GlmDistributionFamily.Binomial);
        return glm;
    }

    public static Glm linear() {
        Glm glm = new Glm();
        glm.setDistributionFamily(GlmDistributionFamily.Normal);
        return glm;
    }

    public String getName() {
        return this.name;
    }

    public static OneVsOneGlmClassifier oneVsOne() {
        return new OneVsOneGlmClassifier();
    }

    public static OneVsOneGlmClassifier oneVsOne(Supplier<Glm> binaryClassifierGenerator) {
        return new OneVsOneGlmClassifier(binaryClassifierGenerator);
    }

    public GlmAlgorithm getSolver() {
        return this.solver;
    }

    public void setSolver(GlmAlgorithm solver) {
        this.solver = solver;
    }

    public void setCoefficients(Coefficients coefficients) {
        this.coefficients = coefficients;
    }

    public void setName(String name) {
        this.name = name;
    }
}

