package org.apache.flink.streaming.runtime.tasks;

import java.io.EOFException;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Executor;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.FSDataInputStream;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.librarycache.FallbackLibraryCacheManager;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.JobInformation;
import org.apache.flink.runtime.executiongraph.TaskInformation;
import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker;
import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.query.TaskKvStateRegistry;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.runtime.taskmanager.CheckpointResponder;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.taskmanager.TaskManagerActions;
import org.apache.flink.runtime.util.EnvironmentInformation;
import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.util.SerializedValue;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.class */
public class InterruptSensitiveRestoreTest {
    private static final OneShotLatch IN_RESTORE_LATCH = new OneShotLatch();
    private static final int OPERATOR_MANAGED = 0;
    private static final int OPERATOR_RAW = 1;
    private static final int KEYED_MANAGED = 2;
    private static final int KEYED_RAW = 3;
    private static final int LEGACY = 4;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest$InterruptLockingStateHandle.class */
    public static class InterruptLockingStateHandle implements StreamStateHandle {
        private static final long serialVersionUID = 1;
        private volatile boolean closed;

        private InterruptLockingStateHandle() {
        }

        public FSDataInputStream openInputStream() throws IOException {
            this.closed = false;
            return new FSDataInputStream() { // from class: org.apache.flink.streaming.runtime.tasks.InterruptSensitiveRestoreTest.InterruptLockingStateHandle.1
                public void seek(long j) throws IOException {
                }

                public long getPos() throws IOException {
                    return 0L;
                }

                public int read() throws IOException {
                    InterruptLockingStateHandle.this.block();
                    throw new EOFException();
                }

                public void close() throws IOException {
                    super.close();
                    InterruptLockingStateHandle.this.closed = true;
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void block() {
            InterruptSensitiveRestoreTest.IN_RESTORE_LATCH.trigger();
            try {
                synchronized (this) {
                    wait();
                }
            } catch (InterruptedException e) {
                while (!this.closed) {
                    synchronized (this) {
                        wait();
                    }
                }
            }
        }

        public void discardState() throws Exception {
        }

        public long getStateSize() {
            return 0L;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest$TestSource.class */
    public static class TestSource implements SourceFunction<Object>, CheckpointedFunction {
        private static final long serialVersionUID = 1;

        private TestSource() {
        }

        public void run(SourceFunction.SourceContext<Object> sourceContext) throws Exception {
            Assert.fail("should never be called");
        }

        public void cancel() {
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            Assert.fail("should never be called");
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            ((StatePartitionStreamProvider) ((StateInitializationContext) functionInitializationContext).getRawOperatorStateInputs().iterator().next()).getStream().read();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest$TestSourceLegacy.class */
    public static class TestSourceLegacy implements SourceFunction<Object>, Checkpointed<Serializable> {
        private static final long serialVersionUID = 1;

        private TestSourceLegacy() {
        }

        public void run(SourceFunction.SourceContext<Object> sourceContext) throws Exception {
            Assert.fail("should never be called");
        }

        public void cancel() {
        }

        public Serializable snapshotState(long j, long j2) throws Exception {
            Assert.fail("should never be called");
            return null;
        }

        public void restoreState(Serializable serializable) throws Exception {
            Assert.fail("should never be called");
        }
    }

    @Test
    public void testRestoreWithInterruptLegacy() throws Exception {
        testRestoreWithInterrupt(LEGACY);
    }

    @Test
    public void testRestoreWithInterruptOperatorManaged() throws Exception {
        testRestoreWithInterrupt(OPERATOR_MANAGED);
    }

    @Test
    public void testRestoreWithInterruptOperatorRaw() throws Exception {
        testRestoreWithInterrupt(OPERATOR_RAW);
    }

    @Test
    public void testRestoreWithInterruptKeyedManaged() throws Exception {
        testRestoreWithInterrupt(KEYED_MANAGED);
    }

    @Test
    public void testRestoreWithInterruptKeyedRaw() throws Exception {
        testRestoreWithInterrupt(KEYED_RAW);
    }

    private void testRestoreWithInterrupt(int i) throws Exception {
        IN_RESTORE_LATCH.reset();
        Configuration configuration = new Configuration();
        StreamConfig streamConfig = new StreamConfig(configuration);
        streamConfig.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
        switch (i) {
            case OPERATOR_MANAGED /* 0 */:
            case OPERATOR_RAW /* 1 */:
            case KEYED_MANAGED /* 2 */:
            case KEYED_RAW /* 3 */:
                streamConfig.setStateKeySerializer(IntSerializer.INSTANCE);
                streamConfig.setStreamOperator(new StreamSource(new TestSource()));
                break;
            case LEGACY /* 4 */:
                streamConfig.setStreamOperator(new StreamSource(new TestSourceLegacy()));
                break;
            default:
                throw new IllegalArgumentException();
        }
        Task createTask = createTask(configuration, new InterruptLockingStateHandle(), i);
        createTask.startTaskThread();
        IN_RESTORE_LATCH.await();
        createTask.cancelExecution();
        createTask.getExecutingThread().join(30000L);
        if (createTask.getExecutionState() == ExecutionState.CANCELING) {
            Assert.fail("Task is stuck and not canceling");
        }
        Assert.assertEquals(ExecutionState.CANCELED, createTask.getExecutionState());
        Assert.assertNull(createTask.getFailureCause());
    }

    private static Task createTask(Configuration configuration, StreamStateHandle streamStateHandle, int i) throws IOException {
        NetworkEnvironment networkEnvironment = (NetworkEnvironment) Mockito.mock(NetworkEnvironment.class);
        Mockito.when(networkEnvironment.createKvStateTaskRegistry((JobID) Mockito.any(JobID.class), (JobVertexID) Mockito.any(JobVertexID.class))).thenReturn(Mockito.mock(TaskKvStateRegistry.class));
        ChainedStateHandle chainedStateHandle = OPERATOR_MANAGED;
        List emptyList = Collections.emptyList();
        List emptyList2 = Collections.emptyList();
        List emptyList3 = Collections.emptyList();
        List emptyList4 = Collections.emptyList();
        HashMap hashMap = new HashMap(OPERATOR_RAW);
        hashMap.put("_default_", new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
        KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(new KeyGroupRange(OPERATOR_MANAGED, OPERATOR_MANAGED));
        List singletonList = Collections.singletonList(new OperatorStateHandle(hashMap, streamStateHandle));
        List singletonList2 = Collections.singletonList(new KeyGroupsStateHandle(keyGroupRangeOffsets, streamStateHandle));
        switch (i) {
            case OPERATOR_MANAGED /* 0 */:
                emptyList3 = Collections.singletonList(singletonList);
                break;
            case OPERATOR_RAW /* 1 */:
                emptyList4 = Collections.singletonList(singletonList);
                break;
            case KEYED_MANAGED /* 2 */:
                emptyList = singletonList2;
                break;
            case KEYED_RAW /* 3 */:
                emptyList2 = singletonList2;
                break;
            case LEGACY /* 4 */:
                chainedStateHandle = new ChainedStateHandle(Collections.singletonList(streamStateHandle));
                break;
            default:
                throw new IllegalArgumentException();
        }
        return new Task(new JobInformation(new JobID(), "test job name", new SerializedValue(new ExecutionConfig()), new Configuration(), Collections.emptyList(), Collections.emptyList()), new TaskInformation(new JobVertexID(), "test task name", OPERATOR_RAW, OPERATOR_RAW, SourceStreamTask.class.getName(), configuration), new ExecutionAttemptID(), new AllocationID(), OPERATOR_MANAGED, OPERATOR_MANAGED, Collections.emptyList(), Collections.emptyList(), OPERATOR_MANAGED, new TaskStateHandles(chainedStateHandle, emptyList3, emptyList4, emptyList, emptyList2), (MemoryManager) Mockito.mock(MemoryManager.class), (IOManager) Mockito.mock(IOManager.class), networkEnvironment, (BroadcastVariableManager) Mockito.mock(BroadcastVariableManager.class), (TaskManagerActions) Mockito.mock(TaskManagerActions.class), (InputSplitProvider) Mockito.mock(InputSplitProvider.class), (CheckpointResponder) Mockito.mock(CheckpointResponder.class), new FallbackLibraryCacheManager(), new FileCache(new String[]{EnvironmentInformation.getTemporaryFileDirectory()}), new TestingTaskManagerRuntimeInfo(), new UnregisteredTaskMetricsGroup(), (ResultPartitionConsumableNotifier) Mockito.mock(ResultPartitionConsumableNotifier.class), (PartitionProducerStateChecker) Mockito.mock(PartitionProducerStateChecker.class), (Executor) Mockito.mock(Executor.class));
    }
}
