package org.apache.flink.runtime.executiongraph;

import java.lang.reflect.Field;
import java.net.InetAddress;
import java.util.Iterator;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.blob.VoidBlobWriter;
import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.clusterframework.types.ResourceID;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
import org.apache.flink.runtime.instance.SimpleSlot;
import org.apache.flink.runtime.instance.SlotProvider;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot;
import org.apache.flink.runtime.jobmanager.slots.SlotOwner;
import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
import org.apache.flink.runtime.testingUtils.TestingUtils;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.TestLogger;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.class */
public class ExecutionVertexLocalityTest extends TestLogger {
    private final JobID jobId = new JobID();
    private final JobVertexID sourceVertexId = new JobVertexID();
    private final JobVertexID targetVertexId = new JobVertexID();

    @Test
    public void testLocalityInputBasedForward() throws Exception {
        TaskManagerLocation[] taskManagerLocationArr = new TaskManagerLocation[10];
        ExecutionGraph createTestGraph = createTestGraph(10, false);
        for (int i = 0; i < 10; i++) {
            ExecutionVertex executionVertex = ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i];
            TaskManagerLocation taskManagerLocation = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i);
            taskManagerLocationArr[i] = taskManagerLocation;
            initializeLocation(executionVertex, taskManagerLocation);
        }
        for (int i2 = 0; i2 < 10; i2++) {
            Iterator it = ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i2].getPreferredLocations().iterator();
            Assert.assertTrue(it.hasNext());
            Assert.assertEquals(taskManagerLocationArr[i2], ((CompletableFuture) it.next()).get());
            Assert.assertFalse(it.hasNext());
        }
    }

    @Test
    public void testNoLocalityInputLargeAllToAll() throws Exception {
        ExecutionGraph createTestGraph = createTestGraph(100, true);
        for (int i = 0; i < 100; i++) {
            initializeLocation(((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i], new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i));
        }
        for (int i2 = 0; i2 < 100; i2++) {
            Assert.assertFalse(((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i2].getPreferredLocations().iterator().hasNext());
        }
    }

    @Test
    public void testLocalityBasedOnState() throws Exception {
        TaskManagerLocation[] taskManagerLocationArr = new TaskManagerLocation[10];
        ExecutionGraph createTestGraph = createTestGraph(10, false);
        for (int i = 0; i < 10; i++) {
            ExecutionVertex executionVertex = ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i];
            ExecutionVertex executionVertex2 = ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i];
            TaskManagerLocation taskManagerLocation = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 10000 + i);
            TaskManagerLocation taskManagerLocation2 = new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 20000 + i);
            taskManagerLocationArr[i] = taskManagerLocation2;
            initializeLocation(executionVertex, taskManagerLocation);
            initializeLocation(executionVertex2, taskManagerLocation2);
            setState(executionVertex.getCurrentExecutionAttempt(), ExecutionState.CANCELED);
            setState(executionVertex2.getCurrentExecutionAttempt(), ExecutionState.CANCELED);
        }
        Iterator it = createTestGraph.getVerticesTopologically().iterator();
        while (it.hasNext()) {
            ((ExecutionJobVertex) it.next()).resetForNewExecution(System.currentTimeMillis(), createTestGraph.getGlobalModVersion());
        }
        for (int i2 = 0; i2 < 10; i2++) {
            initializeLocation(((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.sourceVertexId)).getTaskVertices()[i2], new TaskManagerLocation(ResourceID.generate(), InetAddress.getLoopbackAddress(), 30000 + i2));
            ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i2].getCurrentExecutionAttempt().setInitialState((TaskStateSnapshot) Mockito.mock(TaskStateSnapshot.class));
        }
        for (int i3 = 0; i3 < 10; i3++) {
            Iterator it2 = ((ExecutionJobVertex) createTestGraph.getAllVertices().get(this.targetVertexId)).getTaskVertices()[i3].getPreferredLocations().iterator();
            Assert.assertTrue(it2.hasNext());
            Assert.assertEquals(taskManagerLocationArr[i3], ((CompletableFuture) it2.next()).get());
            Assert.assertFalse(it2.hasNext());
        }
    }

    private ExecutionGraph createTestGraph(int i, boolean z) throws Exception {
        JobVertex jobVertex = new JobVertex("source", this.sourceVertexId);
        jobVertex.setParallelism(i);
        jobVertex.setInvokableClass(NoOpInvokable.class);
        JobVertex jobVertex2 = new JobVertex("source", this.targetVertexId);
        jobVertex2.setParallelism(i);
        jobVertex2.setInvokableClass(NoOpInvokable.class);
        jobVertex2.connectNewDataSetAsInput(jobVertex, z ? DistributionPattern.ALL_TO_ALL : DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
        return ExecutionGraphBuilder.buildGraph((ExecutionGraph) null, new JobGraph(this.jobId, "test job", new JobVertex[]{jobVertex, jobVertex2}), new Configuration(), TestingUtils.defaultExecutor(), TestingUtils.defaultExecutor(), (SlotProvider) Mockito.mock(SlotProvider.class), getClass().getClassLoader(), new StandaloneCheckpointRecoveryFactory(), Time.of(10L, TimeUnit.SECONDS), new FixedDelayRestartStrategy(10, 0L), new UnregisteredMetricsGroup(), 1, VoidBlobWriter.getInstance(), this.log);
    }

    private void initializeLocation(ExecutionVertex executionVertex, TaskManagerLocation taskManagerLocation) throws Exception {
        if (!executionVertex.getCurrentExecutionAttempt().tryAssignResource(new SimpleSlot(new AllocatedSlot(new AllocationID(), this.jobId, taskManagerLocation, 0, ResourceProfile.UNKNOWN, (TaskManagerGateway) Mockito.mock(TaskManagerGateway.class)), (SlotOwner) Mockito.mock(SlotOwner.class), 0))) {
            throw new FlinkException("Could not assign resource.");
        }
    }

    private void setState(Execution execution, ExecutionState executionState) throws Exception {
        Field declaredField = Execution.class.getDeclaredField("state");
        declaredField.setAccessible(true);
        declaredField.set(execution, executionState);
    }
}
