/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.nio.file.Path;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.JobManagerOptions;
import org.apache.flink.core.execution.SavepointFormatType;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.runtime.testutils.CommonTestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.legacy.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
import org.apache.flink.streaming.api.functions.source.legacy.RichSourceFunction;
import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction;
import org.apache.flink.streaming.api.graph.StreamGraph;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.RestartStrategyUtils;
import org.apache.flink.test.util.AbstractTestBaseJUnit4;
import org.apache.flink.testutils.junit.SharedObjectsExtension;
import org.apache.flink.testutils.junit.SharedReference;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.api.io.TempDir;

public class CheckpointAfterAllTasksFinishedITCase
extends AbstractTestBaseJUnit4 {
    private static final int SMALL_SOURCE_NUM_RECORDS = 20;
    private static final int BIG_SOURCE_NUM_RECORDS = 100;
    private StreamExecutionEnvironment env;
    private SharedReference<List<Integer>> smallResult;
    private SharedReference<List<Integer>> bigResult;
    @TempDir
    private Path tmpDir;
    @RegisterExtension
    private final SharedObjectsExtension sharedObjects = SharedObjectsExtension.create();

    @BeforeEach
    public void setUp() {
        this.env = StreamExecutionEnvironment.getExecutionEnvironment();
        this.env.setParallelism(4);
        this.smallResult = this.sharedObjects.add(new CopyOnWriteArrayList());
        this.bigResult = this.sharedObjects.add(new CopyOnWriteArrayList());
        IntegerStreamSource.failedBefore = false;
        IntegerStreamSource.latch = new CountDownLatch(1);
    }

    @Test
    public void testImmediateCheckpointing() throws Exception {
        RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)this.env);
        this.env.enableCheckpointing(Duration.ofNanos(Long.MAX_VALUE).toMillis());
        StreamGraph streamGraph = this.getStreamGraph(this.env, false, false);
        this.env.execute(streamGraph);
        Assertions.assertThat((int)((List)this.smallResult.get()).size()).isEqualTo(20);
        Assertions.assertThat((int)((List)this.bigResult.get()).size()).isEqualTo(100);
    }

    @Test
    public void testRestoreAfterSomeTasksFinished() throws Exception {
        MiniClusterConfiguration cfg = new MiniClusterConfiguration.Builder().withRandomPorts().setNumTaskManagers(1).setNumSlotsPerTaskManager(4).build();
        try (MiniCluster miniCluster = new MiniCluster(cfg);){
            miniCluster.start();
            RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)this.env);
            this.env.enableCheckpointing(100L);
            JobGraph jobGraph = this.getStreamGraph(this.env, true, false).getJobGraph();
            miniCluster.submitJob(jobGraph).get();
            CommonTestUtils.waitForSubtasksToFinish((MiniCluster)miniCluster, (JobID)jobGraph.getJobID(), (JobVertexID)this.findVertexByName(jobGraph, "passA -> Sink: sinkA").getID(), (boolean)false);
            String savepointPath = (String)miniCluster.triggerSavepoint(jobGraph.getJobID(), this.tmpDir.toFile().getAbsolutePath(), true, SavepointFormatType.CANONICAL).get();
            ((List)this.bigResult.get()).clear();
            this.env.enableCheckpointing(Duration.ofNanos(Long.MAX_VALUE).toMillis());
            JobGraph restoredJobGraph = this.getStreamGraph(this.env, true, false).getJobGraph();
            restoredJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)savepointPath, (boolean)false));
            miniCluster.submitJob(restoredJobGraph).get();
            IntegerStreamSource.latch.countDown();
            miniCluster.requestJobResult(restoredJobGraph.getJobID()).get();
            Assertions.assertThat((int)((List)this.smallResult.get()).size()).isEqualTo(20);
            Assertions.assertThat((int)((List)this.bigResult.get()).size()).isEqualTo(100);
        }
    }

    @Test
    public void testFailoverAfterSomeTasksFinished() throws Exception {
        Configuration config = new Configuration();
        config.set(JobManagerOptions.EXECUTION_FAILOVER_STRATEGY, (Object)"full");
        MiniClusterConfiguration cfg = new MiniClusterConfiguration.Builder().withRandomPorts().setNumTaskManagers(1).setNumSlotsPerTaskManager(4).setConfiguration(config).build();
        try (MiniCluster miniCluster = new MiniCluster(cfg);){
            miniCluster.start();
            this.env.enableCheckpointing(100L);
            JobGraph jobGraph = this.getStreamGraph(this.env, true, true).getJobGraph();
            miniCluster.submitJob(jobGraph).get();
            CommonTestUtils.waitForSubtasksToFinish((MiniCluster)miniCluster, (JobID)jobGraph.getJobID(), (JobVertexID)this.findVertexByName(jobGraph, "passA -> Sink: sinkA").getID(), (boolean)true);
            ((List)this.bigResult.get()).clear();
            IntegerStreamSource.latch.countDown();
            miniCluster.requestJobResult(jobGraph.getJobID()).get();
            Assertions.assertThat((int)((List)this.smallResult.get()).size()).isIn(new Object[]{20, 40});
            Assertions.assertThat((int)((List)this.bigResult.get()).size()).isEqualTo(100);
        }
    }

    private StreamGraph getStreamGraph(StreamExecutionEnvironment env, boolean block, boolean needFailover) {
        env.addSource((SourceFunction)new IntegerStreamSource(20, false, false)).transform("passA", Types.INT, (OneInputStreamOperator)new PassThroughOperator()).addSink((SinkFunction)new CollectSink(this.smallResult)).name("sinkA");
        env.addSource((SourceFunction)new IntegerStreamSource(100, block, needFailover)).transform("passB", Types.INT, (OneInputStreamOperator)new PassThroughOperator()).addSink((SinkFunction)new CollectSink(this.bigResult)).name("sinkB");
        return env.getStreamGraph();
    }

    private JobVertex findVertexByName(JobGraph jobGraph, String vertexName) {
        for (JobVertex vertex : jobGraph.getVerticesAsArray()) {
            if (!vertex.getName().equals(vertexName)) continue;
            return vertex;
        }
        return null;
    }

    private static class CollectSink
    extends RichSinkFunction<Integer> {
        private final SharedReference<List<Integer>> result;

        public CollectSink(SharedReference<List<Integer>> result) {
            this.result = result;
        }

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
        }

        public void invoke(Integer value, SinkFunction.Context context) throws Exception {
            ((List)this.result.get()).add(value);
        }
    }

    private static class PassThroughOperator
    extends AbstractStreamOperator<Integer>
    implements OneInputStreamOperator<Integer, Integer> {
        private long checkpointID;

        private PassThroughOperator() {
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            this.checkpointID = context.getCheckpointId();
        }

        public void processElement(StreamRecord<Integer> element) throws Exception {
            this.output.collect(element);
        }

        public void close() throws Exception {
            super.close();
            Assertions.assertThat((long)this.checkpointID).isGreaterThan(0L);
        }
    }

    private static final class IntegerStreamSource
    extends RichSourceFunction<Integer> {
        private static final long serialVersionUID = 1L;
        private static CountDownLatch latch;
        private static boolean failedBefore;
        private final int numRecords;
        private boolean block;
        private volatile boolean running;
        private int emittedCount;
        private boolean needFailover;

        public IntegerStreamSource(int numRecords, boolean block, boolean needFailover) {
            this.numRecords = numRecords;
            this.running = true;
            this.block = block;
            this.needFailover = needFailover;
            this.emittedCount = 0;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            while (this.running && this.emittedCount < this.numRecords) {
                Object object = ctx.getCheckpointLock();
                synchronized (object) {
                    ctx.collect((Object)this.emittedCount);
                }
                ++this.emittedCount;
            }
            if (this.block && latch != null) {
                latch.await();
            }
            if (this.needFailover && !failedBefore) {
                failedBefore = true;
                throw new RuntimeException("forced failure");
            }
        }

        public void cancel() {
            this.running = false;
        }
    }
}

