package org.apache.mahout.classifier;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.apache.mahout.common.MahoutTestCase;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/classifier/ConfusionMatrixTest.class */
public final class ConfusionMatrixTest extends MahoutTestCase {
    private static final int[][] VALUES = {new int[]{2, 3}, new int[]{10, 20}};
    private static final String[] LABELS = {"Label1", "Label2"};
    private static final String DEFAULT_LABEL = "other";

    @Test
    public void testBuild() {
        ConfusionMatrix fillCM = fillCM(VALUES, LABELS, DEFAULT_LABEL);
        checkValues(fillCM);
        checkAccuracy(fillCM);
    }

    @Test
    public void testGetMatrix() {
        Map rowLabelBindings = fillCM(VALUES, LABELS, DEFAULT_LABEL).getMatrix().getRowLabelBindings();
        assertEquals(r0.getLabels().size(), r0.numCols());
        assertTrue(rowLabelBindings.keySet().contains(LABELS[0]));
        assertTrue(rowLabelBindings.keySet().contains(LABELS[1]));
        assertTrue(rowLabelBindings.keySet().contains(DEFAULT_LABEL));
        assertEquals(2L, r0.getCorrect(LABELS[0]));
        assertEquals(20L, r0.getCorrect(LABELS[1]));
        assertEquals(0L, r0.getCorrect(DEFAULT_LABEL));
    }

    private static void checkValues(ConfusionMatrix confusionMatrix) {
        int[][] confusionMatrix2 = confusionMatrix.getConfusionMatrix();
        confusionMatrix.toString();
        assertEquals(confusionMatrix2.length, confusionMatrix2[0].length);
        assertEquals(3L, confusionMatrix2.length);
        assertEquals(VALUES[0][0], confusionMatrix2[0][0]);
        assertEquals(VALUES[0][1], confusionMatrix2[0][1]);
        assertEquals(VALUES[1][0], confusionMatrix2[1][0]);
        assertEquals(VALUES[1][1], confusionMatrix2[1][1]);
        assertTrue(Arrays.equals(new int[3], confusionMatrix2[2]));
        assertEquals(0L, confusionMatrix2[0][2]);
        assertEquals(0L, confusionMatrix2[1][2]);
        assertEquals(3L, confusionMatrix.getLabels().size());
        assertTrue(confusionMatrix.getLabels().contains(LABELS[0]));
        assertTrue(confusionMatrix.getLabels().contains(LABELS[1]));
        assertTrue(confusionMatrix.getLabels().contains(DEFAULT_LABEL));
    }

    private static void checkAccuracy(ConfusionMatrix confusionMatrix) {
        assertEquals(3L, confusionMatrix.getLabels().size());
        assertEquals(40.0d, confusionMatrix.getAccuracy("Label1"), 1.0E-6d);
        assertEquals(66.666666667d, confusionMatrix.getAccuracy("Label2"), 1.0E-6d);
        assertTrue(Double.isNaN(confusionMatrix.getAccuracy(DEFAULT_LABEL)));
    }

    private static ConfusionMatrix fillCM(int[][] iArr, String[] strArr, String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(strArr[0]);
        arrayList.add(strArr[1]);
        ConfusionMatrix confusionMatrix = new ConfusionMatrix(arrayList, str);
        int[][] confusionMatrix2 = confusionMatrix.getConfusionMatrix();
        confusionMatrix2[0][0] = iArr[0][0];
        confusionMatrix2[0][1] = iArr[0][1];
        confusionMatrix2[1][0] = iArr[1][0];
        confusionMatrix2[1][1] = iArr[1][1];
        return confusionMatrix;
    }
}
