/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.cartesianproduct;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.TezReflectionException;
import org.apache.tez.dag.api.UserPayload;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.dag.api.event.VertexStateUpdate;
import org.apache.tez.dag.records.TaskAttemptIdentifierImpl;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
import org.apache.tez.runtime.api.TaskAttemptIdentifier;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductEdgeManager;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductFilter;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload;
import org.apache.tez.runtime.library.cartesianproduct.CartesianProductVertexManagerPartitioned;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.verification.VerificationMode;

public class TestCartesianProductVertexManagerPartitioned {
    @Captor
    private ArgumentCaptor<Map<String, EdgeProperty>> edgePropertiesCaptor;
    @Captor
    private ArgumentCaptor<List<VertexManagerPluginContext.ScheduleTaskRequest>> scheduleTaskRequestCaptor;
    private CartesianProductVertexManagerPartitioned vertexManager;
    private VertexManagerPluginContext context;
    private List<TaskAttemptIdentifier> allCompletions;

    @Before
    public void setup() throws TezReflectionException {
        CartesianProductUserPayload.CartesianProductConfigProto.Builder builder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder();
        builder.setIsPartitioned(true).addSources("v0").addSources("v1").addNumPartitions(2).addNumPartitions(2);
        this.setupWithConfig(builder.build());
    }

    private void setupWithConfig(CartesianProductUserPayload.CartesianProductConfigProto config) throws TezReflectionException {
        MockitoAnnotations.initMocks((Object)this);
        this.context = (VertexManagerPluginContext)Mockito.mock(VertexManagerPluginContext.class);
        Mockito.when((Object)this.context.getVertexName()).thenReturn((Object)"cp");
        Mockito.when((Object)this.context.getVertexNumTasks("cp")).thenReturn((Object)-1);
        this.vertexManager = new CartesianProductVertexManagerPartitioned(this.context);
        HashMap<String, EdgeProperty> edgePropertyMap = new HashMap<String, EdgeProperty>();
        edgePropertyMap.put("v0", EdgeProperty.create((EdgeManagerPluginDescriptor)EdgeManagerPluginDescriptor.create((String)CartesianProductEdgeManager.class.getName()), null, null, null, null));
        edgePropertyMap.put("v1", EdgeProperty.create((EdgeManagerPluginDescriptor)EdgeManagerPluginDescriptor.create((String)CartesianProductEdgeManager.class.getName()), null, null, null, null));
        edgePropertyMap.put("v2", EdgeProperty.create((EdgeProperty.DataMovementType)EdgeProperty.DataMovementType.BROADCAST, null, null, null, null));
        Mockito.when((Object)this.context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
        Mockito.when((Object)this.context.getVertexNumTasks((String)Matchers.eq((Object)"v0"))).thenReturn((Object)4);
        Mockito.when((Object)this.context.getVertexNumTasks((String)Matchers.eq((Object)"v1"))).thenReturn((Object)4);
        Mockito.when((Object)this.context.getVertexNumTasks((String)Matchers.eq((Object)"v2"))).thenReturn((Object)4);
        this.vertexManager.initialize(config);
        this.allCompletions = new ArrayList<TaskAttemptIdentifier>();
        for (int i = 0; i < 3; ++i) {
            for (int j = 0; j < 4; ++j) {
                this.allCompletions.add((TaskAttemptIdentifier)new TaskAttemptIdentifierImpl("dag", "v" + i, TezTaskAttemptID.getInstance((TezTaskID)TezTaskID.getInstance((TezVertexID)TezVertexID.getInstance((TezDAGID)TezDAGID.getInstance((String)"0", (int)0, (int)0), (int)i), (int)j), (int)0)));
            }
        }
    }

    private void testReconfigureVertexHelper(CartesianProductUserPayload.CartesianProductConfigProto config, int parallelism) throws Exception {
        this.setupWithConfig(config);
        ArgumentCaptor parallelismCaptor = ArgumentCaptor.forClass(Integer.class);
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)1))).reconfigureVertex(((Integer)parallelismCaptor.capture()).intValue(), (VertexLocationHint)Matchers.isNull(VertexLocationHint.class), (Map)this.edgePropertiesCaptor.capture());
        Assert.assertEquals((long)((Integer)parallelismCaptor.getValue()).intValue(), (long)parallelism);
        Assert.assertNull((Object)this.edgePropertiesCaptor.getValue());
    }

    @Test(timeout=5000L)
    public void testReconfigureVertex() throws Exception {
        CartesianProductUserPayload.CartesianProductConfigProto.Builder builder = CartesianProductUserPayload.CartesianProductConfigProto.newBuilder();
        builder.setIsPartitioned(true).addSources("v0").addSources("v1").addNumPartitions(5).addNumPartitions(5).setFilterClassName(TestFilter.class.getName());
        this.testReconfigureVertexHelper(builder.build(), 10);
        builder.clearFilterClassName();
        this.testReconfigureVertexHelper(builder.build(), 25);
    }

    @Test(timeout=5000L)
    public void testScheduling() throws Exception {
        int i;
        this.vertexManager.onVertexStarted(null);
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(0));
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(1));
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.never())).scheduleTasks((List)Matchers.any());
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(2));
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.never())).scheduleTasks((List)Matchers.any());
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)1))).scheduleTasks((List)this.scheduleTaskRequestCaptor.capture());
        List scheduleTaskRequests = (List)this.scheduleTaskRequestCaptor.getValue();
        Assert.assertEquals((long)1L, (long)scheduleTaskRequests.size());
        Assert.assertEquals((long)0L, (long)((VertexManagerPluginContext.ScheduleTaskRequest)scheduleTaskRequests.get(0)).getTaskIndex());
        this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(8));
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)1))).scheduleTasks((List)this.scheduleTaskRequestCaptor.capture());
        for (i = 3; i < 6; ++i) {
            this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(i));
            ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)(i - 1)))).scheduleTasks((List)this.scheduleTaskRequestCaptor.capture());
            scheduleTaskRequests = (List)this.scheduleTaskRequestCaptor.getValue();
            Assert.assertEquals((long)1L, (long)scheduleTaskRequests.size());
            Assert.assertEquals((long)(i - 2), (long)((VertexManagerPluginContext.ScheduleTaskRequest)scheduleTaskRequests.get(0)).getTaskIndex());
        }
        for (i = 6; i < 8; ++i) {
            this.vertexManager.onSourceTaskCompleted(this.allCompletions.get(i));
            ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)4))).scheduleTasks((List)Matchers.any());
        }
    }

    @Test(timeout=5000L)
    public void testOnVertexStartWithBroadcastRunning() throws Exception {
        this.testOnVertexStartHelper(true);
    }

    @Test(timeout=5000L)
    public void testOnVertexStartWithoutBroadcastRunning() throws Exception {
        this.testOnVertexStartHelper(false);
    }

    private void testOnVertexStartHelper(boolean broadcastRunning) throws Exception {
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
        this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
        if (broadcastRunning) {
            this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        }
        ArrayList<TaskAttemptIdentifier> completions = new ArrayList<TaskAttemptIdentifier>();
        completions.add(this.allCompletions.get(0));
        completions.add(this.allCompletions.get(1));
        completions.add(this.allCompletions.get(4));
        completions.add(this.allCompletions.get(8));
        this.vertexManager.onVertexStarted(completions);
        if (!broadcastRunning) {
            ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.never())).scheduleTasks((List)Matchers.any());
            this.vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
        }
        ((VertexManagerPluginContext)Mockito.verify((Object)this.context, (VerificationMode)Mockito.times((int)1))).scheduleTasks((List)this.scheduleTaskRequestCaptor.capture());
        List scheduleTaskRequests = (List)this.scheduleTaskRequestCaptor.getValue();
        Assert.assertEquals((long)1L, (long)scheduleTaskRequests.size());
        Assert.assertEquals((long)0L, (long)((VertexManagerPluginContext.ScheduleTaskRequest)scheduleTaskRequests.get(0)).getTaskIndex());
    }

    public static class TestFilter
    extends CartesianProductFilter {
        public TestFilter(UserPayload payload) {
            super(payload);
        }

        public boolean isValidCombination(Map<String, Integer> vertexPartitionMap) {
            return vertexPartitionMap.get("v0") > vertexPartitionMap.get("v1");
        }
    }
}

