package org.apache.mahout.math;

import com.google.common.collect.Maps;
import com.google.common.io.Closeables;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.hadoop.io.Writable;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/MatrixWritableTest.class */
public final class MatrixWritableTest extends MahoutTestCase {
    @Test
    public void testSparseMatrixWritable() throws Exception {
        SparseMatrix sparseMatrix = new SparseMatrix(5, 5);
        sparseMatrix.set(1, 2, 3.0d);
        sparseMatrix.set(3, 4, 5.0d);
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("A", 0);
        newHashMap.put("B", 1);
        newHashMap.put("C", 2);
        newHashMap.put("D", 3);
        newHashMap.put("default", 4);
        sparseMatrix.setRowLabelBindings(newHashMap);
        sparseMatrix.setColumnLabelBindings(newHashMap);
        doTestMatrixWritableEquals(sparseMatrix);
    }

    @Test
    public void testSparseRowMatrixWritable() throws Exception {
        SparseRowMatrix sparseRowMatrix = new SparseRowMatrix(5, 5);
        sparseRowMatrix.set(1, 2, 3.0d);
        sparseRowMatrix.set(3, 4, 5.0d);
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("A", 0);
        newHashMap.put("B", 1);
        newHashMap.put("C", 2);
        newHashMap.put("D", 3);
        newHashMap.put("default", 4);
        sparseRowMatrix.setRowLabelBindings(newHashMap);
        sparseRowMatrix.setColumnLabelBindings(newHashMap);
        doTestMatrixWritableEquals(sparseRowMatrix);
    }

    @Test
    public void testDenseMatrixWritable() throws Exception {
        DenseMatrix denseMatrix = new DenseMatrix(5, 5);
        denseMatrix.set(1, 2, 3.0d);
        denseMatrix.set(3, 4, 5.0d);
        HashMap newHashMap = Maps.newHashMap();
        newHashMap.put("A", 0);
        newHashMap.put("B", 1);
        newHashMap.put("C", 2);
        newHashMap.put("D", 3);
        newHashMap.put("default", 4);
        denseMatrix.setRowLabelBindings(newHashMap);
        denseMatrix.setColumnLabelBindings(newHashMap);
        doTestMatrixWritableEquals(denseMatrix);
    }

    private static void doTestMatrixWritableEquals(Matrix matrix) throws IOException {
        MatrixWritable matrixWritable = new MatrixWritable(matrix);
        MatrixWritable matrixWritable2 = new MatrixWritable();
        writeAndRead(matrixWritable, matrixWritable2);
        Matrix matrix2 = matrixWritable2.get();
        compareMatrices(matrix, matrix2);
        doCheckBindings(matrix2.getRowLabelBindings());
        doCheckBindings(matrix2.getColumnLabelBindings());
    }

    private static void compareMatrices(Matrix matrix, Matrix matrix2) {
        assertEquals(matrix.numRows(), matrix2.numRows());
        assertEquals(matrix.numCols(), matrix2.numCols());
        for (int i = 0; i < matrix.numRows(); i++) {
            for (int i2 = 0; i2 < matrix.numCols(); i2++) {
                assertEquals(matrix.get(i, i2), matrix2.get(i, i2), 1.0E-6d);
            }
        }
        Map rowLabelBindings = matrix.getRowLabelBindings();
        Map rowLabelBindings2 = matrix2.getRowLabelBindings();
        assertEquals(Boolean.valueOf(rowLabelBindings == null), Boolean.valueOf(rowLabelBindings2 == null));
        if (rowLabelBindings != null) {
            assertEquals(rowLabelBindings.size(), matrix.numRows());
            assertEquals(rowLabelBindings.size(), rowLabelBindings2.size());
            for (Map.Entry entry : rowLabelBindings.entrySet()) {
                assertEquals(entry.getValue(), rowLabelBindings2.get(entry.getKey()));
            }
        }
        Map columnLabelBindings = matrix.getColumnLabelBindings();
        Map columnLabelBindings2 = matrix2.getColumnLabelBindings();
        assertEquals(Boolean.valueOf(columnLabelBindings == null), Boolean.valueOf(columnLabelBindings2 == null));
        if (columnLabelBindings != null) {
            assertEquals(columnLabelBindings.size(), columnLabelBindings2.size());
            for (Map.Entry entry2 : columnLabelBindings.entrySet()) {
                assertEquals(entry2.getValue(), columnLabelBindings2.get(entry2.getKey()));
            }
        }
    }

    private static void doCheckBindings(Map<String, Integer> map) {
        assertTrue("Missing label", map.keySet().contains("A"));
        assertTrue("Missing label", map.keySet().contains("B"));
        assertTrue("Missing label", map.keySet().contains("C"));
        assertTrue("Missing label", map.keySet().contains("D"));
        assertTrue("Missing label", map.keySet().contains("default"));
    }

    private static void writeAndRead(Writable writable, Writable writable2) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        try {
            writable.write(dataOutputStream);
            Closeables.close(dataOutputStream, false);
            DataInputStream dataInputStream = new DataInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
            try {
                writable2.readFields(dataInputStream);
                Closeables.close(dataInputStream, true);
            } catch (Throwable th) {
                Closeables.close(dataInputStream, true);
                throw th;
            }
        } catch (Throwable th2) {
            Closeables.close(dataOutputStream, false);
            throw th2;
        }
    }
}
