package org.apache.flink.table.runtime.runners.python.scalar.arrow;

import java.util.Collections;
import java.util.HashMap;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.arrow.ArrowUtils;
import org.apache.flink.table.runtime.arrow.ArrowWriter;
import org.apache.flink.table.runtime.arrow.writers.ArrowFieldWriter;
import org.apache.flink.table.runtime.arrow.writers.RowBigIntWriter;
import org.apache.flink.table.runtime.runners.python.scalar.AbstractPythonScalarFunctionRunnerTest;
import org.apache.flink.table.runtime.utils.PassThroughArrowPythonScalarFunctionRunner;
import org.apache.flink.table.runtime.utils.PythonTestUtils;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/table/runtime/runners/python/scalar/arrow/ArrowPythonScalarFunctionRunnerTest.class */
public class ArrowPythonScalarFunctionRunnerTest extends AbstractPythonScalarFunctionRunnerTest<Row> {
    @Test
    public void testArrowWriterConstructedProperlyForSingleUDF() throws Exception {
        AbstractArrowPythonScalarFunctionRunner createSingleUDFRunner = createSingleUDFRunner();
        createSingleUDFRunner.open();
        ArrowFieldWriter[] fieldWriters = createSingleUDFRunner.arrowWriter.getFieldWriters();
        Assert.assertEquals(1L, fieldWriters.length);
        Assert.assertTrue(fieldWriters[0] instanceof RowBigIntWriter);
    }

    @Test
    public void testArrowWriterConstructedProperlyForMultipleUDFs() throws Exception {
        AbstractArrowPythonScalarFunctionRunner createMultipleUDFRunner = createMultipleUDFRunner();
        createMultipleUDFRunner.open();
        ArrowFieldWriter[] fieldWriters = createMultipleUDFRunner.arrowWriter.getFieldWriters();
        Assert.assertEquals(3L, fieldWriters.length);
        Assert.assertTrue(fieldWriters[0] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[1] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[2] instanceof RowBigIntWriter);
    }

    @Test
    public void testArrowWriterConstructedProperlyForChainedUDFs() throws Exception {
        AbstractArrowPythonScalarFunctionRunner createChainedUDFRunner = createChainedUDFRunner();
        createChainedUDFRunner.open();
        ArrowFieldWriter[] fieldWriters = createChainedUDFRunner.arrowWriter.getFieldWriters();
        Assert.assertEquals(5L, fieldWriters.length);
        Assert.assertTrue(fieldWriters[0] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[1] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[2] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[3] instanceof RowBigIntWriter);
        Assert.assertTrue(fieldWriters[4] instanceof RowBigIntWriter);
    }

    @Test
    public void testArrowPythonScalarFunctionRunner() throws Exception {
        JobBundleFactory jobBundleFactory = (JobBundleFactory) Mockito.spy(JobBundleFactory.class);
        FnDataReceiver<byte[]> fnDataReceiver = (FnDataReceiver) Mockito.spy(FnDataReceiver.class);
        PythonFunctionInfo[] pythonFunctionInfoArr = {new PythonFunctionInfo(AbstractPythonScalarFunctionRunnerTest.DummyPythonFunction.INSTANCE, new Integer[]{0})};
        RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType())));
        AbstractArrowPythonScalarFunctionRunner<Row> createPassThroughArrowPythonScalarFunctionRunner = createPassThroughArrowPythonScalarFunctionRunner(fnDataReceiver, pythonFunctionInfoArr, rowType, rowType, 2, jobBundleFactory);
        StageBundleFactory stageBundleFactory = (StageBundleFactory) Mockito.spy(StageBundleFactory.class);
        Mockito.when(jobBundleFactory.forStage((ExecutableStage) ArgumentMatchers.any())).thenReturn(stageBundleFactory);
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.spy(RemoteBundle.class);
        Mockito.when(stageBundleFactory.getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any())).thenReturn(remoteBundle);
        HashMap hashMap = new HashMap();
        hashMap.put("input", (FnDataReceiver) Mockito.spy(FnDataReceiver.class));
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(hashMap);
        createPassThroughArrowPythonScalarFunctionRunner.open();
        ((JobBundleFactory) Mockito.verify(jobBundleFactory, Mockito.times(1))).forStage((ExecutableStage) ArgumentMatchers.any());
        ((StageBundleFactory) Mockito.verify(stageBundleFactory, Mockito.times(0))).getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any());
        createPassThroughArrowPythonScalarFunctionRunner.startBundle();
        ((StageBundleFactory) Mockito.verify(stageBundleFactory, Mockito.times(1))).getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any());
        createPassThroughArrowPythonScalarFunctionRunner.processElement(Row.of(new Object[]{1L}));
        ((RemoteBundle) Mockito.verify(remoteBundle, Mockito.times(0))).close();
        createPassThroughArrowPythonScalarFunctionRunner.finishBundle();
        ((RemoteBundle) Mockito.verify(remoteBundle, Mockito.times(1))).close();
        ((FnDataReceiver) Mockito.verify(fnDataReceiver, Mockito.times(1))).accept(ArgumentMatchers.any());
    }

    @Override // org.apache.flink.table.runtime.runners.python.scalar.AbstractPythonScalarFunctionRunnerTest
    public AbstractArrowPythonScalarFunctionRunner<Row> createPythonScalarFunctionRunner(PythonFunctionInfo[] pythonFunctionInfoArr, RowType rowType, RowType rowType2) {
        return createPassThroughArrowPythonScalarFunctionRunner(bArr -> {
        }, pythonFunctionInfoArr, rowType, rowType2, 1, (JobBundleFactory) Mockito.spy(JobBundleFactory.class));
    }

    private AbstractArrowPythonScalarFunctionRunner<Row> createPassThroughArrowPythonScalarFunctionRunner(FnDataReceiver<byte[]> fnDataReceiver, PythonFunctionInfo[] pythonFunctionInfoArr, RowType rowType, RowType rowType2, int i, JobBundleFactory jobBundleFactory) {
        return new PassThroughArrowPythonScalarFunctionRunner<Row>("testPythonRunner", fnDataReceiver, pythonFunctionInfoArr, PythonTestUtils.createTestEnvironmentManager(), rowType, rowType2, i, Collections.emptyMap(), jobBundleFactory, PythonTestUtils.createMockFlinkMetricContainer()) { // from class: org.apache.flink.table.runtime.runners.python.scalar.arrow.ArrowPythonScalarFunctionRunnerTest.1
            public ArrowWriter<Row> createArrowWriter() {
                return ArrowUtils.createRowArrowWriter(this.root, getInputType());
            }
        };
    }
}
