package org.apache.beam.runners.flink;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.UUID;
import org.apache.beam.runners.flink.FlinkStreamingPipelineTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.ShardedKey;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslatorTest.class */
public class FlinkStreamingPipelineTranslatorTest {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslatorTest$StatefulNoopDoFn.class */
    public static class StatefulNoopDoFn<KeyT, ValueT> extends DoFn<KV<KeyT, ValueT>, Void> {

        @DoFn.TimerId("my-timer")
        private final TimerSpec myTimer;

        private StatefulNoopDoFn() {
            this.myTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);
        }

        @DoFn.ProcessElement
        public void processElement() {
        }

        @DoFn.OnTimer("my-timer")
        public void onMyTimer() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslatorTest$StatelessIdentityDoFn.class */
    public static class StatelessIdentityDoFn<KeyT, ValueT> extends DoFn<KV<KeyT, ValueT>, KV<KeyT, ValueT>> {
        private StatelessIdentityDoFn() {
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<KV<KeyT, ValueT>, KV<KeyT, ValueT>>.ProcessContext processContext) {
            processContext.output((KV) processContext.element());
        }
    }

    @Test
    public void testAutoBalanceShardKeyResolvesMaxParallelism() {
        MatcherAssert.assertThat(Integer.valueOf(new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(3, -1, StringUtf8Coder.of()).getMaxParallelism()), Matchers.equalTo(Integer.valueOf(KeyGroupRangeAssignment.computeDefaultMaxParallelism(3))));
        MatcherAssert.assertThat(Integer.valueOf(new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(3, 0, StringUtf8Coder.of()).getMaxParallelism()), Matchers.equalTo(Integer.valueOf(KeyGroupRangeAssignment.computeDefaultMaxParallelism(3))));
    }

    @Test
    public void testAutoBalanceShardKeyCacheIsNotSerialized() throws Exception {
        FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction flinkAutoBalancedShardKeyShardingFunction = new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(2, 2, StringUtf8Coder.of());
        Assert.assertNull(flinkAutoBalancedShardKeyShardingFunction.getCache());
        flinkAutoBalancedShardKeyShardingFunction.assignShardKey("target/destination1", "one", 10);
        flinkAutoBalancedShardKeyShardingFunction.assignShardKey("target/destination2", "two", 10);
        MatcherAssert.assertThat(Integer.valueOf(flinkAutoBalancedShardKeyShardingFunction.getCache().size()), Matchers.equalTo(2));
        MatcherAssert.assertThat(SerializableUtils.clone(flinkAutoBalancedShardKeyShardingFunction).getCache(), Matchers.nullValue());
    }

    @Test
    public void testAutoBalanceShardKeyCacheIsStable() throws Exception {
        FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction flinkAutoBalancedShardKeyShardingFunction = new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(50 / 2, 50 * 2, StringUtf8Coder.of());
        ArrayList<KV> newArrayList = Lists.newArrayList();
        for (int i = 0; i < 50 * 100; i++) {
            newArrayList.add(KV.of("target/destination/1", UUID.randomUUID().toString()));
            newArrayList.add(KV.of("target/destination/2", UUID.randomUUID().toString()));
            newArrayList.add(KV.of("target/destination/3", UUID.randomUUID().toString()));
        }
        HashMap hashMap = new HashMap();
        for (KV kv : newArrayList) {
            ShardedKey assignShardKey = flinkAutoBalancedShardKeyShardingFunction.assignShardKey((String) kv.getKey(), (String) kv.getValue(), 50);
            hashMap.put(KV.of((String) kv.getKey(), Integer.valueOf(assignShardKey.getShardNumber())), assignShardKey);
        }
        FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction flinkAutoBalancedShardKeyShardingFunction2 = new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(50 / 2, 50 * 2, StringUtf8Coder.of());
        Collections.shuffle(newArrayList);
        for (KV kv2 : newArrayList) {
            ShardedKey assignShardKey2 = flinkAutoBalancedShardKeyShardingFunction2.assignShardKey((String) kv2.getKey(), (String) kv2.getValue(), 50);
            ShardedKey shardedKey = (ShardedKey) hashMap.get(KV.of((String) kv2.getKey(), Integer.valueOf(assignShardKey2.getShardNumber())));
            if (shardedKey != null) {
                MatcherAssert.assertThat(assignShardKey2, Matchers.equalTo(shardedKey));
            }
        }
    }

    @Test
    public void testAutoBalanceShardKeyCacheMaxSize() throws Exception {
        FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction flinkAutoBalancedShardKeyShardingFunction = new FlinkStreamingPipelineTranslator.FlinkAutoBalancedShardKeyShardingFunction(2, 2, StringUtf8Coder.of());
        for (int i = 0; i < 200; i++) {
            flinkAutoBalancedShardKeyShardingFunction.assignShardKey(UUID.randomUUID().toString(), "one", 2);
        }
        MatcherAssert.assertThat(Integer.valueOf(flinkAutoBalancedShardKeyShardingFunction.getCache().size()), Matchers.equalTo(100));
    }

    @Test
    public void testStatefulParDoAfterCombineChaining() {
        Assert.assertEquals(1L, Iterables.size(getStatefulParDoAfterCombineChainingJobGraph(false).getVertices()) - Iterables.size(getStatefulParDoAfterCombineChainingJobGraph(true).getVertices()));
    }

    private JobGraph getStatefulParDoAfterCombineChainingJobGraph(boolean z) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        FlinkStreamingPipelineTranslator flinkStreamingPipelineTranslator = new FlinkStreamingPipelineTranslator(executionEnvironment, PipelineOptionsFactory.create(), true);
        PipelineOptions create = PipelineOptionsFactory.create();
        create.setRunner(FlinkRunner.class);
        Pipeline create2 = Pipeline.create(create);
        PCollection apply = create2.apply(Create.of("foo", new String[]{"bar"}).withCoder(StringUtf8Coder.of())).apply(Count.perElement());
        if (!z) {
            apply = (PCollection) apply.apply(ParDo.of(new StatelessIdentityDoFn()));
        }
        apply.apply(ParDo.of(new StatefulNoopDoFn()));
        flinkStreamingPipelineTranslator.translate(create2);
        return executionEnvironment.getStreamGraph().getJobGraph();
    }

    @Test
    public void testStatefulParDoAfterGroupByKeyChaining() {
        Assert.assertEquals(1L, Iterables.size(getStatefulParDoAfterGroupByKeyChainingJobGraph(false).getVertices()) - Iterables.size(getStatefulParDoAfterGroupByKeyChainingJobGraph(true).getVertices()));
    }

    private JobGraph getStatefulParDoAfterGroupByKeyChainingJobGraph(boolean z) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        FlinkStreamingPipelineTranslator flinkStreamingPipelineTranslator = new FlinkStreamingPipelineTranslator(executionEnvironment, PipelineOptionsFactory.create(), true);
        PipelineOptions create = PipelineOptionsFactory.create();
        create.setRunner(FlinkRunner.class);
        Pipeline create2 = Pipeline.create(create);
        PCollection apply = create2.apply(Create.of(KV.of("foo", 1L), new KV[]{KV.of("bar", 1L)}).withCoder(KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of()))).apply(GroupByKey.create());
        if (!z) {
            apply = (PCollection) apply.apply(ParDo.of(new StatelessIdentityDoFn()));
        }
        apply.apply(ParDo.of(new StatefulNoopDoFn()));
        flinkStreamingPipelineTranslator.translate(create2);
        return executionEnvironment.getStreamGraph().getJobGraph();
    }
}
