/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.arrow.sources;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.testutils.MultiShotLatch;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.table.runtime.arrow.ArrowUtils;
import org.apache.flink.table.runtime.arrow.ArrowWriter;
import org.apache.flink.table.runtime.arrow.sources.AbstractArrowSourceFunction;
import org.apache.flink.testutils.CustomEqualityMatcher;
import org.apache.flink.testutils.DeeplyEqualsChecker;
import org.apache.flink.util.Preconditions;
import org.hamcrest.Matcher;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

public abstract class ArrowSourceFunctionTestBase<T> {
    final VectorSchemaRoot root;
    private final TypeSerializer<T> typeSerializer;
    private final Comparator<T> comparator;
    private final DeeplyEqualsChecker checker;

    ArrowSourceFunctionTestBase(VectorSchemaRoot root, TypeSerializer<T> typeSerializer, Comparator<T> comparator) {
        this(root, typeSerializer, comparator, new DeeplyEqualsChecker());
    }

    ArrowSourceFunctionTestBase(VectorSchemaRoot root, TypeSerializer<T> typeSerializer, Comparator<T> comparator, DeeplyEqualsChecker checker) {
        this.root = (VectorSchemaRoot)Preconditions.checkNotNull((Object)root);
        this.typeSerializer = (TypeSerializer)Preconditions.checkNotNull(typeSerializer);
        this.comparator = (Comparator)Preconditions.checkNotNull(comparator);
        this.checker = (DeeplyEqualsChecker)Preconditions.checkNotNull((Object)checker);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testRestore() throws Exception {
        OperatorSubtaskState snapshot;
        final Tuple2<List<T>, Integer> testData = this.getTestData();
        AbstractArrowSourceFunction arrowSourceFunction = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource(arrowSourceFunction), 1, 1, 0);
        testHarness.open();
        Throwable[] error = new Throwable[1];
        final MultiShotLatch latch = new MultiShotLatch();
        final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
        final ArrayList results = new ArrayList();
        DummySourceContext sourceContext = new DummySourceContext<T>(){

            public void collect(T element) {
                if (numOfEmittedElements.get() == 2) {
                    latch.trigger();
                    throw new RuntimeException("Fail the arrow source");
                }
                results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy(element));
                numOfEmittedElements.incrementAndGet();
            }
        };
        Thread runner = new Thread(() -> {
            block2: {
                try {
                    arrowSourceFunction.run((SourceFunction.SourceContext)sourceContext);
                }
                catch (Throwable t) {
                    if (t.getMessage().equals("Fail the arrow source")) break block2;
                    error[0] = t;
                }
            }
        });
        runner.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        Object object = sourceContext.getCheckpointLock();
        synchronized (object) {
            snapshot = testHarness.snapshot(0L, 0L);
        }
        runner.join();
        testHarness.close();
        AbstractArrowSourceFunction<T> arrowSourceFunction2 = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarnessCopy = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource(arrowSourceFunction2), 1, 1, 0);
        testHarnessCopy.initializeState(snapshot);
        testHarnessCopy.open();
        Thread runner2 = new Thread(() -> {
            try {
                arrowSourceFunction2.run((SourceFunction.SourceContext)new DummySourceContext<T>(){

                    public void collect(T element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy(element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[0] = t;
            }
        });
        runner2.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        runner2.join();
        Assert.assertNull((Object)error[0]);
        Assert.assertEquals((long)((List)testData.f0).size(), (long)numOfEmittedElements.get());
        this.checkElementsEquals(results, (List)testData.f0);
    }

    @Test
    public void testParallelProcessing() throws Exception {
        final Tuple2<List<T>, Integer> testData = this.getTestData();
        AbstractArrowSourceFunction<T> arrowSourceFunction = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource(arrowSourceFunction), 2, 2, 0);
        testHarness.open();
        Throwable[] error = new Throwable[2];
        final OneShotLatch latch = new OneShotLatch();
        final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
        final List results = Collections.synchronizedList(new ArrayList());
        Thread runner = new Thread(() -> {
            try {
                arrowSourceFunction.run((SourceFunction.SourceContext)new DummySourceContext<T>(){

                    public void collect(T element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy(element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[0] = t;
            }
        });
        runner.start();
        AbstractArrowSourceFunction<T> arrowSourceFunction2 = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness2 = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource(arrowSourceFunction2), 2, 2, 1);
        testHarness2.open();
        Thread runner2 = new Thread(() -> {
            try {
                arrowSourceFunction2.run((SourceFunction.SourceContext)new DummySourceContext<T>(){

                    public void collect(T element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy(element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[1] = t;
            }
        });
        runner2.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        runner.join();
        runner2.join();
        testHarness.close();
        testHarness2.close();
        Assert.assertNull((Object)error[0]);
        Assert.assertNull((Object)error[1]);
        Assert.assertEquals((long)((List)testData.f0).size(), (long)numOfEmittedElements.get());
        this.checkElementsEquals(results, (List)testData.f0);
    }

    public abstract Tuple2<List<T>, Integer> getTestData();

    public abstract ArrowWriter<T> createArrowWriter();

    public abstract AbstractArrowSourceFunction<T> createArrowSourceFunction(byte[][] var1);

    private void checkElementsEquals(List<T> actual, List<T> expected) {
        Assert.assertEquals((long)actual.size(), (long)expected.size());
        actual.sort(this.comparator);
        expected.sort(this.comparator);
        for (int i = 0; i < expected.size(); ++i) {
            Assert.assertThat(actual.get(i), (Matcher)CustomEqualityMatcher.deeplyEquals(expected.get(i)).withChecker(this.checker));
        }
    }

    private AbstractArrowSourceFunction<T> createTestArrowSourceFunction(List<T> testData, int batches) throws IOException {
        ArrowWriter<T> arrowWriter = this.createArrowWriter();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(this.root, null, (OutputStream)baos);
        arrowStreamWriter.start();
        List subLists = Lists.partition(testData, (int)(testData.size() / batches + 1));
        for (List subList : subLists) {
            for (Object value : subList) {
                arrowWriter.write(value);
            }
            arrowWriter.finish();
            arrowStreamWriter.writeBatch();
            arrowWriter.reset();
        }
        AbstractArrowSourceFunction<T> arrowSourceFunction = this.createArrowSourceFunction(ArrowUtils.readArrowBatches((ReadableByteChannel)Channels.newChannel(new ByteArrayInputStream(baos.toByteArray()))));
        arrowSourceFunction.setRuntimeContext((RuntimeContext)Mockito.mock(RuntimeContext.class));
        return arrowSourceFunction;
    }

    private static abstract class DummySourceContext<T>
    implements SourceFunction.SourceContext<T> {
        private final Object lock = new Object();

        private DummySourceContext() {
        }

        public void collectWithTimestamp(T element, long timestamp) {
        }

        public void emitWatermark(Watermark mark) {
        }

        public void markAsTemporarilyIdle() {
        }

        public Object getCheckpointLock() {
            return this.lock;
        }

        public void close() {
        }
    }
}

