package org.apache.flink.ml.common.broadcast;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.flink.api.common.functions.AbstractRichFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.RestOptions;
import org.apache.flink.iteration.config.IterationOptions;
import org.apache.flink.ml.common.broadcast.operator.TestOneInputOp;
import org.apache.flink.ml.common.broadcast.operator.TestTwoInputOp;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
import org.apache.flink.streaming.api.CheckpointingMode;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/ml/common/broadcast/BroadcastUtilsTest.class */
public class BroadcastUtilsTest {

    @Rule
    public TemporaryFolder tempFolder = new TemporaryFolder();
    private static final int NUM_RECORDS_PER_PARTITION = 10;
    private static final int NUM_TM = 2;
    private static final int NUM_SLOT = 2;
    private static final String[] BROADCAST_NAMES = {"source1", "source2"};
    private static final List<Integer> BROADCAST_INPUT = (List) IntStream.range(0, 40).boxed().collect(Collectors.toList());

    private MiniClusterConfiguration createMiniClusterConfiguration() throws IOException {
        Configuration configuration = new Configuration();
        configuration.set(RestOptions.PORT, 18082);
        configuration.set(IterationOptions.DATA_CACHE_PATH, "file://" + this.tempFolder.newFolder().getAbsolutePath());
        configuration.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
        return new MiniClusterConfiguration.Builder().setConfiguration(configuration).setNumTaskManagers(2).setNumSlotsPerTaskManager(2).build();
    }

    @Test
    public void testOneInputGraph() throws Exception {
        MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration());
        try {
            miniCluster.start();
            miniCluster.executeJobBlocking(getJobGraph(1));
            miniCluster.close();
        } catch (Throwable th) {
            try {
                miniCluster.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Test
    public void testTwoInputGraph() throws Exception {
        MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration());
        try {
            miniCluster.start();
            miniCluster.executeJobBlocking(getJobGraph(2));
            miniCluster.close();
        } catch (Throwable th) {
            try {
                miniCluster.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private JobGraph getJobGraph(int i) {
        StreamExecutionEnvironment executionEnvironment = TestUtils.getExecutionEnvironment();
        executionEnvironment.setRestartStrategy(RestartStrategies.fallBackRestart());
        executionEnvironment.enableCheckpointing(500L, CheckpointingMode.EXACTLY_ONCE);
        executionEnvironment.setParallelism(4);
        DataStreamSource addSource = executionEnvironment.addSource(new TestSource(NUM_RECORDS_PER_PARTITION));
        DataStreamSource addSource2 = executionEnvironment.addSource(new TestSource(NUM_RECORDS_PER_PARTITION));
        HashMap hashMap = new HashMap();
        hashMap.put(BROADCAST_NAMES[0], addSource);
        hashMap.put(BROADCAST_NAMES[1], addSource2);
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(addSource);
        for (int i2 = 0; i2 < i - 1; i2++) {
            arrayList.add(executionEnvironment.addSource(new TestSource(NUM_RECORDS_PER_PARTITION)));
        }
        DataStream withBroadcastStream = BroadcastUtils.withBroadcastStream(arrayList, hashMap, getFunc(i));
        ArrayList arrayList2 = new ArrayList(40 * i);
        for (int i3 = 0; i3 < 40; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                arrayList2.add(Integer.valueOf(i3));
            }
        }
        withBroadcastStream.addSink(new TestSink(arrayList2)).setParallelism(1);
        return executionEnvironment.getStreamGraph().getJobGraph();
    }

    private static Function<List<DataStream<?>>, DataStream<Integer>> getFunc(int i) {
        if (i == 1) {
            return list -> {
                return ((DataStream) list.get(0)).transform("one-input", BasicTypeInfo.INT_TYPE_INFO, new TestOneInputOp(new AbstractRichFunction() { // from class: org.apache.flink.ml.common.broadcast.BroadcastUtilsTest.1
                }, BROADCAST_NAMES, Arrays.asList(BROADCAST_INPUT, BROADCAST_INPUT))).name("broadcast");
            };
        }
        if (i == 2) {
            return list2 -> {
                return ((DataStream) list2.get(0)).connect((DataStream) list2.get(1)).transform("two-input", BasicTypeInfo.INT_TYPE_INFO, new TestTwoInputOp(new AbstractRichFunction() { // from class: org.apache.flink.ml.common.broadcast.BroadcastUtilsTest.2
                }, BROADCAST_NAMES, Arrays.asList(BROADCAST_INPUT, BROADCAST_INPUT))).name("broadcast");
            };
        }
        return null;
    }
}
