package org.apache.flink.state.api;

import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/state/api/SavepointReaderITTestBase.class */
public abstract class SavepointReaderITTestBase extends AbstractTestBase {
    static final String UID = "stateful-operator";
    static final String LIST_NAME = "list";
    static final String UNION_NAME = "union";
    static final String BROADCAST_NAME = "broadcast";
    private final ListStateDescriptor<Integer> list;
    private final ListStateDescriptor<Integer> union;
    private final MapStateDescriptor<Integer, String> broadcast;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/state/api/SavepointReaderITTestBase$SavepointSource.class */
    public static class SavepointSource implements SourceFunction<Integer> {
        private static volatile boolean finished;
        private volatile boolean running;
        private static final Integer[] elements = {1, 2, 3};

        private SavepointSource() {
            this.running = true;
        }

        public void run(SourceFunction.SourceContext<Integer> sourceContext) {
            synchronized (sourceContext.getCheckpointLock()) {
                for (Integer num : elements) {
                    sourceContext.collect(num);
                }
                finished = true;
            }
            while (this.running) {
                try {
                    Thread.sleep(100L);
                } catch (InterruptedException e) {
                }
            }
        }

        public void cancel() {
            this.running = false;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static void initializeForTest() {
            finished = false;
        }

        private static boolean isFinished() {
            return finished;
        }

        private static List<Integer> getElements() {
            return Arrays.asList(elements);
        }

        static /* synthetic */ List access$200() {
            return getElements();
        }

        static /* synthetic */ boolean access$400() {
            return isFinished();
        }
    }

    /* loaded from: input_file:org/apache/flink/state/api/SavepointReaderITTestBase$StatefulOperator.class */
    private static class StatefulOperator extends BroadcastProcessFunction<Integer, Integer, Void> implements CheckpointedFunction {
        private final ListStateDescriptor<Integer> list;
        private final ListStateDescriptor<Integer> union;
        private final MapStateDescriptor<Integer, String> broadcast;
        private List<Integer> elements;
        private ListState<Integer> listState;
        private ListState<Integer> unionState;

        private StatefulOperator(ListStateDescriptor<Integer> listStateDescriptor, ListStateDescriptor<Integer> listStateDescriptor2, MapStateDescriptor<Integer, String> mapStateDescriptor) {
            this.list = listStateDescriptor;
            this.union = listStateDescriptor2;
            this.broadcast = mapStateDescriptor;
        }

        public void open(Configuration configuration) {
            this.elements = new ArrayList();
        }

        public void processElement(Integer num, BroadcastProcessFunction<Integer, Integer, Void>.ReadOnlyContext readOnlyContext, Collector<Void> collector) {
            this.elements.add(num);
        }

        public void processBroadcastElement(Integer num, BroadcastProcessFunction<Integer, Integer, Void>.Context context, Collector<Void> collector) throws Exception {
            context.getBroadcastState(this.broadcast).put(num, num.toString());
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
            this.listState.clear();
            this.listState.addAll(this.elements);
            this.unionState.clear();
            this.unionState.addAll(this.elements);
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.listState = functionInitializationContext.getOperatorStateStore().getListState(this.list);
            this.unionState = functionInitializationContext.getOperatorStateStore().getUnionListState(this.union);
        }

        public /* bridge */ /* synthetic */ void processBroadcastElement(Object obj, BroadcastProcessFunction.Context context, Collector collector) throws Exception {
            processBroadcastElement((Integer) obj, (BroadcastProcessFunction<Integer, Integer, Void>.Context) context, (Collector<Void>) collector);
        }

        public /* bridge */ /* synthetic */ void processElement(Object obj, BroadcastProcessFunction.ReadOnlyContext readOnlyContext, Collector collector) throws Exception {
            processElement((Integer) obj, (BroadcastProcessFunction<Integer, Integer, Void>.ReadOnlyContext) readOnlyContext, (Collector<Void>) collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SavepointReaderITTestBase(ListStateDescriptor<Integer> listStateDescriptor, ListStateDescriptor<Integer> listStateDescriptor2, MapStateDescriptor<Integer, String> mapStateDescriptor) {
        this.list = listStateDescriptor;
        this.union = listStateDescriptor2;
        this.broadcast = mapStateDescriptor;
    }

    @Test
    public void testOperatorStateInputFormat() throws Exception {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(4);
        DataStream rebalance = executionEnvironment.addSource(new SavepointSource()).rebalance();
        rebalance.connect(rebalance.broadcast(new MapStateDescriptor[]{this.broadcast})).process(new StatefulOperator(this.list, this.union, this.broadcast)).uid(UID).addSink(new DiscardingSink());
        String takeSavepoint = takeSavepoint(executionEnvironment.getStreamGraph().getJobGraph());
        ExecutionEnvironment executionEnvironment2 = ExecutionEnvironment.getExecutionEnvironment();
        verifyListState(takeSavepoint, executionEnvironment2);
        verifyUnionState(takeSavepoint, executionEnvironment2);
        verifyBroadcastState(takeSavepoint, executionEnvironment2);
    }

    abstract DataSet<Integer> readListState(ExistingSavepoint existingSavepoint) throws IOException;

    abstract DataSet<Integer> readUnionState(ExistingSavepoint existingSavepoint) throws IOException;

    abstract DataSet<Tuple2<Integer, String>> readBroadcastState(ExistingSavepoint existingSavepoint) throws IOException;

    private void verifyListState(String str, ExecutionEnvironment executionEnvironment) throws Exception {
        List collect = readListState(Savepoint.load(executionEnvironment, str, new MemoryStateBackend())).collect();
        collect.sort(Comparator.naturalOrder());
        Assert.assertEquals("Unexpected elements read from list state", SavepointSource.access$200(), collect);
    }

    private void verifyUnionState(String str, ExecutionEnvironment executionEnvironment) throws Exception {
        List collect = readUnionState(Savepoint.load(executionEnvironment, str, new MemoryStateBackend())).collect();
        collect.sort(Comparator.naturalOrder());
        Assert.assertEquals("Unexpected elements read from union state", SavepointSource.access$200(), collect);
    }

    private void verifyBroadcastState(String str, ExecutionEnvironment executionEnvironment) throws Exception {
        List collect = readBroadcastState(Savepoint.load(executionEnvironment, str, new MemoryStateBackend())).collect();
        List list = (List) collect.stream().map(tuple2 -> {
            return (Integer) tuple2.f0;
        }).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
        List list2 = (List) collect.stream().map(tuple22 -> {
            return (String) tuple22.f1;
        }).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
        Assert.assertEquals("Unexpected element in broadcast state keys", SavepointSource.access$200(), list);
        Assert.assertEquals("Unexpected element in broadcast state values", SavepointSource.access$200().stream().map((v0) -> {
            return v0.toString();
        }).sorted().collect(Collectors.toList()), list2);
    }

    private String takeSavepoint(JobGraph jobGraph) throws Exception {
        SavepointSource.initializeForTest();
        ClusterClient clusterClient = miniClusterResource.getClusterClient();
        JobID jobID = jobGraph.getJobID();
        Deadline fromNow = Deadline.fromNow(Duration.ofMinutes(5L));
        String tempDirPath = getTempDirPath(new AbstractID().toHexString());
        try {
            JobID jobID2 = (JobID) clusterClient.submitJob(jobGraph).get();
            boolean z = false;
            while (true) {
                if (!fromNow.hasTimeLeft()) {
                    break;
                }
                if (SavepointSource.access$400()) {
                    z = true;
                    break;
                }
                try {
                    Thread.sleep(2L);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
            }
            if (!z) {
                Assert.fail("Failed to initialize state within deadline");
            }
            String str = (String) clusterClient.triggerSavepoint(jobID2, tempDirPath).get(fromNow.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
            clusterClient.cancel(jobID).get();
            return str;
        } catch (Throwable th) {
            clusterClient.cancel(jobID).get();
            throw th;
        }
    }
}
