/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api;

import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.MapPartitionOperator;
import org.apache.flink.api.java.operators.Operator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.state.api.BootstrapTransformation;
import org.apache.flink.state.api.OperatorTransformation;
import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
import org.apache.flink.state.api.functions.StateBootstrapFunction;
import org.apache.flink.state.api.runtime.OperatorIDGenerator;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.test.util.AbstractTestBase;
import org.junit.Assert;
import org.junit.Test;

public class BootstrapTransformationTest
extends AbstractTestBase {
    @Test
    public void testBroadcastStateTransformationParallelism() {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(10);
        DataSource input = env.fromElements((Object[])new Integer[]{0});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)input).transform((BroadcastStateBootstrapFunction)new ExampleBroadcastStateBootstrapFunction());
        int maxParallelism = transformation.getMaxParallelism(4);
        MapPartitionOperator result = transformation.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid((String)"uid"), (StateBackend)new MemoryStateBackend(), new Path(), maxParallelism);
        Assert.assertEquals((String)"Broadcast transformations should always be run at parallelism 1", (long)1L, (long)BootstrapTransformationTest.getParallelism(result));
    }

    @Test
    public void testDefaultParallelismRespectedWhenLessThanMaxParallelism() {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        DataSource input = env.fromElements((Object[])new Integer[]{0});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)input).transform((StateBootstrapFunction)new ExampleStateBootstrapFunction());
        int maxParallelism = transformation.getMaxParallelism(10);
        MapPartitionOperator result = transformation.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid((String)"uid"), (StateBackend)new MemoryStateBackend(), new Path(), maxParallelism);
        Assert.assertEquals((String)"The parallelism of a data set should not change when less than the max parallelism of the savepoint", (long)-1L, (long)BootstrapTransformationTest.getParallelism(result));
    }

    @Test
    public void testMaxParallelismRespected() {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(10);
        DataSource input = env.fromElements((Object[])new Integer[]{0});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)input).transform((StateBootstrapFunction)new ExampleStateBootstrapFunction());
        int maxParallelism = transformation.getMaxParallelism(4);
        MapPartitionOperator result = transformation.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid((String)"uid"), (StateBackend)new MemoryStateBackend(), new Path(), maxParallelism);
        Assert.assertEquals((String)"The parallelism of a data set should be constrained my the savepoint max parallelism", (long)4L, (long)BootstrapTransformationTest.getParallelism(result));
    }

    @Test
    public void testOperatorSpecificMaxParallelismRespected() {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        DataSource input = env.fromElements((Object[])new Integer[]{0});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)input).setMaxParallelism(1).transform((StateBootstrapFunction)new ExampleStateBootstrapFunction());
        int maxParallelism = transformation.getMaxParallelism(4);
        MapPartitionOperator result = transformation.writeOperatorSubtaskStates(OperatorIDGenerator.fromUid((String)"uid"), (StateBackend)new MemoryStateBackend(), new Path(), maxParallelism);
        Assert.assertEquals((String)"The parallelism of a data set should be constrained my the savepoint max parallelism", (long)1L, (long)BootstrapTransformationTest.getParallelism(result));
    }

    @Test
    public void testStreamConfig() {
        ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource input = env.fromElements((Object[])new String[]{""});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)input).keyBy((KeySelector)new CustomKeySelector()).transform((KeyedStateBootstrapFunction)new ExampleKeyedStateBootstrapFunction());
        StreamConfig config = transformation.getConfig(OperatorIDGenerator.fromUid((String)"uid"), (StateBackend)new MemoryStateBackend(), new Configuration(), null);
        KeySelector selector = config.getStatePartitioner(0, Thread.currentThread().getContextClassLoader());
        Assert.assertEquals((String)"Incorrect key selector forwarded to stream operator", CustomKeySelector.class, selector.getClass());
    }

    private static <T> int getParallelism(DataSet<T> dataSet) {
        return ((Operator)dataSet).getParallelism();
    }

    private static class ExampleKeyedStateBootstrapFunction
    extends KeyedStateBootstrapFunction<String, String> {
        private ExampleKeyedStateBootstrapFunction() {
        }

        public void processElement(String value, KeyedStateBootstrapFunction.Context ctx) throws Exception {
        }
    }

    private static class ExampleStateBootstrapFunction
    extends StateBootstrapFunction<Integer> {
        private ExampleStateBootstrapFunction() {
        }

        public void processElement(Integer value, StateBootstrapFunction.Context ctx) throws Exception {
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
        }
    }

    private static class ExampleBroadcastStateBootstrapFunction
    extends BroadcastStateBootstrapFunction<Integer> {
        private ExampleBroadcastStateBootstrapFunction() {
        }

        public void processElement(Integer value, BroadcastStateBootstrapFunction.Context ctx) throws Exception {
        }
    }

    private static class CustomKeySelector
    implements KeySelector<String, String> {
        private CustomKeySelector() {
        }

        public String getKey(String value) throws Exception {
            return value;
        }
    }
}

