package org.apache.flink.table.runtime.runners.python.scalar.arrow;

import java.io.OutputStream;
import java.util.Map;
import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.python.shaded.org.apache.arrow.memory.BufferAllocator;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.flink.api.python.shaded.org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.flink.python.env.PythonEnvironmentManager;
import org.apache.flink.python.metric.FlinkMetricContainer;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.runtime.arrow.ArrowUtils;
import org.apache.flink.table.runtime.arrow.ArrowWriter;
import org.apache.flink.table.runtime.runners.python.scalar.AbstractPythonScalarFunctionRunner;
import org.apache.flink.table.types.logical.RowType;

@Internal
/* loaded from: input_file:org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.class */
public abstract class AbstractArrowPythonScalarFunctionRunner<IN> extends AbstractPythonScalarFunctionRunner<IN> {
    private static final String SCHEMA_ARROW_CODER_URN = "flink:coder:schema:scalar_function:arrow:v1";
    private final int maxArrowBatchSize;
    protected transient VectorSchemaRoot root;
    private transient BufferAllocator allocator;

    @VisibleForTesting
    transient ArrowWriter<IN> arrowWriter;
    private transient ArrowStreamWriter arrowStreamWriter;
    private transient int currentBatchCount;

    public AbstractArrowPythonScalarFunctionRunner(String str, FnDataReceiver<byte[]> fnDataReceiver, PythonFunctionInfo[] pythonFunctionInfoArr, PythonEnvironmentManager pythonEnvironmentManager, RowType rowType, RowType rowType2, int i, Map<String, String> map, FlinkMetricContainer flinkMetricContainer) {
        super(str, fnDataReceiver, pythonFunctionInfoArr, pythonEnvironmentManager, rowType, rowType2, map, flinkMetricContainer);
        this.maxArrowBatchSize = i;
    }

    @Override // org.apache.flink.python.AbstractPythonFunctionRunner, org.apache.flink.python.PythonFunctionRunner
    public void open() throws Exception {
        super.open();
        this.allocator = ArrowUtils.getRootAllocator().newChildAllocator("writer", 0L, Long.MAX_VALUE);
        this.root = VectorSchemaRoot.create(ArrowUtils.toArrowSchema(getInputType()), this.allocator);
        this.arrowWriter = createArrowWriter();
        this.arrowStreamWriter = new ArrowStreamWriter(this.root, (DictionaryProvider) null, (OutputStream) this.baos);
        this.arrowStreamWriter.start();
        this.currentBatchCount = 0;
    }

    @Override // org.apache.flink.python.AbstractPythonFunctionRunner, org.apache.flink.python.PythonFunctionRunner
    public void close() throws Exception {
        try {
            super.close();
            this.arrowStreamWriter.end();
        } finally {
            this.root.close();
            this.allocator.close();
        }
    }

    @Override // org.apache.flink.python.PythonFunctionRunner
    public void processElement(IN in) {
        try {
            this.arrowWriter.write(in);
            this.currentBatchCount++;
            if (this.currentBatchCount >= this.maxArrowBatchSize) {
                finishCurrentBatch();
            }
        } catch (Throwable th) {
            throw new RuntimeException("Failed to process element.", th);
        }
    }

    @Override // org.apache.flink.python.AbstractPythonFunctionRunner, org.apache.flink.python.PythonFunctionRunner
    public void finishBundle() throws Exception {
        finishCurrentBatch();
        super.finishBundle();
    }

    @Override // org.apache.flink.python.AbstractPythonFunctionRunner
    public OutputReceiverFactory createOutputReceiverFactory() {
        return new OutputReceiverFactory() { // from class: org.apache.flink.table.runtime.runners.python.scalar.arrow.AbstractArrowPythonScalarFunctionRunner.1
            @Override // org.apache.beam.runners.fnexecution.control.OutputReceiverFactory
            public FnDataReceiver<WindowedValue<byte[]>> create(String str) {
                return windowedValue -> {
                    AbstractArrowPythonScalarFunctionRunner.this.resultReceiver.accept(windowedValue.getValue());
                };
            }
        };
    }

    @Override // org.apache.flink.table.runtime.runners.python.AbstractPythonStatelessFunctionRunner
    public String getInputOutputCoderUrn() {
        return SCHEMA_ARROW_CODER_URN;
    }

    public abstract ArrowWriter<IN> createArrowWriter();

    private void finishCurrentBatch() throws Exception {
        if (this.currentBatchCount > 0) {
            this.arrowWriter.finish();
            this.arrowStreamWriter.writeBatch();
            this.arrowWriter.reset();
            this.mainInputReceiver.accept(WindowedValue.valueInGlobalWindow(this.baos.toByteArray()));
            this.baos.reset();
        }
        this.currentBatchCount = 0;
    }

    static {
        ArrowUtils.checkArrowUsable();
    }
}
