/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.rule;

import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv;
import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
import org.apache.beam.sdk.extensions.sql.impl.rule.BeamJoinAssociateRule;
import org.apache.beam.sdk.extensions.sql.impl.rule.BeamJoinPushThroughJoinRule;
import org.apache.beam.sdk.extensions.sql.impl.rule.ThreeTablesSchema;
import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableProvider;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.calcite.v1_28_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.adapter.enumerable.EnumerableRules;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.ConventionTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTrait;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelCollationTraitDef;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelRoot;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.TableScan;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.rules.CoreRules;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.schema.SchemaPlus;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlNode;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.parser.SqlParser;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.FrameworkConfig;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Frameworks;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Planner;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Program;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.Programs;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.RuleSet;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.tools.RuleSets;
import org.junit.Assert;
import org.junit.Test;

public class JoinReorderingTest {
    @Test
    public void testTableSizes() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        Assert.assertEquals((double)1.0, (double)tableProvider.buildBeamSqlTable(tableProvider.getTable("small_table")).getTableStatistics(null).getRowCount(), (double)0.01);
        Assert.assertEquals((double)3.0, (double)tableProvider.buildBeamSqlTable(tableProvider.getTable("medium_table")).getTableStatistics(null).getRowCount(), (double)0.01);
        Assert.assertEquals((double)100.0, (double)tableProvider.buildBeamSqlTable(tableProvider.getTable("large_table")).getTableStatistics(null).getRowCount(), (double)0.01);
    }

    @Test
    public void testBeamJoinAssociationRule() throws Exception {
        RuleSet prepareRules = RuleSets.ofList((RelOptRule[])new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        String sqlQuery = "select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
        RelNode originalPlan = this.transform(sqlQuery, prepareRules);
        RelNode optimizedPlan = this.transform(sqlQuery, RuleSets.ofList((Iterable)ImmutableList.builder().addAll((Iterable)prepareRules).add((Object)BeamJoinAssociateRule.INSTANCE).build()));
        this.assertTopTableInJoins(originalPlan, "small_table");
        this.assertTopTableInJoins(optimizedPlan, "large_table");
    }

    @Test
    public void testBeamJoinPushThroughJoinRuleLeft() throws Exception {
        RuleSet prepareRules = RuleSets.ofList((RelOptRule[])new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        String sqlQuery = "select * from \"tt\".\"large_table\" as large_table  JOIN \"tt\".\"medium_table\" as medium_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
        RelNode originalPlan = this.transform(sqlQuery, prepareRules);
        RelNode optimizedPlan = this.transform(sqlQuery, RuleSets.ofList((Iterable)ImmutableList.builder().addAll((Iterable)prepareRules).add((Object)BeamJoinPushThroughJoinRule.LEFT).build()));
        this.assertTopTableInJoins(originalPlan, "small_table");
        this.assertTopTableInJoins(optimizedPlan, "large_table");
    }

    @Test
    public void testBeamJoinPushThroughJoinRuleRight() throws Exception {
        RuleSet prepareRules = RuleSets.ofList((RelOptRule[])new RelOptRule[]{CoreRules.SORT_PROJECT_TRANSPOSE, EnumerableRules.ENUMERABLE_JOIN_RULE, EnumerableRules.ENUMERABLE_PROJECT_RULE, EnumerableRules.ENUMERABLE_SORT_RULE, EnumerableRules.ENUMERABLE_TABLE_SCAN_RULE});
        String sqlQuery = "select * from \"tt\".\"medium_table\" as medium_table  JOIN \"tt\".\"large_table\" as large_table on large_table.\"medium_key\" = medium_table.\"large_key\"  JOIN \"tt\".\"small_table\" as small_table on medium_table.\"small_key\" = small_table.\"medium_key\" ";
        RelNode originalPlan = this.transform(sqlQuery, prepareRules);
        RelNode optimizedPlan = this.transform(sqlQuery, RuleSets.ofList((Iterable)ImmutableList.builder().addAll((Iterable)prepareRules).add((Object)BeamJoinPushThroughJoinRule.RIGHT).build()));
        this.assertTopTableInJoins(originalPlan, "small_table");
        this.assertTopTableInJoins(optimizedPlan, "large_table");
    }

    @Test
    public void testSystemReorderingLargeMediumSmall() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        BeamSqlEnv env = BeamSqlEnv.withTableProvider((TableProvider)tableProvider);
        BeamRelNode parsedQuery = env.parseQuery("select * from large_table  JOIN medium_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key ");
        this.assertTopTableInJoins((RelNode)parsedQuery, "large_table");
    }

    @Test
    public void testSystemReorderingMediumLargeSmall() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        BeamSqlEnv env = BeamSqlEnv.withTableProvider((TableProvider)tableProvider);
        BeamRelNode parsedQuery = env.parseQuery("select * from medium_table  JOIN large_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key ");
        this.assertTopTableInJoins((RelNode)parsedQuery, "large_table");
    }

    @Test
    public void testSystemNotReorderingWithoutRules() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        List ruleSet = BeamRuleSets.getRuleSets().stream().flatMap(rules -> StreamSupport.stream(rules.spliterator(), false)).filter(rule -> !(rule instanceof BeamJoinPushThroughJoinRule)).filter(rule -> !(rule instanceof BeamJoinAssociateRule)).filter(rule -> !(rule instanceof JoinCommuteRule)).collect(Collectors.toList());
        BeamSqlEnv env = BeamSqlEnv.builder((TableProvider)tableProvider).setPipelineOptions(PipelineOptionsFactory.create()).setRuleSets((Collection)ImmutableList.of((Object)RuleSets.ofList(ruleSet))).build();
        BeamRelNode parsedQuery = env.parseQuery("select * from medium_table  JOIN large_table on large_table.medium_key = medium_table.large_key  JOIN small_table on medium_table.small_key = small_table.medium_key ");
        this.assertTopTableInJoins((RelNode)parsedQuery, "small_table");
    }

    @Test
    public void testSystemNotReorderingMediumSmallLarge() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        BeamSqlEnv env = BeamSqlEnv.withTableProvider((TableProvider)tableProvider);
        BeamRelNode parsedQuery = env.parseQuery("select * from medium_table  JOIN small_table on medium_table.small_key = small_table.medium_key  JOIN large_table on large_table.medium_key = medium_table.large_key ");
        this.assertTopTableInJoins((RelNode)parsedQuery, "large_table");
    }

    @Test
    public void testSystemNotReorderingSmallMediumLarge() {
        TestTableProvider tableProvider = new TestTableProvider();
        this.createThreeTables(tableProvider);
        BeamSqlEnv env = BeamSqlEnv.withTableProvider((TableProvider)tableProvider);
        BeamRelNode parsedQuery = env.parseQuery("select * from small_table  JOIN medium_table on medium_table.small_key = small_table.medium_key  JOIN large_table on large_table.medium_key = medium_table.large_key ");
        this.assertTopTableInJoins((RelNode)parsedQuery, "large_table");
    }

    private RelNode transform(String sql, RuleSet prepareRules) throws Exception {
        SchemaPlus rootSchema = Frameworks.createRootSchema((boolean)true);
        SchemaPlus defSchema = rootSchema.add("tt", (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.schema.Schema)new ThreeTablesSchema());
        FrameworkConfig config = Frameworks.newConfigBuilder().parserConfig(SqlParser.Config.DEFAULT).defaultSchema(defSchema).traitDefs(new RelTraitDef[]{ConventionTraitDef.INSTANCE, RelCollationTraitDef.INSTANCE}).programs(new Program[]{Programs.of((RuleSet)prepareRules)}).build();
        Planner planner = Frameworks.getPlanner((FrameworkConfig)config);
        SqlNode parse = planner.parse(sql);
        SqlNode validate = planner.validate(parse);
        RelRoot planRoot = planner.rel(validate);
        RelNode planBefore = planRoot.rel;
        RelTraitSet desiredTraits = planBefore.getTraitSet().replace((RelTrait)EnumerableConvention.INSTANCE);
        return planner.transform(0, desiredTraits, planBefore);
    }

    private void assertTopTableInJoins(RelNode parsedQuery, String expectedTableName) {
        RelNode firstJoin = parsedQuery;
        while (!(firstJoin instanceof Join)) {
            firstJoin = firstJoin.getInput(0);
        }
        RelNode topRight = ((Join)firstJoin).getRight();
        while (!(topRight instanceof Join) && !(topRight instanceof TableScan)) {
            topRight = topRight.getInput(0);
        }
        if (topRight instanceof TableScan) {
            Assert.assertTrue((boolean)topRight.getDescription().contains(expectedTableName));
        } else {
            RelNode topLeft = ((Join)firstJoin).getLeft();
            while (!(topLeft instanceof TableScan)) {
                topLeft = topLeft.getInput(0);
            }
            Assert.assertTrue((boolean)topLeft.getDescription().contains(expectedTableName));
        }
    }

    private void createThreeTables(TestTableProvider tableProvider) {
        int i;
        BeamSqlEnv env = BeamSqlEnv.withTableProvider((TableProvider)tableProvider);
        env.executeDdl("CREATE EXTERNAL TABLE small_table (id INTEGER, medium_key INTEGER) TYPE text");
        env.executeDdl("CREATE EXTERNAL TABLE medium_table (id INTEGER,small_key INTEGER,large_key INTEGER) TYPE text");
        env.executeDdl("CREATE EXTERNAL TABLE large_table (id INTEGER,medium_key INTEGER) TYPE text");
        Row row = Row.withSchema((Schema)tableProvider.getTable("small_table").getSchema()).addValues(new Object[]{1, 1}).build();
        tableProvider.addRows("small_table", new Row[]{row});
        for (i = 0; i < 3; ++i) {
            row = Row.withSchema((Schema)tableProvider.getTable("medium_table").getSchema()).addValues(new Object[]{i, 1, 2}).build();
            tableProvider.addRows("medium_table", new Row[]{row});
        }
        for (i = 0; i < 100; ++i) {
            row = Row.withSchema((Schema)tableProvider.getTable("large_table").getSchema()).addValues(new Object[]{i, 2}).build();
            tableProvider.addRows("large_table", new Row[]{row});
        }
    }
}

