package com.cezerilab.openjazarilibrary.ml.classifiers;

import com.cezerilab.openjazarilibrary.core.CMatrix;
import com.cezerilab.openjazarilibrary.factory.FactoryUtils;
import com.cezerilab.openjazarilibrary.types.TMatrixOperator;

/* loaded from: input_file:com/cezerilab/openjazarilibrary/ml/classifiers/C_NaiveBayes.class */
public class C_NaiveBayes {

    /* loaded from: input_file:com/cezerilab/openjazarilibrary/ml/classifiers/C_NaiveBayes$AttributeType.class */
    public enum AttributeType {
        Categorical,
        Real
    }

    public static void main(String[] strArr) {
        evaluateModelTrainTest(AttributeType.Categorical, "src\\cezeri\\classifiers\\tic-tac-toe_train.txt", "src\\cezeri\\classifiers\\tic-tac-toe_test.txt");
    }

    public static void evaluateModelTrainTest(AttributeType attributeType, String str, String str2) {
        if (attributeType.equals(AttributeType.Categorical)) {
            evaluateModelCategoricalTrainTest(str, str2);
        }
        if (attributeType.equals(AttributeType.Real)) {
            evaluateModelRealTrainTest(str, str2);
        }
    }

    public static void evaluateModelCrossValidation(AttributeType attributeType, String str, int i) {
        if (attributeType.equals(AttributeType.Categorical)) {
            evaluateModelCategoricalCV(str, i);
        }
        if (attributeType.equals(AttributeType.Real)) {
            evaluateModelRealCV(str, i);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v86, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v88, types: [int[], int[][]] */
    private static void evaluateModelCategoricalTrainTest(String str, String str2) {
        CMatrix shuffleRows = CMatrix.getInstanceFromFile(str, ",").shuffleRows();
        CMatrix shuffleRows2 = CMatrix.getInstanceFromFile(str2, ",").shuffleRows();
        CMatrix cMatrix = CMatrix.getInstance();
        CMatrix cMatrix2 = CMatrix.getInstance();
        int columnNumber = shuffleRows.getColumnNumber() - 1;
        CMatrix find = shuffleRows.find(TMatrixOperator.EQUALS, 1.0d, ":", "" + columnNumber);
        CMatrix find2 = shuffleRows.find(TMatrixOperator.EQUALS, -1.0d, ":", "" + columnNumber);
        if (find.getSize().column == 1) {
            cMatrix = shuffleRows.matrix((int[][]) new int[]{find.toIntArray1D()});
        }
        if (find2.getSize().column == 1) {
            cMatrix2 = shuffleRows.matrix((int[][]) new int[]{find2.toIntArray1D()});
        }
        int rowNumber = shuffleRows.getRowNumber();
        int columnNumber2 = shuffleRows.getColumnNumber();
        int rowNumber2 = cMatrix.getRowNumber();
        int rowNumber3 = cMatrix2.getRowNumber();
        double d = (1.0d * rowNumber2) / rowNumber;
        double d2 = (1.0d * rowNumber3) / rowNumber;
        double[][] dArr = new double[columnNumber2 - 1][3];
        double[][][] dArr2 = new double[columnNumber2 - 1][3][2];
        for (int i = 0; i < columnNumber2 - 1; i++) {
            CMatrix commandParser = shuffleRows.commandParser(":", "" + i);
            dArr[i][0] = (commandParser.find(TMatrixOperator.EQUALS, 2.0d).getRowNumber() * 1.0d) / rowNumber;
            dArr[i][1] = (commandParser.find(TMatrixOperator.EQUALS, 3.0d).getRowNumber() * 1.0d) / rowNumber;
            dArr[i][2] = (commandParser.find(TMatrixOperator.EQUALS, 4.0d).getRowNumber() * 1.0d) / rowNumber;
            CMatrix commandParser2 = cMatrix.commandParser(":", "" + i);
            CMatrix commandParser3 = cMatrix2.commandParser(":", "" + i);
            dArr2[i][0][0] = (commandParser2.find(TMatrixOperator.EQUALS, 2.0d).getRowNumber() * 1.0d) / rowNumber2;
            dArr2[i][0][1] = (commandParser3.find(TMatrixOperator.EQUALS, 2.0d).getRowNumber() * 1.0d) / rowNumber3;
            dArr2[i][1][0] = (commandParser2.find(TMatrixOperator.EQUALS, 3.0d).getRowNumber() * 1.0d) / rowNumber2;
            dArr2[i][1][1] = (commandParser3.find(TMatrixOperator.EQUALS, 3.0d).getRowNumber() * 1.0d) / rowNumber3;
            dArr2[i][2][0] = (commandParser2.find(TMatrixOperator.EQUALS, 4.0d).getRowNumber() * 1.0d) / rowNumber2;
            dArr2[i][2][1] = (commandParser3.find(TMatrixOperator.EQUALS, 4.0d).getRowNumber() * 1.0d) / rowNumber3;
        }
        double[] dArr3 = new double[shuffleRows2.getRowNumber()];
        double[] dArr4 = new double[shuffleRows2.getRowNumber()];
        double[] dArr5 = new double[shuffleRows2.getRowNumber()];
        double[] dArr6 = new double[shuffleRows2.getRowNumber()];
        double[] dArr7 = new double[shuffleRows2.getRowNumber()];
        double[] dArr8 = new double[shuffleRows2.getRowNumber()];
        for (int i2 = 0; i2 < shuffleRows2.getRowNumber(); i2++) {
            CMatrix commandParser4 = shuffleRows2.commandParser("" + i2, "0:8");
            double[] dArr9 = new double[commandParser4.getRowNumber()];
            double[] dArr10 = new double[commandParser4.getRowNumber()];
            for (int i3 = 0; i3 < commandParser4.getRowNumber(); i3++) {
                int value = (int) (commandParser4.getValue(i3, 0) - 2.0d);
                dArr9[i3] = dArr2[i3][value][0];
                dArr10[i3] = dArr2[i3][value][1];
            }
            dArr3[i2] = d * CMatrix.getInstance(dArr9).prod().getValue();
            dArr4[i2] = d2 * CMatrix.getInstance(dArr10).prod().getValue();
            dArr5[i2] = (dArr3[i2] / (dArr3[i2] + dArr4[i2])) * 100.0d;
            dArr6[i2] = 100.0d - dArr5[i2];
            if (dArr5[i2] > dArr6[i2]) {
                dArr8[i2] = 1.0d;
                dArr7[i2] = dArr5[i2];
            } else {
                dArr8[i2] = -1.0d;
                dArr7[i2] = dArr6[i2];
            }
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < shuffleRows2.getRowNumber(); i4++) {
            if (shuffleRows2.getValue(i4, 9) == dArr8[i4]) {
                d3 += 1.0d;
            }
        }
        System.out.println("accuracy:" + ((d3 / shuffleRows2.getRowNumber()) * 100.0d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v6, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v24, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v26, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v28, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v30, types: [int[], int[][]] */
    private static void evaluateModelRealTrainTest(String str, String str2) {
        CMatrix shuffleRows = CMatrix.getInstanceFromFile(str, ",").shuffleRows();
        CMatrix shuffleRows2 = CMatrix.getInstanceFromFile(str2, ",").shuffleRows();
        CMatrix.getInstance();
        CMatrix.getInstance();
        int columnNumber = shuffleRows.getColumnNumber() - 1;
        CMatrix find = shuffleRows.find(TMatrixOperator.EQUALS, 1.0d, ":", "" + columnNumber);
        CMatrix find2 = shuffleRows.find(TMatrixOperator.EQUALS, -1.0d, ":", "" + columnNumber);
        CMatrix matrix = shuffleRows.matrix((int[][]) new int[]{find.toIntArray1D()});
        CMatrix matrix2 = shuffleRows.matrix((int[][]) new int[]{find2.toIntArray1D()});
        int rowNumber = shuffleRows.getRowNumber();
        int columnNumber2 = shuffleRows.getColumnNumber();
        int rowNumber2 = matrix.getRowNumber();
        int rowNumber3 = matrix2.getRowNumber();
        double d = (1.0d * rowNumber2) / rowNumber;
        double d2 = (1.0d * rowNumber3) / rowNumber;
        double[] dArr = new double[columnNumber2 - 1];
        double[] dArr2 = new double[columnNumber2 - 1];
        double[] dArr3 = new double[columnNumber2 - 1];
        double[] dArr4 = new double[columnNumber2 - 1];
        for (int i = 0; i < columnNumber2 - 1; i++) {
            CMatrix commandParser = shuffleRows.commandParser(":", "" + i);
            dArr[i] = commandParser.cmd((int[][]) new int[]{find.toIntArray1D()}).meanTotal();
            dArr2[i] = commandParser.cmd((int[][]) new int[]{find2.toIntArray1D()}).meanTotal();
            dArr3[i] = commandParser.cmd((int[][]) new int[]{find.toIntArray1D()}).stdTotal();
            dArr4[i] = commandParser.cmd((int[][]) new int[]{find2.toIntArray1D()}).stdTotal();
        }
        double d3 = 0.0d;
        for (int i2 = 0; i2 < shuffleRows2.getRowNumber(); i2++) {
            CMatrix commandParser2 = shuffleRows2.commandParser("" + i2, ":");
            double d4 = d;
            double d5 = d2;
            for (int i3 = 0; i3 < commandParser2.getColumnNumber() - 1; i3++) {
                double d6 = commandParser2.toDoubleArray1D()[i3];
                d4 *= (1.0d / (Math.sqrt(6.283185307179586d) * dArr3[i3])) * Math.exp((-1.0d) * (((d6 - dArr[i3]) * (d6 - dArr[i3])) / ((2.0d * dArr3[i3]) * dArr3[i3])));
                d5 *= (1.0d / (Math.sqrt(6.283185307179586d) * dArr4[i3])) * Math.exp((-1.0d) * (((d6 - dArr2[i3]) * (d6 - dArr2[i3])) / ((2.0d * dArr4[i3]) * dArr4[i3])));
            }
            if (d4 / (d4 + d5) > d5 / (d4 + d5)) {
                if (commandParser2.toDoubleArray1D()[commandParser2.getColumnNumber() - 1] == 1.0d) {
                    d3 += 1.0d;
                }
            } else if (commandParser2.toDoubleArray1D()[commandParser2.getColumnNumber() - 1] == -1.0d) {
                d3 += 1.0d;
            }
        }
        System.out.println("accuracy:" + ((d3 / shuffleRows2.getRowNumber()) * 100.0d));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v100, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v98, types: [int[], int[][]] */
    private static void evaluateModelCategoricalCV(String str, int i) {
        CMatrix shuffleRows = CMatrix.getInstanceFromFile(str, ",").shuffleRows();
        CMatrix cMatrix = CMatrix.getInstance();
        CMatrix cMatrix2 = CMatrix.getInstance();
        CMatrix[][] crossValidationSets = shuffleRows.crossValidationSets(i);
        double[] distinctValues = FactoryUtils.getDistinctValues(shuffleRows.getColumn(shuffleRows.getColumnNumber() - 1));
        double[] distinctValues2 = FactoryUtils.getDistinctValues(shuffleRows.getColumn(0));
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            CMatrix cMatrix3 = crossValidationSets[i2][0];
            CMatrix cMatrix4 = crossValidationSets[i2][1];
            int columnNumber = cMatrix3.getColumnNumber() - 1;
            CMatrix find = cMatrix3.find(TMatrixOperator.EQUALS, distinctValues[0], ":", "" + columnNumber);
            CMatrix find2 = cMatrix3.find(TMatrixOperator.EQUALS, distinctValues[1], ":", "" + columnNumber);
            if (find.getSize().column == 1) {
                cMatrix = cMatrix3.matrix((int[][]) new int[]{find.toIntArray1D()});
            }
            if (find2.getSize().column == 1) {
                cMatrix2 = cMatrix3.matrix((int[][]) new int[]{find2.toIntArray1D()});
            }
            int rowNumber = cMatrix3.getRowNumber();
            int columnNumber2 = cMatrix3.getColumnNumber();
            int rowNumber2 = cMatrix.getRowNumber();
            int rowNumber3 = cMatrix2.getRowNumber();
            double d2 = (1.0d * rowNumber2) / rowNumber;
            double d3 = (1.0d * rowNumber3) / rowNumber;
            double[][] dArr = new double[columnNumber2 - 1][3];
            double[][][] dArr2 = new double[columnNumber2 - 1][3][2];
            for (int i3 = 0; i3 < columnNumber2 - 1; i3++) {
                CMatrix commandParser = cMatrix3.commandParser(":", "" + i3);
                dArr[i3][0] = (commandParser.find(TMatrixOperator.EQUALS, distinctValues2[0]).getRowNumber() * 1.0d) / rowNumber;
                dArr[i3][1] = (commandParser.find(TMatrixOperator.EQUALS, distinctValues2[1]).getRowNumber() * 1.0d) / rowNumber;
                dArr[i3][2] = (commandParser.find(TMatrixOperator.EQUALS, distinctValues2[2]).getRowNumber() * 1.0d) / rowNumber;
                CMatrix commandParser2 = cMatrix.commandParser(":", "" + i3);
                CMatrix commandParser3 = cMatrix2.commandParser(":", "" + i3);
                dArr2[i3][0][0] = (commandParser2.find(TMatrixOperator.EQUALS, distinctValues2[0]).getRowNumber() * 1.0d) / rowNumber2;
                dArr2[i3][0][1] = (commandParser3.find(TMatrixOperator.EQUALS, distinctValues2[0]).getRowNumber() * 1.0d) / rowNumber3;
                dArr2[i3][1][0] = (commandParser2.find(TMatrixOperator.EQUALS, distinctValues2[1]).getRowNumber() * 1.0d) / rowNumber2;
                dArr2[i3][1][1] = (commandParser3.find(TMatrixOperator.EQUALS, distinctValues2[1]).getRowNumber() * 1.0d) / rowNumber3;
                dArr2[i3][2][0] = (commandParser2.find(TMatrixOperator.EQUALS, distinctValues2[2]).getRowNumber() * 1.0d) / rowNumber2;
                dArr2[i3][2][1] = (commandParser3.find(TMatrixOperator.EQUALS, distinctValues2[2]).getRowNumber() * 1.0d) / rowNumber3;
            }
            double[] dArr3 = new double[cMatrix4.getRowNumber()];
            double[] dArr4 = new double[cMatrix4.getRowNumber()];
            double[] dArr5 = new double[cMatrix4.getRowNumber()];
            double[] dArr6 = new double[cMatrix4.getRowNumber()];
            double[] dArr7 = new double[cMatrix4.getRowNumber()];
            double[] dArr8 = new double[cMatrix4.getRowNumber()];
            for (int i4 = 0; i4 < cMatrix4.getRowNumber(); i4++) {
                CMatrix commandParser4 = cMatrix4.commandParser("" + i4, "0:8");
                double[] dArr9 = new double[commandParser4.getRowNumber()];
                double[] dArr10 = new double[commandParser4.getRowNumber()];
                for (int i5 = 0; i5 < commandParser4.getRowNumber(); i5++) {
                    int value = (int) (commandParser4.getValue(i5, 0) - 2.0d);
                    dArr9[i5] = dArr2[i5][value][0];
                    dArr10[i5] = dArr2[i5][value][1];
                }
                dArr3[i4] = d2 * CMatrix.getInstance(dArr9).prod().getValue();
                dArr4[i4] = d3 * CMatrix.getInstance(dArr10).prod().getValue();
                dArr5[i4] = (dArr3[i4] / (dArr3[i4] + dArr4[i4])) * 100.0d;
                dArr6[i4] = 100.0d - dArr5[i4];
                if (dArr5[i4] > dArr6[i4]) {
                    dArr8[i4] = distinctValues[0];
                    dArr7[i4] = dArr5[i4];
                } else {
                    dArr8[i4] = distinctValues[1];
                    dArr7[i4] = dArr6[i4];
                }
            }
            double d4 = 0.0d;
            for (int i6 = 0; i6 < cMatrix4.getRowNumber(); i6++) {
                if (cMatrix4.getValue(i6, 9) == dArr8[i6]) {
                    d4 += 1.0d;
                }
            }
            d += (d4 / cMatrix4.getRowNumber()) * 100.0d;
        }
        System.out.println("average accuracy:" + (d / i));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v20, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v22, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v26, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v28, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v30, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v32, types: [int[], int[][]] */
    private static void evaluateModelRealCV(String str, int i) {
        CMatrix shuffleRows = CMatrix.getInstanceFromFile(str, ",").shuffleRows();
        CMatrix.getInstance();
        CMatrix.getInstance();
        CMatrix[][] crossValidationSets = shuffleRows.crossValidationSets(i);
        double[] distinctValues = FactoryUtils.getDistinctValues(shuffleRows.getColumn(shuffleRows.getColumnNumber() - 1));
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            CMatrix cMatrix = crossValidationSets[i2][0];
            CMatrix cMatrix2 = crossValidationSets[i2][1];
            int columnNumber = cMatrix.getColumnNumber() - 1;
            CMatrix find = cMatrix.find(TMatrixOperator.EQUALS, distinctValues[0], ":", "" + columnNumber);
            CMatrix find2 = cMatrix.find(TMatrixOperator.EQUALS, distinctValues[1], ":", "" + columnNumber);
            CMatrix matrix = cMatrix.matrix((int[][]) new int[]{find.toIntArray1D()});
            CMatrix matrix2 = cMatrix.matrix((int[][]) new int[]{find2.toIntArray1D()});
            int rowNumber = cMatrix.getRowNumber();
            int columnNumber2 = cMatrix.getColumnNumber();
            int rowNumber2 = matrix.getRowNumber();
            int rowNumber3 = matrix2.getRowNumber();
            double d2 = (1.0d * rowNumber2) / rowNumber;
            double d3 = (1.0d * rowNumber3) / rowNumber;
            double[] dArr = new double[columnNumber2 - 1];
            double[] dArr2 = new double[columnNumber2 - 1];
            double[] dArr3 = new double[columnNumber2 - 1];
            double[] dArr4 = new double[columnNumber2 - 1];
            for (int i3 = 0; i3 < columnNumber2 - 1; i3++) {
                CMatrix commandParser = cMatrix.commandParser(":", "" + i3);
                dArr[i3] = commandParser.cmd((int[][]) new int[]{find.toIntArray1D()}).meanTotal();
                dArr2[i3] = commandParser.cmd((int[][]) new int[]{find2.toIntArray1D()}).meanTotal();
                dArr3[i3] = commandParser.cmd((int[][]) new int[]{find.toIntArray1D()}).stdTotal();
                dArr4[i3] = commandParser.cmd((int[][]) new int[]{find2.toIntArray1D()}).stdTotal();
            }
            double d4 = 0.0d;
            for (int i4 = 0; i4 < cMatrix2.getRowNumber(); i4++) {
                CMatrix commandParser2 = cMatrix2.commandParser("" + i4, ":");
                double d5 = d2;
                double d6 = d3;
                for (int i5 = 0; i5 < commandParser2.getColumnNumber() - 1; i5++) {
                    double d7 = commandParser2.toDoubleArray1D()[i5];
                    d5 *= (1.0d / (Math.sqrt(6.283185307179586d) * dArr3[i5])) * Math.exp((-1.0d) * (((d7 - dArr[i5]) * (d7 - dArr[i5])) / ((2.0d * dArr3[i5]) * dArr3[i5])));
                    d6 *= (1.0d / (Math.sqrt(6.283185307179586d) * dArr4[i5])) * Math.exp((-1.0d) * (((d7 - dArr2[i5]) * (d7 - dArr2[i5])) / ((2.0d * dArr4[i5]) * dArr4[i5])));
                }
                if (d5 / (d5 + d6) > d6 / (d5 + d6)) {
                    if (commandParser2.toDoubleArray1D()[commandParser2.getColumnNumber() - 1] == distinctValues[0]) {
                        d4 += 1.0d;
                    }
                } else if (commandParser2.toDoubleArray1D()[commandParser2.getColumnNumber() - 1] == distinctValues[1]) {
                    d4 += 1.0d;
                }
            }
            d += (d4 / cMatrix2.getRowNumber()) * 100.0d;
        }
        System.out.println("average accuracy:" + (d / i));
    }
}
