/*
 * Decompiled with CFR 0.152.
 */
package io.kyligence.kap.query.optrule;

import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.guava30.shaded.common.base.Preconditions;
import org.apache.kylin.guava30.shaded.common.base.Predicate;
import org.apache.kylin.guava30.shaded.common.collect.Lists;

public class JoinFilterRule
extends RelOptRule {
    private final boolean pullLeft;
    private final boolean pullRight;
    private static Predicate<Join> innerJoinPredicate = join -> {
        Preconditions.checkArgument((join != null ? 1 : 0) != 0, (Object)"join MUST NOT be null");
        return join.getJoinType() == JoinRelType.INNER;
    };
    private static Predicate<Join> leftJoinPredicate = join -> {
        Preconditions.checkArgument((join != null ? 1 : 0) != 0, (Object)"join MUST NOT be null");
        return join.getJoinType() == JoinRelType.LEFT;
    };
    public static final JoinFilterRule JOIN_LEFT_FILTER = new JoinFilterRule(JoinFilterRule.operand(Join.class, null, join -> innerJoinPredicate.apply(join), (RelOptRuleOperand)JoinFilterRule.operand(Filter.class, (RelOptRuleOperandChildren)JoinFilterRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{JoinFilterRule.operand(RelNode.class, (RelOptRuleOperandChildren)JoinFilterRule.any())}), RelFactories.LOGICAL_BUILDER, true, false);
    public static final JoinFilterRule JOIN_RIGHT_FILTER = new JoinFilterRule(JoinFilterRule.operand(Join.class, null, join -> innerJoinPredicate.apply(join), (RelOptRuleOperand)JoinFilterRule.operand(RelNode.class, (RelOptRuleOperandChildren)JoinFilterRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{JoinFilterRule.operand(Filter.class, (RelOptRuleOperandChildren)JoinFilterRule.any())}), RelFactories.LOGICAL_BUILDER, false, true);
    public static final JoinFilterRule JOIN_BOTH_FILTER = new JoinFilterRule(JoinFilterRule.operand(Join.class, null, join -> innerJoinPredicate.apply(join), (RelOptRuleOperand)JoinFilterRule.operand(Filter.class, (RelOptRuleOperandChildren)JoinFilterRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{JoinFilterRule.operand(Filter.class, (RelOptRuleOperandChildren)JoinFilterRule.any())}), RelFactories.LOGICAL_BUILDER, true, true);
    public static final JoinFilterRule LEFT_JOIN_LEFT_FILTER = new JoinFilterRule(JoinFilterRule.operand(Join.class, null, join -> leftJoinPredicate.apply(join), (RelOptRuleOperand)JoinFilterRule.operand(Filter.class, (RelOptRuleOperandChildren)JoinFilterRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{JoinFilterRule.operand(RelNode.class, (RelOptRuleOperandChildren)JoinFilterRule.any())}), RelFactories.LOGICAL_BUILDER, true, false);

    public JoinFilterRule(RelOptRuleOperand operand, RelBuilderFactory builder, boolean pullLeft, boolean pullRight) {
        super(operand, builder, "JoinFilterRule:" + pullLeft + ":" + pullRight);
        this.pullLeft = pullLeft;
        this.pullRight = pullRight;
    }

    public void onMatch(RelOptRuleCall call) {
        Join join = (Join)call.rel(0);
        RelNode joinLeft = call.rel(1);
        RelNode joinRight = call.rel(2);
        RelNode newJoinLeft = joinLeft;
        RelNode newJoinRight = joinRight;
        int leftCount = joinLeft.getRowType().getFieldCount();
        int rightCount = joinRight.getRowType().getFieldCount();
        RelBuilder builder = call.builder();
        List leftFilters = null;
        List rightFilters = null;
        if (this.pullLeft) {
            newJoinLeft = joinLeft.getInput(0);
            leftFilters = RelOptUtil.conjunctions((RexNode)((Filter)joinLeft).getCondition());
        }
        if (this.pullRight) {
            newJoinRight = joinRight.getInput(0);
            rightFilters = RelOptUtil.conjunctions((RexNode)((Filter)joinRight).getCondition());
            ArrayList shiftedFilters = Lists.newArrayList();
            for (RexNode filter : rightFilters) {
                shiftedFilters.add(JoinFilterRule.shiftFilter(0, rightCount, leftCount, joinRight.getCluster().getRexBuilder(), joinRight.getRowType().getFieldList(), rightCount, join.getRowType().getFieldList(), filter));
            }
            rightFilters = shiftedFilters;
        }
        leftFilters = leftFilters == null ? Lists.newArrayList() : leftFilters;
        rightFilters = rightFilters == null ? Lists.newArrayList() : rightFilters;
        leftFilters.addAll(rightFilters);
        Join newJoin = join.copy(join.getTraitSet(), (List)Lists.newArrayList((Object[])new RelNode[]{newJoinLeft, newJoinRight}));
        RelNode finalFilter = builder.push((RelNode)newJoin).filter(leftFilters).build();
        call.transformTo(finalFilter);
    }

    private static RexNode shiftFilter(int start, int end, int offset, RexBuilder rexBuilder, List<RelDataTypeField> joinFields, int nTotalFields, List<RelDataTypeField> rightFields, RexNode filter) {
        int[] adjustments = new int[nTotalFields];
        for (int i = start; i < end; ++i) {
            adjustments[i] = offset;
        }
        return (RexNode)filter.accept((RexVisitor)new RelOptUtil.RexInputConverter(rexBuilder, joinFields, rightFields, adjustments));
    }
}

