/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.nn;

import com.google.common.base.Preconditions;
import de.jungblut.classification.nn.MultilayerPerceptron;
import de.jungblut.classification.nn.TrainingType;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.AbstractMiniBatchCostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import java.util.Random;

public final class MultilayerPerceptronCostFunction
extends AbstractMiniBatchCostFunction {
    private final NetworkConfiguration configuration = new NetworkConfiguration();

    public MultilayerPerceptronCostFunction(MultilayerPerceptron network, DoubleVector[] features, DoubleVector[] outcome) {
        super(features, outcome, network.getMiniBatchSize(), network.getBatchParallelism(), network.isStochastic());
        this.configuration.lambda = network.getLambda();
        this.configuration.layerSizes = network.getLayers();
        this.configuration.unfoldParameters = MultilayerPerceptronCostFunction.computeUnfoldParameters(this.configuration.layerSizes);
        this.configuration.activations = network.getActivations();
        this.configuration.error = network.getErrorFunction();
        this.configuration.trainingType = network.getTrainingType();
        this.configuration.visibleDropoutProbability = network.getVisibleDropoutProbability();
        this.configuration.hiddenDropoutProbability = network.getHiddenDropoutProbability();
        this.configuration.rnd = new Random();
    }

    @Override
    protected CostGradientTuple evaluateBatch(DoubleVector theta, DoubleMatrix featureBatch, DoubleMatrix outcomeBatch) {
        return MultilayerPerceptronCostFunction.computeNextStep(theta, featureBatch, outcomeBatch, this.configuration);
    }

    public static CostGradientTuple computeNextStep(DoubleVector input, DoubleMatrix x, DoubleMatrix y, NetworkConfiguration conf) {
        Preconditions.checkArgument((x.getColumnCount() - 1 == conf.layerSizes[0] ? 1 : 0) != 0, (Object)("Input layer size must match the given vector dimension! Given: " + (x.getColumnCount() - 1) + ", expected: " + conf.layerSizes[0]));
        int m = x.getRowCount();
        DoubleMatrix[] thetas = DenseMatrixFolder.unfoldMatrices(input, conf.unfoldParameters);
        DoubleMatrix[] thetaGradients = new DoubleMatrix[thetas.length];
        DoubleMatrix[] ax = new DoubleMatrix[conf.layerSizes.length];
        DoubleMatrix[] zx = new DoubleMatrix[conf.layerSizes.length];
        MultilayerPerceptronCostFunction.dropoutVisibleLayer(x, ax, conf);
        MultilayerPerceptronCostFunction.forwardPropagate(thetas, ax, zx, conf);
        double regularization = MultilayerPerceptronCostFunction.calculateRegularization(thetas, m, conf);
        DoubleMatrix[] deltaX = MultilayerPerceptronCostFunction.backwardPropagate(y, thetas, ax, zx, conf);
        MultilayerPerceptronCostFunction.calculateGradients(thetas, thetaGradients, ax, deltaX, m, conf);
        double j = conf.error.calculateLoss(y, ax[conf.layerSizes.length - 1]) + regularization;
        return new CostGradientTuple(j, DenseMatrixFolder.foldMatrices(thetaGradients));
    }

    public static void forwardPropagate(DoubleMatrix[] thetas, DoubleMatrix[] ax, DoubleMatrix[] zx, NetworkConfiguration conf) {
        for (int i = 1; i < conf.layerSizes.length; ++i) {
            zx[i] = MultilayerPerceptronCostFunction.multiply(ax[i - 1], thetas[i - 1], false, true, conf);
            if (i < conf.layerSizes.length - 1) {
                ax[i] = new DenseDoubleMatrix(DenseDoubleVector.ones((int)zx[i].getRowCount()), conf.activations[i].apply(zx[i]));
                if (!(conf.hiddenDropoutProbability > 0.0)) continue;
                MultilayerPerceptronCostFunction.dropout(conf.rnd, ax[i], conf.hiddenDropoutProbability);
                continue;
            }
            ax[i] = conf.activations[i].apply(zx[i]);
        }
    }

    public static DoubleMatrix[] backwardPropagate(DoubleMatrix y, DoubleMatrix[] thetas, DoubleMatrix[] ax, DoubleMatrix[] zx, NetworkConfiguration conf) {
        DoubleMatrix[] deltaX = new DoubleMatrix[conf.layerSizes.length];
        deltaX[deltaX.length - 1] = ax[conf.layerSizes.length - 1].subtract(y);
        for (int i = conf.layerSizes.length - 2; i > 0; --i) {
            DoubleMatrix slice = thetas[i].slice(0, thetas[i].getRowCount(), 1, thetas[i].getColumnCount());
            deltaX[i] = MultilayerPerceptronCostFunction.multiply(deltaX[i + 1], slice, false, false, conf);
            deltaX[i] = deltaX[i].multiplyElementWise(conf.activations[i].gradient(zx[i]));
        }
        return deltaX;
    }

    public static void calculateGradients(DoubleMatrix[] thetas, DoubleMatrix[] thetaGradients, DoubleMatrix[] ax, DoubleMatrix[] deltaX, int m, NetworkConfiguration conf) {
        for (int i = 0; i < thetaGradients.length; ++i) {
            DoubleMatrix gradDXA = MultilayerPerceptronCostFunction.multiply(deltaX[i + 1], ax[i], true, false, conf);
            thetaGradients[i] = m != 1 ? gradDXA.divide((double)m) : gradDXA;
            if (conf.lambda == 0.0) continue;
            thetaGradients[i] = thetaGradients[i].add(thetas[i].multiply(conf.lambda / (double)m));
            DoubleVector regBias = thetas[i].slice(0, thetas[i].getRowCount(), 0, 1).multiply(conf.lambda / (double)m).getColumnVector(0);
            thetaGradients[i].setColumnVector(0, regBias);
        }
    }

    public static double calculateRegularization(DoubleMatrix[] thetas, int m, NetworkConfiguration conf) {
        double regularization = 0.0;
        if (conf.lambda != 0.0) {
            for (DoubleMatrix theta : thetas) {
                regularization += theta.slice(0, theta.getRowCount(), 1, theta.getColumnCount()).pow(2.0).sum();
            }
            regularization = conf.lambda / (2.0 * (double)m) * regularization;
        }
        return regularization;
    }

    public static void dropoutVisibleLayer(DoubleMatrix x, DoubleMatrix[] ax, NetworkConfiguration conf) {
        if (conf.visibleDropoutProbability > 0.0) {
            ax[0] = x.deepCopy();
            MultilayerPerceptronCostFunction.dropout(conf.rnd, ax[0], conf.visibleDropoutProbability);
        } else {
            ax[0] = x;
        }
    }

    private static DoubleMatrix multiply(DoubleMatrix a1, DoubleMatrix a2, boolean a1Transpose, boolean a2Transpose, NetworkConfiguration conf) {
        if (conf.trainingType == TrainingType.CPU) {
            return MultilayerPerceptronCostFunction.multiplyCPU(a1, a2, a1Transpose, a2Transpose);
        }
        throw new IllegalArgumentException("No supported type " + conf.trainingType + " found!");
    }

    private static DoubleMatrix multiplyCPU(DoubleMatrix a1, DoubleMatrix a2, boolean a1Transpose, boolean a2Transpose) {
        a2 = a2Transpose ? a2.transpose() : a2;
        a1 = a1Transpose ? a1.transpose() : a1;
        return a1.multiply(a2);
    }

    public static int[][] computeUnfoldParameters(int[] layerSizes) {
        int[][] unfoldParameters = new int[layerSizes.length - 1][];
        for (int i = 0; i < unfoldParameters.length; ++i) {
            unfoldParameters[i] = new int[]{layerSizes[i + 1], layerSizes[i] + 1};
        }
        return unfoldParameters;
    }

    public static void dropout(Random rnd, DoubleMatrix activations, double p) {
        for (int row = 0; row < activations.getRowCount(); ++row) {
            for (int col = 0; col < activations.getColumnCount(); ++col) {
                if (!(rnd.nextDouble() <= p)) continue;
                activations.set(row, col, 0.0);
            }
        }
    }

    public static class NetworkConfiguration {
        public double lambda;
        public int[] layerSizes;
        public int[][] unfoldParameters;
        public ActivationFunction[] activations;
        public LossFunction error;
        public TrainingType trainingType;
        public double visibleDropoutProbability;
        public double hiddenDropoutProbability;
        public Random rnd;
    }
}

