/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.EquiJoin;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPermuteInputsShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.query.util.RexUtils;

public class KapFilterJoinRule
extends RelOptRule {
    public static final KapFilterJoinRule KAP_FILTER_ON_JOIN_JOIN = new KapFilterJoinRule(KapFilterJoinRule.operand(Filter.class, (RelOptRuleOperand)KapFilterJoinRule.operand(Join.class, (RelOptRuleOperand)KapFilterJoinRule.operand(Join.class, (RelOptRuleOperand)KapFilterJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{KapFilterJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)RelOptRule.any())}), (RelOptRuleOperand[])new RelOptRuleOperand[]{KapFilterJoinRule.operand(RelNode.class, null, input -> !(input instanceof Join), (RelOptRuleOperandChildren)RelOptRule.any())}), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, true, "KapFilterJoinRule:filter-join-join");
    public static final KapFilterJoinRule KAP_FILTER_ON_JOIN_SCAN = new KapFilterJoinRule(KapFilterJoinRule.operand(Filter.class, (RelOptRuleOperand)KapFilterJoinRule.operand(Join.class, (RelOptRuleOperand)KapFilterJoinRule.operand(RelNode.class, null, input -> !(input instanceof Join), (RelOptRuleOperandChildren)RelOptRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{KapFilterJoinRule.operand(RelNode.class, null, input -> !(input instanceof Join), (RelOptRuleOperandChildren)RelOptRule.any())}), (RelOptRuleOperand[])new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER, false, "KapFilterJoinRule:filter-join-scan");
    private boolean needTranspose;

    private KapFilterJoinRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, boolean needTranspose, String discription) {
        super(relOptRuleOperand, relBuilderFactory, discription);
        this.needTranspose = needTranspose;
    }

    public void onMatch(RelOptRuleCall call) {
        RuleMatchHandler handler = new RuleMatchHandler(call);
        handler.perform();
    }

    private class RuleMatchHandler {
        private Filter filterRel;
        private Join topJoinRel;
        private AbstractRelNode bottomJoin;
        private RelNode relA;
        private RelNode relB;
        private RelNode relC;
        private RelBuilder relBuilder;
        private RelOptRuleCall call;

        public RuleMatchHandler(RelOptRuleCall call) {
            this.call = call;
            this.filterRel = (Filter)call.rel(0);
            this.topJoinRel = (Join)call.rel(1);
            this.bottomJoin = (AbstractRelNode)call.rel(2);
            this.relC = call.rel(3);
            this.relB = null;
            this.relA = null;
            if (call.rels.length > 4) {
                this.relB = call.rel(4);
                this.relA = call.rel(5);
            }
            this.relBuilder = call.builder();
        }

        protected void perform() {
            List joinFilters = RelOptUtil.conjunctions((RexNode)this.topJoinRel.getCondition());
            if (!joinFilters.isEmpty() || this.filterRel == null) {
                return;
            }
            List aboveFilters = RelOptUtil.conjunctions((RexNode)this.filterRel.getCondition());
            aboveFilters = aboveFilters.stream().map(RexUtils::stripOffCastInColumnEqualPredicate).collect(Collectors.toList());
            ImmutableList origAboveFilters = ImmutableList.copyOf(aboveFilters);
            JoinRelType joinType = this.topJoinRel.getJoinType();
            ArrayList leftFilters = new ArrayList();
            ArrayList rightFilters = new ArrayList();
            boolean filterPushed = this.pushDownFilter(aboveFilters, leftFilters, rightFilters, joinFilters);
            if (!filterPushed && joinType == this.topJoinRel.getJoinType() || joinFilters.isEmpty() && leftFilters.isEmpty() && rightFilters.isEmpty()) {
                return;
            }
            boolean isNeedProject = false;
            if (KapFilterJoinRule.this.needTranspose && this.bottomJoin instanceof Join && RelOptUtil.conjunctions((RexNode)((Join)this.bottomJoin).getCondition()).isEmpty() && joinFilters.size() > 1 && !(this.relA instanceof Aggregate)) {
                int originFilterSize = joinFilters.size();
                Join originTopJoin = this.topJoinRel.copy(this.topJoinRel.getTraitSet(), this.topJoinRel.getInputs());
                Filter newFilter = (Filter)this.transposeJoinRel();
                ArrayList newLeftFilters = Lists.newArrayList();
                ArrayList newRightFilters = Lists.newArrayList();
                ArrayList newJoinFilters = Lists.newArrayList();
                List newAboveFilter = RelOptUtil.conjunctions((RexNode)newFilter.getCondition());
                this.pushDownFilter(newAboveFilter, newLeftFilters, newRightFilters, newJoinFilters);
                if (newJoinFilters.size() < originFilterSize) {
                    this.filterRel = newFilter;
                    leftFilters = newLeftFilters;
                    rightFilters = newRightFilters;
                    joinFilters = newJoinFilters;
                    aboveFilters = newAboveFilter;
                    isNeedProject = true;
                } else {
                    this.topJoinRel = originTopJoin;
                }
            }
            RexBuilder rexBuilder = this.topJoinRel.getCluster().getRexBuilder();
            RelNode leftRel = this.relBuilder.push(this.topJoinRel.getLeft()).filter(leftFilters).build();
            RelNode rightRel = this.relBuilder.push(this.topJoinRel.getRight()).filter(rightFilters).build();
            ImmutableList fieldTypes = ImmutableList.builder().addAll((Iterable)RelOptUtil.getFieldTypeList((RelDataType)leftRel.getRowType())).addAll((Iterable)RelOptUtil.getFieldTypeList((RelDataType)rightRel.getRowType())).build();
            RexNode joinFilter = RexUtil.composeConjunction((RexBuilder)rexBuilder, (Iterable)RexUtil.fixUp((RexBuilder)rexBuilder, (List)joinFilters, (List)fieldTypes), (boolean)false);
            if (joinFilter.isAlwaysTrue() && leftFilters.isEmpty() && rightFilters.isEmpty() && joinType == this.topJoinRel.getJoinType()) {
                return;
            }
            Join newJoinRel = this.topJoinRel.copy(this.topJoinRel.getTraitSet(), joinFilter, leftRel, rightRel, joinType, this.topJoinRel.isSemiJoinDone());
            this.call.getPlanner().onCopy((RelNode)this.topJoinRel, (RelNode)newJoinRel);
            if (!leftFilters.isEmpty()) {
                this.call.getPlanner().onCopy((RelNode)this.filterRel, leftRel);
            }
            if (!rightFilters.isEmpty()) {
                this.call.getPlanner().onCopy((RelNode)this.filterRel, rightRel);
            }
            this.relBuilder.push((RelNode)newJoinRel);
            this.relBuilder.convert(this.topJoinRel.getRowType(), false);
            this.relBuilder.filter((Iterable)RexUtil.fixUp((RexBuilder)rexBuilder, (List)aboveFilters, (List)RelOptUtil.getFieldTypeList((RelDataType)this.relBuilder.peek().getRowType())));
            if (isNeedProject) {
                int aCount = this.relA.getRowType().getFieldCount();
                int bCount = this.relB.getRowType().getFieldCount();
                int cCount = this.relC.getRowType().getFieldCount();
                Mappings.TargetMapping originJoinFieldsToNew = Mappings.createShiftMapping((int)(aCount + bCount + cCount), (int[])new int[]{0, 0, cCount, cCount + aCount, cCount, bCount, cCount, cCount + bCount, aCount});
                this.relBuilder.project((Iterable)this.relBuilder.fields(originJoinFieldsToNew));
            }
            this.call.transformTo(this.relBuilder.build());
        }

        private boolean pushDownFilter(List<RexNode> aboveFilters, List<RexNode> leftFilters, List<RexNode> rightFilters, List<RexNode> joinFilters) {
            JoinRelType joinType = this.topJoinRel.getJoinType();
            ImmutableList origJoinFilters = ImmutableList.copyOf(joinFilters);
            boolean filterPushed = false;
            if (RelOptUtil.classifyFilters((RelNode)this.topJoinRel, aboveFilters, (JoinRelType)joinType, (!(this.topJoinRel instanceof EquiJoin) ? 1 : 0) != 0, (!joinType.generatesNullsOnLeft() ? 1 : 0) != 0, (!joinType.generatesNullsOnRight() ? 1 : 0) != 0, joinFilters, leftFilters, rightFilters)) {
                filterPushed = true;
            }
            this.pullUpNonEquiFilters(joinFilters, false, this.topJoinRel.getRowType().getFieldList(), aboveFilters);
            this.pullUpNonEquiFilters(leftFilters, false, this.topJoinRel.getInput(0).getRowType().getFieldList(), aboveFilters);
            this.pullUpNonEquiFilters(rightFilters, true, this.topJoinRel.getInput(1).getRowType().getFieldList(), aboveFilters);
            if (leftFilters.isEmpty() && rightFilters.isEmpty() && joinFilters.size() == origJoinFilters.size() && Sets.newHashSet(joinFilters).equals(Sets.newHashSet((Iterable)origJoinFilters))) {
                filterPushed = false;
            }
            if (RelOptUtil.classifyFilters((RelNode)this.topJoinRel, joinFilters, (JoinRelType)joinType, (boolean)false, (!joinType.generatesNullsOnRight() ? 1 : 0) != 0, (!joinType.generatesNullsOnLeft() ? 1 : 0) != 0, joinFilters, leftFilters, rightFilters)) {
                filterPushed = true;
            }
            return filterPushed;
        }

        private RelNode transposeJoinRel() {
            RexBuilder rexBuilder = this.topJoinRel.getCluster().getRexBuilder();
            int aCount = this.relA.getRowType().getFieldCount();
            int bCount = this.relB.getRowType().getFieldCount();
            int cCount = this.relC.getRowType().getFieldCount();
            Mappings.TargetMapping originJoinFieldsToNew = Mappings.createShiftMapping((int)(aCount + bCount + cCount), (int[])new int[]{0, 0, cCount, cCount + aCount, cCount, bCount, cCount, cCount + bCount, aCount});
            RelNode newRightRel = this.relBuilder.push(((Join)this.bottomJoin).getRight()).build();
            Join oldLeft = (Join)this.bottomJoin;
            Join newLeftRel = oldLeft.copy(oldLeft.getTraitSet(), (RexNode)rexBuilder.makeLiteral(true), this.relBuilder.push(oldLeft.getLeft()).build(), this.relBuilder.push(this.topJoinRel.getRight()).build(), oldLeft.getJoinType(), oldLeft.isSemiJoinDone());
            this.topJoinRel = this.topJoinRel.copy(this.topJoinRel.getTraitSet(), (RexNode)rexBuilder.makeLiteral(true), (RelNode)newLeftRel, newRightRel, this.topJoinRel.getJoinType(), this.topJoinRel.isSemiJoinDone());
            ArrayList newFilterList = Lists.newArrayList();
            new RexPermuteInputsShuttle(originJoinFieldsToNew, new RelNode[]{this.topJoinRel}).visitList(RelOptUtil.conjunctions((RexNode)this.filterRel.getCondition()), (List)newFilterList);
            return this.relBuilder.push((RelNode)this.topJoinRel).filter((Iterable)newFilterList).build();
        }

        private void pullUpNonEquiFilters(List<RexNode> filters, boolean isFromRight, List<RelDataTypeField> srcFields, List<RexNode> aboveFilters) {
            RexBuilder rexBuilder = this.topJoinRel.getCluster().getRexBuilder();
            int[] offsets = new int[srcFields.size()];
            for (int i = 0; i < srcFields.size(); ++i) {
                offsets[i] = isFromRight ? this.topJoinRel.getRowType().getFieldCount() - srcFields.size() : 0;
            }
            Iterator<RexNode> itr = filters.iterator();
            while (itr.hasNext()) {
                RexNode filter = itr.next();
                RelOptUtil.InputFinder inputFinder = RelOptUtil.InputFinder.analyze((RexNode)filter);
                if (inputFinder.inputBitSet.build().asList().size() == 2 && SqlKind.EQUALS == filter.getKind()) continue;
                aboveFilters.add((RexNode)filter.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, srcFields, this.topJoinRel.getRowType().getFieldList(), offsets)));
                itr.remove();
            }
        }
    }
}

