/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.tasks;

import java.util.HashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnegative;
import javax.annotation.Nullable;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.DoneFuture;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.LocalRecoveryConfig;
import org.apache.flink.runtime.state.LocalRecoveryDirectoryProvider;
import org.apache.flink.runtime.state.LocalRecoveryDirectoryProviderImpl;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.TaskLocalStateStore;
import org.apache.flink.runtime.state.TaskLocalStateStoreImpl;
import org.apache.flink.runtime.state.TaskStateManager;
import org.apache.flink.runtime.state.TaskStateManagerImpl;
import org.apache.flink.runtime.state.TestTaskStateManager;
import org.apache.flink.runtime.state.changelog.StateChangelogStorage;
import org.apache.flink.runtime.state.changelog.inmemory.InMemoryStateChangelogStorage;
import org.apache.flink.runtime.taskmanager.CheckpointResponder;
import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.runtime.tasks.AsyncCheckpointRunnable;
import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
import org.apache.flink.streaming.runtime.tasks.StreamTaskTest;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.Executors;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.mockito.Mockito;

public class LocalStateForwardingTest
extends TestLogger {
    @Rule
    public TemporaryFolder temporaryFolder = new TemporaryFolder();

    @Test
    public void testReportingFromSnapshotToTaskStateManager() throws Exception {
        TestTaskStateManager taskStateManager = new TestTaskStateManager();
        StreamMockEnvironment streamMockEnvironment = new StreamMockEnvironment(new Configuration(), new Configuration(), new ExecutionConfig(), 0x100000L, new MockInputSplitProvider(), 0, (TaskStateManager)taskStateManager);
        StreamTaskTest.NoOpStreamTask testStreamTask = new StreamTaskTest.NoOpStreamTask(streamMockEnvironment);
        CheckpointMetaData checkpointMetaData = new CheckpointMetaData(0L, 0L);
        CheckpointMetricsBuilder checkpointMetrics = new CheckpointMetricsBuilder();
        HashMap<OperatorID, OperatorSnapshotFutures> snapshots = new HashMap<OperatorID, OperatorSnapshotFutures>(1);
        OperatorSnapshotFutures osFuture = new OperatorSnapshotFutures();
        osFuture.setKeyedStateManagedFuture(LocalStateForwardingTest.createSnapshotResult(KeyedStateHandle.class));
        osFuture.setKeyedStateRawFuture(LocalStateForwardingTest.createSnapshotResult(KeyedStateHandle.class));
        osFuture.setOperatorStateManagedFuture(LocalStateForwardingTest.createSnapshotResult(OperatorStateHandle.class));
        osFuture.setOperatorStateRawFuture(LocalStateForwardingTest.createSnapshotResult(OperatorStateHandle.class));
        osFuture.setInputChannelStateFuture(LocalStateForwardingTest.createSnapshotCollectionResult(InputChannelStateHandle.class));
        osFuture.setResultSubpartitionStateFuture(LocalStateForwardingTest.createSnapshotCollectionResult(ResultSubpartitionStateHandle.class));
        OperatorID operatorID = new OperatorID();
        snapshots.put(operatorID, osFuture);
        AsyncCheckpointRunnable checkpointRunnable = new AsyncCheckpointRunnable(snapshots, checkpointMetaData, checkpointMetrics, 0L, testStreamTask.getName(), asyncCheckpointRunnable -> {}, testStreamTask.getEnvironment(), testStreamTask, false, false, () -> true);
        checkpointMetrics.setAlignmentDurationNanos(0L);
        checkpointMetrics.setBytesProcessedDuringAlignment(0L);
        checkpointRunnable.run();
        TaskStateSnapshot lastJobManagerTaskStateSnapshot = taskStateManager.getLastJobManagerTaskStateSnapshot();
        TaskStateSnapshot lastTaskManagerTaskStateSnapshot = taskStateManager.getLastTaskManagerTaskStateSnapshot();
        OperatorSubtaskState jmState = lastJobManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        OperatorSubtaskState tmState = lastTaskManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        LocalStateForwardingTest.performCheck(osFuture.getKeyedStateManagedFuture(), jmState.getManagedKeyedState(), tmState.getManagedKeyedState());
        LocalStateForwardingTest.performCheck(osFuture.getKeyedStateRawFuture(), jmState.getRawKeyedState(), tmState.getRawKeyedState());
        LocalStateForwardingTest.performCheck(osFuture.getOperatorStateManagedFuture(), jmState.getManagedOperatorState(), tmState.getManagedOperatorState());
        LocalStateForwardingTest.performCheck(osFuture.getOperatorStateRawFuture(), jmState.getRawOperatorState(), tmState.getRawOperatorState());
        LocalStateForwardingTest.performCollectionCheck(osFuture.getInputChannelStateFuture(), jmState.getInputChannelState(), tmState.getInputChannelState());
        LocalStateForwardingTest.performCollectionCheck(osFuture.getResultSubpartitionStateFuture(), jmState.getResultSubpartitionState(), tmState.getResultSubpartitionState());
    }

    @Test
    public void testReportingFromTaskStateManagerToResponderAndTaskLocalStateStore() throws Exception {
        final JobID jobID = new JobID();
        AllocationID allocationID = new AllocationID();
        final ExecutionAttemptID executionAttemptID = new ExecutionAttemptID();
        final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(42L, 4711L);
        final CheckpointMetrics checkpointMetrics = new CheckpointMetrics();
        int subtaskIdx = 42;
        JobVertexID jobVertexID = new JobVertexID();
        TaskStateSnapshot jmSnapshot = new TaskStateSnapshot();
        final TaskStateSnapshot tmSnapshot = new TaskStateSnapshot();
        final AtomicBoolean jmReported = new AtomicBoolean(false);
        final AtomicBoolean tmReported = new AtomicBoolean(false);
        TestCheckpointResponder checkpointResponder = new TestCheckpointResponder(){

            public void acknowledgeCheckpoint(JobID lJobID, ExecutionAttemptID lExecutionAttemptID, long lCheckpointId, CheckpointMetrics lCheckpointMetrics, TaskStateSnapshot lSubtaskState) {
                Assert.assertEquals((Object)jobID, (Object)lJobID);
                Assert.assertEquals((Object)executionAttemptID, (Object)lExecutionAttemptID);
                Assert.assertEquals((long)checkpointMetaData.getCheckpointId(), (long)lCheckpointId);
                Assert.assertEquals((Object)checkpointMetrics, (Object)lCheckpointMetrics);
                jmReported.set(true);
            }
        };
        Executor executor = Executors.directExecutor();
        LocalRecoveryDirectoryProviderImpl directoryProvider = new LocalRecoveryDirectoryProviderImpl(this.temporaryFolder.newFolder(), jobID, jobVertexID, 42);
        LocalRecoveryConfig localRecoveryConfig = new LocalRecoveryConfig((LocalRecoveryDirectoryProvider)directoryProvider);
        TaskLocalStateStoreImpl taskLocalStateStore = new TaskLocalStateStoreImpl(jobID, allocationID, jobVertexID, 42, localRecoveryConfig, executor){

            public void storeLocalState(@Nonnegative long checkpointId, @Nullable TaskStateSnapshot localState) {
                Assert.assertEquals((Object)tmSnapshot, (Object)localState);
                tmReported.set(true);
            }
        };
        InMemoryStateChangelogStorage stateChangelogStorage = new InMemoryStateChangelogStorage();
        TaskStateManagerImpl taskStateManager = new TaskStateManagerImpl(jobID, executionAttemptID, (TaskLocalStateStore)taskLocalStateStore, (StateChangelogStorage)stateChangelogStorage, null, (CheckpointResponder)checkpointResponder);
        taskStateManager.reportTaskStateSnapshots(checkpointMetaData, checkpointMetrics, jmSnapshot, tmSnapshot);
        Assert.assertTrue((String)"Reporting for JM state was not called.", (boolean)jmReported.get());
        Assert.assertTrue((String)"Reporting for TM state was not called.", (boolean)tmReported.get());
    }

    private static <T extends StateObject> void performCheck(Future<SnapshotResult<T>> resultFuture, StateObjectCollection<T> jmState, StateObjectCollection<T> tmState) {
        SnapshotResult<T> snapshotResult;
        try {
            snapshotResult = resultFuture.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        Assert.assertEquals((Object)snapshotResult.getJobManagerOwnedSnapshot(), jmState.iterator().next());
        Assert.assertEquals((Object)snapshotResult.getTaskLocalSnapshot(), tmState.iterator().next());
    }

    private static <T extends StateObject> void performCollectionCheck(Future<SnapshotResult<StateObjectCollection<T>>> resultFuture, StateObjectCollection<T> jmState, StateObjectCollection<T> tmState) {
        SnapshotResult<StateObjectCollection<T>> snapshotResult;
        try {
            snapshotResult = resultFuture.get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        Assert.assertEquals((Object)snapshotResult.getJobManagerOwnedSnapshot(), jmState);
        Assert.assertEquals((Object)snapshotResult.getTaskLocalSnapshot(), tmState);
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<T>> createSnapshotResult(Class<T> clazz) {
        return DoneFuture.of((Object)SnapshotResult.withLocalState((StateObject)((StateObject)Mockito.mock(clazz)), (StateObject)((StateObject)Mockito.mock(clazz))));
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<StateObjectCollection<T>>> createSnapshotCollectionResult(Class<T> clazz) {
        return DoneFuture.of((Object)SnapshotResult.withLocalState((StateObject)StateObjectCollection.singleton((StateObject)((StateObject)Mockito.mock(clazz))), (StateObject)StateObjectCollection.singleton((StateObject)((StateObject)Mockito.mock(clazz)))));
    }
}

