package com.cezerilab.openjazarilibrary.ml.classifiers.deeplearning_blas;

import com.cezerilab.openjazarilibrary.ml.classifiers.deeplearning_blas.NNParams;
import java.util.ArrayList;

/* loaded from: input_file:com/cezerilab/openjazarilibrary/ml/classifiers/deeplearning_blas/NeuralNetworkTest.class */
public class NeuralNetworkTest {
    public static void runKaggleDigitsClassification() throws Exception {
        Matrix[] split = MatrixUtils.split(MatrixUtils.readCSV("data/Kaggle_Digits_1000.csv", ';', 1), 33.0f, 0.0f);
        Matrix matrix = split[0];
        Matrix matrix2 = split[1];
        Matrix columns = matrix.getColumns(1, -1);
        Matrix columns2 = matrix.getColumns(0, 0);
        Matrix columns3 = matrix2.getColumns(1, -1);
        Matrix columns4 = matrix2.getColumns(0, 0);
        NNParams nNParams = new NNParams();
        nNParams.numClasses = 10;
        nNParams.hiddenLayerParams = new NNParams.NNLayerParams[]{new NNParams.NNLayerParams(50, 5, 5, 2, 2), new NNParams.NNLayerParams(200, 5, 5, 2, 2)};
        nNParams.learningRate = 0.01d;
        nNParams.maxIterations = 10;
        nNParams.numThreads = 0;
        long currentTimeMillis = System.currentTimeMillis();
        DeepNeuralNetwork deepNeuralNetwork = new DeepNeuralNetwork(nNParams);
        deepNeuralNetwork.train(columns, columns2);
        System.out.println("Training time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + "s");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        MatrixUtils.split(columns, columns2, nNParams.batchSize, arrayList, arrayList2);
        int i = 0;
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int[] predictedClasses = deepNeuralNetwork.getPredictedClasses((Matrix) arrayList.get(i2));
            for (int i3 = 0; i3 < predictedClasses.length; i3++) {
                if (predictedClasses[i3] == ((Matrix) arrayList2.get(i2)).get(i3, 0)) {
                    i++;
                }
            }
        }
        System.out.println("Training set accuracy: " + ((i / columns.numRows()) * 100.0d) + "%");
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        MatrixUtils.split(columns3, columns4, nNParams.batchSize, arrayList3, arrayList4);
        int i4 = 0;
        for (int i5 = 0; i5 < arrayList3.size(); i5++) {
            int[] predictedClasses2 = deepNeuralNetwork.getPredictedClasses((Matrix) arrayList3.get(i5));
            for (int i6 = 0; i6 < predictedClasses2.length; i6++) {
                if (predictedClasses2[i6] == ((Matrix) arrayList4.get(i5)).get(i6, 0)) {
                    i4++;
                }
            }
        }
        System.out.println("Crossvalidation set accuracy: " + ((i4 / columns3.numRows()) * 100.0d) + "%");
    }

    public static void main(String[] strArr) throws Exception {
        runKaggleDigitsClassification();
    }
}
