/*
 * Decompiled with CFR 0.152.
 */
package org.apache.giraph.master;

import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import junit.framework.Assert;
import org.apache.giraph.combiner.MessageCombiner;
import org.apache.giraph.conf.GiraphConfiguration;
import org.apache.giraph.graph.AbstractComputation;
import org.apache.giraph.graph.Vertex;
import org.apache.giraph.master.DefaultMasterCompute;
import org.apache.giraph.utils.InternalVertexRunner;
import org.apache.giraph.utils.TestGraph;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.junit.Test;

public class TestSwitchClasses {
    @Test
    public void testSwitchingClasses() throws Exception {
        GiraphConfiguration conf = new GiraphConfiguration();
        conf.setComputationClass(Computation3.class);
        conf.setMasterComputeClass(SwitchingClassesMasterCompute.class);
        TestGraph graph = new TestGraph(conf);
        IntWritable id1 = new IntWritable(1);
        graph.addVertex((WritableComparable)id1, (Writable)new StatusValue(), new Map.Entry[0]);
        IntWritable id2 = new IntWritable(2);
        graph.addVertex((WritableComparable)id2, (Writable)new StatusValue(), new Map.Entry[0]);
        graph = InternalVertexRunner.runWithInMemoryOutput((GiraphConfiguration)conf, (TestGraph)graph);
        Assert.assertEquals((int)2, (int)graph.getVertices().size());
    }

    private static void checkVerticesOnFinalSuperstep(Vertex<IntWritable, StatusValue, IntWritable> vertex) {
        ArrayList expectedComputations = Lists.newArrayList((Object[])new Integer[]{1, 1, 2, 3, 1});
        TestSwitchClasses.checkComputations(expectedComputations, ((StatusValue)vertex.getValue()).computations);
        switch (((IntWritable)vertex.getId()).get()) {
            case 1: {
                ArrayList messages1 = Lists.newArrayList((Object[])new HashSet[]{Sets.newHashSet(), Sets.newHashSet((Object[])new Double[]{11.0}), Sets.newHashSet((Object[])new Double[]{11.0}), Sets.newHashSet((Object[])new Double[]{101.5, 201.5}), Sets.newHashSet((Object[])new Double[]{3002.0})});
                TestSwitchClasses.checkMessages(messages1, ((StatusValue)vertex.getValue()).messagesReceived);
                break;
            }
            case 2: {
                ArrayList messages2 = Lists.newArrayList((Object[])new HashSet[]{Sets.newHashSet(), Sets.newHashSet((Object[])new Double[]{12.0}), Sets.newHashSet((Object[])new Double[]{12.0}), Sets.newHashSet((Object[])new Double[]{102.5, 202.5}), Sets.newHashSet((Object[])new Double[]{3004.0})});
                TestSwitchClasses.checkMessages(messages2, ((StatusValue)vertex.getValue()).messagesReceived);
                break;
            }
            default: {
                throw new IllegalStateException("checkVertices: Illegal vertex " + vertex);
            }
        }
    }

    private static void checkComputations(ArrayList<Integer> expected, ArrayList<Integer> actual) {
        Assert.assertEquals((String)"Incorrect number of supersteps", (int)expected.size(), (int)actual.size());
        for (int i = 0; i < expected.size(); ++i) {
            Assert.assertEquals((String)("Incorrect computation on superstep " + i), (int)expected.get(i), (int)actual.get(i));
        }
    }

    private static void checkMessages(ArrayList<HashSet<Double>> expected, ArrayList<HashSet<Double>> actual) {
        Assert.assertEquals((int)expected.size(), (int)actual.size());
        for (int i = 0; i < expected.size(); ++i) {
            Assert.assertEquals((int)expected.get(i).size(), (int)actual.get(i).size());
            for (Double value : expected.get(i)) {
                Assert.assertTrue((boolean)actual.get(i).contains(value));
            }
        }
    }

    public static class StatusValue
    implements Writable {
        private ArrayList<Integer> computations = new ArrayList();
        private ArrayList<HashSet<Double>> messagesReceived = new ArrayList();

        public void addIntMessages(Iterable<IntWritable> messages) {
            HashSet<Double> messagesList = new HashSet<Double>();
            for (IntWritable message : messages) {
                messagesList.add(Double.valueOf(message.get()));
            }
            this.messagesReceived.add(messagesList);
        }

        public void addDoubleMessages(Iterable<Writable> messages) {
            HashSet<Double> messagesList = new HashSet<Double>();
            for (Writable message : messages) {
                messagesList.add(((DoubleWritable)message).get());
            }
            this.messagesReceived.add(messagesList);
        }

        public String toString() {
            return "(computations=" + this.computations + ",messagesReceived=" + this.messagesReceived + ")";
        }

        public void write(DataOutput dataOutput) throws IOException {
            dataOutput.writeInt(this.computations.size());
            for (Integer n : this.computations) {
                dataOutput.writeInt(n);
            }
            dataOutput.writeInt(this.messagesReceived.size());
            for (HashSet hashSet : this.messagesReceived) {
                dataOutput.writeInt(hashSet.size());
                for (Double message : hashSet) {
                    dataOutput.writeDouble(message);
                }
            }
        }

        public void readFields(DataInput dataInput) throws IOException {
            int i;
            int size = dataInput.readInt();
            this.computations = new ArrayList(size);
            for (i = 0; i < size; ++i) {
                this.computations.add(dataInput.readInt());
            }
            size = dataInput.readInt();
            this.messagesReceived = new ArrayList(size);
            for (i = 0; i < size; ++i) {
                int size2 = dataInput.readInt();
                HashSet<Double> messages = new HashSet<Double>(size2);
                for (int j = 0; j < size2; ++j) {
                    messages.add(dataInput.readDouble());
                }
                this.messagesReceived.add(messages);
            }
        }
    }

    public static class SumMessageCombiner
    extends MessageCombiner<IntWritable, IntWritable> {
        public void combine(IntWritable vertexIndex, IntWritable originalMessage, IntWritable messageToCombine) {
            originalMessage.set(originalMessage.get() + messageToCombine.get());
        }

        public IntWritable createInitialMessage() {
            return new IntWritable(0);
        }
    }

    public static class MinimumMessageCombiner
    extends MessageCombiner<IntWritable, IntWritable> {
        public void combine(IntWritable vertexIndex, IntWritable originalMessage, IntWritable messageToCombine) {
            originalMessage.set(Math.min(originalMessage.get(), messageToCombine.get()));
        }

        public IntWritable createInitialMessage() {
            return new IntWritable(Integer.MAX_VALUE);
        }
    }

    public static class Computation3
    extends AbstractComputation<IntWritable, StatusValue, IntWritable, Writable, Writable> {
        public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex, Iterable<Writable> messages) throws IOException {
            ((StatusValue)vertex.getValue()).computations.add(3);
            ((StatusValue)vertex.getValue()).addDoubleMessages(messages);
            IntWritable otherId = new IntWritable(3 - ((IntWritable)vertex.getId()).get());
            this.sendMessage((WritableComparable)otherId, (Writable)new IntWritable(otherId.get() + 1000));
            this.sendMessage((WritableComparable)otherId, (Writable)new IntWritable(otherId.get() + 2000));
        }
    }

    public static class Computation2
    extends AbstractComputation<IntWritable, StatusValue, IntWritable, IntWritable, DoubleWritable> {
        public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex, Iterable<IntWritable> messages) throws IOException {
            ((StatusValue)vertex.getValue()).computations.add(2);
            ((StatusValue)vertex.getValue()).addIntMessages(messages);
            IntWritable otherId = new IntWritable(3 - ((IntWritable)vertex.getId()).get());
            this.sendMessage((WritableComparable)otherId, (Writable)new DoubleWritable((double)otherId.get() + 100.5));
            this.sendMessage((WritableComparable)otherId, (Writable)new DoubleWritable((double)otherId.get() + 200.5));
        }
    }

    public static class Computation1
    extends AbstractComputation<IntWritable, StatusValue, IntWritable, IntWritable, IntWritable> {
        public void compute(Vertex<IntWritable, StatusValue, IntWritable> vertex, Iterable<IntWritable> messages) throws IOException {
            ((StatusValue)vertex.getValue()).computations.add(1);
            ((StatusValue)vertex.getValue()).addIntMessages(messages);
            IntWritable otherId = new IntWritable(3 - ((IntWritable)vertex.getId()).get());
            this.sendMessage((WritableComparable)otherId, (Writable)new IntWritable(otherId.get() + 10));
            this.sendMessage((WritableComparable)otherId, (Writable)new IntWritable(otherId.get() + 20));
            if (this.getSuperstep() == 4L) {
                TestSwitchClasses.checkVerticesOnFinalSuperstep((Vertex<IntWritable, StatusValue, IntWritable>)vertex);
            }
        }
    }

    public static class SwitchingClassesMasterCompute
    extends DefaultMasterCompute {
        public void compute() {
            switch ((int)this.getSuperstep()) {
                case 0: {
                    this.setComputation(Computation1.class);
                    this.setMessageCombiner(MinimumMessageCombiner.class);
                    break;
                }
                case 1: {
                    break;
                }
                case 2: {
                    this.setComputation(Computation2.class);
                    this.setMessageCombiner(null);
                    break;
                }
                case 3: {
                    this.setComputation(Computation3.class);
                    this.setMessageCombiner(SumMessageCombiner.class);
                    this.setIncomingMessage(DoubleWritable.class);
                    this.setOutgoingMessage(IntWritable.class);
                    break;
                }
                case 4: {
                    this.setComputation(Computation1.class);
                    this.setIncomingMessage(null);
                    this.setOutgoingMessage(null);
                    break;
                }
                default: {
                    this.haltComputation();
                }
            }
        }
    }
}

