/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.nn;

import de.jungblut.classification.nn.RBMCostFunction;
import de.jungblut.classification.nn.TrainingType;
import de.jungblut.classification.nn.WeightMatrix;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.activation.ActivationFunctionSelector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.GradientDescent;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.writable.MatrixWritable;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class RBM {
    private static final Logger LOG = LogManager.getLogger(RBM.class);
    private final int[] layerSizes;
    private final DoubleMatrix[] weights;
    private final ActivationFunction activationFunction;
    private TrainingType type = TrainingType.CPU;
    private double lambda;
    private boolean stochastic;
    private boolean verbose;
    private int miniBatchSize = 0;
    private int batchParallelism = 1;
    private long seed;

    private RBM(int[] stackedHiddenLayerSizes, ActivationFunction activationFunction, TrainingType type) {
        this.layerSizes = stackedHiddenLayerSizes;
        this.activationFunction = activationFunction;
        this.weights = new DenseDoubleMatrix[this.layerSizes.length];
        this.type = type;
        this.seed = System.currentTimeMillis();
    }

    private RBM(RBMBuilder rbmBuilder) {
        this(rbmBuilder.layerSizes, rbmBuilder.function, rbmBuilder.type);
        this.lambda = rbmBuilder.lambda;
        this.verbose = rbmBuilder.verbose;
        this.miniBatchSize = rbmBuilder.miniBatchSize;
        this.batchParallelism = rbmBuilder.batchParallelism;
        this.stochastic = rbmBuilder.stochastic;
    }

    public void train(DoubleVector[] trainingSet, double alpha, int numIterations) {
        this.train(trainingSet, new GradientDescent(alpha, 0.0), numIterations);
    }

    public void train(DoubleVector[] trainingSet, Minimizer minimizer, int numIterations) {
        DoubleVector[] currentTrainingSet = Arrays.copyOf(trainingSet, trainingSet.length);
        for (int i = 0; i < this.layerSizes.length; ++i) {
            DoubleMatrix thetaMat;
            if (this.verbose) {
                LOG.info("Training stack at height: " + i);
            }
            DenseDoubleMatrix start = new DenseDoubleMatrix(this.layerSizes[i] + 1, currentTrainingSet[0].getDimension() + 1, new Random(this.seed)).multiply(0.1);
            DoubleVector folded = DenseMatrixFolder.foldMatrices(new DoubleMatrix[]{start});
            start = null;
            RBMCostFunction fnc = new RBMCostFunction(currentTrainingSet, this.miniBatchSize, this.batchParallelism, this.layerSizes[i], this.activationFunction, this.type, this.lambda, this.seed, this.stochastic);
            DoubleVector theta = minimizer.minimize(fnc, folded, numIterations, this.verbose);
            this.weights[i] = thetaMat = DenseMatrixFolder.unfoldMatrices(theta, fnc.getUnfoldParameters())[0];
            if (i + 1 == this.layerSizes.length) continue;
            for (int row = 0; row < currentTrainingSet.length; ++row) {
                currentTrainingSet[row] = this.computeHiddenActivations(currentTrainingSet[row], this.weights[i]);
                currentTrainingSet[row] = currentTrainingSet[row].slice(1, currentTrainingSet[row].getDimension());
                if (!this.verbose || row % 100 != 0) continue;
                LOG.info("Predicting row " + row + " / " + currentTrainingSet.length);
            }
        }
    }

    public DoubleVector predict(DoubleVector input) {
        DoubleVector lastOutput = input;
        for (int i = 0; i < this.layerSizes.length; ++i) {
            lastOutput = this.computeHiddenActivations(lastOutput, this.weights[i]);
        }
        return lastOutput.slice(1, lastOutput.getDimension());
    }

    public DoubleVector reconstructInput(DoubleVector hiddenActivations) {
        DoubleVector lastOutput = hiddenActivations;
        for (int i = this.weights.length - 1; i >= 0; --i) {
            lastOutput = this.computeHiddenActivations(lastOutput, this.weights[i].transpose());
        }
        return lastOutput.slice(1, lastOutput.getDimension());
    }

    public DoubleMatrix[] getWeights() {
        return this.weights;
    }

    public WeightMatrix[] getNeuralNetworkWeights(int outputLayerSize) {
        WeightMatrix[] toReturn = new WeightMatrix[this.weights.length + 1];
        for (int i = 0; i < this.weights.length; ++i) {
            toReturn[i] = new WeightMatrix(this.weights[i].slice(1, this.weights[i].getRowCount(), 0, this.weights[i].getColumnCount()));
        }
        toReturn[toReturn.length - 1] = new WeightMatrix(toReturn[toReturn.length - 2].getWeights().getRowCount(), outputLayerSize);
        return toReturn;
    }

    public void setSeed(long seed) {
        this.seed = seed;
    }

    private DoubleVector computeHiddenActivations(DoubleVector input, DoubleMatrix theta) {
        DenseDoubleVector biased = new DenseDoubleVector(1.0, input.toArray());
        return this.activationFunction.apply(theta.multiplyVectorRow((DoubleVector)biased));
    }

    public static void serialize(RBM model, DataOutput out) throws IOException {
        out.writeInt(model.layerSizes.length);
        for (int layer : model.layerSizes) {
            out.writeInt(layer);
        }
        for (DoubleMatrix mat : model.weights) {
            new MatrixWritable(mat).write(out);
        }
        out.writeUTF(model.activationFunction.getClass().getName());
    }

    public static RBM deserialize(DataInputStream in) throws IOException {
        int layers = in.readInt();
        int[] sizes = new int[layers];
        for (int i = 0; i < layers; ++i) {
            sizes[i] = in.readInt();
        }
        DoubleMatrix[] array = new DoubleMatrix[layers];
        for (int i = 0; i < layers; ++i) {
            MatrixWritable mv = new MatrixWritable();
            mv.readFields(in);
            array[i] = mv.getMatrix();
        }
        ActivationFunction func = null;
        try {
            func = (ActivationFunction)Class.forName(in.readUTF()).newInstance();
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
            throw new RuntimeException(e);
        }
        RBM model = new RBM(sizes, func, TrainingType.CPU);
        for (int i = 0; i < layers; ++i) {
            model.weights[i] = array[i];
        }
        return model;
    }

    public static RBM single(int numHiddenNodes, ActivationFunction func) {
        return new RBM(new int[]{numHiddenNodes}, func, TrainingType.CPU);
    }

    public static RBM stacked(ActivationFunction func, int ... numHiddenNodes) {
        return new RBM(numHiddenNodes, func, TrainingType.CPU);
    }

    public static RBM single(int numHiddenNodes) {
        return new RBM(new int[]{numHiddenNodes}, ActivationFunctionSelector.SIGMOID.get(), TrainingType.CPU);
    }

    public static RBM stacked(int ... numHiddenNodes) {
        return new RBM(numHiddenNodes, ActivationFunctionSelector.SIGMOID.get(), TrainingType.CPU);
    }

    public static RBM singleGPU(int numHiddenNodes, ActivationFunction func) {
        return new RBM(new int[]{numHiddenNodes}, func, TrainingType.GPU);
    }

    public static RBM stackedGPU(ActivationFunction func, int ... numHiddenNodes) {
        return new RBM(numHiddenNodes, func, TrainingType.GPU);
    }

    public static class RBMBuilder {
        private final int[] layerSizes;
        private final ActivationFunction function;
        private TrainingType type = TrainingType.CPU;
        private double lambda;
        private boolean verbose = false;
        private boolean stochastic = false;
        private int miniBatchSize;
        private int batchParallelism = Runtime.getRuntime().availableProcessors();

        private RBMBuilder(int[] layer, ActivationFunction activation) {
            this.layerSizes = layer;
            this.function = activation;
        }

        public RBMBuilder trainingType(TrainingType type) {
            this.type = type;
            return this;
        }

        public RBMBuilder lambda(double lambda) {
            this.lambda = lambda;
            return this;
        }

        public RBMBuilder miniBatchSize(int size) {
            this.miniBatchSize = size;
            return this;
        }

        public RBMBuilder batchParallelism(int numThreads) {
            this.batchParallelism = numThreads;
            return this;
        }

        public RBMBuilder verbose() {
            return this.verbose(true);
        }

        public RBMBuilder stochastic() {
            return this.stochastic(true);
        }

        public RBMBuilder stochastic(boolean stochastic) {
            this.stochastic = stochastic;
            return this;
        }

        public RBMBuilder verbose(boolean verbose) {
            this.verbose = verbose;
            return this;
        }

        public static RBMBuilder create(ActivationFunction activation, int ... layer) {
            return new RBMBuilder(layer, activation);
        }

        public RBM build() {
            return new RBM(this);
        }
    }
}

