package org.apache.beam.sdk.extensions.sql.zetasql;

import com.alibaba.fastjson.JSON;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
import org.apache.beam.sdk.extensions.sql.impl.JdbcConnection;
import org.apache.beam.sdk.extensions.sql.impl.JdbcDriver;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamIOSourceRel;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.meta.Table;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableProvider;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Context;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Contexts;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RuleSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPushDownTest.class */
public class ZetaSQLPushDownTest {
    private static final Long PIPELINE_EXECUTION_WAITTIME_MINUTES = 2L;
    private static final Schema BASIC_SCHEMA = Schema.builder().addInt64Field("unused1").addInt64Field("id").addStringField("name").addInt64Field("unused2").build();
    private static TestTableProvider tableProvider;
    private static FrameworkConfig config;
    private static ZetaSQLQueryPlanner zetaSQLQueryPlanner;
    private static BeamSqlEnv sqlEnv;

    @Rule
    public transient TestPipeline pipeline = TestPipeline.create();

    @BeforeClass
    public static void setUp() {
        initializeBeamTableProvider();
        initializeCalciteEnvironment();
        zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
        sqlEnv = BeamSqlEnv.builder(tableProvider).setPipelineOptions(PipelineOptionsFactory.create()).build();
    }

    @Test
    public void testProjectPushDown_withoutPredicate() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT name, id, unused1 FROM InMemoryTableProject");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT name, id, unused1 FROM InMemoryTableProject");
        MatcherAssert.assertThat(convertToBeamRel, Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery, Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getDigest(), convertToBeamRel.getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    @Test
    public void testProjectPushDown_withoutPredicate_withComplexSelect() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT id+1 FROM InMemoryTableProject");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT id+1 FROM InMemoryTableProject");
        MatcherAssert.assertThat(convertToBeamRel.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getInput(0).getDigest(), convertToBeamRel.getInput(0).getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    @Test
    public void testProjectPushDown_withPredicate() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT name FROM InMemoryTableProject where id=2");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT name FROM InMemoryTableProject where id=2");
        MatcherAssert.assertThat(convertToBeamRel.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getInput(0).getDigest(), convertToBeamRel.getInput(0).getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    @Test
    public void testProjectFilterPushDown_withoutPredicate() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT name, id, unused1 FROM InMemoryTableBoth");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT name, id, unused1 FROM InMemoryTableBoth");
        MatcherAssert.assertThat(convertToBeamRel, Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery, Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getDigest(), convertToBeamRel.getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    @Test
    public void testProjectFilterPushDown_withSupportedPredicate() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT name FROM InMemoryTableBoth where id=2");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT name FROM InMemoryTableBoth where id=2");
        MatcherAssert.assertThat(convertToBeamRel, Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery, Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getDigest(), convertToBeamRel.getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    @Test
    public void testProjectFilterPushDown_withUnsupportedPredicate() {
        BeamRelNode convertToBeamRel = zetaSQLQueryPlanner.convertToBeamRel("SELECT name FROM InMemoryTableBoth where id=2 or unused1=200");
        BeamRelNode parseQuery = sqlEnv.parseQuery("SELECT name FROM InMemoryTableBoth where id=2 or unused1=200");
        MatcherAssert.assertThat(convertToBeamRel.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        MatcherAssert.assertThat(parseQuery.getInput(0), Matchers.instanceOf(BeamIOSourceRel.class));
        Assert.assertEquals(parseQuery.getInput(0).getDigest(), convertToBeamRel.getInput(0).getDigest());
        this.pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES.longValue()));
    }

    private static void initializeCalciteEnvironment() {
        initializeCalciteEnvironmentWithContext(new Context[0]);
    }

    private static void initializeCalciteEnvironmentWithContext(Context... contextArr) {
        JdbcConnection connect = JdbcDriver.connect(tableProvider, PipelineOptionsFactory.create());
        SchemaPlus currentSchemaPlus = connect.getCurrentSchemaPlus();
        config = Frameworks.newConfigBuilder().defaultSchema(currentSchemaPlus).traitDefs(ImmutableList.of(ConventionTraitDef.INSTANCE)).context(Contexts.of(ImmutableList.builder().add(Contexts.of(connect.config())).add(contextArr).build().toArray())).ruleSets((RuleSet[]) ZetaSQLQueryPlanner.getZetaSqlRuleSets().toArray(new RuleSet[0])).costFactory(BeamCostModel.FACTORY).typeSystem(connect.getTypeFactory().getTypeSystem()).build();
    }

    private static void initializeBeamTableProvider() {
        Table table = getTable("InMemoryTableProject", TestTableProvider.PushDownOptions.PROJECT);
        Table table2 = getTable("InMemoryTableBoth", TestTableProvider.PushDownOptions.BOTH);
        Row[] rowArr = {row(BASIC_SCHEMA, 100L, 1L, "one", 100L), row(BASIC_SCHEMA, 200L, 2L, "two", 200L)};
        tableProvider = new TestTableProvider();
        tableProvider.createTable(table);
        tableProvider.createTable(table2);
        tableProvider.addRows(table.getName(), rowArr);
        tableProvider.addRows(table2.getName(), rowArr);
    }

    private static Row row(Schema schema, Object... objArr) {
        return Row.withSchema(schema).addValues(objArr).build();
    }

    private static Table getTable(String str, TestTableProvider.PushDownOptions pushDownOptions) {
        return Table.builder().name(str).comment(str + " table").schema(BASIC_SCHEMA).properties(JSON.parseObject("{ push_down: \"" + pushDownOptions.toString() + "\" }")).type("test").build();
    }
}
