/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.regression;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.regression.LogisticRegressionCostFunction;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunctionSelector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.math.sparse.SparseDoubleRowMatrix;
import de.jungblut.math.sparse.SparseDoubleVector;
import java.util.Iterator;
import java.util.Random;

public final class LogisticRegression
extends AbstractClassifier {
    private final double lambda;
    private final Minimizer minimizer;
    private final int numIterations;
    private final boolean verbose;
    private DoubleVector theta;
    private Random random;

    public LogisticRegression(double lambda, Minimizer minimizer, int numIterations, boolean verbose) {
        this.lambda = lambda;
        this.minimizer = minimizer;
        this.numIterations = numIterations;
        this.verbose = verbose;
        this.random = new Random();
    }

    public LogisticRegression(DoubleVector theta) {
        this(0.0, null, 1, false);
        this.theta = theta;
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)"Features and Outcomes need to match in length!");
        Object x = null;
        Object y = null;
        x = features[0].isSparse() ? new SparseDoubleRowMatrix(DenseDoubleVector.ones((int)features.length), (DoubleMatrix)new SparseDoubleRowMatrix(features)) : new DenseDoubleMatrix(DenseDoubleVector.ones((int)features.length), (DoubleMatrix)new DenseDoubleMatrix(features));
        y = outcome[0].isSparse() ? new SparseDoubleRowMatrix(outcome) : new DenseDoubleMatrix(outcome);
        y = y.transpose();
        LogisticRegressionCostFunction cnf = new LogisticRegressionCostFunction((DoubleMatrix)x, (DoubleMatrix)y, this.lambda);
        this.theta = new DenseDoubleVector(x.getColumnCount() * y.getRowCount());
        for (int i = 0; i < this.theta.getDimension(); ++i) {
            this.theta.set(i, this.random.nextDouble() * 2.0 - 1.0);
        }
        this.theta = this.minimizer.minimize(cnf, this.theta, this.numIterations, this.verbose);
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        if (features.isSparse()) {
            SparseDoubleVector tmp = new SparseDoubleVector(features.getDimension() + 1);
            tmp.set(0, 1.0);
            Iterator iterateNonZero = features.iterateNonZero();
            while (iterateNonZero.hasNext()) {
                DoubleVector.DoubleVectorElement next = (DoubleVector.DoubleVectorElement)iterateNonZero.next();
                tmp.set(next.getIndex() + 1, next.getValue());
            }
            features = tmp;
        } else {
            features = new DenseDoubleVector(1.0, features.toArray());
        }
        return new DenseDoubleVector(new double[]{ActivationFunctionSelector.SIGMOID.get().apply(features.dot(this.theta))});
    }

    public DoubleVector getTheta() {
        return this.theta;
    }

    void setRandom(Random random) {
        this.random = random;
    }
}

