package org.apache.calcite.rel.rules;

import com.sun.xml.bind.v2.runtime.reflect.opt.Const;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.ViewExpanders;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.SemiJoin;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.schema.Wrapper;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import shaded.com.google.common.collect.Lists;
import shaded.com.google.common.collect.Ordering;

/* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer.class */
public class LoptSemiJoinOptimizer {
    private static final int THRESHOLD_SCORE = 10;
    private final RexBuilder rexBuilder;
    private final RelMetadataQuery mq;
    private RelNode[] chosenSemiJoins;
    private Map<Integer, Map<Integer, SemiJoin>> possibleSemiJoins;
    private final Ordering<Integer> factorCostOrdering = Ordering.from(new FactorCostComparator());
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$FactorCostComparator.class */
    private class FactorCostComparator implements Comparator<Integer> {
        private FactorCostComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Integer num, Integer num2) {
            RelOptCost cumulativeCost = LoptSemiJoinOptimizer.this.mq.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[num.intValue()]);
            RelOptCost cumulativeCost2 = LoptSemiJoinOptimizer.this.mq.getCumulativeCost(LoptSemiJoinOptimizer.this.chosenSemiJoins[num2.intValue()]);
            if (cumulativeCost == null || cumulativeCost2 == null || cumulativeCost.isLt(cumulativeCost2)) {
                return -1;
            }
            return cumulativeCost.equals(cumulativeCost2) ? 0 : 1;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$FemLocalIndex.class */
    public static class FemLocalIndex {
        private FemLocalIndex() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$LcsIndexOptimizer.class */
    public static class LcsIndexOptimizer {
        LcsIndexOptimizer(LcsTableScan lcsTableScan) {
        }

        public FemLocalIndex findSemiJoinIndexByCost(RelNode relNode, List<Integer> list, List<Integer> list2, List<Integer> list3) {
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$LcsTable.class */
    public static abstract class LcsTable implements RelOptTable {
        private LcsTable() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$LcsTableScan.class */
    public static class LcsTableScan {
        private LcsTableScan() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/calcite/rel/rules/LoptSemiJoinOptimizer$LucidDbSpecialOperators.class */
    public static class LucidDbSpecialOperators {
        private LucidDbSpecialOperators() {
        }

        public static boolean isLcsRidColumnId(int i) {
            return false;
        }
    }

    public LoptSemiJoinOptimizer(RelMetadataQuery relMetadataQuery, LoptMultiJoin loptMultiJoin, RexBuilder rexBuilder) {
        this.mq = relMetadataQuery;
        int numJoinFactors = loptMultiJoin.getNumJoinFactors();
        this.chosenSemiJoins = new RelNode[numJoinFactors];
        for (int i = 0; i < numJoinFactors; i++) {
            this.chosenSemiJoins[i] = loptMultiJoin.getJoinFactor(i);
        }
        this.rexBuilder = rexBuilder;
    }

    public void makePossibleSemiJoins(LoptMultiJoin loptMultiJoin) {
        SemiJoin findSemiJoinIndexByCost;
        this.possibleSemiJoins = new HashMap();
        if (loptMultiJoin.getMultiJoinRel().isFullOuterJoin()) {
            return;
        }
        int numJoinFactors = loptMultiJoin.getNumJoinFactors();
        for (int i = 0; i < numJoinFactors; i++) {
            HashMap hashMap = new HashMap();
            HashMap hashMap2 = new HashMap();
            for (RexNode rexNode : loptMultiJoin.getJoinFilters()) {
                int isSuitableFilter = isSuitableFilter(loptMultiJoin, rexNode, i);
                if (isSuitableFilter != -1 && !loptMultiJoin.isNullGenerating(i) && !loptMultiJoin.isNullGenerating(isSuitableFilter)) {
                    List list = (List) hashMap.get(Integer.valueOf(isSuitableFilter));
                    if (list == null) {
                        list = new ArrayList();
                    }
                    list.add(rexNode);
                    hashMap.put(Integer.valueOf(isSuitableFilter), list);
                }
            }
            for (Integer num : hashMap.keySet()) {
                List<RexNode> list2 = (List) hashMap.get(num);
                if (list2 != null && (findSemiJoinIndexByCost = findSemiJoinIndexByCost(loptMultiJoin, list2, i, num.intValue())) != null) {
                    hashMap2.put(num, findSemiJoinIndexByCost);
                    this.possibleSemiJoins.put(Integer.valueOf(i), hashMap2);
                }
            }
        }
    }

    private int isSuitableFilter(LoptMultiJoin loptMultiJoin, RexNode rexNode, int i) {
        switch (rexNode.getKind()) {
            case EQUALS:
                List<RexNode> operands = ((RexCall) rexNode).getOperands();
                if (!(operands.get(0) instanceof RexInputRef) || !(operands.get(1) instanceof RexInputRef)) {
                    return -1;
                }
                ImmutableBitSet factorsRefByJoinFilter = loptMultiJoin.getFactorsRefByJoinFilter(rexNode);
                if (!$assertionsDisabled && factorsRefByJoinFilter.cardinality() != 2) {
                    throw new AssertionError();
                }
                int nextSetBit = factorsRefByJoinFilter.nextSetBit(0);
                int nextSetBit2 = factorsRefByJoinFilter.nextSetBit(nextSetBit + 1);
                if (nextSetBit == i) {
                    return nextSetBit2;
                }
                if (nextSetBit2 == i) {
                    return nextSetBit;
                }
                return -1;
            default:
                return -1;
        }
    }

    private SemiJoin findSemiJoinIndexByCost(LoptMultiJoin loptMultiJoin, List<RexNode> list, int i, int i2) {
        ArrayList arrayList;
        ArrayList arrayList2;
        RexNode composeConjunction = RexUtil.composeConjunction(this.rexBuilder, list, true);
        int i3 = 0;
        for (int i4 = 0; i4 < i; i4++) {
            i3 -= loptMultiJoin.getNumFieldsInJoinFactor(i4);
        }
        RexNode adjustSemiJoinCondition = adjustSemiJoinCondition(loptMultiJoin, i3, composeConjunction, i, i2);
        RelNode joinFactor = loptMultiJoin.getJoinFactor(i);
        RelNode joinFactor2 = loptMultiJoin.getJoinFactor(i2);
        JoinInfo of = JoinInfo.of(joinFactor, joinFactor2, adjustSemiJoinCondition);
        if (!$assertionsDisabled && of.leftKeys.size() <= 0) {
            throw new AssertionError();
        }
        ArrayList newArrayList = Lists.newArrayList(of.leftKeys);
        ArrayList newArrayList2 = Lists.newArrayList(of.rightKeys);
        ArrayList arrayList3 = new ArrayList();
        LcsTable validateKeys = validateKeys(joinFactor, newArrayList, newArrayList2, arrayList3);
        if (validateKeys == null) {
            return null;
        }
        ArrayList arrayList4 = new ArrayList();
        if (new LcsIndexOptimizer((LcsTableScan) validateKeys.toRel(ViewExpanders.simpleContext(joinFactor.getCluster()))).findSemiJoinIndexByCost(joinFactor2, arrayList3, newArrayList2, arrayList4) == null) {
            return null;
        }
        if (arrayList3.size() == arrayList4.size()) {
            arrayList = newArrayList;
            arrayList2 = newArrayList2;
        } else {
            arrayList = new ArrayList();
            arrayList2 = new ArrayList();
            Iterator it2 = arrayList4.iterator();
            while (it2.hasNext()) {
                int intValue = ((Integer) it2.next()).intValue();
                arrayList.add(newArrayList.get(intValue));
                arrayList2.add(newArrayList2.get(intValue));
            }
            adjustSemiJoinCondition = removeExtraFilters(arrayList, loptMultiJoin.getNumFieldsInJoinFactor(i), adjustSemiJoinCondition);
        }
        return SemiJoin.create(joinFactor, joinFactor2, adjustSemiJoinCondition, ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList), ImmutableIntList.copyOf((Iterable<? extends Number>) arrayList2));
    }

    private RexNode adjustSemiJoinCondition(LoptMultiJoin loptMultiJoin, int i, RexNode rexNode, int i2, int i3) {
        int i4 = 0;
        for (int i5 = 0; i5 < i3; i5++) {
            i4 -= loptMultiJoin.getNumFieldsInJoinFactor(i5);
        }
        int i6 = -i4;
        int numFieldsInJoinFactor = loptMultiJoin.getNumFieldsInJoinFactor(i2);
        int numFieldsInJoinFactor2 = loptMultiJoin.getNumFieldsInJoinFactor(i3);
        int i7 = i4 + numFieldsInJoinFactor;
        if (i == 0 && i7 == 0) {
            return rexNode;
        }
        int[] iArr = new int[loptMultiJoin.getNumTotalFields()];
        if (i != 0) {
            for (int i8 = -i; i8 < (-i) + numFieldsInJoinFactor; i8++) {
                iArr[i8] = i;
            }
        }
        if (i7 != 0) {
            for (int i9 = i6; i9 < i6 + numFieldsInJoinFactor2; i9++) {
                iArr[i9] = i7;
            }
        }
        return (RexNode) rexNode.accept(new RelOptUtil.RexInputConverter(this.rexBuilder, loptMultiJoin.getMultiJoinFields(), iArr));
    }

    private LcsTable validateKeys(RelNode relNode, List<Integer> list, List<Integer> list2, List<Integer> list3) {
        int i = 0;
        Wrapper wrapper = null;
        ListIterator<Integer> listIterator = list.listIterator();
        while (listIterator.hasNext()) {
            boolean z = false;
            RelColumnOrigin columnOrigin = this.mq.getColumnOrigin(relNode, listIterator.next().intValue());
            if (columnOrigin == null || LucidDbSpecialOperators.isLcsRidColumnId(columnOrigin.getOriginColumnOrdinal())) {
                z = true;
            } else {
                Wrapper originTable = columnOrigin.getOriginTable();
                if (wrapper == null) {
                    if (originTable instanceof LcsTable) {
                        wrapper = originTable;
                    } else {
                        z = true;
                    }
                } else if (!$assertionsDisabled && originTable != wrapper) {
                    throw new AssertionError();
                }
            }
            if (z) {
                listIterator.remove();
                list2.remove(i);
            } else {
                list3.add(Integer.valueOf(columnOrigin.getOriginColumnOrdinal()));
                i++;
            }
        }
        if (list3.isEmpty()) {
            return null;
        }
        return (LcsTable) wrapper;
    }

    private RexNode removeExtraFilters(List<Integer> list, int i, RexNode rexNode) {
        if (!$assertionsDisabled && !(rexNode instanceof RexCall)) {
            throw new AssertionError();
        }
        RexCall rexCall = (RexCall) rexNode;
        if (rexNode.isA(SqlKind.AND)) {
            List<RexNode> operands = rexCall.getOperands();
            RexNode removeExtraFilters = removeExtraFilters(list, i, operands.get(0));
            RexNode removeExtraFilters2 = removeExtraFilters(list, i, operands.get(1));
            return removeExtraFilters == null ? removeExtraFilters2 : removeExtraFilters2 == null ? removeExtraFilters : this.rexBuilder.makeCall(SqlStdOperatorTable.AND, removeExtraFilters, removeExtraFilters2);
        }
        if (!$assertionsDisabled && rexCall.getOperator() != SqlStdOperatorTable.EQUALS) {
            throw new AssertionError();
        }
        List<RexNode> operands2 = rexCall.getOperands();
        if (!$assertionsDisabled && !(operands2.get(0) instanceof RexInputRef)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !(operands2.get(1) instanceof RexInputRef)) {
            throw new AssertionError();
        }
        int index = ((RexInputRef) operands2.get(0)).getIndex();
        if (index < i) {
            if (!list.contains(Integer.valueOf(index))) {
                return null;
            }
        } else if (!list.contains(Integer.valueOf(((RexInputRef) operands2.get(1)).getIndex()))) {
            return null;
        }
        return rexNode;
    }

    public boolean chooseBestSemiJoin(LoptMultiJoin loptMultiJoin) {
        int numJoinFactors = loptMultiJoin.getNumJoinFactors();
        List immutableSortedCopy = this.factorCostOrdering.immutableSortedCopy(Util.range(numJoinFactors));
        for (int i = 0; i < numJoinFactors; i++) {
            Integer num = (Integer) immutableSortedCopy.get(i);
            RelNode relNode = this.chosenSemiJoins[num.intValue()];
            Map<Integer, SemiJoin> map = this.possibleSemiJoins.get(num);
            if (map != null) {
                double d = 0.0d;
                int i2 = -1;
                for (Integer num2 : map.keySet()) {
                    SemiJoin semiJoin = map.get(num2);
                    if (semiJoin != null) {
                        double computeScore = computeScore(relNode, this.chosenSemiJoins[num2.intValue()], semiJoin);
                        if (computeScore > 10.0d && computeScore > d) {
                            i2 = num2.intValue();
                            d = computeScore;
                        }
                    }
                }
                if (i2 != -1) {
                    SemiJoin semiJoin2 = map.get(Integer.valueOf(i2));
                    SemiJoin create = SemiJoin.create(relNode, this.chosenSemiJoins[i2], semiJoin2.getCondition(), semiJoin2.getLeftKeys(), semiJoin2.getRightKeys());
                    this.chosenSemiJoins[num.intValue()] = create;
                    removeJoin(loptMultiJoin, create, num.intValue(), i2);
                    removePossibleSemiJoin(map, num, Integer.valueOf(i2));
                    removePossibleSemiJoin(this.possibleSemiJoins.get(Integer.valueOf(i2)), Integer.valueOf(i2), num);
                    return true;
                }
            }
        }
        return false;
    }

    private double computeScore(RelNode relNode, RelNode relNode2, SemiJoin semiJoin) {
        RelOptCost cumulativeCost;
        ImmutableBitSet of = ImmutableBitSet.of(semiJoin.getRightKeys());
        double computeSemiJoinSelectivity = RelMdUtil.computeSemiJoinSelectivity(this.mq, relNode, relNode2, semiJoin);
        if (computeSemiJoinSelectivity > 0.5d || (cumulativeCost = this.mq.getCumulativeCost(relNode)) == null) {
            return Const.default_value_double;
        }
        double sqrt = (1.0d - Math.sqrt(computeSemiJoinSelectivity)) * Math.max(1.0d, cumulativeCost.getRows());
        boolean areColumnsDefinitelyUniqueWhenNullsFiltered = RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(this.mq, relNode2, of);
        if (areColumnsDefinitelyUniqueWhenNullsFiltered) {
            sqrt *= 2.0d;
        }
        Double rowCount = this.mq.getRowCount(relNode2);
        Double valueOf = Double.valueOf(areColumnsDefinitelyUniqueWhenNullsFiltered ? Const.default_value_double : rowCount.doubleValue());
        RelOptCost cumulativeCost2 = this.mq.getCumulativeCost(relNode2);
        if (rowCount == null || valueOf == null || cumulativeCost2 == null) {
            return Const.default_value_double;
        }
        double rows = cumulativeCost2.getRows();
        if (rows < 1.0d) {
            rows = 1.0d;
        }
        return sqrt / rows;
    }

    private void removeJoin(LoptMultiJoin loptMultiJoin, SemiJoin semiJoin, int i, int i2) {
        if (loptMultiJoin.getJoinRemovalFactor(i2) != null) {
            return;
        }
        ImmutableBitSet of = ImmutableBitSet.of(semiJoin.getRightKeys());
        if (RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(this.mq, loptMultiJoin.getJoinFactor(i2), of)) {
            ImmutableBitSet projFields = loptMultiJoin.getProjFields(i2);
            if (projFields == null) {
                projFields = ImmutableBitSet.range(0, loptMultiJoin.getNumFieldsInJoinFactor(i2));
            }
            if (of.contains(projFields)) {
                int[] joinFieldRefCounts = loptMultiJoin.getJoinFieldRefCounts(i2);
                for (int i3 = 0; i3 < joinFieldRefCounts.length; i3++) {
                    if (joinFieldRefCounts[i3] > 0 && !of.get(i3)) {
                        return;
                    }
                }
                loptMultiJoin.setJoinRemovalFactor(i2, i);
                loptMultiJoin.setJoinRemovalSemiJoin(i2, semiJoin);
                if (projFields.cardinality() != 0) {
                    return;
                }
                for (int i4 = 0; i4 < joinFieldRefCounts.length; i4++) {
                    if (joinFieldRefCounts[i4] > 1) {
                        return;
                    }
                    if (joinFieldRefCounts[i4] == 1 && !of.get(i4)) {
                        return;
                    }
                }
                int[] joinFieldRefCounts2 = loptMultiJoin.getJoinFieldRefCounts(i);
                Iterator<Integer> it2 = semiJoin.getLeftKeys().iterator();
                while (it2.hasNext()) {
                    int intValue = it2.next().intValue();
                    joinFieldRefCounts2[intValue] = joinFieldRefCounts2[intValue] - 1;
                }
            }
        }
    }

    private void removePossibleSemiJoin(Map<Integer, SemiJoin> map, Integer num, Integer num2) {
        if (map == null) {
            return;
        }
        map.remove(num2);
        if (map.isEmpty()) {
            this.possibleSemiJoins.remove(num);
        } else {
            this.possibleSemiJoins.put(num, map);
        }
    }

    public RelNode getChosenSemiJoin(int i) {
        return this.chosenSemiJoins[i];
    }

    static {
        $assertionsDisabled = !LoptSemiJoinOptimizer.class.desiredAssertionStatus();
    }
}
