package org.apache.mahout.math;

import com.google.common.io.Closeables;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import org.apache.hadoop.io.Writable;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/VectorWritableTest.class */
public final class VectorWritableTest extends MahoutTestCase {
    @Test
    public void testSequentialAccessSparseVectorWritable() throws Exception {
        SequentialAccessSparseVector sequentialAccessSparseVector = new SequentialAccessSparseVector(5);
        sequentialAccessSparseVector.set(1, 3.0d);
        sequentialAccessSparseVector.set(3, 5.0d);
        doTestVectorWritableEquals(sequentialAccessSparseVector);
    }

    @Test
    public void testRandomAccessSparseVectorWritable() throws Exception {
        RandomAccessSparseVector randomAccessSparseVector = new RandomAccessSparseVector(5);
        randomAccessSparseVector.set(1, 3.0d);
        randomAccessSparseVector.set(3, 5.0d);
        doTestVectorWritableEquals(randomAccessSparseVector);
    }

    @Test
    public void testDenseVectorWritable() throws Exception {
        DenseVector denseVector = new DenseVector(5);
        denseVector.set(1, 3.0d);
        denseVector.set(3, 5.0d);
        doTestVectorWritableEquals(denseVector);
    }

    @Test
    public void testNamedVectorWritable() throws Exception {
        NamedVector namedVector = new NamedVector(new DenseVector(5), "Victor");
        namedVector.set(1, 3.0d);
        namedVector.set(3, 5.0d);
        doTestVectorWritableEquals(namedVector);
    }

    private static void doTestVectorWritableEquals(Vector vector) throws IOException {
        VectorWritable vectorWritable = new VectorWritable(vector);
        VectorWritable vectorWritable2 = new VectorWritable();
        writeAndRead(vectorWritable, vectorWritable2);
        NamedVector namedVector = vectorWritable2.get();
        if (vector instanceof NamedVector) {
            assertTrue(namedVector instanceof NamedVector);
            NamedVector namedVector2 = (NamedVector) vector;
            assertEquals(namedVector2.getName(), namedVector.getName());
            assertEquals("Victor", namedVector2.getName());
        }
        assertEquals(vector, namedVector);
    }

    private static void writeAndRead(Writable writable, Writable writable2) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        try {
            writable.write(dataOutputStream);
            Closeables.closeQuietly(dataOutputStream);
            DataInputStream dataInputStream = new DataInputStream(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()));
            try {
                writable2.readFields(dataInputStream);
                Closeables.closeQuietly(dataInputStream);
            } catch (Throwable th) {
                Closeables.closeQuietly(dataInputStream);
                throw th;
            }
        } catch (Throwable th2) {
            Closeables.closeQuietly(dataOutputStream);
            throw th2;
        }
    }
}
