/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.tasks;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.execution.librarycache.FallbackLibraryCacheManager;
import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.JobInformation;
import org.apache.flink.runtime.executiongraph.TaskInformation;
import org.apache.flink.runtime.filecache.FileCache;
import org.apache.flink.runtime.instance.ActorGateway;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.network.NetworkEnvironment;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
import org.apache.flink.runtime.util.EnvironmentInformation;
import org.apache.flink.runtime.util.SerializableObject;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.runtime.tasks.SourceStreamTask;
import org.apache.flink.streaming.runtime.tasks.StreamTaskState;
import org.apache.flink.streaming.runtime.tasks.StreamTaskStateList;
import org.apache.flink.util.SerializedValue;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;
import scala.concurrent.duration.FiniteDuration;

public class InterruptSensitiveRestoreTest {
    private static final OneShotLatch IN_RESTORE_LATCH = new OneShotLatch();

    @Test
    public void testRestoreWithInterrupt() throws Exception {
        Configuration taskConfig = new Configuration();
        StreamConfig cfg = new StreamConfig(taskConfig);
        cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
        cfg.setStreamOperator((StreamOperator)new StreamSource((SourceFunction)new TestSource()));
        InterruptLockingStateHandle lockingHandle = new InterruptLockingStateHandle();
        StreamTaskState opState = new StreamTaskState();
        opState.setFunctionState((StateHandle)lockingHandle);
        StreamTaskStateList taskState = new StreamTaskStateList(new StreamTaskState[]{opState});
        Task task = InterruptSensitiveRestoreTest.createTask(taskConfig, taskState);
        task.startTaskThread();
        IN_RESTORE_LATCH.await();
        task.cancelExecution();
        task.getExecutingThread().join(30000L);
        if (task.getExecutionState() == ExecutionState.CANCELING) {
            Assert.fail((String)"Task is stuck and not canceling");
        }
        Assert.assertEquals((Object)ExecutionState.CANCELED, (Object)task.getExecutionState());
        Assert.assertNull((Object)task.getFailureCause());
    }

    private static Task createTask(Configuration taskConfig, StateHandle<?> state) throws IOException {
        JobInformation jobInformation = new JobInformation(new JobID(), "test job name", new SerializedValue((Object)new ExecutionConfig()), new Configuration(), Collections.emptyList(), Collections.emptyList());
        TaskInformation taskInformation = new TaskInformation(new JobVertexID(), "test task name", 1, SourceStreamTask.class.getName(), taskConfig);
        return new Task(jobInformation, taskInformation, new ExecutionAttemptID(), 0, 0, Collections.emptyList(), Collections.emptyList(), 0, new SerializedValue(state), (MemoryManager)Mockito.mock(MemoryManager.class), (IOManager)Mockito.mock(IOManager.class), (NetworkEnvironment)Mockito.mock(NetworkEnvironment.class), (BroadcastVariableManager)Mockito.mock(BroadcastVariableManager.class), (ActorGateway)Mockito.mock(ActorGateway.class), (ActorGateway)Mockito.mock(ActorGateway.class), new FiniteDuration(10L, TimeUnit.SECONDS), (LibraryCacheManager)new FallbackLibraryCacheManager(), new FileCache(new Configuration()), new TaskManagerRuntimeInfo("localhost", new Configuration(), EnvironmentInformation.getTemporaryFileDirectory()), (TaskMetricGroup)new UnregisteredTaskMetricsGroup());
    }

    private static class TestSource
    implements SourceFunction<Object>,
    Checkpointed<Serializable> {
        private static final long serialVersionUID = 1L;

        private TestSource() {
        }

        public void run(SourceFunction.SourceContext<Object> ctx) throws Exception {
            Assert.fail((String)"should never be called");
        }

        public void cancel() {
        }

        public Serializable snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
            Assert.fail((String)"should never be called");
            return null;
        }

        public void restoreState(Serializable state) throws Exception {
            Assert.fail((String)"should never be called");
        }
    }

    private static class InterruptLockingStateHandle
    implements StateHandle<Serializable> {
        private volatile transient boolean closed;

        private InterruptLockingStateHandle() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public Serializable getState(ClassLoader userCodeClassLoader) {
            IN_RESTORE_LATCH.trigger();
            try {
                InterruptLockingStateHandle interruptLockingStateHandle = this;
                synchronized (interruptLockingStateHandle) {
                    this.wait();
                }
            }
            catch (InterruptedException e) {
                while (!this.closed) {
                    try {
                        InterruptLockingStateHandle interruptLockingStateHandle = this;
                        synchronized (interruptLockingStateHandle) {
                            this.wait();
                        }
                    }
                    catch (InterruptedException interruptedException) {
                    }
                }
            }
            return new SerializableObject();
        }

        public void discardState() throws Exception {
        }

        public long getStateSize() throws Exception {
            return 0L;
        }

        public void close() throws IOException {
            this.closed = true;
        }
    }
}

