/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.regression;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunctionSelector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.loss.LogLoss;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import java.util.Arrays;

public final class LogisticRegressionCostFunction
implements CostFunction {
    private static final LossFunction ERROR_FUNCTION = new LogLoss();
    private final DoubleMatrix x;
    private final DoubleMatrix xTransposed;
    private final DoubleMatrix y;
    private final int m;
    private final double lambda;

    public LogisticRegressionCostFunction(DoubleMatrix x, DoubleMatrix y, double lambda) {
        this.x = x;
        this.lambda = lambda;
        this.m = x.getRowCount();
        this.xTransposed = this.x.transpose();
        this.y = y;
    }

    @Override
    public CostGradientTuple evaluateCost(DoubleVector theta) {
        DoubleVector activation = ActivationFunctionSelector.SIGMOID.get().apply(this.x.multiplyVectorRow(theta));
        DenseDoubleMatrix hypo = new DenseDoubleMatrix(Arrays.asList(activation));
        double error = ERROR_FUNCTION.calculateLoss(this.y, (DoubleMatrix)hypo);
        DoubleMatrix loss = hypo.subtract(this.y);
        double j = error / (double)this.m;
        DoubleVector gradient = this.xTransposed.multiplyVectorRow(loss.getRowVector(0)).divide((double)this.m);
        if (this.lambda != 0.0) {
            DoubleVector reg = theta.multiply(this.lambda / (double)this.m);
            reg.set(0, 0.0);
            gradient = gradient.add(reg);
            j += this.lambda * theta.pow(2.0).sum() / (double)this.m;
        }
        return new CostGradientTuple(j, gradient);
    }
}

