/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.io.IOException;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.checkpoint.Checkpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.checkpointing.StreamFaultToleranceTestBase;
import org.junit.Assert;

public class PartitionedStateCheckpointingITCase
extends StreamFaultToleranceTestBase {
    final long NUM_STRINGS = 10000000L;

    @Override
    public void testProgram(StreamExecutionEnvironment env) {
        Assert.assertTrue((String)"Broken test setup", (boolean)true);
        DataStreamSource stream1 = env.addSource((SourceFunction)new IntGeneratingSourceFunction(5000000L));
        DataStreamSource stream2 = env.addSource((SourceFunction)new IntGeneratingSourceFunction(5000000L));
        stream1.union(new DataStream[]{stream2}).keyBy(new IdentityKeySelector()).map((MapFunction)new OnceFailingPartitionedSum(10000000L)).keyBy(new int[]{0}).addSink((SinkFunction)new CounterSink());
    }

    @Override
    public void postSubmit() {
        for (Map.Entry sum : OnceFailingPartitionedSum.allSums.entrySet()) {
            Assert.assertEquals((Object)new Long((long)((Integer)sum.getKey()).intValue() * 10000000L / 40L), sum.getValue());
        }
        for (Long count : CounterSink.allCounts.values()) {
            Assert.assertEquals((Object)new Long(250000L), (Object)count);
        }
        Assert.assertEquals((long)40L, (long)CounterSink.allCounts.size());
        Assert.assertEquals((long)40L, (long)OnceFailingPartitionedSum.allSums.size());
    }

    private static class IdentityKeySelector<T>
    implements KeySelector<T, T> {
        private IdentityKeySelector() {
        }

        public T getKey(T value) throws Exception {
            return value;
        }
    }

    private static class NonSerializableLong {
        public Long value;

        private NonSerializableLong(long value) {
            this.value = value;
        }

        public static NonSerializableLong of(long value) {
            return new NonSerializableLong(value);
        }
    }

    private static class CounterSink
    extends RichSinkFunction<Tuple2<Integer, Long>> {
        private static Map<Integer, Long> allCounts = new ConcurrentHashMap<Integer, Long>();
        private OperatorState<NonSerializableLong> aCounts;
        private OperatorState<Long> bCounts;

        private CounterSink() {
        }

        public void open(Configuration parameters) throws IOException {
            this.aCounts = this.getRuntimeContext().getKeyValueState("a", NonSerializableLong.class, (Object)NonSerializableLong.of(0L));
            this.bCounts = this.getRuntimeContext().getKeyValueState("b", Long.class, (Object)0L);
        }

        public void invoke(Tuple2<Integer, Long> value) throws Exception {
            long ac = ((NonSerializableLong)this.aCounts.value()).value;
            long bc = (Long)this.bCounts.value();
            Assert.assertEquals((long)ac, (long)bc);
            long currentCount = ac + 1L;
            this.aCounts.update((Object)NonSerializableLong.of(currentCount));
            this.bCounts.update((Object)currentCount);
            allCounts.put((Integer)value.f0, currentCount);
        }
    }

    private static class OnceFailingPartitionedSum
    extends RichMapFunction<Integer, Tuple2<Integer, Long>> {
        private static Map<Integer, Long> allSums = new ConcurrentHashMap<Integer, Long>();
        private static volatile boolean hasFailed = false;
        private final long numElements;
        private long failurePos;
        private long count;
        private OperatorState<Long> sum;

        OnceFailingPartitionedSum(long numElements) {
            this.numElements = numElements;
        }

        public void open(Configuration parameters) throws IOException {
            long failurePosMin = (long)(0.4 * (double)this.numElements / (double)this.getRuntimeContext().getNumberOfParallelSubtasks());
            long failurePosMax = (long)(0.7 * (double)this.numElements / (double)this.getRuntimeContext().getNumberOfParallelSubtasks());
            this.failurePos = new Random().nextLong() % (failurePosMax - failurePosMin) + failurePosMin;
            this.count = 0L;
            this.sum = this.getRuntimeContext().getKeyValueState("my_state", Long.class, (Object)0L);
        }

        public Tuple2<Integer, Long> map(Integer value) throws Exception {
            ++this.count;
            if (!hasFailed && this.count >= this.failurePos) {
                hasFailed = true;
                throw new Exception("Test Failure");
            }
            long currentSum = (Long)this.sum.value() + (long)value.intValue();
            this.sum.update((Object)currentSum);
            allSums.put(value, currentSum);
            return new Tuple2((Object)value, (Object)currentSum);
        }
    }

    private static class IntGeneratingSourceFunction
    extends RichParallelSourceFunction<Integer>
    implements Checkpointed<Integer> {
        private final long numElements;
        private int index;
        private int step;
        private volatile boolean isRunning = true;
        static final long[] counts = new long[6];

        public void close() throws IOException {
            IntGeneratingSourceFunction.counts[this.getRuntimeContext().getIndexOfThisSubtask()] = this.index;
        }

        IntGeneratingSourceFunction(long numElements) {
            this.numElements = numElements;
        }

        public void open(Configuration parameters) throws IOException {
            this.step = this.getRuntimeContext().getNumberOfParallelSubtasks();
            if (this.index == 0) {
                this.index = this.getRuntimeContext().getIndexOfThisSubtask();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void run(SourceFunction.SourceContext<Integer> ctx) throws Exception {
            Object lockingObject = ctx.getCheckpointLock();
            while (this.isRunning && (long)this.index < this.numElements) {
                Object object = lockingObject;
                synchronized (object) {
                    this.index += this.step;
                    ctx.collect((Object)(this.index % 40));
                }
            }
        }

        public void cancel() {
            this.isRunning = false;
        }

        public Integer snapshotState(long checkpointId, long checkpointTimestamp) {
            return this.index;
        }

        public void restoreState(Integer state) {
            this.index = state;
        }
    }
}

