/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.nn;

import de.jungblut.classification.nn.MultilayerPerceptronCostFunction;
import de.jungblut.classification.nn.TrainingType;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.AbstractMiniBatchCostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.DenseMatrixFolder;
import java.util.Random;

public final class RBMCostFunction
extends AbstractMiniBatchCostFunction {
    private final ActivationFunction activationFunction;
    private final int[][] unfoldParameters;
    private final TrainingType type;
    private final double lambda;
    private final Random random;

    public RBMCostFunction(DoubleVector[] currentTrainingSet, int batchSize, int numThreads, int numHiddenUnits, ActivationFunction activationFunction, TrainingType type, double lambda, long seed, boolean stochastic) {
        super(currentTrainingSet, null, batchSize, numThreads, stochastic);
        this.activationFunction = activationFunction;
        this.type = type;
        this.lambda = lambda;
        this.random = new Random(seed);
        this.unfoldParameters = MultilayerPerceptronCostFunction.computeUnfoldParameters(new int[]{currentTrainingSet[0].getDimension(), numHiddenUnits + 1});
    }

    @Override
    protected CostGradientTuple evaluateBatch(DoubleVector input, DoubleMatrix data, DoubleMatrix outcomeBatch) {
        DoubleMatrix theta = DenseMatrixFolder.unfoldMatrices(input, this.unfoldParameters)[0].transpose();
        DoubleMatrix positiveHiddenProbs = this.activationFunction.apply(this.multiply(data, theta, false, false));
        positiveHiddenProbs.setColumnVector(0, (DoubleVector)DenseDoubleVector.ones((int)positiveHiddenProbs.getRowCount()));
        DoubleMatrix positiveAssociations = this.multiply(data, positiveHiddenProbs, true, false);
        RBMCostFunction.binarize(this.random, positiveHiddenProbs);
        DoubleMatrix negativeData = this.activationFunction.apply(this.multiply(positiveHiddenProbs, theta, false, true));
        negativeData.setColumnVector(0, (DoubleVector)DenseDoubleVector.ones((int)negativeData.getRowCount()));
        DoubleMatrix negativeHiddenProbs = this.activationFunction.apply(this.multiply(negativeData, theta, false, false));
        negativeHiddenProbs.setColumnVector(0, (DoubleVector)DenseDoubleVector.ones((int)negativeHiddenProbs.getRowCount()));
        DoubleMatrix negativeAssociations = this.multiply(negativeData, negativeHiddenProbs, true, false);
        double j = data.subtract(negativeData).pow(2.0).sum();
        DoubleMatrix thetaGradient = positiveAssociations.subtract(negativeAssociations).divide((double)data.getRowCount());
        if (this.lambda != 0.0) {
            DoubleVector bias = thetaGradient.getColumnVector(0);
            thetaGradient = thetaGradient.subtract(thetaGradient.multiply(this.lambda / (double)data.getRowCount()));
            thetaGradient.setColumnVector(0, bias);
        }
        return new CostGradientTuple(j, DenseMatrixFolder.foldMatrices(new DoubleMatrix[]{(DenseDoubleMatrix)thetaGradient.multiply(-1.0).transpose()}));
    }

    private DoubleMatrix multiply(DoubleMatrix a1, DoubleMatrix a2, boolean a1Transpose, boolean a2Transpose) {
        if (this.type == TrainingType.CPU) {
            return RBMCostFunction.multiplyCPU(a1, a2, a1Transpose, a2Transpose);
        }
        throw new IllegalArgumentException("Unsupported Trainingtype " + this.type);
    }

    private static DoubleMatrix multiplyCPU(DoubleMatrix a1, DoubleMatrix a2, boolean a1Transpose, boolean a2Transpose) {
        a2 = a2Transpose ? a2.transpose() : a2;
        a1 = a1Transpose ? a1.transpose() : a1;
        return a1.multiply(a2);
    }

    int[][] getUnfoldParameters() {
        return this.unfoldParameters;
    }

    static DoubleVector[] binarize(Random r, DoubleVector[] hiddenActivations) {
        for (int i = 0; i < hiddenActivations.length; ++i) {
            RBMCostFunction.binarize(r, hiddenActivations[i]);
        }
        return hiddenActivations;
    }

    static DoubleMatrix binarize(Random r, DoubleMatrix hiddenActivations) {
        for (int i = 0; i < hiddenActivations.getRowCount(); ++i) {
            for (int j = 0; j < hiddenActivations.getColumnCount(); ++j) {
                hiddenActivations.set(i, j, hiddenActivations.get(i, j) > r.nextDouble() ? 1.0 : 0.0);
            }
        }
        return hiddenActivations;
    }

    static DoubleVector binarize(Random r, DoubleVector v) {
        for (int j = 0; j < v.getDimension(); ++j) {
            v.set(j, v.get(j) > r.nextDouble() ? 1.0 : 0.0);
        }
        return v;
    }
}

