/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.chaining;

import java.util.ArrayList;
import java.util.List;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.operators.BatchTask;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.FlatMapDriver;
import org.apache.flink.runtime.operators.FlatMapTaskTest;
import org.apache.flink.runtime.operators.chaining.ChainedAllReduceDriver;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.testutils.TaskTestBase;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.taskmanager.Task;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;

@RunWith(value=PowerMockRunner.class)
@PrepareForTest(value={Task.class, ResultPartitionWriter.class})
public class ChainedAllReduceDriverTest
extends TaskTestBase {
    private static final int MEMORY_MANAGER_SIZE = 0x300000;
    private static final int NETWORK_BUFFER_SIZE = 1024;
    private final List<Record> outList = new ArrayList<Record>();
    private final RecordComparatorFactory compFact = new RecordComparatorFactory(new int[]{0}, new Class[]{IntValue.class}, new boolean[]{true});
    private final RecordSerializerFactory serFact = RecordSerializerFactory.get();

    @Test
    public void testMapTask() {
        int keyCnt = 100;
        int valCnt = 20;
        double memoryFraction = 1.0;
        try {
            this.initEnvironment(0x300000L, 1024);
            this.mockEnv.getExecutionConfig().enableObjectReuse();
            this.addInput(new UniformRecordGenerator(100, 20, false), 0);
            this.addOutput(this.outList);
            TaskConfig reduceConfig = new TaskConfig(new Configuration());
            reduceConfig.addInputToGroup(0);
            reduceConfig.setInputSerializer((TypeSerializerFactory)this.serFact, 0);
            reduceConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
            reduceConfig.setOutputSerializer((TypeSerializerFactory)this.serFact);
            reduceConfig.setDriverStrategy(DriverStrategy.ALL_REDUCE);
            reduceConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 0);
            reduceConfig.setDriverComparator((TypeComparatorFactory)this.compFact, 1);
            reduceConfig.setRelativeMemoryDriver(1.0);
            reduceConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(MockReduceStub.class));
            this.getTaskConfig().addChainedTask(ChainedAllReduceDriver.class, reduceConfig, "reduce");
            BatchTask testTask = new BatchTask();
            this.registerTask((AbstractInvokable)testTask, FlatMapDriver.class, FlatMapTaskTest.MockMapStub.class);
            try {
                testTask.invoke();
            }
            catch (Exception e) {
                e.printStackTrace();
                Assert.fail((String)"Invoke method caused exception.");
            }
            int sumTotal = 99000;
            Assert.assertEquals((long)1L, (long)this.outList.size());
            Assert.assertEquals((long)sumTotal, (long)((IntValue)this.outList.get(0).getField(0, IntValue.class)).getValue());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    private static class MockReduceStub
    implements ReduceFunction<Record> {
        private static final long serialVersionUID = 1047525105526690165L;

        private MockReduceStub() {
        }

        public Record reduce(Record value1, Record value2) throws Exception {
            IntValue v1 = (IntValue)value1.getField(0, IntValue.class);
            IntValue v2 = (IntValue)value2.getField(0, IntValue.class);
            v1.setValue(v1.getValue() + v2.getValue());
            value1.setField(0, (Value)v1);
            value1.updateBinaryRepresenation();
            return value1;
        }
    }
}

