/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api.input;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import javax.annotation.Nonnull;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.VoidSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.state.api.functions.KeyedStateReaderFunction;
import org.apache.flink.state.api.input.KeyedStateInputFormat;
import org.apache.flink.state.api.input.operator.KeyedStateReaderOperator;
import org.apache.flink.state.api.input.operator.StateReaderOperator;
import org.apache.flink.state.api.input.splits.KeyGroupRangeInputSplit;
import org.apache.flink.state.api.runtime.OperatorIDGenerator;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamFlatMap;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

public class KeyedStateInputFormatTest {
    private static ValueStateDescriptor<Integer> stateDescriptor = new ValueStateDescriptor("state", Types.INT);

    @Test
    public void testCreatePartitionedInputSplits() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new StreamFlatMap((FlatMapFunction)new StatefulFunction()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new ReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit[] splits = format.createInputSplits(4);
        Assert.assertEquals((String)"Failed to properly partition operator state into input splits", (long)4L, (long)splits.length);
    }

    @Test
    public void testMaxParallelismRespected() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new StreamFlatMap((FlatMapFunction)new StatefulFunction()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new ReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit[] splits = format.createInputSplits(129);
        Assert.assertEquals((String)"Failed to properly partition operator state into input splits", (long)128L, (long)splits.length);
    }

    @Test
    public void testReadState() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new StreamFlatMap((FlatMapFunction)new StatefulFunction()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new ReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
        ReaderFunction userFunction = new ReaderFunction();
        List<Integer> data = this.readInputSplit(split, userFunction);
        Assert.assertEquals((String)"Incorrect data read from input split", Arrays.asList(1, 2, 3), data);
    }

    @Test
    public void testReadMultipleOutputPerKey() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new StreamFlatMap((FlatMapFunction)new StatefulFunction()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new ReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
        DoubleReaderFunction userFunction = new DoubleReaderFunction();
        List<Integer> data = this.readInputSplit(split, userFunction);
        Assert.assertEquals((String)"Incorrect data read from input split", Arrays.asList(1, 1, 2, 2, 3, 3), data);
    }

    @Test(expected=IOException.class)
    public void testInvalidProcessReaderFunctionFails() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new StreamFlatMap((FlatMapFunction)new StatefulFunction()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new ReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
        InvalidReaderFunction userFunction = new InvalidReaderFunction();
        this.readInputSplit(split, userFunction);
        Assert.fail((String)"KeyedStateReaderFunction did not fail on invalid RuntimeContext use");
    }

    @Test
    public void testReadTime() throws Exception {
        OperatorID operatorID = OperatorIDGenerator.fromUid((String)"uid");
        OperatorSubtaskState state = this.createOperatorSubtaskState((OneInputStreamOperator<Integer, Void>)new KeyedProcessOperator((KeyedProcessFunction)new StatefulFunctionWithTime()));
        OperatorState operatorState = new OperatorState(operatorID, 1, 128);
        operatorState.putState(0, state);
        KeyedStateInputFormat format = new KeyedStateInputFormat(operatorState, (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator((KeyedStateReaderFunction)new TimerReaderFunction(), Types.INT));
        KeyGroupRangeInputSplit split = format.createInputSplits(1)[0];
        TimerReaderFunction userFunction = new TimerReaderFunction();
        List<Integer> data = this.readInputSplit(split, userFunction);
        Assert.assertEquals((String)"Incorrect data read from input split", Arrays.asList(1, 1, 2, 2, 3, 3), data);
    }

    @Nonnull
    private List<Integer> readInputSplit(KeyGroupRangeInputSplit split, KeyedStateReaderFunction<Integer, Integer> userFunction) throws IOException {
        KeyedStateInputFormat format = new KeyedStateInputFormat(new OperatorState(OperatorIDGenerator.fromUid((String)"uid"), 1, 4), (StateBackend)new MemoryStateBackend(), new Configuration(), (StateReaderOperator)new KeyedStateReaderOperator(userFunction, Types.INT));
        ArrayList<Integer> data = new ArrayList<Integer>();
        format.setRuntimeContext((RuntimeContext)new MockStreamingRuntimeContext(false, 1, 0));
        format.openInputFormat();
        format.open(split);
        while (!format.reachedEnd()) {
            data.add((Integer)format.nextRecord((Object)0));
        }
        format.close();
        format.closeInputFormat();
        data.sort(Comparator.comparingInt(id -> id));
        return data;
    }

    private OperatorSubtaskState createOperatorSubtaskState(OneInputStreamOperator<Integer, Void> operator) throws Exception {
        try (KeyedOneInputStreamOperatorTestHarness testHarness = new KeyedOneInputStreamOperatorTestHarness(operator, (KeySelector & Serializable)id -> id, Types.INT, 128, 1, 0);){
            testHarness.setup((TypeSerializer)VoidSerializer.INSTANCE);
            testHarness.open();
            testHarness.processElement((Object)1, 0L);
            testHarness.processElement((Object)2, 0L);
            testHarness.processElement((Object)3, 0L);
            OperatorSubtaskState operatorSubtaskState = testHarness.snapshot(0L, 0L);
            return operatorSubtaskState;
        }
    }

    static class TimerReaderFunction
    extends KeyedStateReaderFunction<Integer, Integer> {
        ValueState<Integer> state;

        TimerReaderFunction() {
        }

        public void open(Configuration parameters) {
            this.state = this.getRuntimeContext().getState(stateDescriptor);
        }

        public void readKey(Integer key, KeyedStateReaderFunction.Context ctx, Collector<Integer> out) throws Exception {
            Set eventTimers = ctx.registeredEventTimeTimers();
            Assert.assertEquals((String)("Each key should have exactly one event timer for key " + key), (long)1L, (long)eventTimers.size());
            out.collect((Object)((Long)eventTimers.iterator().next()).intValue());
            Set procTimers = ctx.registeredProcessingTimeTimers();
            Assert.assertEquals((String)("Each key should have exactly one processing timer for key " + key), (long)1L, (long)procTimers.size());
            out.collect((Object)((Long)procTimers.iterator().next()).intValue());
        }
    }

    static class StatefulFunctionWithTime
    extends KeyedProcessFunction<Integer, Integer, Void> {
        ValueState<Integer> state;

        StatefulFunctionWithTime() {
        }

        public void open(Configuration parameters) {
            this.state = this.getRuntimeContext().getState(stateDescriptor);
        }

        public void processElement(Integer value, KeyedProcessFunction.Context ctx, Collector<Void> out) throws Exception {
            this.state.update((Object)value);
            ctx.timerService().registerEventTimeTimer((long)value.intValue());
            ctx.timerService().registerProcessingTimeTimer((long)value.intValue());
        }
    }

    static class StatefulFunction
    extends RichFlatMapFunction<Integer, Void> {
        ValueState<Integer> state;

        StatefulFunction() {
        }

        public void open(Configuration parameters) {
            this.state = this.getRuntimeContext().getState(stateDescriptor);
        }

        public void flatMap(Integer value, Collector<Void> out) throws Exception {
            this.state.update((Object)value);
        }
    }

    static class InvalidReaderFunction
    extends KeyedStateReaderFunction<Integer, Integer> {
        InvalidReaderFunction() {
        }

        public void open(Configuration parameters) {
            this.getRuntimeContext().getState(stateDescriptor);
        }

        public void readKey(Integer key, KeyedStateReaderFunction.Context ctx, Collector<Integer> out) throws Exception {
            ValueState state = this.getRuntimeContext().getState(stateDescriptor);
            out.collect(state.value());
        }
    }

    static class DoubleReaderFunction
    extends KeyedStateReaderFunction<Integer, Integer> {
        ValueState<Integer> state;

        DoubleReaderFunction() {
        }

        public void open(Configuration parameters) {
            this.state = this.getRuntimeContext().getState(stateDescriptor);
        }

        public void readKey(Integer key, KeyedStateReaderFunction.Context ctx, Collector<Integer> out) throws Exception {
            out.collect(this.state.value());
            out.collect(this.state.value());
        }
    }

    static class ReaderFunction
    extends KeyedStateReaderFunction<Integer, Integer> {
        ValueState<Integer> state;

        ReaderFunction() {
        }

        public void open(Configuration parameters) {
            this.state = this.getRuntimeContext().getState(stateDescriptor);
        }

        public void readKey(Integer key, KeyedStateReaderFunction.Context ctx, Collector<Integer> out) throws Exception {
            out.collect(this.state.value());
        }
    }
}

