package com.cezerilab.openjazarilibrary.ml.classifiers.deeplearning_blas;

import com.cezerilab.openjazarilibrary.ml.classifiers.deeplearning_blas.NNParams;

/* loaded from: input_file:com/cezerilab/openjazarilibrary/ml/classifiers/deeplearning_blas/NNClassificationExample.class */
public class NNClassificationExample {
    public static void runKaggleDigitsClassification(boolean z) throws Exception {
        if (z) {
            System.out.println("Running classification on Kaggle Digits dataset, with convolution...\n");
        } else {
            System.out.println("Running classification on Kaggle Digits dataset...\n");
        }
        Matrix[] split = MatrixUtils.split(MatrixUtils.readCSV("example_data/Kaggle_Digits_1000.csv", ',', 1), 33.0f, 0.0f);
        Matrix matrix = split[0];
        Matrix matrix2 = split[1];
        Matrix columns = matrix.getColumns(1, -1);
        Matrix columns2 = matrix.getColumns(0, 0);
        Matrix columns3 = matrix2.getColumns(1, -1);
        Matrix columns4 = matrix2.getColumns(0, 0);
        NNParams nNParams = new NNParams();
        nNParams.numClasses = 10;
        nNParams.hiddenLayerParams = z ? new NNParams.NNLayerParams[]{new NNParams.NNLayerParams(20, 5, 5, 2, 2), new NNParams.NNLayerParams(100, 5, 5, 2, 2)} : new NNParams.NNLayerParams[]{new NNParams.NNLayerParams(100)};
        nNParams.maxIterations = z ? 10 : 200;
        nNParams.learningRate = z ? 0.01d : 0.0d;
        long currentTimeMillis = System.currentTimeMillis();
        DeepNeuralNetwork deepNeuralNetwork = new DeepNeuralNetwork(nNParams);
        deepNeuralNetwork.train(columns, columns2);
        System.out.println("\nTraining time: " + String.format("%.3g", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)) + "s");
        int[] predictedClasses = deepNeuralNetwork.getPredictedClasses(columns);
        int i = 0;
        for (int i2 = 0; i2 < predictedClasses.length; i2++) {
            if (predictedClasses[i2] == columns2.get(i2, 0)) {
                i++;
            }
        }
        System.out.println("Training set accuracy: " + String.format("%.3g", Double.valueOf((i / predictedClasses.length) * 100.0d)) + "%");
        int[] predictedClasses2 = deepNeuralNetwork.getPredictedClasses(columns3);
        int i3 = 0;
        for (int i4 = 0; i4 < predictedClasses2.length; i4++) {
            if (predictedClasses2[i4] == columns4.get(i4, 0)) {
                i3++;
            }
        }
        System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", Double.valueOf((i3 / predictedClasses2.length) * 100.0d)) + "%");
    }

    public static void runKaggleTitanicClassification() throws Exception {
        System.out.println("Running classification on Kaggle Titanic dataset...\n");
        Matrix[] split = MatrixUtils.split(MatrixUtils.readCSV("example_data/Kaggle_Titanic_Cleaned.csv", ',', 1), 33.0f, 0.0f);
        Matrix matrix = split[0];
        Matrix matrix2 = split[1];
        Matrix columns = matrix.getColumns(1, -1);
        Matrix columns2 = matrix.getColumns(0, 0);
        Matrix columns3 = matrix2.getColumns(1, -1);
        Matrix columns4 = matrix2.getColumns(0, 0);
        NNParams nNParams = new NNParams();
        nNParams.numCategories = new int[]{3, 2, 1, 1, 1, 1, 3};
        nNParams.numClasses = 2;
        long currentTimeMillis = System.currentTimeMillis();
        DeepNeuralNetwork deepNeuralNetwork = new DeepNeuralNetwork(nNParams);
        deepNeuralNetwork.train(columns, columns2);
        System.out.println("\nTraining time: " + String.format("%.3g", Double.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)) + "s");
        int[] predictedClasses = deepNeuralNetwork.getPredictedClasses(columns);
        int i = 0;
        for (int i2 = 0; i2 < predictedClasses.length; i2++) {
            if (predictedClasses[i2] == columns2.get(i2, 0)) {
                i++;
            }
        }
        System.out.println("Training set accuracy: " + String.format("%.3g", Double.valueOf((i / predictedClasses.length) * 100.0d)) + "%");
        int[] predictedClasses2 = deepNeuralNetwork.getPredictedClasses(columns3);
        int i3 = 0;
        for (int i4 = 0; i4 < predictedClasses2.length; i4++) {
            if (predictedClasses2[i4] == columns4.get(i4, 0)) {
                i3++;
            }
        }
        System.out.println("Crossvalidation set accuracy: " + String.format("%.3g", Double.valueOf((i3 / predictedClasses2.length) * 100.0d)) + "%");
    }

    public static void main(String[] strArr) throws Exception {
        runKaggleDigitsClassification(false);
        System.out.println("\n\n\n");
        runKaggleDigitsClassification(true);
        System.out.println("\n\n\n");
        runKaggleTitanicClassification();
    }
}
