/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import org.apache.commons.collections.CollectionUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ExternalizedCheckpointRetention;
import org.apache.flink.configuration.HighAvailabilityOptions;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.core.execution.CheckpointingMode;
import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.PerJobCheckpointRecoveryFactory;
import org.apache.flink.runtime.checkpoint.StandaloneCompletedCheckpointStore;
import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
import org.apache.flink.runtime.highavailability.HighAvailabilityServicesFactory;
import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServicesWithLeadershipControl;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.legacy.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
import org.apache.flink.streaming.api.functions.source.legacy.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.test.util.TestUtils;
import org.apache.flink.util.TestLogger;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class RegionFailoverITCase
extends TestLogger {
    private static final int FAIL_BASE = 1000;
    private static final int NUM_OF_REGIONS = 3;
    private static final int MAX_PARALLELISM = 6;
    private static final Set<Integer> EXPECTED_INDICES_MULTI_REGION = IntStream.range(0, 3).boxed().collect(Collectors.toSet());
    private static final Set<Integer> EXPECTED_INDICES_SINGLE_REGION = Collections.singleton(0);
    private static final int NUM_OF_RESTARTS = 3;
    private static final int NUM_ELEMENTS = 10000;
    private static final String SINGLE_REGION_SOURCE_NAME = "single-source";
    private static final String MULTI_REGION_SOURCE_NAME = "multi-source";
    private static AtomicLong lastCompletedCheckpointId = new AtomicLong(0L);
    private static AtomicInteger numCompletedCheckpoints = new AtomicInteger(0);
    private static AtomicInteger jobFailedCnt = new AtomicInteger(0);
    private static Map<Long, Integer> snapshotIndicesOfSubTask = new HashMap<Long, Integer>();
    private static MiniClusterWithClientResource cluster;
    private static boolean restoredState;
    @ClassRule
    public static final TemporaryFolder TEMPORARY_FOLDER;

    @Before
    public void setup() throws Exception {
        Configuration configuration = new Configuration();
        configuration.set(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, (Object)"region");
        configuration.set(HighAvailabilityOptions.HA_MODE, (Object)TestingHAFactory.class.getName());
        cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(configuration).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
        cluster.before();
        jobFailedCnt.set(0);
        numCompletedCheckpoints.set(0);
    }

    @AfterClass
    public static void shutDownExistingCluster() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test(timeout=60000L)
    public void testMultiRegionFailover() {
        try {
            JobGraph jobGraph = this.createJobGraph();
            ClusterClient client = cluster.getClusterClient();
            TestUtils.submitJobAndWaitForResult((ClusterClient)client, (JobGraph)jobGraph, (ClassLoader)((Object)((Object)this)).getClass().getClassLoader());
            this.verifyAfterJobExecuted();
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    private void verifyAfterJobExecuted() {
        Assert.assertTrue((String)"The test multi-region job has never ever restored state.", (boolean)restoredState);
        int keyCount = 0;
        for (Map<Integer, Integer> map : ValidatingSink.maps) {
            for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
                Assert.assertEquals((long)(4 * entry.getKey() + 1), (long)entry.getValue().intValue());
                ++keyCount;
            }
        }
        Assert.assertEquals((long)5000L, (long)keyCount);
    }

    private JobGraph createJobGraph() {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(3);
        env.setMaxParallelism(6);
        env.enableCheckpointing(200L, CheckpointingMode.EXACTLY_ONCE);
        env.getCheckpointConfig().setExternalizedCheckpointRetention(ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION);
        env.disableOperatorChaining();
        DataStreamUtils.reinterpretAsKeyedStream((DataStream)env.addSource((SourceFunction)new StringGeneratingSourceFunction(10000L, 3333L)).name(MULTI_REGION_SOURCE_NAME).setParallelism(3), (KeySelector & Serializable)value -> (Integer)value.f0, (TypeInformation)TypeInformation.of(Integer.class)).map((MapFunction)new FailingMapperFunction(3)).setParallelism(3).addSink((SinkFunction)new ValidatingSink()).setParallelism(3);
        env.addSource((SourceFunction)new StringGeneratingSourceFunction(10000L, 3333L)).name(SINGLE_REGION_SOURCE_NAME).setParallelism(1).map((MapFunction & Serializable)value -> value).setParallelism(1);
        return env.getStreamGraph().getJobGraph();
    }

    static {
        restoredState = false;
        TEMPORARY_FOLDER = new TemporaryFolder();
    }

    public static class TestingHAFactory
    implements HighAvailabilityServicesFactory {
        public HighAvailabilityServices createHAServices(Configuration configuration, Executor executor) {
            CheckpointRecoveryFactory checkpointRecoveryFactory = PerJobCheckpointRecoveryFactory.withoutCheckpointStoreRecovery(maxCheckpoints -> new TestingCompletedCheckpointStore());
            return new EmbeddedHaServicesWithLeadershipControl(executor, checkpointRecoveryFactory);
        }
    }

    private static class TestingCompletedCheckpointStore
    extends StandaloneCompletedCheckpointStore {
        TestingCompletedCheckpointStore() {
            super(1);
        }

        public CompletedCheckpoint addCheckpointAndSubsumeOldestOne(CompletedCheckpoint checkpoint, CheckpointsCleaner checkpointsCleaner, Runnable postCleanup) throws Exception {
            CompletedCheckpoint subsumedCheckpoint = super.addCheckpointAndSubsumeOldestOne(checkpoint, checkpointsCleaner, postCleanup);
            lastCompletedCheckpointId.set(checkpoint.getCheckpointID());
            numCompletedCheckpoints.incrementAndGet();
            return subsumedCheckpoint;
        }
    }

    private static class TestException
    extends IOException {
        private static final long serialVersionUID = 1L;

        private TestException() {
        }
    }

    private static class ValidatingSink
    extends RichSinkFunction<Tuple2<Integer, Integer>>
    implements ListCheckpointed<HashMap<Integer, Integer>> {
        private static Map<Integer, Integer>[] maps = new Map[3];
        private HashMap<Integer, Integer> counts = new HashMap();

        private ValidatingSink() {
        }

        public void invoke(Tuple2<Integer, Integer> input) {
            this.counts.merge((Integer)input.f0, (Integer)input.f1, Math::max);
        }

        public void close() throws Exception {
            ValidatingSink.maps[this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask()] = this.counts;
        }

        public List<HashMap<Integer, Integer>> snapshotState(long checkpointId, long timestamp) throws Exception {
            return Collections.singletonList(this.counts);
        }

        public void restoreState(List<HashMap<Integer, Integer>> state) throws Exception {
            if (state.size() != 1) {
                throw new RuntimeException("Test failed due to unexpected recovered state size " + state.size());
            }
            this.counts.putAll((Map<Integer, Integer>)state.get(0));
        }
    }

    private static class FailingMapperFunction
    extends RichMapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
        private final int restartTimes;
        private ValueState<Integer> valueState;

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            this.valueState = this.getRuntimeContext().getState(new ValueStateDescriptor("value", Integer.class));
        }

        FailingMapperFunction(int restartTimes) {
            this.restartTimes = restartTimes;
        }

        public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> input) throws Exception {
            Integer value;
            int indexOfThisSubtask = this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
            if ((Integer)input.f1 > 1000 * (jobFailedCnt.get() + 1)) {
                if (jobFailedCnt.get() < 1 && indexOfThisSubtask == 0) {
                    jobFailedCnt.incrementAndGet();
                    throw new TestException();
                }
                if (jobFailedCnt.get() < this.restartTimes && indexOfThisSubtask == 2) {
                    jobFailedCnt.incrementAndGet();
                    throw new TestException();
                }
            }
            if ((value = (Integer)this.valueState.value()) == null) {
                this.valueState.update((Object)((Integer)input.f1));
                return input;
            }
            return Tuple2.of((Object)((Integer)input.f0), (Object)(value + (Integer)input.f1));
        }
    }

    private static class StringGeneratingSourceFunction
    extends RichParallelSourceFunction<Tuple2<Integer, Integer>>
    implements CheckpointedFunction {
        private static final long serialVersionUID = 1L;
        private final long numElements;
        private final long checkpointLatestAt;
        private int index = -1;
        private int lastRegionIndex = -1;
        private volatile boolean isRunning = true;
        private ListState<Integer> listState;
        private static final ListStateDescriptor<Integer> stateDescriptor = new ListStateDescriptor("list-1", Integer.class);
        private ListState<Integer> unionListState;
        private static final ListStateDescriptor<Integer> unionStateDescriptor = new ListStateDescriptor("list-2", Integer.class);

        StringGeneratingSourceFunction(long numElements, long checkpointLatestAt) {
            this.numElements = numElements;
            this.checkpointLatestAt = checkpointLatestAt;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Tuple2<Integer, Integer>> ctx) throws Exception {
            if (this.index < 0) {
                this.index = 0;
            }
            int subTaskIndex = this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
            while (this.isRunning && (long)this.index < this.numElements) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    int key = this.index / 2;
                    int forwardTaskIndex = KeyGroupRangeAssignment.assignKeyToParallelOperator((Object)key, (int)6, (int)3);
                    if (forwardTaskIndex == subTaskIndex) {
                        ctx.collect((Object)Tuple2.of((Object)key, (Object)this.index));
                    }
                    ++this.index;
                }
                if (numCompletedCheckpoints.get() < 3) {
                    if ((long)this.index < this.checkpointLatestAt) {
                        Thread.sleep(1L);
                    } else {
                        while (this.isRunning && numCompletedCheckpoints.get() < 3) {
                            Thread.sleep(300L);
                        }
                    }
                }
                if (jobFailedCnt.get() >= 3) continue;
                Thread.sleep(1L);
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            int indexOfThisSubtask = this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
            if (indexOfThisSubtask != 0) {
                this.listState.update(Collections.singletonList(this.index));
                if (indexOfThisSubtask == 2) {
                    this.lastRegionIndex = this.index;
                    snapshotIndicesOfSubTask.put(context.getCheckpointId(), this.lastRegionIndex);
                }
            }
            this.unionListState.update(Collections.singletonList(indexOfThisSubtask));
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            int indexOfThisSubtask = this.getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
            if (context.isRestored()) {
                restoredState = true;
                this.unionListState = context.getOperatorStateStore().getUnionListState(unionStateDescriptor);
                Set actualIndices = StreamSupport.stream(((Iterable)this.unionListState.get()).spliterator(), false).collect(Collectors.toSet());
                if (this.getRuntimeContext().getTaskInfo().getTaskName().contains(RegionFailoverITCase.SINGLE_REGION_SOURCE_NAME)) {
                    Assert.assertTrue((boolean)CollectionUtils.isEqualCollection(EXPECTED_INDICES_SINGLE_REGION, actualIndices));
                } else {
                    Assert.assertTrue((boolean)CollectionUtils.isEqualCollection(EXPECTED_INDICES_MULTI_REGION, actualIndices));
                }
                if (indexOfThisSubtask == 0) {
                    this.listState = context.getOperatorStateStore().getListState(stateDescriptor);
                    Assert.assertTrue((String)"list state should be empty for subtask-0", (boolean)((List)this.listState.get()).isEmpty());
                } else {
                    this.listState = context.getOperatorStateStore().getListState(stateDescriptor);
                    Assert.assertTrue((String)("list state should not be empty for subtask-" + indexOfThisSubtask), (((List)this.listState.get()).size() > 0 ? 1 : 0) != 0);
                    if (indexOfThisSubtask == 2) {
                        this.index = (Integer)((Iterable)this.listState.get()).iterator().next();
                        if (this.index != snapshotIndicesOfSubTask.get(lastCompletedCheckpointId.get())) {
                            throw new RuntimeException("Test failed due to unexpected recovered index: " + this.index + ", while last completed checkpoint record index: " + snapshotIndicesOfSubTask.get(lastCompletedCheckpointId.get()));
                        }
                    }
                }
            } else {
                this.unionListState = context.getOperatorStateStore().getUnionListState(unionStateDescriptor);
                if (indexOfThisSubtask != 0) {
                    this.listState = context.getOperatorStateStore().getListState(stateDescriptor);
                }
            }
        }
    }
}

