package org.apache.flink.test.checkpointing;

import java.time.Duration;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.execution.CheckpointingMode;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.junit.SharedObjects;
import org.apache.flink.testutils.junit.SharedReference;
import org.apache.flink.util.TestLogger;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/test/checkpointing/IgnoreInFlightDataITCase.class */
public class IgnoreInFlightDataITCase extends TestLogger {

    @ClassRule
    public static final MiniClusterWithClientResource CLUSTER = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(getConfiguration()).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
    private static final int PARALLELISM = 3;
    private SharedReference<OneShotLatch> checkpointReachSinkLatch;
    private SharedReference<AtomicLong> resultBeforeFail;
    private SharedReference<AtomicLong> result;
    private SharedReference<AtomicInteger> lastCheckpointValue;

    @Rule
    public final SharedObjects sharedObjects = SharedObjects.create();
    private int checkpointInterval = 5;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/IgnoreInFlightDataITCase$NumberSource.class */
    public static class NumberSource implements SourceFunction<Integer>, CheckpointedFunction {
        private static final long serialVersionUID = 1;
        private final SharedReference<AtomicInteger> lastCheckpointValue;
        private ListState<Integer> valueState;
        private volatile boolean isRunning = true;

        public NumberSource(SharedReference<AtomicInteger> sharedReference) {
            this.lastCheckpointValue = sharedReference;
        }

        public void run(SourceFunction.SourceContext<Integer> sourceContext) throws Exception {
            Iterator it = ((Iterable) this.valueState.get()).iterator();
            if (it.hasNext()) {
                synchronized (sourceContext.getCheckpointLock()) {
                    Integer num = (Integer) it.next();
                    Assert.assertEquals(((AtomicInteger) this.lastCheckpointValue.get()).intValue(), num.intValue());
                    sourceContext.collect(Integer.valueOf(num.intValue() + 1));
                }
                return;
            }
            int i = 0;
            synchronized (sourceContext.getCheckpointLock()) {
                do {
                    i++;
                    this.valueState.update(Collections.singletonList(Integer.valueOf(i)));
                    sourceContext.collect(Integer.valueOf(i));
                } while (i < IgnoreInFlightDataITCase.PARALLELISM);
            }
            while (this.isRunning) {
                LockSupport.parkNanos(100000L);
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            Iterator it = ((Iterable) this.valueState.get()).iterator();
            if (!it.hasNext() || ((Integer) it.next()).intValue() < IgnoreInFlightDataITCase.PARALLELISM || (functionSnapshotContext.getCheckpointId() > 1 && ((AtomicInteger) this.lastCheckpointValue.get()).get() < IgnoreInFlightDataITCase.PARALLELISM)) {
                throw new RuntimeException("Not enough data to guarantee the in-flight data were generated before the first checkpoint");
            }
            if (functionSnapshotContext.getCheckpointId() > 2) {
                return;
            }
            if (functionSnapshotContext.getCheckpointId() == 2) {
                throw new ExpectedTestException("The planned fail on the second checkpoint");
            }
            ((AtomicInteger) this.lastCheckpointValue.get()).set(((Integer) ((Iterable) this.valueState.get()).iterator().next()).intValue());
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.valueState = functionInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("state", Types.INT));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/IgnoreInFlightDataITCase$SlowMap.class */
    public static class SlowMap extends RichMapFunction<Integer, Integer> {
        private final SharedReference<OneShotLatch> checkpointReachSinkLatch;

        public SlowMap(SharedReference<OneShotLatch> sharedReference) {
            this.checkpointReachSinkLatch = sharedReference;
        }

        public Integer map(Integer num) throws Exception {
            if (getRuntimeContext().getTaskInfo().getIndexOfThisSubtask() > 0) {
                ((OneShotLatch) this.checkpointReachSinkLatch.get()).await();
            }
            return num;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/IgnoreInFlightDataITCase$SumFailSink.class */
    public static class SumFailSink implements SinkFunction<Integer>, CheckpointedFunction {
        private final SharedReference<OneShotLatch> checkpointReachSinkLatch;
        private final SharedReference<AtomicLong> resultBeforeFail;
        private final SharedReference<AtomicLong> result;

        public SumFailSink(SharedReference<OneShotLatch> sharedReference, SharedReference<AtomicLong> sharedReference2, SharedReference<AtomicLong> sharedReference3) {
            this.checkpointReachSinkLatch = sharedReference;
            this.resultBeforeFail = sharedReference2;
            this.result = sharedReference3;
        }

        public void invoke(Integer num, SinkFunction.Context context) throws Exception {
            ((AtomicLong) this.result.get()).addAndGet(num.intValue());
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            if (functionSnapshotContext.getCheckpointId() == 1) {
                ((AtomicLong) this.resultBeforeFail.get()).set(((AtomicLong) this.result.get()).longValue());
                sinkCheckpointStarted();
            }
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            ((AtomicLong) this.result.get()).set(((AtomicLong) this.resultBeforeFail.get()).longValue());
        }

        public void sinkCheckpointStarted() {
            ((OneShotLatch) this.checkpointReachSinkLatch.get()).trigger();
        }
    }

    private static Configuration getConfiguration() {
        Configuration configuration = new Configuration();
        configuration.set(TaskManagerOptions.MANAGED_MEMORY_SIZE, MemorySize.parse("48m"));
        return configuration;
    }

    public void setupSharedObjects() {
        this.checkpointReachSinkLatch = this.sharedObjects.add(new OneShotLatch());
        this.resultBeforeFail = this.sharedObjects.add(new AtomicLong());
        this.result = this.sharedObjects.add(new AtomicLong());
        this.lastCheckpointValue = this.sharedObjects.add(new AtomicInteger());
    }

    @Test
    public void testIgnoreInFlightDataDuringRecovery() {
        do {
        } while (!executeIgnoreInFlightDataDuringRecovery());
    }

    private boolean executeIgnoreInFlightDataDuringRecovery() {
        setupSharedObjects();
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(PARALLELISM);
        int i = this.checkpointInterval * 2;
        this.checkpointInterval = i;
        executionEnvironment.enableCheckpointing(i);
        executionEnvironment.disableOperatorChaining();
        executionEnvironment.getCheckpointConfig().enableUnalignedCheckpoints();
        executionEnvironment.getCheckpointConfig().setAlignmentTimeout(Duration.ZERO);
        executionEnvironment.getCheckpointConfig().setCheckpointingConsistencyMode(CheckpointingMode.EXACTLY_ONCE);
        executionEnvironment.getCheckpointConfig().setCheckpointIdOfIgnoredInFlightData(1L);
        executionEnvironment.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0L));
        executionEnvironment.addSource(new NumberSource(this.lastCheckpointValue)).map(new SlowMap(this.checkpointReachSinkLatch)).addSink(new SumFailSink(this.checkpointReachSinkLatch, this.resultBeforeFail, this.result)).setParallelism(1);
        try {
            executionEnvironment.execute("Total sum");
            int intValue = ((AtomicInteger) this.lastCheckpointValue.get()).intValue() + 1;
            long j = 0;
            for (int i2 = 0; i2 <= intValue; i2++) {
                j += i2;
            }
            MatcherAssert.assertThat(Long.valueOf(((AtomicLong) this.result.get()).longValue()), Matchers.lessThan(Long.valueOf(j)));
            Assert.assertEquals(((AtomicLong) this.resultBeforeFail.get()).longValue() + intValue, ((AtomicLong) this.result.get()).longValue());
            return true;
        } catch (Exception e) {
            this.log.error("Execution failed", e);
            return false;
        }
    }
}
