package org.apache.flink.ml.common.datastream;

import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Collector;
import org.apache.flink.util.NumberSequenceIterator;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/ml/common/datastream/DataStreamUtilsTest.class */
public class DataStreamUtilsTest {
    private StreamExecutionEnvironment env;

    /* loaded from: input_file:org/apache/flink/ml/common/datastream/DataStreamUtilsTest$TestAggregateFunc.class */
    private static class TestAggregateFunc implements AggregateFunction<Long, Long, String> {
        private TestAggregateFunc() {
        }

        @Override // 
        /* renamed from: createAccumulator, reason: merged with bridge method [inline-methods] */
        public Long mo9createAccumulator() {
            return 0L;
        }

        public Long add(Long l, Long l2) {
            return Long.valueOf(l.longValue() + l2.longValue());
        }

        public String getResult(Long l) {
            return String.valueOf(l);
        }

        public Long merge(Long l, Long l2) {
            return Long.valueOf(l.longValue() + l2.longValue());
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/common/datastream/DataStreamUtilsTest$TestAggregateFuncWithNonNeutralInitialAccumulator.class */
    private static class TestAggregateFuncWithNonNeutralInitialAccumulator extends TestAggregateFunc {
        private TestAggregateFuncWithNonNeutralInitialAccumulator() {
            super();
        }

        @Override // org.apache.flink.ml.common.datastream.DataStreamUtilsTest.TestAggregateFunc
        /* renamed from: createAccumulator */
        public Long mo9createAccumulator() {
            return 1L;
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/common/datastream/DataStreamUtilsTest$TestMapPartitionFunc.class */
    private static class TestMapPartitionFunc extends RichMapPartitionFunction<Long, Integer> {
        private TestMapPartitionFunc() {
        }

        public void mapPartition(Iterable<Long> iterable, Collector<Integer> collector) {
            Assert.assertNotNull(getRuntimeContext());
            int i = 0;
            Iterator<Long> it = iterable.iterator();
            while (it.hasNext()) {
                it.next().longValue();
                i++;
            }
            collector.collect(Integer.valueOf(i));
        }
    }

    @Before
    public void before() {
        this.env = TestUtils.getExecutionEnvironment();
    }

    @Test
    public void testCoGroupWithSingleParallelism() throws Exception {
        Assert.assertArrayEquals(new double[]{5.0d, 2.0d, 6.5d, 5.5d}, IteratorUtils.toList(DataStreamUtils.coGroup(this.env.fromCollection(Arrays.asList(Tuple2.of(1, 1), Tuple2.of(2, 2), Tuple2.of(3, 3))), this.env.fromCollection(Arrays.asList(Tuple2.of(1, Double.valueOf(1.5d)), Tuple2.of(5, Double.valueOf(5.5d)), Tuple2.of(3, Double.valueOf(3.5d)), Tuple2.of(1, Double.valueOf(2.5d)))), tuple2 -> {
            return (Integer) tuple2.f0;
        }, tuple22 -> {
            return (Integer) tuple22.f0;
        }, BasicTypeInfo.DOUBLE_TYPE_INFO, new CoGroupFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Double>, Double>() { // from class: org.apache.flink.ml.common.datastream.DataStreamUtilsTest.1
            public void coGroup(Iterable<Tuple2<Integer, Integer>> iterable, Iterable<Tuple2<Integer, Double>> iterable2, Collector<Double> collector) {
                List list = IteratorUtils.toList(iterable.iterator());
                List list2 = IteratorUtils.toList(iterable2.iterator());
                double d = 0.0d;
                while (list.iterator().hasNext()) {
                    d += ((Integer) ((Tuple2) r0.next()).f1).intValue();
                }
                Iterator it = list2.iterator();
                while (it.hasNext()) {
                    d += ((Double) ((Tuple2) it.next()).f1).doubleValue();
                }
                collector.collect(Double.valueOf(d));
            }
        }).executeAndCollect()).stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray(), 1.0E-5d);
    }

    @Test
    public void testCoGroupWithMultiParallelism() throws Exception {
        long[] array = IteratorUtils.toList(DataStreamUtils.coGroup(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 10L), Types.LONG), this.env.fromParallelCollection(new NumberSequenceIterator(6L, 16L), Types.LONG), l -> {
            return Long.valueOf(l.longValue() / 2);
        }, l2 -> {
            return Long.valueOf(l2.longValue() / 2);
        }, BasicTypeInfo.LONG_TYPE_INFO, new CoGroupFunction<Long, Long, Long>() { // from class: org.apache.flink.ml.common.datastream.DataStreamUtilsTest.2
            public void coGroup(Iterable<Long> iterable, Iterable<Long> iterable2, Collector<Long> collector) {
                List list = IteratorUtils.toList(iterable.iterator());
                List list2 = IteratorUtils.toList(iterable2.iterator());
                long j = 0;
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    j += ((Long) it.next()).longValue();
                }
                Iterator it2 = list2.iterator();
                while (it2.hasNext()) {
                    j += ((Long) it2.next()).longValue();
                }
                collector.collect(Long.valueOf(j));
            }
        }).executeAndCollect()).stream().mapToLong((v0) -> {
            return v0.longValue();
        }).toArray();
        Arrays.sort(array);
        Assert.assertArrayEquals(new long[]{1, 5, 9, 16, 25, 26, 29, 31, 34}, array);
    }

    @Test
    public void testMapPartition() throws Exception {
        Assert.assertArrayEquals(new int[]{5, 5, 5, 5}, IteratorUtils.toList(DataStreamUtils.mapPartition(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), new TestMapPartitionFunc()).executeAndCollect()).stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray());
    }

    @Test
    public void testReduce() throws Exception {
        Assert.assertArrayEquals(new long[]{190}, IteratorUtils.toList(DataStreamUtils.reduce(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), (v0, v1) -> {
            return Long.sum(v0, v1);
        }).executeAndCollect()).stream().mapToLong((v0) -> {
            return v0.longValue();
        }).toArray());
    }

    @Test
    public void testAggregate() throws Exception {
        List list = IteratorUtils.toList(DataStreamUtils.aggregate(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), new TestAggregateFunc()).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        Assert.assertEquals("190", list.get(0));
    }

    @Test
    public void testAggregateWithNonNeutralInitialAccumulator() throws Exception {
        List list = IteratorUtils.toList(DataStreamUtils.aggregate(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), new TestAggregateFuncWithNonNeutralInitialAccumulator()).executeAndCollect());
        Assert.assertEquals(1L, list.size());
        Assert.assertEquals(Integer.toString(190 + this.env.getParallelism()), list.get(0));
        this.env.setParallelism(this.env.getParallelism() + 1);
        List list2 = IteratorUtils.toList(DataStreamUtils.aggregate(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), new TestAggregateFuncWithNonNeutralInitialAccumulator()).executeAndCollect());
        Assert.assertEquals(1L, list2.size());
        Assert.assertEquals(Integer.toString(190 + this.env.getParallelism()), list2.get(0));
    }

    @Test
    public void testSample() throws Exception {
        for (int i : new int[]{0, 5, 9, 10, 11, 20, 30, 40, 200}) {
            Assert.assertEquals(Math.min(10, i + 1), IteratorUtils.toList(DataStreamUtils.sample(this.env.fromParallelCollection(new NumberSequenceIterator(0L, i), Types.LONG), 10, 0L).executeAndCollect()).size());
        }
    }

    @Test
    public void testGenerateBatchData() throws Exception {
        Iterator it = IteratorUtils.toList(DataStreamUtils.generateBatchData(this.env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG), 2, 4).executeAndCollect()).iterator();
        while (it.hasNext()) {
            Assert.assertEquals(2L, ((Long[]) it.next()).length);
        }
        Assert.assertEquals(10L, r0.size());
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 114251:
                if (implMethodName.equals("sum")) {
                    z = 4;
                    break;
                }
                break;
            case 316907539:
                if (implMethodName.equals("lambda$testCoGroupWithSingleParallelism$9747a172$1")) {
                    z = false;
                    break;
                }
                break;
            case 316907540:
                if (implMethodName.equals("lambda$testCoGroupWithSingleParallelism$9747a172$2")) {
                    z = 3;
                    break;
                }
                break;
            case 1814096560:
                if (implMethodName.equals("lambda$testCoGroupWithMultiParallelism$9747a172$1")) {
                    z = true;
                    break;
                }
                break;
            case 1814096561:
                if (implMethodName.equals("lambda$testCoGroupWithMultiParallelism$9747a172$2")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/DataStreamUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Ljava/lang/Integer;")) {
                    return tuple2 -> {
                        return (Integer) tuple2.f0;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/DataStreamUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;)Ljava/lang/Long;")) {
                    return l -> {
                        return Long.valueOf(l.longValue() / 2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/DataStreamUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Long;)Ljava/lang/Long;")) {
                    return l2 -> {
                        return Long.valueOf(l2.longValue() / 2);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/common/datastream/DataStreamUtilsTest") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/api/java/tuple/Tuple2;)Ljava/lang/Integer;")) {
                    return tuple22 -> {
                        return (Integer) tuple22.f0;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/ReduceFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("reduce") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("java/lang/Long") && serializedLambda.getImplMethodSignature().equals("(JJ)J")) {
                    return (v0, v1) -> {
                        return Long.sum(v0, v1);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
