package com.cezerilab.openjazarilibrary.ml.dl4j;

import java.util.HashMap;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/cezerilab/openjazarilibrary/ml/dl4j/LenetMnistExample.class */
public class LenetMnistExample {
    private static final Logger log = LoggerFactory.getLogger(LenetMnistExample.class);

    public static void main(String[] strArr) throws Exception {
        log.info("Load data....");
        MnistDataSetIterator mnistDataSetIterator = new MnistDataSetIterator(64, true, 12345);
        MnistDataSetIterator mnistDataSetIterator2 = new MnistDataSetIterator(64, false, 12345);
        log.info("Build model....");
        HashMap hashMap = new HashMap();
        hashMap.put(0, Double.valueOf(0.01d));
        hashMap.put(1000, Double.valueOf(0.005d));
        hashMap.put(3000, Double.valueOf(0.001d));
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(123).iterations(1).regularization(true).l2(5.0E-4d).learningRate(0.01d).learningRateDecayPolicy(LearningRatePolicy.Schedule).learningRateSchedule(hashMap).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).list().layer(0, new ConvolutionLayer.Builder(new int[]{5, 5}).nIn(1).stride(new int[]{1, 1}).nOut(20).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{2, 2}).stride(new int[]{2, 2}).build()).layer(2, new ConvolutionLayer.Builder(new int[]{5, 5}).stride(new int[]{1, 1}).nOut(50).activation(Activation.IDENTITY).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{2, 2}).stride(new int[]{2, 2}).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)).backprop(true).pretrain(false).build());
        multiLayerNetwork.init();
        log.info("Train model....");
        multiLayerNetwork.setListeners(new IterationListener[]{new ScoreIterationListener(1)});
        for (int i = 0; i < 1; i++) {
            multiLayerNetwork.fit(mnistDataSetIterator);
            log.info("*** Completed epoch {} ***", Integer.valueOf(i));
            log.info("Evaluate model....");
            log.info(multiLayerNetwork.evaluate(mnistDataSetIterator2).stats());
            mnistDataSetIterator2.reset();
        }
        log.info("****************Example finished********************");
    }
}
