/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.checkpointing;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.FlinkVersion;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.BroadcastStream;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.source.legacy.SourceFunction;
import org.apache.flink.streaming.util.RestartStrategyUtils;
import org.apache.flink.streaming.util.StateBackendUtils;
import org.apache.flink.test.checkpointing.utils.MigrationTestUtils;
import org.apache.flink.test.checkpointing.utils.SnapshotMigrationTestBase;
import org.apache.flink.test.util.MigrationTest;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class StatefulJobWBroadcastStateMigrationITCase
extends SnapshotMigrationTestBase
implements MigrationTest {
    private static final int NUM_SOURCE_ELEMENTS = 4;
    private final SnapshotMigrationTestBase.SnapshotSpec snapshotSpec;

    @Parameterized.Parameters(name="Test snapshot: {0}")
    public static Collection<SnapshotMigrationTestBase.SnapshotSpec> createSpecsForTestRuns() {
        return StatefulJobWBroadcastStateMigrationITCase.internalParameters(null);
    }

    public static Collection<SnapshotMigrationTestBase.SnapshotSpec> createSpecsForTestDataGeneration(FlinkVersion targetVersion) {
        return StatefulJobWBroadcastStateMigrationITCase.internalParameters(targetVersion);
    }

    private static Collection<SnapshotMigrationTestBase.SnapshotSpec> internalParameters(@Nullable FlinkVersion targetGeneratingVersion) {
        BiFunction<FlinkVersion, FlinkVersion, Collection> getFlinkVersions = (minInclVersion, maxInclVersion) -> {
            if (targetGeneratingVersion != null) {
                return FlinkVersion.rangeOf((FlinkVersion)minInclVersion, (FlinkVersion)maxInclVersion).stream().filter(v -> v.equals((Object)targetGeneratingVersion)).collect(Collectors.toList());
            }
            return FlinkVersion.rangeOf((FlinkVersion)minInclVersion, (FlinkVersion)maxInclVersion);
        };
        LinkedList<SnapshotMigrationTestBase.SnapshotSpec> parameters = new LinkedList<SnapshotMigrationTestBase.SnapshotSpec>();
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("hashmap", SnapshotMigrationTestBase.SnapshotType.SAVEPOINT_CANONICAL, getFlinkVersions.apply(FlinkVersion.v1_15, MigrationTest.getMostRecentlyPublishedVersion())));
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("rocksdb", SnapshotMigrationTestBase.SnapshotType.SAVEPOINT_CANONICAL, getFlinkVersions.apply(FlinkVersion.v1_8, MigrationTest.getMostRecentlyPublishedVersion())));
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("hashmap", SnapshotMigrationTestBase.SnapshotType.SAVEPOINT_NATIVE, getFlinkVersions.apply(FlinkVersion.v1_15, MigrationTest.getMostRecentlyPublishedVersion())));
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("rocksdb", SnapshotMigrationTestBase.SnapshotType.SAVEPOINT_NATIVE, getFlinkVersions.apply(FlinkVersion.v1_15, MigrationTest.getMostRecentlyPublishedVersion())));
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("hashmap", SnapshotMigrationTestBase.SnapshotType.CHECKPOINT, getFlinkVersions.apply(FlinkVersion.v1_15, MigrationTest.getMostRecentlyPublishedVersion())));
        parameters.addAll(SnapshotMigrationTestBase.SnapshotSpec.withVersions("rocksdb", SnapshotMigrationTestBase.SnapshotType.CHECKPOINT, getFlinkVersions.apply(FlinkVersion.v1_15, MigrationTest.getMostRecentlyPublishedVersion())));
        return parameters;
    }

    public StatefulJobWBroadcastStateMigrationITCase(SnapshotMigrationTestBase.SnapshotSpec snapshotSpec) throws Exception {
        this.snapshotSpec = snapshotSpec;
    }

    @MigrationTest.ParameterizedSnapshotsGenerator(value="createSpecsForTestDataGeneration")
    public void generateSnapshots(SnapshotMigrationTestBase.SnapshotSpec snapshotSpec) throws Exception {
        this.testOrCreateSavepoint(SnapshotMigrationTestBase.ExecutionMode.CREATE_SNAPSHOT, snapshotSpec);
    }

    @Test
    public void testSavepoint() throws Exception {
        this.testOrCreateSavepoint(SnapshotMigrationTestBase.ExecutionMode.VERIFY_SNAPSHOT, this.snapshotSpec);
    }

    private void testOrCreateSavepoint(SnapshotMigrationTestBase.ExecutionMode executionMode, SnapshotMigrationTestBase.SnapshotSpec snapshotSpec) throws Exception {
        KeyedBroadcastProcessFunction secondBroadcastFunction;
        KeyedBroadcastProcessFunction firstBroadcastFunction;
        CheckpointedFunction parallelSourceB;
        CheckpointedFunction parallelSource;
        CheckpointedFunction nonParallelSourceB;
        CheckpointedFunction nonParallelSource;
        int parallelism = 4;
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        RestartStrategyUtils.configureNoRestartStrategy((StreamExecutionEnvironment)env);
        switch (snapshotSpec.getStateBackendType()) {
            case "rocksdb": {
                StateBackendUtils.configureRocksDBStateBackend((StreamExecutionEnvironment)env);
                if (executionMode != SnapshotMigrationTestBase.ExecutionMode.CREATE_SNAPSHOT) break;
                env.enableChangelogStateBackend(false);
                break;
            }
            case "hashmap": {
                StateBackendUtils.configureHashMapStateBackend((StreamExecutionEnvironment)env);
                break;
            }
            default: {
                throw new UnsupportedOperationException();
            }
        }
        env.enableCheckpointing(500L);
        env.setParallelism(4);
        env.setMaxParallelism(4);
        HashMap<Long, Long> expectedFirstState = new HashMap<Long, Long>();
        expectedFirstState.put(0L, 0L);
        expectedFirstState.put(1L, 1L);
        expectedFirstState.put(2L, 2L);
        expectedFirstState.put(3L, 3L);
        HashMap<String, Long> expectedSecondState = new HashMap<String, Long>();
        expectedSecondState.put("0", 0L);
        expectedSecondState.put("1", 1L);
        expectedSecondState.put("2", 2L);
        expectedSecondState.put("3", 3L);
        HashMap<Long, String> expectedThirdState = new HashMap<Long, String>();
        expectedThirdState.put(0L, "0");
        expectedThirdState.put(1L, "1");
        expectedThirdState.put(2L, "2");
        expectedThirdState.put(3L, "3");
        if (executionMode == SnapshotMigrationTestBase.ExecutionMode.CREATE_SNAPSHOT) {
            nonParallelSource = new MigrationTestUtils.CheckpointingNonParallelSourceWithListState(4);
            nonParallelSourceB = new MigrationTestUtils.CheckpointingNonParallelSourceWithListState(4);
            parallelSource = new MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(4);
            parallelSourceB = new MigrationTestUtils.CheckpointingParallelSourceWithUnionListState(4);
            firstBroadcastFunction = new CheckpointingKeyedBroadcastFunction();
            secondBroadcastFunction = new CheckpointingKeyedSingleBroadcastFunction();
        } else if (executionMode == SnapshotMigrationTestBase.ExecutionMode.VERIFY_SNAPSHOT) {
            nonParallelSource = new MigrationTestUtils.CheckingNonParallelSourceWithListState(4);
            nonParallelSourceB = new MigrationTestUtils.CheckingNonParallelSourceWithListState(4);
            parallelSource = new MigrationTestUtils.CheckingParallelSourceWithUnionListState(4);
            parallelSourceB = new MigrationTestUtils.CheckingParallelSourceWithUnionListState(4);
            firstBroadcastFunction = new CheckingKeyedBroadcastFunction(expectedFirstState, expectedSecondState);
            secondBroadcastFunction = new CheckingKeyedSingleBroadcastFunction(expectedThirdState);
        } else {
            throw new IllegalStateException("Unknown ExecutionMode " + executionMode);
        }
        KeyedStream npStream = env.addSource((SourceFunction)nonParallelSource).uid("CheckpointingSource1").keyBy((KeySelector)new KeySelector<Tuple2<Long, Long>, Long>(){
            private static final long serialVersionUID = -4514793867774977152L;

            public Long getKey(Tuple2<Long, Long> value) throws Exception {
                return (Long)value.f0;
            }
        });
        KeyedStream pStream = env.addSource((SourceFunction)parallelSource).uid("CheckpointingSource2").keyBy((KeySelector)new KeySelector<Tuple2<Long, Long>, Long>(){
            private static final long serialVersionUID = 4940496713319948104L;

            public Long getKey(Tuple2<Long, Long> value) throws Exception {
                return (Long)value.f0;
            }
        });
        MapStateDescriptor firstBroadcastStateDesc = new MapStateDescriptor("broadcast-state-1", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
        MapStateDescriptor secondBroadcastStateDesc = new MapStateDescriptor("broadcast-state-2", (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
        MapStateDescriptor thirdBroadcastStateDesc = new MapStateDescriptor("broadcast-state-3", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO);
        BroadcastStream npBroadcastStream = env.addSource((SourceFunction)nonParallelSourceB).uid("BrCheckpointingSource1").broadcast(new MapStateDescriptor[]{firstBroadcastStateDesc, secondBroadcastStateDesc});
        BroadcastStream pBroadcastStream = env.addSource((SourceFunction)parallelSourceB).uid("BrCheckpointingSource2").broadcast(new MapStateDescriptor[]{thirdBroadcastStateDesc});
        npStream.connect(npBroadcastStream).process(firstBroadcastFunction).uid("BrProcess1").addSink(new MigrationTestUtils.AccumulatorCountingSink());
        pStream.connect(pBroadcastStream).process(secondBroadcastFunction).uid("BrProcess2").addSink(new MigrationTestUtils.AccumulatorCountingSink());
        if (executionMode == SnapshotMigrationTestBase.ExecutionMode.CREATE_SNAPSHOT) {
            this.executeAndSnapshot(env, "src/test/resources/" + this.getSnapshotPath(snapshotSpec), snapshotSpec.getSnapshotType(), new Tuple2((Object)MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, (Object)8));
        } else {
            this.restoreAndExecute(env, StatefulJobWBroadcastStateMigrationITCase.getResourceFilename(this.getSnapshotPath(snapshotSpec)), new Tuple2((Object)MigrationTestUtils.CheckingNonParallelSourceWithListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, (Object)2), new Tuple2((Object)MigrationTestUtils.CheckingParallelSourceWithUnionListState.SUCCESSFUL_RESTORE_CHECK_ACCUMULATOR, (Object)8), new Tuple2((Object)MigrationTestUtils.AccumulatorCountingSink.NUM_ELEMENTS_ACCUMULATOR, (Object)8));
        }
    }

    private String getSnapshotPath(SnapshotMigrationTestBase.SnapshotSpec snapshotSpec) {
        return "new-stateful-broadcast-udf-migration-itcase-" + snapshotSpec;
    }

    private static class CheckingKeyedSingleBroadcastFunction
    extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1333992081671604521L;
        private final Map<Long, String> expectedState;
        private MapStateDescriptor<Long, String> stateDesc;

        CheckingKeyedSingleBroadcastFunction(Map<Long, String> state) {
            this.expectedState = state;
        }

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            this.stateDesc = new MapStateDescriptor("broadcast-state-3", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO);
        }

        public void processElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            HashMap<Long, String> actualState = new HashMap<Long, String>();
            for (Map.Entry entry : ctx.getBroadcastState(this.stateDesc).immutableEntries()) {
                actualState.put((Long)entry.getKey(), (String)entry.getValue());
            }
            Assert.assertEquals(this.expectedState, actualState);
            out.collect(value);
        }

        public void processBroadcastElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
        }
    }

    private static class CheckingKeyedBroadcastFunction
    extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1333992081671604521L;
        private final Map<Long, Long> expectedFirstState;
        private final Map<String, Long> expectedSecondState;
        private MapStateDescriptor<Long, Long> firstStateDesc;
        private MapStateDescriptor<String, Long> secondStateDesc;

        CheckingKeyedBroadcastFunction(Map<Long, Long> firstState, Map<String, Long> secondState) {
            this.expectedFirstState = firstState;
            this.expectedSecondState = secondState;
        }

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            this.firstStateDesc = new MapStateDescriptor("broadcast-state-1", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
            this.secondStateDesc = new MapStateDescriptor("broadcast-state-2", (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
        }

        public void processElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            HashMap<Long, Long> actualFirstState = new HashMap<Long, Long>();
            for (Map.Entry entry : ctx.getBroadcastState(this.firstStateDesc).immutableEntries()) {
                actualFirstState.put((Long)entry.getKey(), (Long)entry.getValue());
            }
            Assert.assertEquals(this.expectedFirstState, actualFirstState);
            HashMap<String, Long> actualSecondState = new HashMap<String, Long>();
            for (Map.Entry entry : ctx.getBroadcastState(this.secondStateDesc).immutableEntries()) {
                actualSecondState.put((String)entry.getKey(), (Long)entry.getValue());
            }
            Assert.assertEquals(this.expectedSecondState, actualSecondState);
            out.collect(value);
        }

        public void processBroadcastElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
        }
    }

    private static class CheckpointingKeyedSingleBroadcastFunction
    extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1333992081671604521L;
        private MapStateDescriptor<Long, String> stateDesc;

        private CheckpointingKeyedSingleBroadcastFunction() {
        }

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            this.stateDesc = new MapStateDescriptor("broadcast-state-3", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO);
        }

        public void processElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            out.collect(value);
        }

        public void processBroadcastElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            ctx.getBroadcastState(this.stateDesc).put((Object)((Long)value.f0), (Object)Long.toString((Long)value.f1));
        }
    }

    private static class CheckpointingKeyedBroadcastFunction
    extends KeyedBroadcastProcessFunction<Long, Tuple2<Long, Long>, Tuple2<Long, Long>, Tuple2<Long, Long>> {
        private static final long serialVersionUID = 1333992081671604521L;
        private MapStateDescriptor<Long, Long> firstStateDesc;
        private MapStateDescriptor<String, Long> secondStateDesc;

        private CheckpointingKeyedBroadcastFunction() {
        }

        public void open(OpenContext openContext) throws Exception {
            super.open(openContext);
            this.firstStateDesc = new MapStateDescriptor("broadcast-state-1", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
            this.secondStateDesc = new MapStateDescriptor("broadcast-state-2", (TypeInformation)BasicTypeInfo.STRING_TYPE_INFO, (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO);
        }

        public void processElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.ReadOnlyContext ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            out.collect(value);
        }

        public void processBroadcastElement(Tuple2<Long, Long> value, KeyedBroadcastProcessFunction.Context ctx, Collector<Tuple2<Long, Long>> out) throws Exception {
            ctx.getBroadcastState(this.firstStateDesc).put((Object)((Long)value.f0), (Object)((Long)value.f1));
            ctx.getBroadcastState(this.secondStateDesc).put((Object)Long.toString((Long)value.f0), (Object)((Long)value.f1));
        }
    }
}

