package org.apache.giraph.comm.messages;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.Iterator;
import junit.framework.Assert;
import org.apache.giraph.bsp.CentralizedServiceWorker;
import org.apache.giraph.combiner.DoubleSumMessageCombiner;
import org.apache.giraph.comm.messages.primitives.LongDoubleMessageStore;
import org.apache.giraph.comm.messages.primitives.long_id.LongByteArrayMessageStore;
import org.apache.giraph.conf.GiraphConfiguration;
import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
import org.apache.giraph.factories.TestMessageValueFactory;
import org.apache.giraph.graph.BasicComputation;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.partition.Partition;
import org.apache.giraph.partition.PartitionStore;
import org.apache.giraph.utils.ByteArrayVertexIdMessages;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

/* loaded from: input_file:org/apache/giraph/comm/messages/TestLongDoublePrimitiveMessageStores.class */
public class TestLongDoublePrimitiveMessageStores {
    private static final int NUM_PARTITIONS = 2;
    private static CentralizedServiceWorker<LongWritable, Writable, Writable> service;

    /* loaded from: input_file:org/apache/giraph/comm/messages/TestLongDoublePrimitiveMessageStores$LongDoubleNoOpComputation.class */
    private static class LongDoubleNoOpComputation extends BasicComputation<LongWritable, NullWritable, NullWritable, DoubleWritable> {
        private LongDoubleNoOpComputation() {
        }

        public void compute(Vertex<LongWritable, NullWritable, NullWritable> vertex, Iterable<DoubleWritable> iterable) throws IOException {
        }
    }

    @Before
    public void prepare() throws IOException {
        service = (CentralizedServiceWorker) Mockito.mock(CentralizedServiceWorker.class);
        Mockito.when(Integer.valueOf(service.getPartitionId((WritableComparable) Mockito.any(LongWritable.class)))).thenAnswer(new Answer<Integer>() { // from class: org.apache.giraph.comm.messages.TestLongDoublePrimitiveMessageStores.1
            /* renamed from: answer, reason: merged with bridge method [inline-methods] */
            public Integer m7answer(InvocationOnMock invocationOnMock) {
                return Integer.valueOf((int) (((LongWritable) invocationOnMock.getArguments()[0]).get() % 2));
            }
        });
        PartitionStore partitionStore = (PartitionStore) Mockito.mock(PartitionStore.class);
        Mockito.when(service.getPartitionStore()).thenReturn(partitionStore);
        Mockito.when(partitionStore.getPartitionIds()).thenReturn(Lists.newArrayList(new Integer[]{0, 1}));
        Partition partition = (Partition) Mockito.mock(Partition.class);
        Mockito.when(Long.valueOf(partition.getVertexCount())).thenReturn(1L);
        Mockito.when(partitionStore.getOrCreatePartition(0)).thenReturn(partition);
        Mockito.when(partitionStore.getOrCreatePartition(1)).thenReturn(partition);
    }

    private static ImmutableClassesGiraphConfiguration<LongWritable, Writable, Writable> createLongDoubleConf() {
        GiraphConfiguration giraphConfiguration = new GiraphConfiguration();
        giraphConfiguration.setComputationClass(LongDoubleNoOpComputation.class);
        return new ImmutableClassesGiraphConfiguration<>(giraphConfiguration);
    }

    private static ByteArrayVertexIdMessages<LongWritable, DoubleWritable> createLongDoubleMessages() {
        ByteArrayVertexIdMessages<LongWritable, DoubleWritable> byteArrayVertexIdMessages = new ByteArrayVertexIdMessages<>(new TestMessageValueFactory(DoubleWritable.class));
        byteArrayVertexIdMessages.setConf(createLongDoubleConf());
        byteArrayVertexIdMessages.initialize();
        return byteArrayVertexIdMessages;
    }

    private static void insertLongDoubleMessages(MessageStore<LongWritable, DoubleWritable> messageStore) throws IOException {
        ByteArrayVertexIdMessages<LongWritable, DoubleWritable> createLongDoubleMessages = createLongDoubleMessages();
        createLongDoubleMessages.add(new LongWritable(0L), new DoubleWritable(1.0d));
        createLongDoubleMessages.add(new LongWritable(2L), new DoubleWritable(3.0d));
        createLongDoubleMessages.add(new LongWritable(0L), new DoubleWritable(4.0d));
        messageStore.addPartitionMessages(0, createLongDoubleMessages);
        ByteArrayVertexIdMessages<LongWritable, DoubleWritable> createLongDoubleMessages2 = createLongDoubleMessages();
        createLongDoubleMessages2.add(new LongWritable(1L), new DoubleWritable(1.0d));
        createLongDoubleMessages2.add(new LongWritable(1L), new DoubleWritable(3.0d));
        createLongDoubleMessages2.add(new LongWritable(1L), new DoubleWritable(4.0d));
        messageStore.addPartitionMessages(1, createLongDoubleMessages2);
        ByteArrayVertexIdMessages<LongWritable, DoubleWritable> createLongDoubleMessages3 = createLongDoubleMessages();
        createLongDoubleMessages3.add(new LongWritable(0L), new DoubleWritable(5.0d));
        messageStore.addPartitionMessages(0, createLongDoubleMessages3);
    }

    @Test
    public void testLongDoubleMessageStore() throws IOException {
        LongDoubleMessageStore longDoubleMessageStore = new LongDoubleMessageStore(service, new DoubleSumMessageCombiner());
        insertLongDoubleMessages(longDoubleMessageStore);
        Iterable vertexMessages = longDoubleMessageStore.getVertexMessages(new LongWritable(0L));
        Assert.assertEquals(1, Iterables.size(vertexMessages));
        Assert.assertEquals(Double.valueOf(10.0d), Double.valueOf(((DoubleWritable) vertexMessages.iterator().next()).get()));
        Iterable vertexMessages2 = longDoubleMessageStore.getVertexMessages(new LongWritable(1L));
        Assert.assertEquals(1, Iterables.size(vertexMessages2));
        Assert.assertEquals(Double.valueOf(8.0d), Double.valueOf(((DoubleWritable) vertexMessages2.iterator().next()).get()));
        Iterable vertexMessages3 = longDoubleMessageStore.getVertexMessages(new LongWritable(2L));
        Assert.assertEquals(1, Iterables.size(vertexMessages3));
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(((DoubleWritable) vertexMessages3.iterator().next()).get()));
        Assert.assertTrue(Iterables.isEmpty(longDoubleMessageStore.getVertexMessages(new LongWritable(3L))));
    }

    @Test
    public void testLongByteArrayMessageStore() throws IOException {
        LongByteArrayMessageStore longByteArrayMessageStore = new LongByteArrayMessageStore(new TestMessageValueFactory(DoubleWritable.class), service, createLongDoubleConf());
        insertLongDoubleMessages(longByteArrayMessageStore);
        Iterable vertexMessages = longByteArrayMessageStore.getVertexMessages(new LongWritable(0L));
        Assert.assertEquals(3, Iterables.size(vertexMessages));
        Iterator it = vertexMessages.iterator();
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(((DoubleWritable) it.next()).get()));
        Assert.assertEquals(Double.valueOf(4.0d), Double.valueOf(((DoubleWritable) it.next()).get()));
        Assert.assertEquals(Double.valueOf(5.0d), Double.valueOf(((DoubleWritable) it.next()).get()));
        Iterable vertexMessages2 = longByteArrayMessageStore.getVertexMessages(new LongWritable(1L));
        Assert.assertEquals(3, Iterables.size(vertexMessages2));
        Iterator it2 = vertexMessages2.iterator();
        Assert.assertEquals(Double.valueOf(1.0d), Double.valueOf(((DoubleWritable) it2.next()).get()));
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(((DoubleWritable) it2.next()).get()));
        Assert.assertEquals(Double.valueOf(4.0d), Double.valueOf(((DoubleWritable) it2.next()).get()));
        Iterable vertexMessages3 = longByteArrayMessageStore.getVertexMessages(new LongWritable(2L));
        Assert.assertEquals(1, Iterables.size(vertexMessages3));
        Assert.assertEquals(Double.valueOf(3.0d), Double.valueOf(((DoubleWritable) vertexMessages3.iterator().next()).get()));
        Assert.assertTrue(Iterables.isEmpty(longByteArrayMessageStore.getVertexMessages(new LongWritable(3L))));
    }
}
