/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.math.loss;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.MathUtils;
import de.jungblut.math.loss.LossFunction;

public final class LogLoss
implements LossFunction {
    @Override
    public double calculateLoss(DoubleMatrix y, DoubleMatrix hypothesis) {
        DoubleMatrix negativeOutcome = y.subtractBy(1.0);
        DoubleMatrix inverseOutcome = y.multiply(-1.0);
        DoubleMatrix negativeHypo = hypothesis.subtractBy(1.0);
        DoubleMatrix negativeLogHypo = MathUtils.logMatrix(negativeHypo);
        DoubleMatrix positiveLogHypo = MathUtils.logMatrix(hypothesis);
        DoubleMatrix negativePenalty = negativeOutcome.multiplyElementWise(negativeLogHypo);
        DoubleMatrix positivePenalty = inverseOutcome.multiplyElementWise(positiveLogHypo);
        return positivePenalty.subtract(negativePenalty).sum() / (double)y.getRowCount();
    }

    @Override
    public double calculateLoss(DoubleVector y, DoubleVector hypothesis) {
        DoubleVector negativeOutcome = y.subtractFrom(1.0);
        DoubleVector inverseOutcome = y.multiply(-1.0);
        DoubleVector negativeHypo = hypothesis.subtractFrom(1.0);
        DoubleVector negativeLogHypo = MathUtils.logVector(negativeHypo);
        DoubleVector positiveLogHypo = MathUtils.logVector(hypothesis);
        DoubleVector negativePenalty = negativeOutcome.multiply(negativeLogHypo);
        DoubleVector positivePenalty = inverseOutcome.multiply(positiveLogHypo);
        return positivePenalty.subtract(negativePenalty).sum();
    }

    @Override
    public DoubleVector calculateGradient(DoubleVector feature, DoubleVector y, DoubleVector hypothesis) {
        return feature.multiply(hypothesis.subtract(y).get(0));
    }
}

