/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.nn;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.nn.MultilayerPerceptronCostFunction;
import de.jungblut.classification.nn.TrainingType;
import de.jungblut.classification.nn.WeightMatrix;
import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunction;
import de.jungblut.math.activation.LinearActivationFunction;
import de.jungblut.math.activation.SigmoidActivationFunction;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.DenseMatrixFolder;
import de.jungblut.math.minimize.Minimizer;
import de.jungblut.writable.MatrixWritable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

public final class MultilayerPerceptron
extends AbstractClassifier {
    public static long SEED = System.currentTimeMillis();
    private final WeightMatrix[] weights;
    private final Minimizer minimizer;
    private final int maxIterations;
    private final int[] layers;
    private final ActivationFunction[] activations;
    private double lambda;
    private double hiddenDropoutProbability;
    private double visibleDropoutProbability;
    private TrainingType type = TrainingType.CPU;
    private boolean verbose;
    private LossFunction error;
    private boolean stochastic = false;
    private int miniBatchSize;
    private int batchParallelism = Runtime.getRuntime().availableProcessors();

    private MultilayerPerceptron(MultilayerPerceptronBuilder conf) {
        int i;
        this.layers = conf.layer;
        this.maxIterations = conf.maxIterations;
        this.minimizer = conf.minimizer;
        this.lambda = conf.lambda;
        this.type = conf.type;
        this.hiddenDropoutProbability = conf.hiddenDropoutProbability;
        this.visibleDropoutProbability = conf.visibleDropoutProbability;
        this.verbose = conf.verbose;
        this.error = conf.error;
        this.stochastic = conf.stochastic;
        this.miniBatchSize = conf.miniBatchSize;
        this.batchParallelism = conf.batchParallelism;
        if (conf.activationFunctions == null) {
            this.activations = new ActivationFunction[this.layers.length];
            this.activations[0] = new LinearActivationFunction();
            for (i = 1; i < this.layers.length; ++i) {
                this.activations[i] = new SigmoidActivationFunction();
            }
        } else {
            this.activations = conf.activationFunctions;
        }
        Preconditions.checkArgument((this.layers.length == this.activations.length ? 1 : 0) != 0, (Object)"Size of layers and activations must match!");
        if (conf.weights == null) {
            this.weights = new WeightMatrix[this.layers.length - 1];
            for (i = 0; i < this.weights.length; ++i) {
                this.weights[i] = new WeightMatrix(this.layers[i], this.layers[i + 1]);
            }
        } else {
            this.weights = conf.weights;
            for (i = 0; i < this.weights.length; ++i) {
                Preconditions.checkArgument((this.weights[i].getWeights().getRowCount() == this.layers[i + 1] ? 1 : 0) != 0, (Object)("Number of rows must match the layer size of the following layer. Given: " + this.weights[i].getWeights().getRowCount() + ". Expected: " + this.layers[i + 1]));
                Preconditions.checkArgument((this.weights[i].getWeights().getColumnCount() == this.layers[i] + 1 ? 1 : 0) != 0, (Object)("Number of columns must match the layer size of the current layer. Given: " + this.weights[i].getWeights().getColumnCount() + ". Expected: " + (this.layers[i] + 1)));
            }
        }
    }

    private MultilayerPerceptron(int[] layers, WeightMatrix[] weights, ActivationFunction[] activations, LossFunction error) {
        this.layers = layers;
        this.weights = weights;
        this.activations = activations;
        this.error = error;
        this.minimizer = null;
        this.maxIterations = -1;
    }

    @Override
    public DoubleVector predict(DoubleVector xi) {
        DoubleVector activationVector = MultilayerPerceptron.addBias(xi);
        int len = this.layers.length - 1;
        for (int i = 1; i <= len; ++i) {
            activationVector = this.activations[i].apply(this.weights[i - 1].getWeights().multiplyVectorRow(activationVector));
            if (i == len) continue;
            activationVector = MultilayerPerceptron.addBias(activationVector);
        }
        return activationVector;
    }

    public DoubleVector predict(DoubleVector xi, double threshold) {
        DoubleVector activations = this.predict(xi);
        for (int i = 0; i < activations.getLength(); ++i) {
            activations.set(i, activations.get(i) > threshold ? 1.0 : 0.0);
        }
        return activations;
    }

    private static DoubleVector addBias(DoubleVector activations) {
        return new DenseDoubleVector(1.0, activations.toArray());
    }

    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        this.train(features, outcome, this.minimizer, this.maxIterations, this.lambda, this.verbose);
    }

    public final double train(DoubleVector[] features, DoubleVector[] outcome, Minimizer minimizer, int maxIterations, double lambda, boolean verbose) {
        MultilayerPerceptronCostFunction costFunction = new MultilayerPerceptronCostFunction(this, features, outcome);
        return this.trainInternal(minimizer, maxIterations, verbose, costFunction, this.getFoldedThetaVector());
    }

    private double trainInternal(Minimizer minimizer, int maxIterations, boolean verbose, CostFunction costFunction, DoubleVector initialTheta) {
        Preconditions.checkNotNull((Object)minimizer, (Object)"Minimizer must be supplied!");
        DoubleVector theta = minimizer.minimize(costFunction, initialTheta, maxIterations, verbose);
        int[][] unfoldParameters = MultilayerPerceptronCostFunction.computeUnfoldParameters(this.layers);
        DoubleMatrix[] unfoldMatrices = DenseMatrixFolder.unfoldMatrices(theta, unfoldParameters);
        for (int i = 0; i < unfoldMatrices.length; ++i) {
            this.getWeights()[i].setWeights(unfoldMatrices[i]);
        }
        return costFunction.evaluateCost(theta).getCost();
    }

    public DoubleVector getFoldedThetaVector() {
        DoubleMatrix[] weightMatrices = new DoubleMatrix[this.getWeights().length];
        for (int i = 0; i < weightMatrices.length; ++i) {
            weightMatrices[i] = this.getWeights()[i].getWeights();
        }
        return DenseMatrixFolder.foldMatrices(weightMatrices);
    }

    public WeightMatrix[] getWeights() {
        return this.weights;
    }

    public int[] getLayers() {
        return this.layers;
    }

    public ActivationFunction[] getActivations() {
        return this.activations;
    }

    double getHiddenDropoutProbability() {
        return this.hiddenDropoutProbability;
    }

    double getVisibleDropoutProbability() {
        return this.visibleDropoutProbability;
    }

    LossFunction getErrorFunction() {
        return this.error;
    }

    TrainingType getTrainingType() {
        return this.type;
    }

    double getLambda() {
        return this.lambda;
    }

    int getBatchParallelism() {
        return this.batchParallelism;
    }

    int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    boolean isStochastic() {
        return this.stochastic;
    }

    public static MultilayerPerceptron deserialize(DataInput in) throws IOException {
        int numLayers = in.readInt();
        int[] layers = new int[numLayers];
        for (int i = 0; i < numLayers; ++i) {
            layers[i] = in.readInt();
        }
        WeightMatrix[] weights = new WeightMatrix[numLayers - 1];
        for (int i = 0; i < weights.length; ++i) {
            MatrixWritable wm = new MatrixWritable();
            wm.readFields(in);
            weights[i] = new WeightMatrix(wm.getMatrix());
        }
        ActivationFunction[] funcs = new ActivationFunction[numLayers];
        for (int i = 0; i < numLayers; ++i) {
            try {
                funcs[i] = (ActivationFunction)Class.forName(in.readUTF()).newInstance();
                continue;
            }
            catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }
        LossFunction error = null;
        try {
            error = (LossFunction)Class.forName(in.readUTF()).newInstance();
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
            throw new RuntimeException(e);
        }
        return new MultilayerPerceptron(layers, weights, funcs, error);
    }

    public static void serialize(MultilayerPerceptron model, DataOutput out) throws IOException {
        out.writeInt(model.layers.length);
        for (int l : model.layers) {
            out.writeInt(l);
        }
        for (WeightMatrix mat : model.weights) {
            DoubleMatrix weights = mat.getWeights();
            new MatrixWritable(weights).write(out);
        }
        for (ActivationFunction func : model.activations) {
            out.writeUTF(func.getClass().getName());
        }
        out.writeUTF(model.error.getClass().getName());
    }

    public static final class MultilayerPerceptronBuilder {
        private final Minimizer minimizer;
        private final int maxIterations;
        private final int[] layer;
        private final ActivationFunction[] activationFunctions;
        private final LossFunction error;
        private TrainingType type = TrainingType.CPU;
        private double lambda = 0.0;
        private boolean verbose = false;
        private double hiddenDropoutProbability = 0.0;
        private double visibleDropoutProbability = 0.0;
        private WeightMatrix[] weights;
        private boolean stochastic = false;
        private int miniBatchSize;
        private int batchParallelism = Runtime.getRuntime().availableProcessors();

        private MultilayerPerceptronBuilder(int[] layer, ActivationFunction[] activations, Minimizer minimizer, int maxIterations, LossFunction error) {
            this.layer = layer;
            this.minimizer = minimizer;
            this.error = error;
            this.maxIterations = maxIterations;
            this.activationFunctions = activations;
        }

        public MultilayerPerceptronBuilder trainingType(TrainingType type) {
            this.type = type;
            return this;
        }

        public MultilayerPerceptronBuilder lambda(double lambda) {
            this.lambda = lambda;
            return this;
        }

        public MultilayerPerceptronBuilder verbose() {
            return this.verbose(true);
        }

        public MultilayerPerceptronBuilder verbose(boolean verbose) {
            this.verbose = verbose;
            return this;
        }

        public MultilayerPerceptronBuilder hiddenLayerDropout(double d) {
            this.hiddenDropoutProbability = d;
            return this;
        }

        public MultilayerPerceptronBuilder inputLayerDropout(double d) {
            this.visibleDropoutProbability = d;
            return this;
        }

        public MultilayerPerceptronBuilder withWeights(WeightMatrix[] weights) {
            this.weights = weights;
            return this;
        }

        public MultilayerPerceptronBuilder miniBatchSize(int size) {
            this.miniBatchSize = size;
            return this;
        }

        public MultilayerPerceptronBuilder batchParallelism(int numThreads) {
            this.batchParallelism = numThreads;
            return this;
        }

        public MultilayerPerceptronBuilder stochastic() {
            return this.stochastic(true);
        }

        public MultilayerPerceptronBuilder stochastic(boolean stochastic) {
            this.stochastic = stochastic;
            return this;
        }

        public MultilayerPerceptron build() {
            return new MultilayerPerceptron(this);
        }

        public static MultilayerPerceptronBuilder create(int[] layer, ActivationFunction[] activations, LossFunction errorFunction, Minimizer minimizer, int maxIteration) {
            return new MultilayerPerceptronBuilder(layer, activations, minimizer, maxIteration, errorFunction);
        }
    }
}

