package org.apache.druid.sql.calcite.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlBinaryOperator;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.error.DruidException;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
import org.apache.druid.sql.calcite.rel.DruidQueryRel;
import org.apache.druid.sql.calcite.rel.DruidRel;
import org.apache.druid.sql.calcite.rel.PartialDruidQuery;

/* loaded from: input_file:org/apache/druid/sql/calcite/rule/DruidJoinRule.class */
public class DruidJoinRule extends RelOptRule {
    private final boolean enableLeftScanDirect;
    private final PlannerContext plannerContext;

    /* loaded from: input_file:org/apache/druid/sql/calcite/rule/DruidJoinRule$ConditionAnalysis.class */
    public static class ConditionAnalysis {
        private final int numLeftFields;
        private final List<RexEquality> equalitySubConditions;
        private final List<RexLiteral> literalSubConditions;
        private final List<RexNode> unsupportedOnSubConditions;
        private final Set<RexInputRef> rightColumns;

        ConditionAnalysis(int i, List<RexEquality> list, List<RexLiteral> list2, List<RexNode> list3, Set<RexInputRef> set) {
            this.numLeftFields = i;
            this.equalitySubConditions = list;
            this.literalSubConditions = list2;
            this.unsupportedOnSubConditions = list3;
            this.rightColumns = set;
        }

        public ConditionAnalysis pushThroughLeftProject(Project project) {
            int fieldCount = project.getInput().getRowType().getFieldCount() - project.getRowType().getFieldCount();
            return new ConditionAnalysis(project.getInput().getRowType().getFieldCount(), (List) this.equalitySubConditions.stream().map(rexEquality -> {
                return new RexEquality(RelOptUtil.pushPastProject(rexEquality.left, project), RexUtil.shift(rexEquality.right, fieldCount), rexEquality.kind);
            }).collect(Collectors.toList()), this.literalSubConditions, this.unsupportedOnSubConditions, this.rightColumns);
        }

        public ConditionAnalysis pushThroughRightProject(Project project) {
            Preconditions.checkArgument(onlyUsesMappingsFromRightProject(project), "Cannot push through");
            return new ConditionAnalysis(this.numLeftFields, (List) this.equalitySubConditions.stream().map(rexEquality -> {
                return new RexEquality(rexEquality.left, RexUtil.shift(RelOptUtil.pushPastProject(RexUtil.shift(rexEquality.right, -this.numLeftFields), project), this.numLeftFields), rexEquality.kind);
            }).collect(Collectors.toList()), this.literalSubConditions, this.unsupportedOnSubConditions, this.rightColumns);
        }

        public boolean onlyUsesMappingsFromRightProject(Project project) {
            Iterator<RexEquality> it = this.equalitySubConditions.iterator();
            while (it.hasNext()) {
                if (!((RexNode) project.getProjects().get(it.next().right.getIndex() - this.numLeftFields)).isA(SqlKind.INPUT_REF)) {
                    return false;
                }
            }
            return true;
        }

        public RexNode getConditionWithUnsupportedSubConditionsIgnored(RexBuilder rexBuilder) {
            return RexUtil.composeConjunction(rexBuilder, Iterables.concat(this.literalSubConditions, (Iterable) this.equalitySubConditions.stream().map(rexEquality -> {
                return rexEquality.makeCall(rexBuilder);
            }).collect(Collectors.toList())), false);
        }

        public List<RexNode> getUnsupportedOnSubConditions() {
            return this.unsupportedOnSubConditions;
        }

        public String toString() {
            return "ConditionAnalysis{numLeftFields=" + this.numLeftFields + ", equalitySubConditions=" + this.equalitySubConditions + ", literalSubConditions=" + this.literalSubConditions + ", unsupportedSubConditions=" + this.unsupportedOnSubConditions + ", rightColumns=" + this.rightColumns + '}';
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/druid/sql/calcite/rule/DruidJoinRule$RexEquality.class */
    public static class RexEquality {
        private final RexNode left;
        private final RexInputRef right;
        private final SqlKind kind;

        public RexEquality(RexNode rexNode, RexInputRef rexInputRef, SqlKind sqlKind) {
            this.left = rexNode;
            this.right = rexInputRef;
            this.kind = sqlKind;
        }

        public RexNode makeCall(RexBuilder rexBuilder) {
            SqlBinaryOperator sqlBinaryOperator;
            if (this.kind == SqlKind.EQUALS) {
                sqlBinaryOperator = SqlStdOperatorTable.EQUALS;
            } else {
                if (this.kind != SqlKind.IS_NOT_DISTINCT_FROM) {
                    throw DruidException.defensive("Unexpected operator kind[%s]", new Object[]{this.kind});
                }
                sqlBinaryOperator = SqlStdOperatorTable.IS_NOT_DISTINCT_FROM;
            }
            return rexBuilder.makeCall(sqlBinaryOperator, new RexNode[]{this.left, this.right});
        }

        public String toString() {
            return "RexEquality{left=" + this.left + ", right=" + this.right + ", kind=" + this.kind + '}';
        }
    }

    private DruidJoinRule(PlannerContext plannerContext) {
        super(operand(Join.class, operand(DruidRel.class, any()), new RelOptRuleOperand[]{operand(DruidRel.class, any())}));
        this.enableLeftScanDirect = plannerContext.queryContext().getEnableJoinLeftScanDirect();
        this.plannerContext = plannerContext;
    }

    public static DruidJoinRule instance(PlannerContext plannerContext) {
        return new DruidJoinRule(plannerContext);
    }

    public boolean matches(RelOptRuleCall relOptRuleCall) {
        Join rel = relOptRuleCall.rel(0);
        DruidRel rel2 = relOptRuleCall.rel(1);
        DruidRel<?> druidRel = (DruidRel) relOptRuleCall.rel(2);
        return (!canHandleCondition(rel.getCondition(), rel.getLeft().getRowType(), druidRel, rel.getJoinType(), rel.getSystemFieldList(), rel.getCluster().getRexBuilder()) || rel2.getPartialDruidQuery() == null || druidRel.getPartialDruidQuery() == null) ? false : true;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        DruidRel druidRel;
        Filter filter;
        DruidRel druidRel2;
        Join rel = relOptRuleCall.rel(0);
        DruidRel rel2 = relOptRuleCall.rel(1);
        DruidRel rel3 = relOptRuleCall.rel(2);
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        ArrayList arrayList = new ArrayList();
        ConditionAnalysis analyzeCondition = analyzeCondition(rel.getCondition(), rel.getLeft().getRowType(), rexBuilder);
        boolean z = this.enableLeftScanDirect && (rel2 instanceof DruidQueryRel);
        if (!this.plannerContext.getJoinAlgorithm().requiresSubquery() && rel2.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && (z || rel2.getPartialDruidQuery().getWhereFilter() == null)) {
            RelNode scan = rel2.getPartialDruidQuery().getScan();
            Project selectProject = rel2.getPartialDruidQuery().getSelectProject();
            filter = rel2.getPartialDruidQuery().getWhereFilter();
            arrayList.addAll(selectProject.getProjects());
            druidRel = rel2.withPartialQuery(PartialDruidQuery.create(scan));
            analyzeCondition = analyzeCondition.pushThroughLeftProject(selectProject);
        } else {
            for (int i = 0; i < rel2.getRowType().getFieldCount(); i++) {
                arrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(i)).getType(), i));
            }
            druidRel = rel2;
            filter = null;
        }
        if (!this.plannerContext.getJoinAlgorithm().requiresSubquery() && rel3.getPartialDruidQuery().stage() == PartialDruidQuery.Stage.SELECT_PROJECT && rel3.getPartialDruidQuery().getWhereFilter() == null && !rel3.getPartialDruidQuery().getSelectProject().isMapping() && analyzeCondition.onlyUsesMappingsFromRightProject(rel3.getPartialDruidQuery().getSelectProject())) {
            RelNode scan2 = rel3.getPartialDruidQuery().getScan();
            Project selectProject2 = rel3.getPartialDruidQuery().getSelectProject();
            for (RexNode rexNode : RexUtil.shift(selectProject2.getProjects(), druidRel.getRowType().getFieldCount())) {
                if (rel.getJoinType().generatesNullsOnRight()) {
                    arrayList.add(makeNullableIfLiteral(rexNode, rexBuilder));
                } else {
                    arrayList.add(rexNode);
                }
            }
            druidRel2 = rel3.withPartialQuery(PartialDruidQuery.create(scan2));
            analyzeCondition = analyzeCondition.pushThroughRightProject(selectProject2);
        } else {
            for (int i2 = 0; i2 < rel3.getRowType().getFieldCount(); i2++) {
                arrayList.add(rexBuilder.makeInputRef(((RelDataTypeField) rel.getRowType().getFieldList().get(rel2.getRowType().getFieldCount() + i2)).getType(), druidRel.getRowType().getFieldCount() + i2));
            }
            druidRel2 = rel3;
        }
        DruidJoinQueryRel create = DruidJoinQueryRel.create(rel.copy(rel.getTraitSet(), analyzeCondition.getConditionWithUnsupportedSubConditionsIgnored(rexBuilder), druidRel, druidRel2, rel.getJoinType(), rel.isSemiJoinDone()), filter, rel2.getPlannerContext());
        RelBuilder project = relOptRuleCall.builder().push(create).project(RexUtil.fixUp(rexBuilder, arrayList, RelOptUtil.getFieldTypeList(create.getRowType())));
        RexNode composeConjunction = RexUtil.composeConjunction(rexBuilder, analyzeCondition.getUnsupportedOnSubConditions(), true);
        if (composeConjunction != null) {
            project = project.filter(new RexNode[]{composeConjunction});
        }
        project.convert(rel.getRowType(), false);
        relOptRuleCall.transformTo(project.build());
    }

    private static RexNode makeNullableIfLiteral(RexNode rexNode, RexBuilder rexBuilder) {
        return rexNode.isA(SqlKind.LITERAL) ? rexBuilder.makeLiteral(RexLiteral.value(rexNode), rexBuilder.getTypeFactory().createTypeWithNullability(rexNode.getType(), true), true) : rexNode;
    }

    @VisibleForTesting
    public boolean canHandleCondition(RexNode rexNode, RelDataType relDataType, DruidRel<?> druidRel, JoinRelType joinRelType, List<RelDataTypeField> list, RexBuilder rexBuilder) {
        ConditionAnalysis analyzeCondition = analyzeCondition(rexNode, relDataType, rexBuilder);
        if (druidRel != null && !DruidJoinQueryRel.computeRightRequiresSubquery(this.plannerContext, DruidJoinQueryRel.getSomeDruidChild(druidRel)) && (druidRel instanceof DruidQueryRel) && (((DruidQueryRel) druidRel).getDruidTable().getDataSource() instanceof LookupDataSource) && analyzeCondition.rightColumns.stream().map((v0) -> {
            return v0.getIndex();
        }).distinct().count() > 1) {
            this.plannerContext.setPlanningError("SQL is resulting in a join involving lookup where value column is used in the condition.", new Object[0]);
            return false;
        }
        if (joinRelType == JoinRelType.INNER && list.isEmpty() && !NullHandling.replaceWithDefault()) {
            return true;
        }
        return analyzeCondition.getUnsupportedOnSubConditions().isEmpty();
    }

    public ConditionAnalysis analyzeCondition(RexNode rexNode, RelDataType relDataType, RexBuilder rexBuilder) {
        RexNode makeLiteral;
        RexNode rexNode2;
        SqlKind sqlKind;
        List<RexNode> decomposeAnd = decomposeAnd(rexNode);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        HashSet hashSet = new HashSet();
        int fieldCount = relDataType.getFieldCount();
        for (RexNode rexNode3 : decomposeAnd) {
            if (RexUtil.isLiteral(rexNode3, true)) {
                if (rexNode3.isA(SqlKind.CAST)) {
                    RexCall rexCall = (RexCall) rexNode3;
                    if (rexCall.getType().getSqlTypeName().equals(((RexNode) rexCall.getOperands().get(0)).getType().getSqlTypeName())) {
                        arrayList2.add((RexLiteral) rexCall.getOperands().get(0));
                    } else {
                        arrayList3.add(rexNode3);
                    }
                } else {
                    arrayList2.add((RexLiteral) rexNode3);
                }
            } else if (rexNode3.isA(SqlKind.INPUT_REF)) {
                makeLiteral = rexBuilder.makeLiteral(true);
                rexNode2 = rexNode3;
                sqlKind = SqlKind.EQUALS;
                if (!SqlTypeName.BOOLEAN_TYPES.contains(rexNode2.getType().getSqlTypeName())) {
                    this.plannerContext.setPlanningError("SQL requires a join with '%s' condition where the column is of the type %s, that is not supported", rexNode3.getKind(), rexNode2.getType().getSqlTypeName());
                    arrayList3.add(rexNode3);
                } else if (!isLeftExpression(makeLiteral, fieldCount) && isRightInputRef(rexNode2, fieldCount)) {
                    arrayList.add(new RexEquality(makeLiteral, (RexInputRef) rexNode2, sqlKind));
                    hashSet.add((RexInputRef) rexNode2);
                } else if (isRightInputRef(makeLiteral, fieldCount) || !isLeftExpression(rexNode2, fieldCount)) {
                    this.plannerContext.setPlanningError("SQL is resulting in a join that has unsupported operand types.", new Object[0]);
                    arrayList3.add(rexNode3);
                } else {
                    arrayList.add(new RexEquality(rexNode2, (RexInputRef) makeLiteral, rexNode3.getKind()));
                    hashSet.add((RexInputRef) makeLiteral);
                }
            } else if (rexNode3.isA(SqlKind.EQUALS) || rexNode3.isA(SqlKind.IS_NOT_DISTINCT_FROM)) {
                List operands = ((RexCall) rexNode3).getOperands();
                Preconditions.checkState(operands.size() == 2, "Expected 2 operands, got[%s]", operands.size());
                makeLiteral = (RexNode) operands.get(0);
                rexNode2 = (RexNode) operands.get(1);
                sqlKind = rexNode3.getKind();
                if (!isLeftExpression(makeLiteral, fieldCount)) {
                }
                if (isRightInputRef(makeLiteral, fieldCount)) {
                }
                this.plannerContext.setPlanningError("SQL is resulting in a join that has unsupported operand types.", new Object[0]);
                arrayList3.add(rexNode3);
            } else {
                this.plannerContext.setPlanningError("SQL requires a join with '%s' condition that is not supported.", rexNode3.getKind());
                arrayList3.add(rexNode3);
            }
        }
        return new ConditionAnalysis(fieldCount, arrayList, arrayList2, arrayList3, hashSet);
    }

    @VisibleForTesting
    static List<RexNode> decomposeAnd(RexNode rexNode) {
        ArrayList arrayList = new ArrayList();
        Stack stack = new Stack();
        stack.push(rexNode);
        while (!stack.empty()) {
            RexCall rexCall = (RexNode) stack.pop();
            if (rexCall.isA(SqlKind.AND)) {
                List operands = rexCall.getOperands();
                for (int size = operands.size() - 1; size >= 0; size--) {
                    stack.push((RexNode) operands.get(size));
                }
            } else {
                arrayList.add(rexCall);
            }
        }
        return arrayList;
    }

    private static boolean isLeftExpression(RexNode rexNode, int i) {
        return ImmutableBitSet.range(i).contains(RelOptUtil.InputFinder.bits(rexNode));
    }

    private static boolean isRightInputRef(RexNode rexNode, int i) {
        return rexNode.isA(SqlKind.INPUT_REF) && ((RexInputRef) rexNode).getIndex() >= i;
    }
}
