package org.apache.flink.runtime.operators;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeutils.record.RecordComparator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.runtime.operators.testutils.DelayingInfinitiveInputIterator;
import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector;
import org.apache.flink.runtime.operators.testutils.DriverTestBase;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.TaskCancelThread;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest.class */
public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Record, ?>> {
    private static final long COMBINE_MEM = 3145728;
    private final double combine_frac;
    private final ArrayList<Record> outList;
    private final RecordComparator comparator;

    @ReduceOperator.Combinable
    /* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest$MockCombiningReduceStub.class */
    public static class MockCombiningReduceStub extends RichGroupReduceFunction<Record, Record> {
        private static final long serialVersionUID = 1;
        private final IntValue theInteger = new IntValue();

        @Override // org.apache.flink.api.common.functions.RichGroupReduceFunction, org.apache.flink.api.common.functions.GroupReduceFunction
        public void reduce(Iterable<Record> iterable, Collector<Record> collector) {
            Record record = null;
            int i = 0;
            Iterator<Record> it = iterable.iterator();
            while (it.hasNext()) {
                record = it.next();
                record.getField(1, (int) this.theInteger);
                i += this.theInteger.getValue();
            }
            this.theInteger.setValue(i);
            record.setField(1, this.theInteger);
            collector.collect(record);
        }

        @Override // org.apache.flink.api.common.functions.RichGroupReduceFunction, org.apache.flink.api.common.functions.GroupCombineFunction
        public void combine(Iterable<Record> iterable, Collector<Record> collector) throws Exception {
            reduce(iterable, collector);
        }
    }

    @ReduceOperator.Combinable
    /* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest$MockFailingCombiningReduceStub.class */
    public static final class MockFailingCombiningReduceStub extends RichGroupReduceFunction<Record, Record> {
        private static final long serialVersionUID = 1;
        private int cnt = 0;
        private final IntValue key = new IntValue();
        private final IntValue value = new IntValue();
        private final IntValue combineValue = new IntValue();

        @Override // org.apache.flink.api.common.functions.RichGroupReduceFunction, org.apache.flink.api.common.functions.GroupReduceFunction
        public void reduce(Iterable<Record> iterable, Collector<Record> collector) {
            Record record = null;
            int i = 0;
            Iterator<Record> it = iterable.iterator();
            while (it.hasNext()) {
                record = it.next();
                record.getField(1, (int) this.value);
                i += this.value.getValue();
            }
            record.getField(0, (int) this.key);
            this.value.setValue(i - this.key.getValue());
            record.setField(1, this.value);
            collector.collect(record);
        }

        @Override // org.apache.flink.api.common.functions.RichGroupReduceFunction, org.apache.flink.api.common.functions.GroupCombineFunction
        public void combine(Iterable<Record> iterable, Collector<Record> collector) {
            Record record = null;
            int i = 0;
            Iterator<Record> it = iterable.iterator();
            while (it.hasNext()) {
                record = it.next();
                record.getField(1, (int) this.combineValue);
                i += this.combineValue.getValue();
            }
            int i2 = this.cnt + 1;
            this.cnt = i2;
            if (i2 >= 10) {
                throw new ExpectedTestException();
            }
            this.combineValue.setValue(i);
            record.setField(1, this.combineValue);
            collector.collect(record);
        }
    }

    public CombineTaskTest(ExecutionConfig executionConfig) {
        super(executionConfig, COMBINE_MEM, 0);
        this.outList = new ArrayList<>();
        this.comparator = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
        this.combine_frac = 3145728.0d / getMemoryManager().getMemorySize();
    }

    @Test
    public void testCombineTask() {
        addInput(new UniformRecordGenerator(100, 20, false));
        addDriverComparator(this.comparator);
        addDriverComparator(this.comparator);
        setOutput(this.outList);
        getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
        getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
        getTaskConfig().setFilehandlesDriver(2);
        try {
            testDriver(new GroupReduceCombineDriver(), MockCombiningReduceStub.class);
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail("Invoke method caused exception.");
        }
        int i = 0;
        for (int i2 = 1; i2 < 20; i2++) {
            i += i2;
        }
        Assert.assertTrue("Resultset size was " + this.outList.size() + ". Expected was 100", this.outList.size() == 100);
        Iterator<Record> it = this.outList.iterator();
        while (it.hasNext()) {
            Assert.assertTrue("Incorrect result", ((IntValue) it.next().getField(1, IntValue.class)).getValue() == i);
        }
        this.outList.clear();
    }

    @Test
    public void testFailingCombineTask() {
        addInput(new UniformRecordGenerator(100, 20, false));
        addDriverComparator(this.comparator);
        addDriverComparator(this.comparator);
        setOutput(new DiscardingOutputCollector());
        getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
        getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
        getTaskConfig().setFilehandlesDriver(2);
        try {
            testDriver(new GroupReduceCombineDriver(), MockFailingCombiningReduceStub.class);
            Assert.fail("Exception not forwarded.");
        } catch (ExpectedTestException e) {
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail("Test failed due to an exception.");
        }
    }

    @Test
    public void testCancelCombineTaskSorting() {
        addInput(new DelayingInfinitiveInputIterator(100));
        addDriverComparator(this.comparator);
        addDriverComparator(this.comparator);
        setOutput(new DiscardingOutputCollector());
        getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
        getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
        getTaskConfig().setFilehandlesDriver(2);
        final GroupReduceCombineDriver groupReduceCombineDriver = new GroupReduceCombineDriver();
        final AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        Thread thread = new Thread() { // from class: org.apache.flink.runtime.operators.CombineTaskTest.1
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                try {
                    CombineTaskTest.this.testDriver(groupReduceCombineDriver, MockFailingCombiningReduceStub.class);
                    atomicBoolean.set(true);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        };
        thread.start();
        TaskCancelThread taskCancelThread = new TaskCancelThread(1, thread, this);
        taskCancelThread.start();
        try {
            taskCancelThread.join();
            thread.join();
        } catch (InterruptedException e) {
            Assert.fail("Joining threads failed");
        }
        Assert.assertTrue("Exception was thrown despite proper canceling.", atomicBoolean.get());
    }
}
