/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.File;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.client.program.rest.RestClusterClient;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.configuration.StateRecoveryOptions;
import org.apache.flink.configuration.WebOptions;
import org.apache.flink.contrib.streaming.state.RocksDBConfigurableOptions;
import org.apache.flink.core.execution.CheckpointingMode;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.client.JobExecutionException;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobResourceRequirements;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction;
import org.apache.flink.streaming.util.RestartStrategyUtils;
import org.apache.flink.test.scheduling.UpdateJobResourceRequirementsITCase;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class AutoRescalingITCase
extends TestLogger {
    @ClassRule
    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorResource();
    private static final int numTaskManagers = 2;
    private static final int slotsPerTaskManager = 2;
    private static final int totalSlots = 4;
    private final String backend;
    private final boolean useIngestDB;
    private String currentBackend = null;
    private static MiniClusterWithClientResource cluster;
    private static RestClusterClient<?> restClusterClient;
    @ClassRule
    public static TemporaryFolder temporaryFolder;

    @Parameterized.Parameters(name="backend = {0}, useIngestDB = {1}")
    public static Collection<Object[]> data() {
        return Arrays.asList({"rocksdb", false}, {"rocksdb", true}, {"hashmap", false});
    }

    public AutoRescalingITCase(String backend, boolean useIngestDB) {
        this.backend = backend;
        this.useIngestDB = useIngestDB;
    }

    @Before
    public void setup() throws Exception {
        if (!Objects.equals(this.currentBackend, this.backend)) {
            AutoRescalingITCase.shutDownExistingCluster();
            this.currentBackend = this.backend;
            Configuration config = new Configuration();
            File checkpointDir = temporaryFolder.newFolder();
            File savepointDir = temporaryFolder.newFolder();
            config.set(StateBackendOptions.STATE_BACKEND, (Object)this.currentBackend);
            config.set(RocksDBConfigurableOptions.USE_INGEST_DB_RESTORE_MODE, (Object)this.useIngestDB);
            config.set(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, (Object)true);
            config.set(StateRecoveryOptions.LOCAL_RECOVERY, (Object)true);
            config.set(CheckpointingOptions.CHECKPOINTS_DIRECTORY, (Object)checkpointDir.toURI().toString());
            config.set(CheckpointingOptions.SAVEPOINT_DIRECTORY, (Object)savepointDir.toURI().toString());
            config.set(JobManagerOptions.SCHEDULER, (Object)JobManagerOptions.SchedulerType.Adaptive);
            config.set(JobManagerOptions.SCHEDULER_EXECUTING_COOLDOWN_AFTER_RESCALING, (Object)Duration.ofMillis(0L));
            config.set(WebOptions.REFRESH_INTERVAL, (Object)Duration.ofMillis(50L));
            config.set(JobManagerOptions.SLOT_IDLE_TIMEOUT, (Object)Duration.ofMillis(50L));
            cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(config).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
            cluster.before();
            restClusterClient = cluster.getRestClusterClient();
        }
    }

    @AfterClass
    public static void shutDownExistingCluster() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test
    public void testCheckpointRescalingInKeyedState() throws Exception {
        this.testCheckpointRescalingKeyedState(false);
    }

    @Test
    public void testCheckpointRescalingOutKeyedState() throws Exception {
        this.testCheckpointRescalingKeyedState(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception {
        int numberKeys = 42;
        int numberElements = 1000;
        int parallelism = scaleOut ? 2 : 4;
        int parallelism2 = scaleOut ? 4 : 2;
        int maxParallelism = 13;
        Duration timeout = Duration.ofMinutes(3L);
        Deadline deadline = Deadline.now().plus(timeout);
        ClusterClient client = cluster.getClusterClient();
        try {
            JobGraph jobGraph = AutoRescalingITCase.createJobGraphWithKeyedState(cluster.getMiniCluster().getConfiguration().clone(), parallelism, 13, 42, 1000);
            JobID jobID = jobGraph.getJobID();
            client.submitJob(jobGraph).get();
            SubtaskIndexSource.SOURCE_LATCH.trigger();
            Assert.assertTrue((boolean)SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS));
            Set actualResult = CollectionSink.getElementsSet();
            HashSet<Tuple2> expectedResult = new HashSet<Tuple2>();
            for (int key = 0; key < 42; ++key) {
                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup((Object)key, (int)13);
                expectedResult.add(Tuple2.of((Object)KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup((int)13, (int)parallelism, (int)keyGroupIndex), (Object)(1000 * key)));
            }
            Assert.assertEquals(expectedResult, actualResult);
            CollectionSink.clearElementsSet();
            CommonTestUtils.waitForAllTaskRunning((MiniCluster)cluster.getMiniCluster(), (JobID)jobGraph.getJobID(), (boolean)false);
            CommonTestUtils.waitForNewCheckpoint((JobID)jobID, (MiniCluster)cluster.getMiniCluster());
            SubtaskIndexSource.SOURCE_LATCH.reset();
            JobResourceRequirements.Builder builder = JobResourceRequirements.newBuilder();
            for (JobVertex vertex : jobGraph.getVertices()) {
                builder.setParallelismForJobVertex(vertex.getID(), parallelism2, parallelism2);
            }
            restClusterClient.updateJobResourceRequirements(jobID, builder.build()).join();
            UpdateJobResourceRequirementsITCase.waitForRunningTasks(restClusterClient, jobID, 2 * parallelism2);
            UpdateJobResourceRequirementsITCase.waitForAvailableSlots(restClusterClient, 4 - parallelism2);
            SubtaskIndexSource.SOURCE_LATCH.trigger();
            client.requestJobResult(jobID).get();
            Set actualResult2 = CollectionSink.getElementsSet();
            HashSet<Tuple2> expectedResult2 = new HashSet<Tuple2>();
            for (int key = 0; key < 42; ++key) {
                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup((Object)key, (int)13);
                expectedResult2.add(Tuple2.of((Object)KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup((int)13, (int)parallelism2, (int)keyGroupIndex), (Object)(key * 2 * 1000)));
            }
            Assert.assertEquals(expectedResult2, actualResult2);
        }
        finally {
            CollectionSink.clearElementsSet();
        }
    }

    @Test
    public void testCheckpointRescalingNonPartitionedStateCausesException() throws Exception {
        block3: {
            int parallelism = 2;
            int parallelism2 = 4;
            int maxParallelism = 13;
            ClusterClient client = cluster.getClusterClient();
            try {
                JobGraph jobGraph = AutoRescalingITCase.createJobGraphWithOperatorState(2, 13, OperatorCheckpointMethod.NON_PARTITIONED);
                StateSourceBase.canFinishLatch = new CountDownLatch(1);
                JobID jobID = jobGraph.getJobID();
                client.submitJob(jobGraph).get();
                CommonTestUtils.waitForAllTaskRunning((MiniCluster)cluster.getMiniCluster(), (JobID)jobGraph.getJobID(), (boolean)false);
                StateSourceBase.workStartedLatch.await();
                CommonTestUtils.waitForNewCheckpoint((JobID)jobID, (MiniCluster)cluster.getMiniCluster());
                JobResourceRequirements.Builder builder = JobResourceRequirements.newBuilder();
                for (JobVertex vertex : jobGraph.getVertices()) {
                    builder.setParallelismForJobVertex(vertex.getID(), 4, 4);
                }
                restClusterClient.updateJobResourceRequirements(jobID, builder.build()).join();
                UpdateJobResourceRequirementsITCase.waitForRunningTasks(restClusterClient, jobID, 8);
                UpdateJobResourceRequirementsITCase.waitForAvailableSlots(restClusterClient, 0);
                StateSourceBase.canFinishLatch.countDown();
                client.requestJobResult(jobID).get();
            }
            catch (JobExecutionException exception) {
                if (exception.getCause() instanceof IllegalStateException) break block3;
                throw exception;
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testCheckpointRescalingWithKeyedAndNonPartitionedState() throws Exception {
        int numberKeys = 42;
        int numberElements = 1000;
        int parallelism = 2;
        int parallelism2 = 4;
        int maxParallelism = 13;
        Duration timeout = Duration.ofMinutes(3L);
        Deadline deadline = Deadline.now().plus(timeout);
        ClusterClient client = cluster.getClusterClient();
        try {
            JobGraph jobGraph = AutoRescalingITCase.createJobGraphWithKeyedAndNonPartitionedOperatorState(parallelism, maxParallelism, parallelism, numberKeys, numberElements, numberElements);
            JobID jobID = jobGraph.getJobID();
            client.submitJob(jobGraph).get();
            SubtaskIndexSource.SOURCE_LATCH.trigger();
            Assert.assertTrue((boolean)SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS));
            Set actualResult = CollectionSink.getElementsSet();
            HashSet<Tuple2> expectedResult = new HashSet<Tuple2>();
            for (int key = 0; key < numberKeys; ++key) {
                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup((Object)key, (int)maxParallelism);
                expectedResult.add(Tuple2.of((Object)KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup((int)maxParallelism, (int)parallelism, (int)keyGroupIndex), (Object)(numberElements * key)));
            }
            Assert.assertEquals(expectedResult, actualResult);
            CollectionSink.clearElementsSet();
            CommonTestUtils.waitForNewCheckpoint((JobID)jobID, (MiniCluster)cluster.getMiniCluster());
            SubtaskIndexSource.SOURCE_LATCH.reset();
            JobResourceRequirements.Builder builder = JobResourceRequirements.newBuilder();
            for (JobVertex vertex : jobGraph.getVertices()) {
                if (vertex.getMaxParallelism() >= parallelism2) {
                    builder.setParallelismForJobVertex(vertex.getID(), parallelism2, parallelism2);
                    continue;
                }
                builder.setParallelismForJobVertex(vertex.getID(), vertex.getMaxParallelism(), vertex.getMaxParallelism());
            }
            restClusterClient.updateJobResourceRequirements(jobID, builder.build()).join();
            UpdateJobResourceRequirementsITCase.waitForRunningTasks(restClusterClient, jobID, parallelism + parallelism2);
            UpdateJobResourceRequirementsITCase.waitForAvailableSlots(restClusterClient, 4 - parallelism2);
            SubtaskIndexSource.SOURCE_LATCH.trigger();
            client.requestJobResult(jobID).get();
            Set actualResult2 = CollectionSink.getElementsSet();
            HashSet<Tuple2> expectedResult2 = new HashSet<Tuple2>();
            for (int key = 0; key < numberKeys; ++key) {
                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup((Object)key, (int)maxParallelism);
                expectedResult2.add(Tuple2.of((Object)KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup((int)maxParallelism, (int)parallelism2, (int)keyGroupIndex), (Object)(key * 2 * numberElements)));
            }
            Assert.assertEquals(expectedResult2, actualResult2);
        }
        finally {
            CollectionSink.clearElementsSet();
        }
    }

    @Test
    public void testCheckpointRescalingInPartitionedOperatorState() throws Exception {
        this.testCheckpointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
    }

    @Test
    public void testCheckpointRescalingOutPartitionedOperatorState() throws Exception {
        this.testCheckpointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
    }

    @Test
    public void testCheckpointRescalingInBroadcastOperatorState() throws Exception {
        this.testCheckpointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
    }

    @Test
    public void testCheckpointRescalingOutBroadcastOperatorState() throws Exception {
        this.testCheckpointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
    }

    public void testCheckpointRescalingPartitionedOperatorState(boolean scaleOut, OperatorCheckpointMethod checkpointMethod) throws Exception {
        int parallelism = scaleOut ? 4 : 2;
        int parallelism2 = scaleOut ? 2 : 4;
        int maxParallelism = 13;
        ClusterClient client = cluster.getClusterClient();
        int counterSize = Math.max(parallelism, parallelism2);
        if (checkpointMethod != OperatorCheckpointMethod.CHECKPOINTED_FUNCTION && checkpointMethod != OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
            throw new UnsupportedOperationException("Unsupported method:" + checkpointMethod);
        }
        PartitionedStateSource.checkCorrectSnapshot = new int[counterSize];
        PartitionedStateSource.checkCorrectRestore = new int[counterSize];
        PartitionedStateSource.checkCorrectSnapshots.clear();
        JobGraph jobGraph = AutoRescalingITCase.createJobGraphWithOperatorState(parallelism, 13, checkpointMethod);
        StateSourceBase.canFinishLatch = new CountDownLatch(1);
        JobID jobID = jobGraph.getJobID();
        client.submitJob(jobGraph).get();
        CommonTestUtils.waitForAllTaskRunning((MiniCluster)cluster.getMiniCluster(), (JobID)jobGraph.getJobID(), (boolean)false);
        StateSourceBase.workStartedLatch.await();
        CommonTestUtils.waitForNewCheckpoint((JobID)jobID, (MiniCluster)cluster.getMiniCluster());
        JobResourceRequirements.Builder builder = JobResourceRequirements.newBuilder();
        for (JobVertex vertex : jobGraph.getVertices()) {
            builder.setParallelismForJobVertex(vertex.getID(), parallelism2, parallelism2);
        }
        restClusterClient.updateJobResourceRequirements(jobID, builder.build()).join();
        UpdateJobResourceRequirementsITCase.waitForRunningTasks(restClusterClient, jobID, 2 * parallelism2);
        UpdateJobResourceRequirementsITCase.waitForAvailableSlots(restClusterClient, 4 - parallelism2);
        StateSourceBase.canFinishLatch.countDown();
        client.requestJobResult(jobID).get();
        int sumExp = 0;
        int sumAct = 0;
        if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
            for (int c : PartitionedStateSource.checkCorrectSnapshot) {
                sumExp += c;
            }
            for (int c : PartitionedStateSource.checkCorrectRestore) {
                sumAct += c;
            }
        } else {
            for (int c : PartitionedStateSource.checkCorrectSnapshot) {
                sumExp += c;
            }
            for (int c : PartitionedStateSource.checkCorrectRestore) {
                sumAct += c;
            }
            sumExp *= parallelism2;
        }
        Assert.assertEquals((long)sumExp, (long)sumAct);
    }

    private static void configureCheckpointing(CheckpointConfig config) {
        config.setCheckpointInterval(100L);
        config.setCheckpointingConsistencyMode(CheckpointingMode.EXACTLY_ONCE);
        config.enableUnalignedCheckpoints(true);
    }

    private static JobGraph createJobGraphWithOperatorState(int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
        StateSourceBase src;
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        AutoRescalingITCase.configureCheckpointing(env.getCheckpointConfig());
        env.setParallelism(parallelism);
        env.getConfig().setMaxParallelism(maxParallelism);
        RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)env);
        StateSourceBase.workStartedLatch = new CountDownLatch(parallelism);
        switch (checkpointMethod) {
            case CHECKPOINTED_FUNCTION: {
                src = new PartitionedStateSource(false);
                break;
            }
            case CHECKPOINTED_FUNCTION_BROADCAST: {
                src = new PartitionedStateSource(true);
                break;
            }
            case NON_PARTITIONED: {
                src = new NonPartitionedStateSource();
                break;
            }
            default: {
                throw new IllegalArgumentException(checkpointMethod.name());
            }
        }
        DataStreamSource input = env.addSource((SourceFunction)src);
        input.sinkTo((Sink)new DiscardingSink());
        return env.getStreamGraph().getJobGraph();
    }

    public static JobGraph createJobGraphWithKeyedState(Configuration configuration, int parallelism, int maxParallelism, int numberKeys, int numberElements) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment((Configuration)configuration);
        env.setParallelism(parallelism);
        if (0 < maxParallelism) {
            env.getConfig().setMaxParallelism(maxParallelism);
        }
        AutoRescalingITCase.configureCheckpointing(env.getCheckpointConfig());
        RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)env);
        env.getConfig().setUseSnapshotCompression(true);
        KeyedStream input = env.addSource((SourceFunction)new SubtaskIndexSource(numberKeys, numberElements, parallelism)).keyBy((KeySelector)new KeySelector<Integer, Integer>(){
            private static final long serialVersionUID = -7952298871120320940L;

            public Integer getKey(Integer value) {
                return value;
            }
        });
        SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
        SingleOutputStreamOperator result = input.flatMap((FlatMapFunction)new SubtaskIndexFlatMapper(numberElements));
        result.addSink(new CollectionSink());
        return env.getStreamGraph().getJobGraph();
    }

    private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(int parallelism, int maxParallelism, int fixedParallelism, int numberKeys, int numberElements, int numberElementsAfterRestart) {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(parallelism);
        env.getConfig().setMaxParallelism(maxParallelism);
        AutoRescalingITCase.configureCheckpointing(env.getCheckpointConfig());
        RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)env);
        KeyedStream input = env.addSource((SourceFunction)new SubtaskIndexNonPartitionedStateSource(numberKeys, numberElements, numberElementsAfterRestart, parallelism)).setParallelism(fixedParallelism).setMaxParallelism(fixedParallelism).keyBy((KeySelector)new KeySelector<Integer, Integer>(){
            private static final long serialVersionUID = -7952298871120320940L;

            public Integer getKey(Integer value) {
                return value;
            }
        });
        SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
        SingleOutputStreamOperator result = input.flatMap((FlatMapFunction)new SubtaskIndexFlatMapper(numberElements));
        result.addSink(new CollectionSink());
        return env.getStreamGraph().getJobGraph();
    }

    static {
        temporaryFolder = new TemporaryFolder();
    }

    private static class PartitionedStateSource
    extends StateSourceBase
    implements CheckpointedFunction {
        private static final long serialVersionUID = -359715965103593462L;
        private static final int NUM_PARTITIONS = 7;
        private transient ListState<Integer> counterPartitions;
        private final boolean broadcast;
        private static final ConcurrentHashMap<Long, int[]> checkCorrectSnapshots = new ConcurrentHashMap();
        private static int[] checkCorrectSnapshot;
        private static int[] checkCorrectRestore;

        public PartitionedStateSource(boolean broadcast) {
            this.broadcast = broadcast;
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            if (this.getRuntimeContext().getTaskInfo().getAttemptNumber() == 0) {
                int[] snapshot = checkCorrectSnapshots.computeIfAbsent(context.getCheckpointId(), x -> new int[checkCorrectRestore.length]);
                snapshot[this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask()] = this.counter;
            }
            this.counterPartitions.clear();
            int div = this.counter / 7;
            int mod = this.counter % 7;
            for (int i = 0; i < 7; ++i) {
                int partitionValue = div;
                if (mod > 0) {
                    --mod;
                    ++partitionValue;
                }
                this.counterPartitions.add((Object)partitionValue);
            }
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.counterPartitions = this.broadcast ? context.getOperatorStateStore().getUnionListState(new ListStateDescriptor("counter_partitions", (TypeSerializer)IntSerializer.INSTANCE)) : context.getOperatorStateStore().getListState(new ListStateDescriptor("counter_partitions", (TypeSerializer)IntSerializer.INSTANCE));
            if (context.isRestored()) {
                Iterator iterator = ((Iterable)this.counterPartitions.get()).iterator();
                while (iterator.hasNext()) {
                    int v = (Integer)iterator.next();
                    this.counter += v;
                }
                PartitionedStateSource.checkCorrectRestore[this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask()] = this.counter;
                context.getRestoredCheckpointId().ifPresent(id -> {
                    checkCorrectSnapshot = checkCorrectSnapshots.get(id);
                });
            }
        }
    }

    private static class NonPartitionedStateSource
    extends StateSourceBase
    implements ListCheckpointed<Integer> {
        private static final long serialVersionUID = -8108185918123186841L;

        private NonPartitionedStateSource() {
        }

        public List<Integer> snapshotState(long checkpointId, long timestamp) {
            return Collections.singletonList(this.counter);
        }

        public void restoreState(List<Integer> state) {
            if (!state.isEmpty()) {
                this.counter = state.get(0);
            }
        }
    }

    private static class StateSourceBase
    extends RichParallelSourceFunction<Integer> {
        private static final long serialVersionUID = 7512206069681177940L;
        private static CountDownLatch workStartedLatch = new CountDownLatch(1);
        private static CountDownLatch canFinishLatch = new CountDownLatch(0);
        protected volatile int counter = 0;
        protected volatile boolean running = true;

        private StateSourceBase() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            while (this.running) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    ++this.counter;
                    ctx.collect((Object)1);
                }
                Thread.sleep(2L);
                if (this.counter == 10) {
                    workStartedLatch.countDown();
                }
                if (this.counter < 500) continue;
            }
            canFinishLatch.await();
        }

        public void cancel() {
            this.running = false;
        }
    }

    private static class CollectionSink<IN>
    implements SinkFunction<IN> {
        private static final Set<Object> elements = Collections.newSetFromMap(new ConcurrentHashMap());
        private static final long serialVersionUID = -1652452958040267745L;

        private CollectionSink() {
        }

        public static <IN> Set<IN> getElementsSet() {
            return elements;
        }

        public static void clearElementsSet() {
            elements.clear();
        }

        public void invoke(IN value) {
            elements.add(value);
        }
    }

    private static class SubtaskIndexFlatMapper
    extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
    implements CheckpointedFunction {
        private static final long serialVersionUID = 5273172591283191348L;
        private static CountDownLatch workCompletedLatch = new CountDownLatch(1);
        private transient ValueState<Integer> counter;
        private transient ValueState<Integer> sum;
        private final int numberElements;

        SubtaskIndexFlatMapper(int numberElements) {
            this.numberElements = numberElements;
        }

        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
            int count = (Integer)this.counter.value() + 1;
            this.counter.update((Object)count);
            int s = (Integer)this.sum.value() + value;
            this.sum.update((Object)s);
            if (count % this.numberElements == 0) {
                out.collect((Object)Tuple2.of((Object)this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(), (Object)s));
                workCompletedLatch.countDown();
            }
        }

        public void snapshotState(FunctionSnapshotContext context) {
        }

        public void initializeState(FunctionInitializationContext context) {
            this.counter = context.getKeyedStateStore().getState(new ValueStateDescriptor("counter", Integer.class, (Object)0));
            this.sum = context.getKeyedStateStore().getState(new ValueStateDescriptor("sum", Integer.class, (Object)0));
        }
    }

    private static class SubtaskIndexNonPartitionedStateSource
    extends SubtaskIndexSource
    implements ListCheckpointed<Integer> {
        private static final long serialVersionUID = 8388073059042040203L;
        private final int numElementsAfterRestart;

        SubtaskIndexNonPartitionedStateSource(int numberKeys, int numberElements, int numElementsAfterRestart, int originalParallelism) {
            super(numberKeys, numberElements, originalParallelism);
            this.numElementsAfterRestart = numElementsAfterRestart;
        }

        public List<Integer> snapshotState(long checkpointId, long timestamp) {
            return Collections.singletonList(this.counter);
        }

        public void restoreState(List<Integer> state) {
            if (state.size() != 1) {
                throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size());
            }
            this.counter = state.get(0);
            this.numberElements += this.numElementsAfterRestart;
        }
    }

    private static class SubtaskIndexSource
    extends RichParallelSourceFunction<Integer> {
        private static final long serialVersionUID = -400066323594122516L;
        private final int numberKeys;
        private final int originalParallelism;
        protected int numberElements;
        protected int counter = 0;
        private boolean running = true;
        private static final OneShotLatch SOURCE_LATCH = new OneShotLatch();

        SubtaskIndexSource(int numberKeys, int numberElements, int originalParallelism) {
            this.numberKeys = numberKeys;
            this.numberElements = numberElements;
            this.originalParallelism = originalParallelism;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            boolean isRestartedOrRescaled;
            RuntimeContext runtimeContext = this.getRuntimeContext();
            int subtaskIndex = runtimeContext.getTaskInfo().getIndexOfThisSubtask();
            boolean bl = isRestartedOrRescaled = runtimeContext.getTaskInfo().getNumberOfParallelSubtasks() != this.originalParallelism || runtimeContext.getTaskInfo().getAttemptNumber() > 0;
            while (this.running) {
                SOURCE_LATCH.await();
                if (this.counter < this.numberElements) {
                    Object object = ctx.getCheckpointLock();
                    synchronized (object) {
                        for (int value = subtaskIndex; value < this.numberKeys; value += runtimeContext.getTaskInfo().getNumberOfParallelSubtasks()) {
                            ctx.collect((Object)value);
                        }
                        ++this.counter;
                        continue;
                    }
                }
                if (isRestartedOrRescaled) {
                    this.running = false;
                    continue;
                }
                Thread.sleep(100L);
            }
        }

        public void cancel() {
            this.running = false;
        }
    }

    static enum OperatorCheckpointMethod {
        NON_PARTITIONED,
        CHECKPOINTED_FUNCTION,
        CHECKPOINTED_FUNCTION_BROADCAST,
        LIST_CHECKPOINTED;

    }
}

