package hex;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.FileUtils;
import water.util.FrameUtils;
import water.util.VecUtils;

/* loaded from: input_file:hex/ConfusionMatrixTest.class */
public class ConfusionMatrixTest extends TestUtil {
    final boolean debug = false;

    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(5);
    }

    /* JADX WARN: Type inference failed for: r6v10, types: [double[], double[][]] */
    @Test
    public void testIdenticalVectors() {
        try {
            Scope.enter();
            simpleCMTest("smalldata/junit/cm/v1.csv", "smalldata/junit/cm/v1.csv", ar("A", "B", "C"), ar("A", "B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(2.0d, 0.0d, 0.0d), ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 1.0d)}), false);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r6v10, types: [double[], double[][]] */
    @Test
    public void testVectorAlignment() {
        simpleCMTest("smalldata/junit/cm/v1.csv", "smalldata/junit/cm/v2.csv", ar("A", "B", "C"), ar("A", "B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(1.0d, 1.0d, 0.0d), ard(0.0d, 1.0d, 1.0d), ard(0.0d, 0.0d, 1.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v10, types: [double[], double[][]] */
    @Test(expected = IllegalArgumentException.class)
    public void testDifferentLengthVectors() {
        simpleCMTest("smalldata/junit/cm/v1.csv", "smalldata/junit/cm/v3.csv", ar("A", "B", "C"), ar("A", "B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(1.0d, 1.0d, 0.0d), ard(0.0d, 1.0d, 1.0d), ard(0.0d, 0.0d, 1.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v20, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r6v9, types: [double[], double[][]] */
    @Test
    public void testDifferentDomains() {
        simpleCMTest("smalldata/junit/cm/v1.csv", "smalldata/junit/cm/v4.csv", ar("A", "B", "C"), ar("B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 2.0d), ard(0.0d, 0.0d, 1.0d)}), false);
        simpleCMTest("smalldata/junit/cm/v2.csv", "smalldata/junit/cm/v4.csv", ar("A", "B", "C"), ar("B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(0.0d, 1.0d, 0.0d), ard(0.0d, 1.0d, 1.0d), ard(0.0d, 0.0d, 2.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v10, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r6v22, types: [double[], double[][]] */
    @Test
    public void testSimpleNumericVectors() {
        simpleCMTest("smalldata/junit/cm/v1n.csv", "smalldata/junit/cm/v1n.csv", ar("0", "1", "2"), ar("0", "1", "2"), ar("0", "1", "2"), ard((double[][]) new double[]{ard(2.0d, 0.0d, 0.0d), ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 1.0d)}), false);
        simpleCMTest("smalldata/junit/cm/v1n.csv", "smalldata/junit/cm/v2n.csv", ar("0", "1", "2"), ar("0", "1", "2"), ar("0", "1", "2"), ard((double[][]) new double[]{ard(1.0d, 1.0d, 0.0d), ard(0.0d, 1.0d, 1.0d), ard(0.0d, 0.0d, 1.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v20, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r6v9, types: [double[], double[][]] */
    @Test
    public void testDifferentDomainsNumericVectors() {
        simpleCMTest("smalldata/junit/cm/v1n.csv", "smalldata/junit/cm/v4n.csv", ar("0", "1", "2"), ar("1", "2"), ar("0", "1", "2"), ard((double[][]) new double[]{ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 2.0d), ard(0.0d, 0.0d, 1.0d)}), false);
        simpleCMTest("smalldata/junit/cm/v2n.csv", "smalldata/junit/cm/v4n.csv", ar("0", "1", "2"), ar("1", "2"), ar("0", "1", "2"), ard((double[][]) new double[]{ard(0.0d, 1.0d, 0.0d), ard(0.0d, 1.0d, 1.0d), ard(0.0d, 0.0d, 2.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v23, types: [double[], double[][]] */
    @Test
    public void testBadModelPrect() {
        simpleCMTest(ArrayUtils.frame("v1", vec(ar("A", "B", "C"), ari(0, 0, 1, 1, 2))), ArrayUtils.frame("v1", vec(ar("A", "B", "C"), ari(1, 1, 2, 2, 2))), ar("A", "B", "C"), ar("A", "B", "C"), ar("A", "B", "C"), ard((double[][]) new double[]{ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 2.0d), ard(0.0d, 0.0d, 1.0d)}), false);
    }

    /* JADX WARN: Type inference failed for: r6v21, types: [double[], double[][]] */
    @Test
    public void testBadModelPrect2() {
        simpleCMTest(ArrayUtils.frame("v1", vec(ar("-1", "0", "1"), ari(0, 0, 1, 1, 2))), ArrayUtils.frame("v1", vec(ar("0", "1"), ari(0, 0, 1, 1, 1))), ar("-1", "0", "1"), ar("0", "1"), ar("-1", "0", "1"), ard((double[][]) new double[]{ard(0.0d, 2.0d, 0.0d), ard(0.0d, 0.0d, 2.0d), ard(0.0d, 0.0d, 1.0d)}), false);
    }

    private void simpleCMTest(String str, String str2, String[] strArr, String[] strArr2, String[] strArr3, double[][] dArr, boolean z) {
        try {
            Frame parseFrame = FrameUtils.parseFrame(Key.make("v1.hex"), new File[]{FileUtils.getFile(str)});
            Frame parseFrame2 = FrameUtils.parseFrame(Key.make("v2.hex"), new File[]{FileUtils.getFile(str2)});
            if (!parseFrame.isCompatible(parseFrame2)) {
                parseFrame2 = new Frame(parseFrame.makeCompatible(parseFrame2));
                parseFrame2.delete();
            }
            simpleCMTest(parseFrame, parseFrame2, strArr, strArr2, strArr3, dArr, z);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void simpleCMTest(Frame frame, Frame frame2, String[] strArr, String[] strArr2, String[] strArr3, double[][] dArr, boolean z) {
        Scope.enter();
        try {
            ConfusionMatrix buildCM = ConfusionMatrix.buildCM(VecUtils.toCategoricalVec(frame.vecs()[0]), VecUtils.toCategoricalVec(frame2.vecs()[0]));
            if (z) {
                System.err.println("actual            : " + Arrays.toString(strArr));
                System.err.println("predicted         : " + Arrays.toString(strArr2));
                System.err.println("CM domain         : " + Arrays.toString(buildCM._domain));
                System.err.println("expected CM domain: " + Arrays.toString(strArr3) + "\n");
                for (int i = 0; i < buildCM._cm.length; i++) {
                    System.err.println(Arrays.toString(buildCM._cm[i]));
                }
                System.err.println("");
                System.err.println(buildCM.toASCII());
            }
            assertCMEqual(strArr3, dArr, buildCM);
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            if (frame2 != null) {
                frame2.delete();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private void assertCMEqual(String[] strArr, double[][] dArr, ConfusionMatrix confusionMatrix) {
        Assert.assertArrayEquals("Expected domain differs", strArr, confusionMatrix._domain);
        double[][] dArr2 = confusionMatrix._cm;
        Assert.assertEquals("CM dimension differs", dArr.length, dArr2.length);
        for (int i = 0; i < dArr2.length; i++) {
            Assert.assertArrayEquals("CM row " + i + " differs!", dArr[i], dArr2[i], 1.0E-10d);
        }
    }
}
