/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.arrow.sources;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.testutils.MultiShotLatch;
import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamSource;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.arrow.ArrowUtils;
import org.apache.flink.table.runtime.arrow.ArrowWriter;
import org.apache.flink.table.runtime.arrow.sources.ArrowSourceFunction;
import org.apache.flink.testutils.CustomEqualityMatcher;
import org.apache.flink.testutils.DeeplyEqualsChecker;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

abstract class ArrowSourceFunctionTestBase {
    final VectorSchemaRoot root;
    private final TypeSerializer<RowData> typeSerializer;
    private final Comparator<RowData> comparator;
    private final DeeplyEqualsChecker checker;

    ArrowSourceFunctionTestBase(VectorSchemaRoot root, TypeSerializer<RowData> typeSerializer, Comparator<RowData> comparator) {
        this(root, typeSerializer, comparator, new DeeplyEqualsChecker());
    }

    ArrowSourceFunctionTestBase(VectorSchemaRoot root, TypeSerializer<RowData> typeSerializer, Comparator<RowData> comparator, DeeplyEqualsChecker checker) {
        this.root = (VectorSchemaRoot)Preconditions.checkNotNull((Object)root);
        this.typeSerializer = (TypeSerializer)Preconditions.checkNotNull(typeSerializer);
        this.comparator = (Comparator)Preconditions.checkNotNull(comparator);
        this.checker = (DeeplyEqualsChecker)Preconditions.checkNotNull((Object)checker);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    void testRestore() throws Exception {
        OperatorSubtaskState snapshot;
        final Tuple2<List<RowData>, Integer> testData = this.getTestData();
        ArrowSourceFunction arrowSourceFunction = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource((SourceFunction)arrowSourceFunction), 1, 1, 0);
        testHarness.open();
        Throwable[] error = new Throwable[1];
        final MultiShotLatch latch = new MultiShotLatch();
        final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
        final ArrayList<RowData> results = new ArrayList<RowData>();
        DummySourceContext<RowData> sourceContext = new DummySourceContext<RowData>(){

            public void collect(RowData element) {
                if (numOfEmittedElements.get() == 2) {
                    latch.trigger();
                    throw new RuntimeException("Fail the arrow source");
                }
                results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy((Object)element));
                numOfEmittedElements.incrementAndGet();
            }
        };
        Thread runner = new Thread(() -> {
            block2: {
                try {
                    arrowSourceFunction.run((SourceFunction.SourceContext)sourceContext);
                }
                catch (Throwable t) {
                    if (t.getMessage().equals("Fail the arrow source")) break block2;
                    error[0] = t;
                }
            }
        });
        runner.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        Object object = sourceContext.getCheckpointLock();
        synchronized (object) {
            snapshot = testHarness.snapshot(0L, 0L);
        }
        runner.join();
        testHarness.close();
        ArrowSourceFunction arrowSourceFunction2 = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarnessCopy = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource((SourceFunction)arrowSourceFunction2), 1, 1, 0);
        testHarnessCopy.initializeState(snapshot);
        testHarnessCopy.open();
        Thread runner2 = new Thread(() -> {
            try {
                arrowSourceFunction2.run((SourceFunction.SourceContext)new DummySourceContext<RowData>(){

                    public void collect(RowData element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy((Object)element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[0] = t;
            }
        });
        runner2.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        runner2.join();
        Assertions.assertThat((Throwable)error[0]).isNull();
        Assertions.assertThat((List)((List)testData.f0)).hasSize(numOfEmittedElements.get());
        this.checkElementsEquals(results, (List)testData.f0);
    }

    @Test
    void testParallelProcessing() throws Exception {
        final Tuple2<List<RowData>, Integer> testData = this.getTestData();
        ArrowSourceFunction arrowSourceFunction = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource((SourceFunction)arrowSourceFunction), 2, 2, 0);
        testHarness.open();
        Throwable[] error = new Throwable[2];
        final OneShotLatch latch = new OneShotLatch();
        final AtomicInteger numOfEmittedElements = new AtomicInteger(0);
        final List<RowData> results = Collections.synchronizedList(new ArrayList());
        Thread runner = new Thread(() -> {
            try {
                arrowSourceFunction.run((SourceFunction.SourceContext)new DummySourceContext<RowData>(){

                    public void collect(RowData element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy((Object)element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[0] = t;
            }
        });
        runner.start();
        ArrowSourceFunction arrowSourceFunction2 = this.createTestArrowSourceFunction((List)testData.f0, (Integer)testData.f1);
        AbstractStreamOperatorTestHarness testHarness2 = new AbstractStreamOperatorTestHarness((StreamOperator)new StreamSource((SourceFunction)arrowSourceFunction2), 2, 2, 1);
        testHarness2.open();
        Thread runner2 = new Thread(() -> {
            try {
                arrowSourceFunction2.run((SourceFunction.SourceContext)new DummySourceContext<RowData>(){

                    public void collect(RowData element) {
                        results.add(ArrowSourceFunctionTestBase.this.typeSerializer.copy((Object)element));
                        if (numOfEmittedElements.incrementAndGet() == ((List)testData.f0).size()) {
                            latch.trigger();
                        }
                    }
                });
            }
            catch (Throwable t) {
                error[1] = t;
            }
        });
        runner2.start();
        if (!latch.isTriggered()) {
            latch.await();
        }
        runner.join();
        runner2.join();
        testHarness.close();
        testHarness2.close();
        Assertions.assertThat((Throwable)error[0]).isNull();
        Assertions.assertThat((Throwable)error[1]).isNull();
        Assertions.assertThat((List)((List)testData.f0)).hasSize(numOfEmittedElements.get());
        this.checkElementsEquals(results, (List)testData.f0);
    }

    public abstract Tuple2<List<RowData>, Integer> getTestData();

    public abstract ArrowWriter<RowData> createArrowWriter();

    public abstract ArrowSourceFunction createArrowSourceFunction(byte[][] var1);

    private void checkElementsEquals(List<RowData> actual, List<RowData> expected) {
        Assertions.assertThat(actual).hasSize(expected.size());
        actual.sort(this.comparator);
        expected.sort(this.comparator);
        for (int i = 0; i < expected.size(); ++i) {
            Assertions.assertThat((Object)actual.get(i)).matches((Predicate)CustomEqualityMatcher.deeplyEquals((Object)expected.get(i)).withChecker(this.checker));
        }
    }

    private ArrowSourceFunction createTestArrowSourceFunction(List<RowData> testData, int batches) throws IOException {
        ArrowWriter<RowData> arrowWriter = this.createArrowWriter();
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(this.root, null, (OutputStream)baos);
        arrowStreamWriter.start();
        List subLists = Lists.partition(testData, (int)(testData.size() / batches + 1));
        for (List subList : subLists) {
            for (RowData value : subList) {
                arrowWriter.write((Object)value);
            }
            arrowWriter.finish();
            arrowStreamWriter.writeBatch();
            arrowWriter.reset();
        }
        ArrowSourceFunction arrowSourceFunction = this.createArrowSourceFunction(ArrowUtils.readArrowBatches((ReadableByteChannel)Channels.newChannel(new ByteArrayInputStream(baos.toByteArray()))));
        arrowSourceFunction.setRuntimeContext((RuntimeContext)new MockStreamingRuntimeContext(false, 1, 0));
        return arrowSourceFunction;
    }

    private static abstract class DummySourceContext<T>
    implements SourceFunction.SourceContext<T> {
        private final Object lock = new Object();

        private DummySourceContext() {
        }

        public void collectWithTimestamp(T element, long timestamp) {
        }

        public void emitWatermark(Watermark mark) {
        }

        public void markAsTemporarilyIdle() {
        }

        public Object getCheckpointLock() {
            return this.lock;
        }

        public void close() {
        }
    }
}

