package org.apache.beam.runners.flink.batch;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.UUID;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.FlinkTestPipeline;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;

/* loaded from: input_file:org/apache/beam/runners/flink/batch/ReshuffleTest.class */
public class ReshuffleTest {

    @Rule
    public transient Timeout globalTimeout = Timeout.seconds(1200);

    /* loaded from: input_file:org/apache/beam/runners/flink/batch/ReshuffleTest$WithBundleIdFn.class */
    private static class WithBundleIdFn extends DoFn<String, String> {
        private String uuid;

        private WithBundleIdFn() {
        }

        @DoFn.StartBundle
        public void startBundle() {
            this.uuid = UUID.randomUUID().toString();
        }

        @DoFn.ProcessElement
        public void processElement(DoFn<String, String>.ProcessContext processContext) {
            processContext.output(((String) processContext.element()) + "@" + this.uuid);
        }
    }

    @Test
    public void testEqualDistributionOnReshuffleAcrossMultipleStages() {
        FlinkTestPipeline createForBatch = FlinkTestPipeline.createForBatch();
        createForBatch.getOptions().as(FlinkPipelineOptions.class).setParallelism(3);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 10000; i++) {
            arrayList.add("el_" + i);
        }
        PAssert.that(createForBatch.apply(Create.of(arrayList)).apply(ParDo.of(new WithBundleIdFn())).apply(Reshuffle.viaRandomKey()).apply(ParDo.of(new WithBundleIdFn())).apply(Reshuffle.viaRandomKey()).apply(ParDo.of(new WithBundleIdFn()))).satisfies(iterable -> {
            HashMap hashMap = new HashMap();
            Iterator it = iterable.iterator();
            while (it.hasNext()) {
                String[] split = ((String) it.next()).split("@");
                Assert.assertEquals(4L, split.length);
                hashMap.merge(String.join("->", (CharSequence[]) Arrays.copyOfRange(split, 1, 3)), 1, (v0, v1) -> {
                    return Integer.sum(v0, v1);
                });
                hashMap.merge(String.join("->", (CharSequence[]) Arrays.copyOfRange(split, 2, 4)), 1, (v0, v1) -> {
                    return Integer.sum(v0, v1);
                });
            }
            Assert.assertEquals(20000L, hashMap.values().stream().mapToInt(num -> {
                return num.intValue();
            }).sum());
            Assert.assertEquals(18L, hashMap.size());
            return null;
        });
        createForBatch.run().waitUntilFinish();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1870413166:
                if (implMethodName.equals("lambda$testEqualDistributionOnReshuffleAcrossMultipleStages$43268ee4$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/beam/sdk/transforms/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/beam/runners/flink/batch/ReshuffleTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Iterable;)Ljava/lang/Void;")) {
                    return iterable -> {
                        HashMap hashMap = new HashMap();
                        Iterator it = iterable.iterator();
                        while (it.hasNext()) {
                            String[] split = ((String) it.next()).split("@");
                            Assert.assertEquals(4L, split.length);
                            hashMap.merge(String.join("->", (CharSequence[]) Arrays.copyOfRange(split, 1, 3)), 1, (v0, v1) -> {
                                return Integer.sum(v0, v1);
                            });
                            hashMap.merge(String.join("->", (CharSequence[]) Arrays.copyOfRange(split, 2, 4)), 1, (v0, v1) -> {
                                return Integer.sum(v0, v1);
                            });
                        }
                        Assert.assertEquals(20000L, hashMap.values().stream().mapToInt(num -> {
                            return num.intValue();
                        }).sum());
                        Assert.assertEquals(18L, hashMap.size());
                        return null;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
