/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.flink.batch;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.UUID;
import org.apache.beam.runners.flink.FlinkPipelineOptions;
import org.apache.beam.runners.flink.FlinkTestPipeline;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.junit.Assert;
import org.junit.Test;

public class ReshuffleTest {
    @Test
    public void testEqualDistributionOnReshuffleAcrossMultipleStages() {
        int numElements = 10000;
        int parallelism = 3;
        int numReshuffles = 2;
        FlinkTestPipeline p = FlinkTestPipeline.createForBatch();
        ((FlinkPipelineOptions)p.getOptions().as(FlinkPipelineOptions.class)).setParallelism(Integer.valueOf(3));
        ArrayList<String> input = new ArrayList<String>();
        for (int i = 0; i < 10000; ++i) {
            input.add("el_" + i);
        }
        PCollection result = (PCollection)((PCollection)((PCollection)((PCollection)((PCollection)((PCollection)p.apply((PTransform)Create.of(input))).apply((PTransform)ParDo.of((DoFn)new WithBundleIdFn()))).apply((PTransform)Reshuffle.viaRandomKey())).apply((PTransform)ParDo.of((DoFn)new WithBundleIdFn()))).apply((PTransform)Reshuffle.viaRandomKey())).apply((PTransform)ParDo.of((DoFn)new WithBundleIdFn()));
        PAssert.that((PCollection)result).satisfies((SerializableFunction & Serializable)it -> {
            HashMap<String, Integer> histo = new HashMap<String, Integer>();
            for (String item : it) {
                String[] parts = item.split("@");
                Assert.assertEquals((long)4L, (long)parts.length);
                histo.merge(String.join((CharSequence)"->", Arrays.copyOfRange(parts, 1, 3)), 1, Integer::sum);
                histo.merge(String.join((CharSequence)"->", Arrays.copyOfRange(parts, 2, 4)), 1, Integer::sum);
            }
            Assert.assertEquals((long)20000L, (long)histo.values().stream().mapToInt(v -> v).sum());
            Assert.assertEquals((long)18L, (long)histo.size());
            return null;
        });
        p.run().waitUntilFinish();
    }

    private static class WithBundleIdFn
    extends DoFn<String, String> {
        private @UnknownKeyFor @NonNull @Initialized String uuid;

        private WithBundleIdFn() {
        }

        @DoFn.StartBundle
        public void startBundle() {
            this.uuid = UUID.randomUUID().toString();
        }

        @DoFn.ProcessElement
        public void processElement(/*
         * Issues handling annotations - annotations may be inaccurate
         */
        // Could not load outer class - annotation placement on inner may be incorrect
        @UnknownKeyFor @UnknownKeyFor @UnknownKeyFor @NonNull @Initialized @NonNull @Initialized @NonNull @Initialized DoFn. @UnknownKeyFor @NonNull @Initialized ProcessContext ctx) {
            ctx.output((Object)((String)ctx.element() + "@" + this.uuid));
        }
    }
}

