package org.apache.flink.ml.common.datastream;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Collection;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.NumberSequenceIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Enclosed.class)
/* loaded from: input_file:org/apache/flink/ml/common/datastream/AllReduceImplTest.class */
public class AllReduceImplTest {
    private static final int parallelism = 4;
    private static final int chunkSize = 4096;
    private static final double TOLERANCE = 1.0E-7d;

    /* loaded from: input_file:org/apache/flink/ml/common/datastream/AllReduceImplTest$NonParameterizedTest.class */
    public static class NonParameterizedTest {
        private StreamExecutionEnvironment env;

        @Before
        public void before() {
            this.env = TestUtils.getExecutionEnvironment();
        }

        @Test
        public void testAllReduceWithMoreThanOneArray() {
            try {
                DataStreamUtils.allReduceSum(this.env.fromParallelCollection(new NumberSequenceIterator(1L, 4L), BasicTypeInfo.LONG_TYPE_INFO).flatMap(new FlatMapFunction<Long, double[]>() { // from class: org.apache.flink.ml.common.datastream.AllReduceImplTest.NonParameterizedTest.1
                    public void flatMap(Long l, Collector<double[]> collector) {
                        collector.collect(new double[100]);
                        collector.collect(new double[100]);
                    }

                    public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                        flatMap((Long) obj, (Collector<double[]>) collector);
                    }
                })).addSink(new SinkFunction<double[]>() { // from class: org.apache.flink.ml.common.datastream.AllReduceImplTest.NonParameterizedTest.2
                });
                this.env.execute();
                Assert.fail();
            } catch (Exception e) {
                Assert.assertEquals("The input cannot contain more than one double array.", e.getCause().getCause().getMessage());
            }
        }

        @Test
        public void testAllReduceWithDifferentLength() {
            try {
                DataStreamUtils.allReduceSum(this.env.fromParallelCollection(new NumberSequenceIterator(1L, 4L), BasicTypeInfo.LONG_TYPE_INFO).map(l -> {
                    return new double[l.intValue()];
                })).addSink(new SinkFunction<double[]>() { // from class: org.apache.flink.ml.common.datastream.AllReduceImplTest.NonParameterizedTest.3
                });
                this.env.execute();
                Assert.fail();
            } catch (Exception e) {
                Assert.assertEquals("The input double array must have same length.", e.getCause().getCause().getMessage());
            }
        }

        @Test
        public void testAllReduceWithEmptyInput() throws Exception {
            Assert.assertFalse(DataStreamUtils.allReduceSum(this.env.fromParallelCollection(new NumberSequenceIterator(1L, 4L), BasicTypeInfo.LONG_TYPE_INFO).flatMap((l, collector) -> {
            }).returns(PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)).executeAndCollect().hasNext());
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case -996348112:
                    if (implMethodName.equals("lambda$testAllReduceWithEmptyInput$393bc85e$1")) {
                        z = true;
                        break;
                    }
                    break;
                case 1346786099:
                    if (implMethodName.equals("lambda$testAllReduceWithDifferentLength$bb6d68d7$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/AllReduceImplTest$NonParameterizedTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;)[D")) {
                        return l -> {
                            return new double[l.intValue()];
                        };
                    }
                    break;
                case true:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/FlatMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("flatMap") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Lorg/apache/flink/util/Collector;)V") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/AllReduceImplTest$NonParameterizedTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;Lorg/apache/flink/util/Collector;)V")) {
                        return (l2, collector) -> {
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    @RunWith(Parameterized.class)
    /* loaded from: input_file:org/apache/flink/ml/common/datastream/AllReduceImplTest$ParameterizedTest.class */
    public static class ParameterizedTest {
        private static int numElements;
        private StreamExecutionEnvironment env;

        @Before
        public void before() {
            this.env = TestUtils.getExecutionEnvironment();
        }

        @Parameterized.Parameters
        public static Collection<Object[]> params() {
            return Arrays.asList(new Object[]{0}, new Object[]{2048}, new Object[]{8192}, new Object[]{24576});
        }

        public ParameterizedTest(int i) {
            numElements = i;
        }

        @Test
        public void testAllReduce() throws Exception {
            DataStreamUtils.allReduceSum(this.env.fromParallelCollection(new NumberSequenceIterator(1L, 4L), BasicTypeInfo.LONG_TYPE_INFO).map(l -> {
                double[] dArr = new double[numElements];
                for (int i = 0; i < dArr.length; i++) {
                    dArr[i] = i;
                }
                return dArr;
            })).addSink(new SinkFunction<double[]>() { // from class: org.apache.flink.ml.common.datastream.AllReduceImplTest.ParameterizedTest.1
                public void invoke(double[] dArr, SinkFunction.Context context) {
                    Assert.assertEquals(ParameterizedTest.numElements, dArr.length);
                    for (int i = 0; i < dArr.length; i++) {
                        Assert.assertEquals(i * AllReduceImplTest.parallelism, dArr[i], AllReduceImplTest.TOLERANCE);
                    }
                }
            });
            this.env.execute();
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case 1002409692:
                    if (implMethodName.equals("lambda$testAllReduce$bb6d68d7$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/AllReduceImplTest$ParameterizedTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;)[D")) {
                        return l -> {
                            double[] dArr = new double[numElements];
                            for (int i = 0; i < dArr.length; i++) {
                                dArr[i] = i;
                            }
                            return dArr;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }
}
