package org.apache.flink.state.api;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
import org.apache.flink.state.api.functions.StateBootstrapFunction;
import org.apache.flink.state.api.runtime.OperatorIDGenerator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/state/api/BootstrapTransformationTest.class */
public class BootstrapTransformationTest extends AbstractTestBase {

    /* loaded from: input_file:org/apache/flink/state/api/BootstrapTransformationTest$CustomKeySelector.class */
    private static class CustomKeySelector implements KeySelector<String, String> {
        private CustomKeySelector() {
        }

        public String getKey(String str) throws Exception {
            return str;
        }
    }

    /* loaded from: input_file:org/apache/flink/state/api/BootstrapTransformationTest$ExampleBroadcastStateBootstrapFunction.class */
    private static class ExampleBroadcastStateBootstrapFunction extends BroadcastStateBootstrapFunction<Integer> {
        private ExampleBroadcastStateBootstrapFunction() {
        }

        public void processElement(Integer num, BroadcastStateBootstrapFunction.Context context) throws Exception {
        }
    }

    /* loaded from: input_file:org/apache/flink/state/api/BootstrapTransformationTest$ExampleKeyedStateBootstrapFunction.class */
    private static class ExampleKeyedStateBootstrapFunction extends KeyedStateBootstrapFunction<String, String> {
        private ExampleKeyedStateBootstrapFunction() {
        }

        public void processElement(String str, KeyedStateBootstrapFunction<String, String>.Context context) throws Exception {
        }

        public /* bridge */ /* synthetic */ void processElement(Object obj, KeyedStateBootstrapFunction.Context context) throws Exception {
            processElement((String) obj, (KeyedStateBootstrapFunction<String, String>.Context) context);
        }
    }

    /* loaded from: input_file:org/apache/flink/state/api/BootstrapTransformationTest$ExampleStateBootstrapFunction.class */
    private static class ExampleStateBootstrapFunction extends StateBootstrapFunction<Integer> {
        private ExampleStateBootstrapFunction() {
        }

        public void processElement(Integer num, StateBootstrapFunction.Context context) throws Exception {
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
        }
    }

    @Test
    public void testBroadcastStateTransformationParallelism() {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(10);
        BootstrapTransformation transform = OperatorTransformation.bootstrapWith(executionEnvironment.fromElements(new Integer[]{0})).transform(new ExampleBroadcastStateBootstrapFunction());
        Assert.assertEquals("Broadcast transformations should always be run at parallelism 1", 1L, getParallelism(transform.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid("uid"), new MemoryStateBackend(), new Path(), transform.getMaxParallelism(4))));
    }

    @Test
    public void testDefaultParallelismRespectedWhenLessThanMaxParallelism() {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(4);
        BootstrapTransformation transform = OperatorTransformation.bootstrapWith(executionEnvironment.fromElements(new Integer[]{0})).transform(new ExampleStateBootstrapFunction());
        Assert.assertEquals("The parallelism of a data set should not change when less than the max parallelism of the savepoint", -1L, getParallelism(transform.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid("uid"), new MemoryStateBackend(), new Path(), transform.getMaxParallelism(10))));
    }

    @Test
    public void testMaxParallelismRespected() {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(10);
        BootstrapTransformation transform = OperatorTransformation.bootstrapWith(executionEnvironment.fromElements(new Integer[]{0})).transform(new ExampleStateBootstrapFunction());
        Assert.assertEquals("The parallelism of a data set should be constrained my the savepoint max parallelism", 4L, getParallelism(transform.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid("uid"), new MemoryStateBackend(), new Path(), transform.getMaxParallelism(4))));
    }

    @Test
    public void testOperatorSpecificMaxParallelismRespected() {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(4);
        BootstrapTransformation transform = OperatorTransformation.bootstrapWith(executionEnvironment.fromElements(new Integer[]{0})).setMaxParallelism(1).transform(new ExampleStateBootstrapFunction());
        Assert.assertEquals("The parallelism of a data set should be constrained my the savepoint max parallelism", 1L, getParallelism(transform.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid("uid"), new MemoryStateBackend(), new Path(), transform.getMaxParallelism(4))));
    }

    @Test
    public void testStreamConfig() {
        Assert.assertEquals("Incorrect key selector forwarded to stream operator", CustomKeySelector.class, OperatorTransformation.bootstrapWith(ExecutionEnvironment.getExecutionEnvironment().fromElements(new String[]{""})).keyBy(new CustomKeySelector()).transform(new ExampleKeyedStateBootstrapFunction()).getConfig(OperatorIDGenerator.fromUid("uid"), new MemoryStateBackend(), (StreamOperator) null).getStatePartitioner(0, Thread.currentThread().getContextClassLoader()).getClass());
    }

    private static <T> int getParallelism(DataSet<T> dataSet) {
        return ((Operator) dataSet).getParallelism();
    }
}
