package org.apache.beam.sdk.extensions.sql.impl.rule;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableProvider;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.ConventionTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.TableScan;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.rules.CoreRules;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.parser.SqlParser;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Planner;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Program;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Programs;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.RuleSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.RuleSets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/beam/sdk/extensions/sql/impl/rule/JoinReorderingTest.class */
public class JoinReorderingTest {
    @Test
    public void testTableSizes() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        Assert.assertEquals(1.0d, testTableProvider.buildBeamSqlTable(testTableProvider.getTable("small_table")).getTableStatistics((PipelineOptions) null).getRowCount().doubleValue(), 0.01d);
        Assert.assertEquals(3.0d, testTableProvider.buildBeamSqlTable(testTableProvider.getTable("medium_table")).getTableStatistics((PipelineOptions) null).getRowCount().doubleValue(), 0.01d);
        Assert.assertEquals(100.0d, testTableProvider.buildBeamSqlTable(testTableProvider.getTable("large_table")).getTableStatistics((PipelineOptions) null).getRowCount().doubleValue(), 0.01d);
    }

    @Test
    public void testBeamJoinAssociationRule() throws Exception {
        RuleSet ofList = RuleSets.ofList(new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        RelNode transform = transform("select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", ofList);
        RelNode transform2 = transform("select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", RuleSets.ofList(ImmutableList.builder().addAll(ofList).add(BeamJoinAssociateRule.INSTANCE).build()));
        assertTopTableInJoins(transform, "small_table");
        assertTopTableInJoins(transform2, "large_table");
    }

    @Test
    public void testBeamJoinPushThroughJoinRuleLeft() throws Exception {
        RuleSet ofList = RuleSets.ofList(new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        RelNode transform = transform("select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", ofList);
        RelNode transform2 = transform("select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", RuleSets.ofList(ImmutableList.builder().addAll(ofList).add(BeamJoinPushThroughJoinRule.LEFT).build()));
        assertTopTableInJoins(transform, "small_table");
        assertTopTableInJoins(transform2, "large_table");
    }

    @Test
    public void testBeamJoinPushThroughJoinRuleRight() throws Exception {
        RuleSet ofList = RuleSets.ofList(new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        RelNode transform = transform("select * from \"tt\".\"medium_table\" as medium_table  JOIN \"tt\".\"large_table\" as large_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", ofList);
        RelNode transform2 = transform("select * from \"tt\".\"medium_table\" as medium_table  JOIN \"tt\".\"large_table\" as large_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ", RuleSets.ofList(ImmutableList.builder().addAll(ofList).add(BeamJoinPushThroughJoinRule.RIGHT).build()));
        assertTopTableInJoins(transform, "small_table");
        assertTopTableInJoins(transform2, "large_table");
    }

    @Test
    public void testSystemReorderingLargeMediumSmall() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        assertTopTableInJoins(BeamSqlEnv.withTableProvider(testTableProvider).parseQuery("select * from large_table  JOIN medium_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key "), "large_table");
    }

    @Test
    public void testSystemReorderingMediumLargeSmall() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        assertTopTableInJoins(BeamSqlEnv.withTableProvider(testTableProvider).parseQuery("select * from medium_table  JOIN large_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key "), "large_table");
    }

    @Test
    public void testSystemNotReorderingWithoutRules() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        assertTopTableInJoins(BeamSqlEnv.builder(testTableProvider).setPipelineOptions(PipelineOptionsFactory.create()).setRuleSets(ImmutableList.of(RuleSets.ofList((List) BeamRuleSets.getRuleSets().stream().flatMap(ruleSet -> {
            return StreamSupport.stream(ruleSet.spliterator(), false);
        }).filter(relOptRule -> {
            return !(relOptRule instanceof BeamJoinPushThroughJoinRule);
        }).filter(relOptRule2 -> {
            return !(relOptRule2 instanceof BeamJoinAssociateRule);
        }).filter(relOptRule3 -> {
            return !(relOptRule3 instanceof JoinCommuteRule);
        }).collect(Collectors.toList())))).build().parseQuery("select * from medium_table  JOIN large_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key "), "small_table");
    }

    @Test
    public void testSystemNotReorderingMediumSmallLarge() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        assertTopTableInJoins(BeamSqlEnv.withTableProvider(testTableProvider).parseQuery("select * from medium_table  JOIN small_table on medium_table.small_key = small_table.medium_key  JOIN large_table on large_table.medium_key = medium_table.large_key "), "large_table");
    }

    @Test
    public void testSystemNotReorderingSmallMediumLarge() {
        TestTableProvider testTableProvider = new TestTableProvider();
        createThreeTables(testTableProvider);
        assertTopTableInJoins(BeamSqlEnv.withTableProvider(testTableProvider).parseQuery("select * from small_table  JOIN medium_table on medium_table.small_key = small_table.medium_key  JOIN large_table on large_table.medium_key = medium_table.large_key "), "large_table");
    }

    private RelNode transform(String str, RuleSet ruleSet) throws Exception {
        Planner planner = Frameworks.getPlanner(Frameworks.newConfigBuilder().parserConfig(SqlParser.Config.DEFAULT).defaultSchema(Frameworks.createRootSchema(true).add("tt", new ThreeTablesSchema())).traitDefs(new RelTraitDef[]{ConventionTraitDef.INSTANCE, RelCollationTraitDef.INSTANCE}).programs(new Program[]{Programs.of(ruleSet)}).build());
        RelNode relNode = planner.rel(planner.validate(planner.parse(str))).rel;
        return planner.transform(0, relNode.getTraitSet().replace(EnumerableConvention.INSTANCE), relNode);
    }

    private void assertTopTableInJoins(RelNode relNode, String str) {
        RelNode relNode2;
        RelNode relNode3;
        RelNode relNode4 = relNode;
        while (true) {
            relNode2 = relNode4;
            if (relNode2 instanceof Join) {
                break;
            } else {
                relNode4 = relNode2.getInput(0);
            }
        }
        RelNode right = ((Join) relNode2).getRight();
        while (true) {
            relNode3 = right;
            if ((relNode3 instanceof Join) || (relNode3 instanceof TableScan)) {
                break;
            } else {
                right = relNode3.getInput(0);
            }
        }
        if (relNode3 instanceof TableScan) {
            Assert.assertTrue(relNode3.getDescription().contains(str));
            return;
        }
        RelNode left = ((Join) relNode2).getLeft();
        while (true) {
            RelNode relNode5 = left;
            if (relNode5 instanceof TableScan) {
                Assert.assertTrue(relNode5.getDescription().contains(str));
                return;
            }
            left = relNode5.getInput(0);
        }
    }

    private void createThreeTables(TestTableProvider testTableProvider) {
        BeamSqlEnv withTableProvider = BeamSqlEnv.withTableProvider(testTableProvider);
        withTableProvider.executeDdl("CREATE EXTERNAL TABLE small_table (id INTEGER, medium_key INTEGER) TYPE text");
        withTableProvider.executeDdl("CREATE EXTERNAL TABLE medium_table (id INTEGER,small_key INTEGER,large_key INTEGER) TYPE text");
        withTableProvider.executeDdl("CREATE EXTERNAL TABLE large_table (id INTEGER,medium_key INTEGER) TYPE text");
        testTableProvider.addRows("small_table", new Row[]{Row.withSchema(testTableProvider.getTable("small_table").getSchema()).addValues(new Object[]{1, 1}).build()});
        for (int i = 0; i < 3; i++) {
            testTableProvider.addRows("medium_table", new Row[]{Row.withSchema(testTableProvider.getTable("medium_table").getSchema()).addValues(new Object[]{Integer.valueOf(i), 1, 2}).build()});
        }
        for (int i2 = 0; i2 < 100; i2++) {
            testTableProvider.addRows("large_table", new Row[]{Row.withSchema(testTableProvider.getTable("large_table").getSchema()).addValues(new Object[]{Integer.valueOf(i2), 2}).build()});
        }
    }
}
