package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.LoptMultiJoin;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableFlinkBushyJoinReorderRule;
import org.immutables.value.Value;

@Value.Enclosing
/* loaded from: input_file:flink-table-planner.jar:org/apache/flink/table/planner/plan/rules/logical/FlinkBushyJoinReorderRule.class */
public class FlinkBushyJoinReorderRule extends RelRule<Config> implements TransformationRule {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Value.Immutable(singleton = false)
    /* loaded from: input_file:flink-table-planner.jar:org/apache/flink/table/planner/plan/rules/logical/FlinkBushyJoinReorderRule$Config.class */
    public interface Config extends RelRule.Config {
        public static final Config DEFAULT = (Config) ImmutableFlinkBushyJoinReorderRule.Config.builder().build().withOperandSupplier(operandBuilder -> {
            return operandBuilder.operand(MultiJoin.class).anyInputs();
        }).as(Config.class);

        @Override // org.apache.calcite.plan.RelRule.Config
        default FlinkBushyJoinReorderRule toRule() {
            return new FlinkBushyJoinReorderRule(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:flink-table-planner.jar:org/apache/flink/table/planner/plan/rules/logical/FlinkBushyJoinReorderRule$JoinConditionShuttle.class */
    public static class JoinConditionShuttle extends RexShuttle {
        private final LoptMultiJoin multiJoin;
        private final List<Integer> leftFactorIds;
        private final List<Integer> rightFactorIds;

        public JoinConditionShuttle(LoptMultiJoin loptMultiJoin, List<Integer> list, List<Integer> list2) {
            this.multiJoin = loptMultiJoin;
            this.leftFactorIds = list;
            this.rightFactorIds = list2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitInputRef */
        public RexNode mo4800visitInputRef(RexInputRef rexInputRef) {
            int index = rexInputRef.getIndex();
            int i = 0;
            int findRef = this.multiJoin.findRef(index);
            if (this.leftFactorIds.contains(Integer.valueOf(findRef))) {
                for (Integer num : this.leftFactorIds) {
                    if (num.intValue() == findRef) {
                        return new RexInputRef(i + findFactorIndex(index, this.multiJoin), rexInputRef.getType());
                    }
                    i += this.multiJoin.getNumFieldsInJoinFactor(num.intValue());
                }
            } else {
                Iterator<Integer> it = this.leftFactorIds.iterator();
                while (it.hasNext()) {
                    i += this.multiJoin.getNumFieldsInJoinFactor(it.next().intValue());
                }
                for (Integer num2 : this.rightFactorIds) {
                    if (num2.intValue() == findRef) {
                        return new RexInputRef(i + findFactorIndex(index, this.multiJoin), rexInputRef.getType());
                    }
                    i += this.multiJoin.getNumFieldsInJoinFactor(num2.intValue());
                }
            }
            return rexInputRef;
        }

        private static int findFactorIndex(int i, LoptMultiJoin loptMultiJoin) {
            int findRef = loptMultiJoin.findRef(i);
            int i2 = 0;
            for (int i3 = 0; i3 < findRef; i3++) {
                i2 += loptMultiJoin.getNumFieldsInJoinFactor(i3);
            }
            return i - i2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:flink-table-planner.jar:org/apache/flink/table/planner/plan/rules/logical/FlinkBushyJoinReorderRule$JoinPlan.class */
    public static class JoinPlan {
        final Set<Integer> factorIds;
        final RelNode relNode;

        JoinPlan(Set<Integer> set, RelNode relNode) {
            this.factorIds = set;
            this.relNode = relNode;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean betterThan(JoinPlan joinPlan) {
            RelMetadataQuery metadataQuery = this.relNode.getCluster().getMetadataQuery();
            RelOptCost cumulativeCost = metadataQuery.getCumulativeCost(this.relNode);
            RelOptCost cumulativeCost2 = metadataQuery.getCumulativeCost(joinPlan.relNode);
            if (cumulativeCost == null || cumulativeCost2 == null) {
                return false;
            }
            return cumulativeCost.isLt(cumulativeCost2);
        }
    }

    protected FlinkBushyJoinReorderRule(Config config) {
        super(config);
    }

    @Deprecated
    public FlinkBushyJoinReorderRule(RelBuilderFactory relBuilderFactory) {
        this((Config) Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class));
    }

    @Deprecated
    public FlinkBushyJoinReorderRule(RelFactories.JoinFactory joinFactory, RelFactories.ProjectFactory projectFactory, RelFactories.FilterFactory filterFactory) {
        this(RelBuilder.proto(joinFactory, projectFactory, filterFactory));
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        relOptRuleCall.transformTo(findBestOrder(relOptRuleCall.builder(), new LoptMultiJoin((MultiJoin) relOptRuleCall.rel(0))));
    }

    private static RelNode findBestOrder(RelBuilder relBuilder, LoptMultiJoin loptMultiJoin) {
        List<Map<Set<Integer>, JoinPlan>> reorderInnerJoin = reorderInnerJoin(relBuilder, loptMultiJoin);
        JoinPlan bestPlan = getBestPlan(reorderInnerJoin.get(reorderInnerJoin.size() - 1));
        JoinPlan addOuterJoinToTop = (loptMultiJoin.getMultiJoinRel().isFullOuterJoin() || outerJoinConditionExists(loptMultiJoin)) ? addOuterJoinToTop(bestPlan, loptMultiJoin, relBuilder) : bestPlan;
        return createTopProject(relBuilder, loptMultiJoin, addOuterJoinToTop.factorIds.size() != loptMultiJoin.getNumJoinFactors() ? addCrossJoinToTop(addOuterJoinToTop, loptMultiJoin, relBuilder) : addOuterJoinToTop, loptMultiJoin.getMultiJoinRel().getRowType().getFieldNames());
    }

    private static List<Map<Set<Integer>, JoinPlan>> reorderInnerJoin(RelBuilder relBuilder, LoptMultiJoin loptMultiJoin) {
        int numJoinFactors = loptMultiJoin.getNumJoinFactors();
        ArrayList arrayList = new ArrayList();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < numJoinFactors; i++) {
            if (!loptMultiJoin.isNullGenerating(i)) {
                HashSet hashSet = new HashSet();
                LinkedHashSet linkedHashSet = new LinkedHashSet();
                hashSet.add(Integer.valueOf(i));
                linkedHashSet.add(Integer.valueOf(i));
                linkedHashMap.put(hashSet, new JoinPlan(linkedHashSet, loptMultiJoin.getJoinFactor(i)));
            }
        }
        arrayList.add(linkedHashMap);
        if (loptMultiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return arrayList;
        }
        while (arrayList.size() < numJoinFactors) {
            Map<Set<Integer>, JoinPlan> foundNextLevel = foundNextLevel(relBuilder, new ArrayList(arrayList), loptMultiJoin);
            if (foundNextLevel.size() == 0) {
                break;
            }
            arrayList.add(foundNextLevel);
        }
        return arrayList;
    }

    private static boolean outerJoinConditionExists(LoptMultiJoin loptMultiJoin) {
        for (int i = 0; i < loptMultiJoin.getNumJoinFactors(); i++) {
            if (loptMultiJoin.getOuterJoinCond(i) != null && RelOptUtil.conjunctions(loptMultiJoin.getOuterJoinCond(i)).size() != 0) {
                return true;
            }
        }
        return false;
    }

    private static JoinPlan getBestPlan(Map<Set<Integer>, JoinPlan> map) {
        JoinPlan joinPlan = null;
        for (Map.Entry<Set<Integer>, JoinPlan> entry : map.entrySet()) {
            if (joinPlan == null || entry.getValue().betterThan(joinPlan)) {
                joinPlan = entry.getValue();
            }
        }
        return joinPlan;
    }

    private static JoinPlan addOuterJoinToTop(JoinPlan joinPlan, LoptMultiJoin loptMultiJoin, RelBuilder relBuilder) {
        List<Integer> remainIndexes = getRemainIndexes(loptMultiJoin.getNumJoinFactors(), joinPlan.factorIds);
        RelNode relNode = joinPlan.relNode;
        LinkedHashSet linkedHashSet = new LinkedHashSet(joinPlan.factorIds);
        Iterator<Integer> it = remainIndexes.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            RelNode joinFactor = loptMultiJoin.getJoinFactor(intValue);
            Optional<List<RexCall>> joinConditions = getJoinConditions(joinPlan.factorIds, Collections.singleton(Integer.valueOf(intValue)), loptMultiJoin, true);
            if (joinConditions.isPresent()) {
                List<RexCall> convertToNewCondition = convertToNewCondition(new ArrayList(linkedHashSet), Collections.singletonList(Integer.valueOf(intValue)), joinConditions.get(), loptMultiJoin);
                JoinRelType joinRelType = JoinRelType.LEFT;
                if (loptMultiJoin.getMultiJoinRel().isFullOuterJoin()) {
                    if (!$assertionsDisabled && remainIndexes.size() != 1) {
                        throw new AssertionError();
                    }
                    joinRelType = JoinRelType.FULL;
                }
                relBuilder.clear();
                relNode = relBuilder.push(relNode).push(joinFactor).join(joinRelType, convertToNewCondition).build();
                linkedHashSet.add(Integer.valueOf(intValue));
            }
        }
        return new JoinPlan(linkedHashSet, relNode);
    }

    private static JoinPlan addCrossJoinToTop(JoinPlan joinPlan, LoptMultiJoin loptMultiJoin, RelBuilder relBuilder) {
        RexBuilder rexBuilder = loptMultiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        List<Integer> remainIndexes = getRemainIndexes(loptMultiJoin.getNumJoinFactors(), joinPlan.factorIds);
        RelNode relNode = joinPlan.relNode;
        LinkedHashSet linkedHashSet = new LinkedHashSet(joinPlan.factorIds);
        Iterator<Integer> it = remainIndexes.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            relBuilder.clear();
            linkedHashSet.add(Integer.valueOf(intValue));
            relNode = relBuilder.push(relNode).push(loptMultiJoin.getJoinFactor(intValue)).join(loptMultiJoin.getMultiJoinRel().getJoinTypes().get(intValue), rexBuilder.makeLiteral(true)).build();
        }
        return new JoinPlan(linkedHashSet, relNode);
    }

    private static RelNode createTopProject(RelBuilder relBuilder, LoptMultiJoin loptMultiJoin, JoinPlan joinPlan, List<String> list) {
        Integer rightColumnMapping;
        ArrayList arrayList = new ArrayList();
        RexBuilder rexBuilder = loptMultiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        ArrayList arrayList2 = new ArrayList(joinPlan.factorIds);
        int numJoinFactors = loptMultiJoin.getNumJoinFactors();
        List<RelDataTypeField> multiJoinFields = loptMultiJoin.getMultiJoinFields();
        HashMap hashMap = new HashMap();
        int i = 0;
        for (int i2 = 0; i2 < numJoinFactors; i2++) {
            hashMap.put(arrayList2.get(i2), Integer.valueOf(i));
            i += loptMultiJoin.getNumFieldsInJoinFactor(((Integer) arrayList2.get(i2)).intValue());
        }
        for (int i3 = 0; i3 < numJoinFactors; i3++) {
            Integer otherSelfJoinFactor = loptMultiJoin.isRightFactorInRemovableSelfJoin(i3) ? loptMultiJoin.getOtherSelfJoinFactor(i3) : null;
            for (int i4 = 0; i4 < loptMultiJoin.getNumFieldsInJoinFactor(i3); i4++) {
                int intValue = ((Integer) Objects.requireNonNull(hashMap.get(Integer.valueOf(i3)), (Supplier<String>) () -> {
                    return "factorToOffsetMap.get(currFactor)";
                })).intValue() + i4;
                if (otherSelfJoinFactor != null && (rightColumnMapping = loptMultiJoin.getRightColumnMapping(i3, i4)) != null) {
                    intValue = ((Integer) Objects.requireNonNull(hashMap.get(otherSelfJoinFactor), "factorToOffsetMap.get(leftFactor)")).intValue() + rightColumnMapping.intValue();
                }
                arrayList.add(rexBuilder.makeInputRef(multiJoinFields.get(arrayList.size()).getType(), intValue));
            }
        }
        relBuilder.clear();
        relBuilder.push(joinPlan.relNode);
        relBuilder.project(arrayList, list);
        RexNode postJoinFilter = loptMultiJoin.getMultiJoinRel().getPostJoinFilter();
        if (postJoinFilter != null) {
            relBuilder.filter(postJoinFilter);
        }
        return relBuilder.build();
    }

    private static Map<Set<Integer>, JoinPlan> foundNextLevel(RelBuilder relBuilder, List<Map<Set<Integer>, JoinPlan>> list, LoptMultiJoin loptMultiJoin) {
        ArrayList arrayList;
        JoinPlan joinPlan;
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int i = 0;
        for (int size = list.size() - 1; i <= size; size--) {
            ArrayList arrayList2 = new ArrayList(list.get(i).values());
            int size2 = arrayList2.size();
            for (int i2 = 0; i2 < size2; i2++) {
                JoinPlan joinPlan2 = (JoinPlan) arrayList2.get(i2);
                if (i == size) {
                    arrayList = new ArrayList(arrayList2);
                    if (i2 > 0) {
                        arrayList.subList(0, i2).clear();
                    }
                } else {
                    arrayList = new ArrayList(list.get(size).values());
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Optional<JoinPlan> buildInnerJoin = buildInnerJoin(relBuilder, joinPlan2, (JoinPlan) it.next(), loptMultiJoin);
                    if (buildInnerJoin.isPresent() && ((joinPlan = (JoinPlan) linkedHashMap.get(buildInnerJoin.get().factorIds)) == null || buildInnerJoin.get().betterThan(joinPlan))) {
                        linkedHashMap.put(buildInnerJoin.get().factorIds, buildInnerJoin.get());
                    }
                }
            }
            i++;
        }
        return linkedHashMap;
    }

    private static Optional<JoinPlan> buildInnerJoin(RelBuilder relBuilder, JoinPlan joinPlan, JoinPlan joinPlan2, LoptMultiJoin loptMultiJoin) {
        JoinPlan joinPlan3;
        JoinPlan joinPlan4;
        HashSet hashSet = new HashSet(joinPlan.factorIds);
        hashSet.retainAll(joinPlan2.factorIds);
        if (!hashSet.isEmpty()) {
            return Optional.empty();
        }
        Optional<List<RexCall>> joinConditions = getJoinConditions(joinPlan.factorIds, joinPlan2.factorIds, loptMultiJoin, false);
        if (!joinConditions.isPresent()) {
            return Optional.empty();
        }
        List<RexCall> list = joinConditions.get();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        if (joinPlan.factorIds.size() >= joinPlan2.factorIds.size()) {
            joinPlan3 = joinPlan;
            joinPlan4 = joinPlan2;
        } else {
            joinPlan3 = joinPlan2;
            joinPlan4 = joinPlan;
        }
        linkedHashSet.addAll(joinPlan3.factorIds);
        linkedHashSet.addAll(joinPlan4.factorIds);
        List<RexCall> convertToNewCondition = convertToNewCondition(new ArrayList(joinPlan3.factorIds), new ArrayList(joinPlan4.factorIds), list, loptMultiJoin);
        relBuilder.clear();
        return Optional.of(new JoinPlan(linkedHashSet, (Join) relBuilder.push(joinPlan3.relNode).push(joinPlan4.relNode).join(JoinRelType.INNER, convertToNewCondition).build()));
    }

    private static List<RexCall> convertToNewCondition(List<Integer> list, List<Integer> list2, List<RexCall> list3, LoptMultiJoin loptMultiJoin) {
        RexBuilder rexBuilder = loptMultiJoin.getMultiJoinRel().getCluster().getRexBuilder();
        ArrayList arrayList = new ArrayList();
        for (RexCall rexCall : list3) {
            ArrayList arrayList2 = new ArrayList();
            Iterator<RexNode> it = rexCall.getOperands().iterator();
            while (it.hasNext()) {
                arrayList2.add((RexNode) it.next().accept(new JoinConditionShuttle(loptMultiJoin, list, list2)));
            }
            arrayList.add((RexCall) rexBuilder.makeCall(rexCall.op, arrayList2));
        }
        return arrayList;
    }

    private static Optional<List<RexCall>> getJoinConditions(Set<Integer> set, Set<Integer> set2, LoptMultiJoin loptMultiJoin, boolean z) {
        ArrayList arrayList = new ArrayList();
        List<RexNode> arrayList2 = new ArrayList();
        if (!z || loptMultiJoin.getMultiJoinRel().isFullOuterJoin()) {
            arrayList2 = loptMultiJoin.getJoinFilters();
        } else {
            for (int i = 0; i < loptMultiJoin.getNumJoinFactors(); i++) {
                arrayList2.addAll(RelOptUtil.conjunctions(loptMultiJoin.getOuterJoinCond(i)));
            }
        }
        for (RexNode rexNode : arrayList2) {
            if (!(rexNode instanceof RexCall)) {
                return Optional.empty();
            }
            RexCall rexCall = (RexCall) rexNode;
            ImmutableBitSet factorsRefByJoinFilter = loptMultiJoin.getFactorsRefByJoinFilter(rexCall);
            int i2 = 0;
            int i3 = 0;
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                if (factorsRefByJoinFilter.get(it.next().intValue())) {
                    i2++;
                }
            }
            Iterator<Integer> it2 = set2.iterator();
            while (it2.hasNext()) {
                if (factorsRefByJoinFilter.get(it2.next().intValue())) {
                    i3++;
                }
            }
            if (i2 > 0 && i3 > 0 && i2 + i3 == factorsRefByJoinFilter.asSet().size()) {
                arrayList.add(rexCall);
            }
        }
        return arrayList.isEmpty() ? Optional.empty() : Optional.of(arrayList);
    }

    private static List<Integer> getRemainIndexes(int i, Set<Integer> set) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            if (!set.contains(Integer.valueOf(i2))) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        return arrayList;
    }

    static {
        $assertionsDisabled = !FlinkBushyJoinReorderRule.class.desiredAssertionStatus();
    }
}
