package org.apache.giraph.comm;

import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.giraph.bsp.CentralizedServiceWorker;
import org.apache.giraph.comm.messages.ByteArrayMessagesPerVertexStore;
import org.apache.giraph.comm.messages.MessageEncodeAndStoreType;
import org.apache.giraph.comm.messages.MessageStore;
import org.apache.giraph.comm.messages.MessageStoreFactory;
import org.apache.giraph.conf.DefaultMessageClasses;
import org.apache.giraph.conf.GiraphConfiguration;
import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration;
import org.apache.giraph.factories.DefaultMessageValueFactory;
import org.apache.giraph.factories.TestMessageValueFactory;
import org.apache.giraph.utils.ByteArrayVertexIdMessages;
import org.apache.giraph.utils.CollectionUtils;
import org.apache.giraph.utils.IntNoOpComputation;
import org.apache.giraph.utils.MockUtils;
import org.apache.giraph.worker.EdgeInputSplitsCallable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/giraph/comm/TestMessageStores.class */
public class TestMessageStores {
    private static File directory;
    private static ImmutableClassesGiraphConfiguration<IntWritable, IntWritable, IntWritable> config;
    private static TestData testData;
    private static CentralizedServiceWorker<IntWritable, IntWritable, IntWritable> service;
    private static final Random RANDOM = new Random(101);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/giraph/comm/TestMessageStores$TestData.class */
    public static class TestData {
        int numTimes;
        int numVertices;
        int maxNumberOfMessages;
        int maxId;
        int maxMessage;
        int numOfPartitions;
        int maxMessagesInMemory;

        private TestData() {
        }
    }

    @Before
    public void prepare() {
        Configuration.addDefaultResource("giraph-site.xml");
        GiraphConfiguration giraphConfiguration = new GiraphConfiguration();
        giraphConfiguration.setComputationClass(IntNoOpComputation.class);
        config = new ImmutableClassesGiraphConfiguration<>(giraphConfiguration);
        testData = new TestData();
        testData.maxId = EdgeInputSplitsCallable.EDGES_UPDATE_PERIOD;
        testData.maxMessage = EdgeInputSplitsCallable.EDGES_UPDATE_PERIOD;
        testData.maxNumberOfMessages = 100;
        testData.numVertices = 50;
        testData.numTimes = 10;
        testData.numOfPartitions = 5;
        testData.maxMessagesInMemory = 20;
        service = MockUtils.mockServiceGetVertexPartitionOwner(testData.numOfPartitions);
    }

    @After
    public void cleanUp() {
    }

    private SortedMap<IntWritable, Collection<IntWritable>> createRandomMessages(TestData testData2) {
        TreeMap treeMap = new TreeMap();
        for (int i = 0; i < testData2.numVertices; i++) {
            int nextFloat = (int) (RANDOM.nextFloat() * testData2.maxNumberOfMessages);
            ArrayList newArrayList = Lists.newArrayList();
            for (int i2 = 0; i2 < nextFloat; i2++) {
                newArrayList.add(new IntWritable((int) (RANDOM.nextFloat() * testData2.maxMessage)));
            }
            treeMap.put(new IntWritable((int) (RANDOM.nextFloat() * testData2.maxId)), newArrayList);
        }
        return treeMap;
    }

    private static void addMessages(MessageStore<IntWritable, IntWritable> messageStore, CentralizedServiceWorker<IntWritable, ?, ?> centralizedServiceWorker, ImmutableClassesGiraphConfiguration<IntWritable, ?, ?> immutableClassesGiraphConfiguration, Map<IntWritable, Collection<IntWritable>> map) {
        for (Map.Entry<IntWritable, Collection<IntWritable>> entry : map.entrySet()) {
            int partitionId = centralizedServiceWorker.getVertexPartitionOwner(entry.getKey()).getPartitionId();
            ByteArrayVertexIdMessages byteArrayVertexIdMessages = new ByteArrayVertexIdMessages(new TestMessageValueFactory(IntWritable.class));
            byteArrayVertexIdMessages.setConf(immutableClassesGiraphConfiguration);
            byteArrayVertexIdMessages.initialize();
            Iterator<IntWritable> it2 = entry.getValue().iterator();
            while (it2.hasNext()) {
                byteArrayVertexIdMessages.add((ByteArrayVertexIdMessages) entry.getKey(), (WritableComparable) it2.next());
            }
            messageStore.addPartitionMessages(partitionId, byteArrayVertexIdMessages);
        }
    }

    private void putNTimes(MessageStore<IntWritable, IntWritable> messageStore, Map<IntWritable, Collection<IntWritable>> map, TestData testData2) {
        for (int i = 0; i < testData2.numTimes; i++) {
            SortedMap<IntWritable, Collection<IntWritable>> createRandomMessages = createRandomMessages(testData2);
            addMessages(messageStore, service, config, createRandomMessages);
            for (Map.Entry<IntWritable, Collection<IntWritable>> entry : createRandomMessages.entrySet()) {
                if (map.containsKey(entry.getKey())) {
                    map.get(entry.getKey()).addAll(entry.getValue());
                } else {
                    map.put(entry.getKey(), entry.getValue());
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private <I extends WritableComparable, M extends Writable> boolean equalMessages(MessageStore<I, M> messageStore, Map<I, Collection<M>> map, TestData testData2) {
        for (int i = 0; i < testData2.numOfPartitions; i++) {
            TreeSet newTreeSet = Sets.newTreeSet();
            Iterables.addAll(newTreeSet, messageStore.getPartitionDestinationVertices(i));
            Iterator it2 = newTreeSet.iterator();
            while (it2.hasNext()) {
                WritableComparable writableComparable = (WritableComparable) it2.next();
                Collection<M> collection = map.get(writableComparable);
                if (collection == null) {
                    return false;
                }
                Iterable vertexMessages = messageStore.getVertexMessages(writableComparable);
                if (!CollectionUtils.isEqual(collection, vertexMessages)) {
                    System.err.println("equalMessages: For vertexId " + writableComparable + " expected " + collection + ", but got " + vertexMessages);
                    return false;
                }
            }
        }
        return true;
    }

    private <S extends MessageStore<IntWritable, IntWritable>> S doCheckpoint(MessageStoreFactory<IntWritable, IntWritable, S> messageStoreFactory, S s, TestData testData2) throws IOException {
        File file = new File(directory, "messageStoreTest");
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
        for (int i = 0; i < testData2.numOfPartitions; i++) {
            s.writePartition(dataOutputStream, i);
        }
        dataOutputStream.close();
        S newStore = messageStoreFactory.newStore(new DefaultMessageClasses(IntWritable.class, DefaultMessageValueFactory.class, null, MessageEncodeAndStoreType.BYTEARRAY_PER_PARTITION));
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
        for (int i2 = 0; i2 < testData2.numOfPartitions; i2++) {
            newStore.readFieldsForPartition(dataInputStream, i2);
        }
        dataInputStream.close();
        file.delete();
        return newStore;
    }

    private <S extends MessageStore<IntWritable, IntWritable>> void testMessageStore(MessageStoreFactory<IntWritable, IntWritable, S> messageStoreFactory, TestData testData2) throws IOException {
        TreeMap treeMap = new TreeMap();
        S newStore = messageStoreFactory.newStore(new DefaultMessageClasses(IntWritable.class, DefaultMessageValueFactory.class, null, MessageEncodeAndStoreType.BYTEARRAY_PER_PARTITION));
        putNTimes(newStore, treeMap, testData2);
        Assert.assertTrue(equalMessages(newStore, treeMap, testData2));
        newStore.clearAll();
        MessageStore doCheckpoint = doCheckpoint(messageStoreFactory, newStore, testData2);
        Assert.assertTrue(equalMessages(doCheckpoint, treeMap, testData2));
        doCheckpoint.clearAll();
    }

    @Test
    public void testByteArrayMessagesPerVertexStore() {
        try {
            testMessageStore(ByteArrayMessagesPerVertexStore.newFactory(service, config), testData);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
