package org.apache.flink.connector.mongodb.table;

import com.mongodb.client.model.Filters;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.TimeZone;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rex.RexBuilder;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
import org.apache.flink.table.catalog.CatalogManager;
import org.apache.flink.table.catalog.FunctionCatalog;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.resolver.ExpressionResolver;
import org.apache.flink.table.operations.QueryOperation;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.FlinkTypeSystem;
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
import org.apache.flink.table.planner.plan.utils.RexNodeToExpressionConverter;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.assertj.core.api.Assertions;
import org.bson.BsonBoolean;
import org.bson.BsonDateTime;
import org.bson.BsonDecimal128;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
import org.bson.BsonNull;
import org.bson.BsonString;
import org.bson.conversions.Bson;
import org.bson.types.Decimal128;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import scala.Option;

/* loaded from: input_file:org/apache/flink/connector/mongodb/table/MongoFilterPushDownVisitorTest.class */
class MongoFilterPushDownVisitorTest {
    private static final String INPUT_TABLE = "mongo_source";
    private static final BsonDocument EMPTY_FILTER = Filters.empty().toBsonDocument();
    private static StreamExecutionEnvironment env;
    private static StreamTableEnvironment tEnv;
    private final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();

    MongoFilterPushDownVisitorTest() {
    }

    @BeforeEach
    void before() {
        env = StreamExecutionEnvironment.getExecutionEnvironment();
        tEnv = StreamTableEnvironment.create(env);
        tEnv.getConfig().setLocalTimeZone(ZoneId.of("UTC"));
        tEnv.executeSql("CREATE TABLE mongo_source(id INTEGER,description VARCHAR(200),boolean_col BOOLEAN,timestamp_col TIMESTAMP(0),timestamp3_col TIMESTAMP(3),double_col DOUBLE,decimal_col DECIMAL(10, 4)) WITH (  'connector'='mongodb',  'uri'='mongodb://127.0.0.1:27017',  'database'='test_db',  'collection'='test_coll')");
    }

    @Test
    void testSimpleExpressionPrimitiveType() {
        ResolvedSchema resolvedSchema = tEnv.sqlQuery("SELECT * FROM mongo_source").getResolvedSchema();
        Arrays.asList(new Object[]{"id = 6", Filters.eq("id", new BsonInt32(6))}, new Object[]{"id >= 6", Filters.gte("id", new BsonInt32(6))}, new Object[]{"id > 6", Filters.gt("id", new BsonInt32(6))}, new Object[]{"id < 6", Filters.lt("id", new BsonInt32(6))}, new Object[]{"id <= 5", Filters.lte("id", 5)}, new Object[]{"description = 'Halo'", Filters.eq("description", new BsonString("Halo"))}, new Object[]{"boolean_col = true", Filters.eq("boolean_col", new BsonBoolean(true))}, new Object[]{"boolean_col = false", Filters.eq("boolean_col", new BsonBoolean(false))}, new Object[]{"double_col > 0.5", Filters.gt("double_col", new BsonDecimal128(new Decimal128(new BigDecimal("0.5"))))}, new Object[]{"decimal_col <= -0.3", Filters.lte("decimal_col", new BsonDecimal128(new Decimal128(new BigDecimal("-0.3"))))}).forEach(objArr -> {
            assertGeneratedFilter((String) objArr[0], resolvedSchema, ((Bson) objArr[1]).toBsonDocument());
        });
    }

    @Test
    void testComplexExpressionDatetime() {
        ResolvedSchema resolvedSchema = tEnv.sqlQuery("SELECT * FROM mongo_source").getResolvedSchema();
        assertGeneratedFilter("id = 6 AND timestamp_col = TIMESTAMP '2022-01-01 07:00:01'", resolvedSchema, Filters.and(new Bson[]{Filters.eq("id", new BsonInt32(6)), Filters.eq("timestamp_col", new BsonDateTime(Timestamp.valueOf("2022-01-01 07:00:01").getTime()))}).toBsonDocument());
        assertGeneratedFilter("timestamp3_col = TIMESTAMP '2022-01-01 07:00:01.333' OR description = 'Halo'", resolvedSchema, Filters.or(new Bson[]{Filters.eq("timestamp3_col", new BsonDateTime(Timestamp.valueOf("2022-01-01 07:00:01.333").getTime())), Filters.eq("description", new BsonString("Halo"))}).toBsonDocument());
    }

    @Test
    void testExpressionWithNull() {
        ResolvedSchema resolvedSchema = tEnv.sqlQuery("SELECT * FROM mongo_source").getResolvedSchema();
        assertGeneratedFilter("id = NULL AND decimal_col <= 0.6", resolvedSchema, Filters.and(new Bson[]{Filters.eq("id", BsonNull.VALUE), Filters.lte("decimal_col", new BsonDecimal128(new Decimal128(new BigDecimal("0.6"))))}).toBsonDocument());
        assertGeneratedFilter("id = 6 OR description = NULL", resolvedSchema, Filters.or(new Bson[]{Filters.eq("id", new BsonInt32(6)), Filters.eq("description", BsonNull.VALUE)}).toBsonDocument());
    }

    @Test
    void testExpressionIsNull() {
        ResolvedSchema resolvedSchema = tEnv.sqlQuery("SELECT * FROM mongo_source").getResolvedSchema();
        assertGeneratedFilter("id IS NULL AND decimal_col <= 0.6", resolvedSchema, Filters.and(new Bson[]{Filters.eq("id", BsonNull.VALUE), Filters.lte("decimal_col", new BsonDecimal128(new Decimal128(new BigDecimal("0.6"))))}).toBsonDocument());
        assertGeneratedFilter("id = 6 OR description IS NOT NULL", resolvedSchema, Filters.or(new Bson[]{Filters.eq("id", new BsonInt32(6)), Filters.ne("description", BsonNull.VALUE)}).toBsonDocument());
    }

    @Test
    void testExpressionCannotBePushedDown() {
        ResolvedSchema resolvedSchema = tEnv.sqlQuery("SELECT * FROM mongo_source").getResolvedSchema();
        assertGeneratedFilter("description LIKE '_bcd%'", resolvedSchema, EMPTY_FILTER);
        assertGeneratedFilter("double_col = decimal_col", resolvedSchema, EMPTY_FILTER);
        assertGeneratedFilter("boolean_col = (decimal_col > 2.0)", resolvedSchema, EMPTY_FILTER);
        assertGeneratedFilter("id IS NULL AND description LIKE '_bcd%'", resolvedSchema, Filters.eq("id", BsonNull.VALUE).toBsonDocument());
        assertGeneratedFilter("id IS NOT NULL OR double_col = decimal_col", resolvedSchema, EMPTY_FILTER);
    }

    private void assertGeneratedFilter(String str, ResolvedSchema resolvedSchema, BsonDocument bsonDocument) {
        List<ResolvedExpression> resolveSQLFilterToExpression = resolveSQLFilterToExpression(str, resolvedSchema);
        ArrayList arrayList = new ArrayList();
        Iterator<ResolvedExpression> it = resolveSQLFilterToExpression.iterator();
        while (it.hasNext()) {
            BsonDocument parseFilter = MongoDynamicTableSource.parseFilter(it.next());
            if (!parseFilter.isEmpty()) {
                arrayList.add(parseFilter);
            }
        }
        BsonDocument bsonDocument2 = EMPTY_FILTER;
        if (!arrayList.isEmpty()) {
            bsonDocument2 = (arrayList.size() == 1 ? (Bson) arrayList.get(0) : Filters.and(arrayList)).toBsonDocument();
        }
        Assertions.assertThat(bsonDocument2).isEqualTo(bsonDocument);
    }

    private List<ResolvedExpression> resolveSQLFilterToExpression(String str, ResolvedSchema resolvedSchema) {
        StreamTableEnvironmentImpl streamTableEnvironmentImpl = tEnv;
        FlinkContext flinkContext = streamTableEnvironmentImpl.getPlanner().getFlinkContext();
        CatalogManager catalogManager = streamTableEnvironmentImpl.getCatalogManager();
        FunctionCatalog functionCatalog = flinkContext.getFunctionCatalog();
        RowType logicalType = resolvedSchema.toSourceRowDataType().getLogicalType();
        RexBuilder rexBuilder = new RexBuilder(new FlinkTypeFactory(this.classLoader, FlinkTypeSystem.INSTANCE));
        RexNodeToExpressionConverter rexNodeToExpressionConverter = new RexNodeToExpressionConverter(rexBuilder, (String[]) logicalType.getFieldNames().toArray(new String[0]), functionCatalog, catalogManager, TimeZone.getTimeZone(tEnv.getConfig().getLocalTimeZone()));
        return ExpressionResolver.resolverFor(tEnv.getConfig(), this.classLoader, str2 -> {
            return Optional.empty();
        }, functionCatalog.asLookup(str3 -> {
            throw new TableException("We should not need to lookup any expressions at this point");
        }), catalogManager.getDataTypeFactory(), (str4, rowType, logicalType2) -> {
            throw new TableException("SQL expression parsing is not supported at this location.");
        }, new QueryOperation[0]).build().resolve((List) RelOptUtil.conjunctions(FlinkRexUtil.toCnf(rexBuilder, -1, streamTableEnvironmentImpl.getParser().parseSqlExpression(str, logicalType, (LogicalType) null).getRexNode())).stream().map(rexNode -> {
            return (Option) rexNode.accept(rexNodeToExpressionConverter);
        }).filter((v0) -> {
            return v0.isDefined();
        }).map((v0) -> {
            return v0.get();
        }).collect(Collectors.toList()));
    }
}
