package org.apache.flink.table.runtime.runners.python.scalar;

import java.io.ByteArrayOutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.Objects;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.control.BundleProgressHandler;
import org.apache.beam.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.runners.fnexecution.control.RemoteBundle;
import org.apache.beam.runners.fnexecution.control.StageBundleFactory;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.runners.python.scalar.AbstractPythonScalarFunctionRunnerTest;
import org.apache.flink.table.runtime.typeutils.PythonTypeUtils;
import org.apache.flink.table.runtime.utils.PassThroughPythonScalarFunctionRunner;
import org.apache.flink.table.runtime.utils.PythonTestUtils;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.types.Row;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/table/runtime/runners/python/scalar/PythonScalarFunctionRunnerTest.class */
public class PythonScalarFunctionRunnerTest extends AbstractPythonScalarFunctionRunnerTest<Row> {
    @Test
    public void testInputOutputDataTypeConstructedProperlyForSingleUDF() throws Exception {
        Assert.assertEquals(1L, createSingleUDFRunner().getInputTypeSerializer().getArity());
    }

    @Test
    public void testInputOutputDataTypeConstructedProperlyForMultipleUDFs() throws Exception {
        Assert.assertEquals(3L, createMultipleUDFRunner().getInputTypeSerializer().getArity());
    }

    @Test
    public void testInputOutputDataTypeConstructedProperlyForChainedUDFs() throws Exception {
        Assert.assertEquals(5L, createChainedUDFRunner().getInputTypeSerializer().getArity());
    }

    @Test
    public void testUDFProtoConstructedProperlyForSingleUDF() throws Exception {
        FlinkFnApi.UserDefinedFunctions userDefinedFunctionsProto = createSingleUDFRunner().getUserDefinedFunctionsProto();
        Assert.assertEquals(1L, userDefinedFunctionsProto.getUdfsCount());
        FlinkFnApi.UserDefinedFunction udfs = userDefinedFunctionsProto.getUdfs(0);
        Assert.assertEquals(1L, udfs.getInputsCount());
        Assert.assertEquals(0L, udfs.getInputs(0).getInputOffset());
    }

    @Test
    public void testUDFProtoConstructedProperlyForMultipleUDFs() throws Exception {
        FlinkFnApi.UserDefinedFunctions userDefinedFunctionsProto = createMultipleUDFRunner().getUserDefinedFunctionsProto();
        Assert.assertEquals(2L, userDefinedFunctionsProto.getUdfsCount());
        FlinkFnApi.UserDefinedFunction udfs = userDefinedFunctionsProto.getUdfs(0);
        Assert.assertEquals(2L, udfs.getInputsCount());
        Assert.assertEquals(0L, udfs.getInputs(0).getInputOffset());
        Assert.assertEquals(1L, udfs.getInputs(1).getInputOffset());
        FlinkFnApi.UserDefinedFunction udfs2 = userDefinedFunctionsProto.getUdfs(1);
        Assert.assertEquals(2L, udfs2.getInputsCount());
        Assert.assertEquals(0L, udfs2.getInputs(0).getInputOffset());
        Assert.assertEquals(2L, udfs2.getInputs(1).getInputOffset());
    }

    @Test
    public void testUDFProtoConstructedProperlyForChainedUDFs() throws Exception {
        FlinkFnApi.UserDefinedFunctions userDefinedFunctionsProto = createChainedUDFRunner().getUserDefinedFunctionsProto();
        Assert.assertEquals(3L, userDefinedFunctionsProto.getUdfsCount());
        FlinkFnApi.UserDefinedFunction udfs = userDefinedFunctionsProto.getUdfs(0);
        Assert.assertEquals(2L, udfs.getInputsCount());
        Assert.assertEquals(0L, udfs.getInputs(0).getInputOffset());
        Assert.assertEquals(1L, udfs.getInputs(1).getInputOffset());
        FlinkFnApi.UserDefinedFunction udfs2 = userDefinedFunctionsProto.getUdfs(1);
        Assert.assertEquals(2L, udfs2.getInputsCount());
        Assert.assertEquals(0L, udfs2.getInputs(0).getInputOffset());
        FlinkFnApi.UserDefinedFunction udf = udfs2.getInputs(1).getUdf();
        Assert.assertEquals(2L, udf.getInputsCount());
        Assert.assertEquals(1L, udf.getInputs(0).getInputOffset());
        Assert.assertEquals(2L, udf.getInputs(1).getInputOffset());
        FlinkFnApi.UserDefinedFunction udfs3 = userDefinedFunctionsProto.getUdfs(2);
        FlinkFnApi.UserDefinedFunction udf2 = udfs3.getInputs(0).getUdf();
        Assert.assertEquals(2L, udf2.getInputsCount());
        Assert.assertEquals(1L, udf2.getInputs(0).getInputOffset());
        Assert.assertEquals(3L, udf2.getInputs(1).getInputOffset());
        FlinkFnApi.UserDefinedFunction udf3 = udfs3.getInputs(1).getUdf();
        Assert.assertEquals(2L, udf3.getInputsCount());
        Assert.assertEquals(3L, udf3.getInputs(0).getInputOffset());
        Assert.assertEquals(4L, udf3.getInputs(1).getInputOffset());
    }

    @Test
    public void testPythonScalarFunctionRunner() throws Exception {
        JobBundleFactory jobBundleFactory = (JobBundleFactory) Mockito.spy(JobBundleFactory.class);
        FnDataReceiver<byte[]> fnDataReceiver = (FnDataReceiver) Mockito.spy(FnDataReceiver.class);
        AbstractGeneralPythonScalarFunctionRunner<Row> createUDFRunner = createUDFRunner(jobBundleFactory, fnDataReceiver);
        StageBundleFactory stageBundleFactory = (StageBundleFactory) Mockito.spy(StageBundleFactory.class);
        Mockito.when(jobBundleFactory.forStage((ExecutableStage) ArgumentMatchers.any())).thenReturn(stageBundleFactory);
        RemoteBundle remoteBundle = (RemoteBundle) Mockito.spy(RemoteBundle.class);
        Mockito.when(stageBundleFactory.getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any())).thenReturn(remoteBundle);
        HashMap hashMap = new HashMap();
        hashMap.put("input", (FnDataReceiver) Mockito.spy(FnDataReceiver.class));
        Mockito.when(remoteBundle.getInputReceivers()).thenReturn(hashMap);
        createUDFRunner.open();
        ((JobBundleFactory) Mockito.verify(jobBundleFactory, Mockito.times(1))).forStage((ExecutableStage) ArgumentMatchers.any());
        ((StageBundleFactory) Mockito.verify(stageBundleFactory, Mockito.times(0))).getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any());
        createUDFRunner.startBundle();
        ((StageBundleFactory) Mockito.verify(stageBundleFactory, Mockito.times(1))).getBundle((OutputReceiverFactory) ArgumentMatchers.any(), (StateRequestHandler) ArgumentMatchers.any(), (BundleProgressHandler) ArgumentMatchers.any());
        createUDFRunner.processElement(Row.of(new Object[]{1L}));
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        createUDFRunner.getInputTypeSerializer().serialize(Row.of(new Object[]{1L}), new DataOutputViewStreamWrapper(byteArrayOutputStream));
        ((RemoteBundle) Mockito.verify(remoteBundle, Mockito.times(0))).close();
        createUDFRunner.finishBundle();
        ((RemoteBundle) Mockito.verify(remoteBundle, Mockito.times(1))).close();
        ((FnDataReceiver) Mockito.verify(fnDataReceiver, Mockito.times(1))).accept(ArgumentMatchers.argThat(bArr -> {
            return Objects.deepEquals(bArr, byteArrayOutputStream.toByteArray());
        }));
    }

    @Override // org.apache.flink.table.runtime.runners.python.scalar.AbstractPythonScalarFunctionRunnerTest
    public AbstractGeneralPythonScalarFunctionRunner<Row> createPythonScalarFunctionRunner(PythonFunctionInfo[] pythonFunctionInfoArr, RowType rowType, RowType rowType2) {
        return new PythonScalarFunctionRunner("testPythonRunner", bArr -> {
        }, pythonFunctionInfoArr, PythonTestUtils.createTestEnvironmentManager(), rowType, rowType2, Collections.emptyMap(), PythonTestUtils.createMockFlinkMetricContainer());
    }

    private AbstractGeneralPythonScalarFunctionRunner<Row> createUDFRunner(JobBundleFactory jobBundleFactory, FnDataReceiver<byte[]> fnDataReceiver) {
        PythonFunctionInfo[] pythonFunctionInfoArr = {new PythonFunctionInfo(AbstractPythonScalarFunctionRunnerTest.DummyPythonFunction.INSTANCE, new Integer[]{0})};
        RowType rowType = new RowType(Collections.singletonList(new RowType.RowField("f1", new BigIntType())));
        return new PassThroughPythonScalarFunctionRunner<Row>("testPythonRunner", fnDataReceiver, pythonFunctionInfoArr, PythonTestUtils.createTestEnvironmentManager(), rowType, rowType, Collections.emptyMap(), jobBundleFactory, PythonTestUtils.createMockFlinkMetricContainer()) { // from class: org.apache.flink.table.runtime.runners.python.scalar.PythonScalarFunctionRunnerTest.1
            public TypeSerializer<Row> getInputTypeSerializer() {
                return PythonTypeUtils.toFlinkTypeSerializer(getInputType());
            }
        };
    }
}
