package org.apache.flink.table.runtime.operators.python.scalar;

import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.configuration.TaskManagerOptions;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.watermark.Watermark;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.functions.python.PythonFunction;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.table.types.logical.VarCharType;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase.class */
public abstract class PythonScalarFunctionOperatorTestBase<IN, OUT, UDFIN> {

    /* loaded from: input_file:org/apache/flink/table/runtime/operators/python/scalar/PythonScalarFunctionOperatorTestBase$DummyPythonFunction.class */
    public static class DummyPythonFunction implements PythonFunction {
        private static final long serialVersionUID = 1;
        public static final PythonFunction INSTANCE = new DummyPythonFunction();

        public byte[] getSerializedPythonFunction() {
            return new byte[0];
        }

        public PythonEnv getPythonEnv() {
            return new PythonEnv(PythonEnv.ExecType.PROCESS);
        }
    }

    @Test
    public void testRetractionFieldKept() throws Exception {
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(new Configuration());
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        testHarness.processElement(new StreamRecord(newRow(false, "c3", "c4", 1L), 0 + 2));
        testHarness.processElement(new StreamRecord(newRow(false, "c5", "c6", 2L), 0 + 3));
        testHarness.close();
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        concurrentLinkedQueue.add(new StreamRecord(newRow(false, "c3", "c4", 1L)));
        concurrentLinkedQueue.add(new StreamRecord(newRow(false, "c5", "c6", 2L)));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
    }

    @Test
    public void testFinishBundleTriggeredOnCheckpoint() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(configuration);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        testHarness.prepareSnapshotPreBarrier(0L);
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByCount() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 2);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(configuration);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        assertOutputEquals("FinishBundle should not be triggered.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 1L), 0 + 2));
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 1L)));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByTime() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
        configuration.setLong(PythonOptions.MAX_BUNDLE_TIME_MILLS, 1000L);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(configuration);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        assertOutputEquals("FinishBundle should not be triggered.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.setProcessingTime(1000L);
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testFinishBundleTriggeredByClose() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(configuration);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        assertOutputEquals("FinishBundle should not be triggered.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.close();
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
    }

    @Test
    public void testWatermarkProcessedOnFinishBundle() throws Exception {
        Configuration configuration = new Configuration();
        configuration.setInteger(PythonOptions.MAX_BUNDLE_SIZE, 10);
        OneInputStreamOperatorTestHarness<IN, OUT> testHarness = getTestHarness(configuration);
        ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        testHarness.open();
        testHarness.processElement(new StreamRecord(newRow(true, "c1", "c2", 0L), 0 + 1));
        testHarness.processWatermark(0 + 2);
        assertOutputEquals("Watermark has been processed", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.prepareSnapshotPreBarrier(0L);
        concurrentLinkedQueue.add(new StreamRecord(newRow(true, "c1", "c2", 0L)));
        concurrentLinkedQueue.add(new Watermark(0 + 2));
        assertOutputEquals("Output was not correct.", concurrentLinkedQueue, testHarness.getOutput());
        testHarness.close();
    }

    @Test
    public void testPythonScalarFunctionOperatorIsChainedByDefault() {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(1);
        StreamTableEnvironment createTableEnvironment = createTableEnvironment(executionEnvironment);
        createTableEnvironment.getConfig().set(TaskManagerOptions.TASK_OFF_HEAP_MEMORY, MemorySize.parse("80mb"));
        createTableEnvironment.registerFunction("pyFunc", new JavaUserDefinedScalarFunctions.PythonScalarFunction("pyFunc"));
        createTableEnvironment.toAppendStream(createTableEnvironment.fromDataStream(executionEnvironment.fromElements(new Tuple2[]{new Tuple2(1, 2)}), new Expression[]{Expressions.$("a"), Expressions.$("b")}).select(new Expression[]{Expressions.call("pyFunc", new Object[]{Expressions.$("a"), Expressions.$("b")})}), BasicTypeInfo.INT_TYPE_INFO);
        Assert.assertEquals(1L, executionEnvironment.getStreamGraph().getJobGraph().getVerticesSortedTopologicallyFromSources().size());
    }

    private OneInputStreamOperatorTestHarness<IN, OUT> getTestHarness(Configuration configuration) throws Exception {
        RowType rowType = new RowType(Arrays.asList(new RowType.RowField("f1", new VarCharType()), new RowType.RowField("f2", new VarCharType()), new RowType.RowField("f3", new BigIntType())));
        OneInputStreamOperatorTestHarness<IN, OUT> oneInputStreamOperatorTestHarness = new OneInputStreamOperatorTestHarness<>(mo22getTestOperator(configuration, new PythonFunctionInfo[]{new PythonFunctionInfo(DummyPythonFunction.INSTANCE, new Integer[]{0})}, rowType, rowType, new int[]{2}, new int[]{0, 1}));
        oneInputStreamOperatorTestHarness.getStreamConfig().setManagedMemoryFractionOperatorOfUseCase(ManagedMemoryUseCase.PYTHON, 0.5d);
        oneInputStreamOperatorTestHarness.setup(getOutputTypeSerializer(rowType));
        return oneInputStreamOperatorTestHarness;
    }

    /* renamed from: getTestOperator */
    public abstract AbstractPythonScalarFunctionOperator mo22getTestOperator(Configuration configuration, PythonFunctionInfo[] pythonFunctionInfoArr, RowType rowType, RowType rowType2, int[] iArr, int[] iArr2);

    public abstract IN newRow(boolean z, Object... objArr);

    public abstract void assertOutputEquals(String str, Collection<Object> collection, Collection<Object> collection2);

    public abstract StreamTableEnvironment createTableEnvironment(StreamExecutionEnvironment streamExecutionEnvironment);

    public abstract TypeSerializer<OUT> getOutputTypeSerializer(RowType rowType);
}
