package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.junit.Assert;

/* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase.class */
public class PartitionedStateCheckpointingITCase extends StreamFaultToleranceTestBase {
    final long NUM_STRINGS = 10000000;

    /* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase$CounterSink.class */
    private static class CounterSink extends RichSinkFunction<Tuple2<Integer, Long>> {
        private static Map<Integer, Long> allCounts = new ConcurrentHashMap();
        private OperatorState<NonSerializableLong> aCounts;
        private OperatorState<Long> bCounts;

        private CounterSink() {
        }

        public void open(Configuration configuration) throws IOException {
            this.aCounts = getRuntimeContext().getKeyValueState("a", NonSerializableLong.class, NonSerializableLong.of(0L));
            this.bCounts = getRuntimeContext().getKeyValueState("b", Long.class, 0L);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void invoke(Tuple2<Integer, Long> tuple2) throws Exception {
            long longValue = ((NonSerializableLong) this.aCounts.value()).value.longValue();
            Assert.assertEquals(longValue, ((Long) this.bCounts.value()).longValue());
            long j = longValue + 1;
            this.aCounts.update(NonSerializableLong.of(j));
            this.bCounts.update(Long.valueOf(j));
            allCounts.put(tuple2.f0, Long.valueOf(j));
        }
    }

    /* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase$IdentityKeySelector.class */
    private static class IdentityKeySelector<T> implements KeySelector<T, T> {
        private IdentityKeySelector() {
        }

        public T getKey(T t) throws Exception {
            return t;
        }
    }

    /* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase$IntGeneratingSourceFunction.class */
    private static class IntGeneratingSourceFunction extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> {
        private final long numElements;
        private int index;
        private int step;
        private volatile boolean isRunning = true;
        static final long[] counts = new long[6];

        public void close() throws IOException {
            counts[getRuntimeContext().getIndexOfThisSubtask()] = this.index;
        }

        IntGeneratingSourceFunction(long j) {
            this.numElements = j;
        }

        public void open(Configuration configuration) throws IOException {
            this.step = getRuntimeContext().getNumberOfParallelSubtasks();
            if (this.index == 0) {
                this.index = getRuntimeContext().getIndexOfThisSubtask();
            }
        }

        public void run(SourceFunction.SourceContext<Integer> sourceContext) throws Exception {
            Object checkpointLock = sourceContext.getCheckpointLock();
            while (this.isRunning && this.index < this.numElements) {
                synchronized (checkpointLock) {
                    this.index += this.step;
                    sourceContext.collect(Integer.valueOf(this.index % 40));
                }
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        /* renamed from: snapshotState, reason: merged with bridge method [inline-methods] */
        public Integer m543snapshotState(long j, long j2) {
            return Integer.valueOf(this.index);
        }

        public void restoreState(Integer num) {
            this.index = num.intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase$NonSerializableLong.class */
    public static class NonSerializableLong {
        public Long value;

        private NonSerializableLong(long j) {
            this.value = Long.valueOf(j);
        }

        public static NonSerializableLong of(long j) {
            return new NonSerializableLong(j);
        }
    }

    /* loaded from: input_file:org/apache/flink/test/checkpointing/PartitionedStateCheckpointingITCase$OnceFailingPartitionedSum.class */
    private static class OnceFailingPartitionedSum extends RichMapFunction<Integer, Tuple2<Integer, Long>> {
        private static Map<Integer, Long> allSums = new ConcurrentHashMap();
        private static volatile boolean hasFailed = false;
        private final long numElements;
        private long failurePos;
        private long count;
        private OperatorState<Long> sum;

        OnceFailingPartitionedSum(long j) {
            this.numElements = j;
        }

        public void open(Configuration configuration) throws IOException {
            long numberOfParallelSubtasks = (long) ((0.4d * this.numElements) / getRuntimeContext().getNumberOfParallelSubtasks());
            this.failurePos = (new Random().nextLong() % (((long) ((0.7d * this.numElements) / getRuntimeContext().getNumberOfParallelSubtasks())) - numberOfParallelSubtasks)) + numberOfParallelSubtasks;
            this.count = 0L;
            this.sum = getRuntimeContext().getKeyValueState("my_state", Long.class, 0L);
        }

        public Tuple2<Integer, Long> map(Integer num) throws Exception {
            this.count++;
            if (!hasFailed && this.count >= this.failurePos) {
                hasFailed = true;
                throw new Exception("Test Failure");
            }
            long longValue = ((Long) this.sum.value()).longValue() + num.intValue();
            this.sum.update(Long.valueOf(longValue));
            allSums.put(num, Long.valueOf(longValue));
            return new Tuple2<>(num, Long.valueOf(longValue));
        }
    }

    @Override // org.apache.flink.test.checkpointing.StreamFaultToleranceTestBase
    public void testProgram(StreamExecutionEnvironment streamExecutionEnvironment) {
        Assert.assertTrue("Broken test setup", true);
        streamExecutionEnvironment.addSource(new IntGeneratingSourceFunction(5000000L)).union(new DataStream[]{streamExecutionEnvironment.addSource(new IntGeneratingSourceFunction(5000000L))}).keyBy(new IdentityKeySelector()).map(new OnceFailingPartitionedSum(10000000L)).keyBy(new int[]{0}).addSink(new CounterSink());
    }

    @Override // org.apache.flink.test.checkpointing.StreamFaultToleranceTestBase
    public void postSubmit() {
        Iterator it = OnceFailingPartitionedSum.allSums.entrySet().iterator();
        while (it.hasNext()) {
            Assert.assertEquals(new Long((((Integer) r0.getKey()).intValue() * 10000000) / 40), ((Map.Entry) it.next()).getValue());
        }
        Iterator it2 = CounterSink.allCounts.values().iterator();
        while (it2.hasNext()) {
            Assert.assertEquals(new Long(250000L), (Long) it2.next());
        }
        Assert.assertEquals(40L, CounterSink.allCounts.size());
        Assert.assertEquals(40L, OnceFailingPartitionedSum.allSums.size());
    }
}
