package org.apache.kylin.query.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.job.shaded.com.google.common.collect.ImmutableList;
import org.apache.kylin.job.shaded.com.google.common.collect.UnmodifiableIterator;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.volcano.RelSubset;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Project;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexInputRef;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexLiteral;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitorImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.SqlKind;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableBitSet;
import org.apache.kylin.query.exception.SumExprUnSupportException;

/* loaded from: input_file:org/apache/kylin/query/util/AggExpressionUtil.class */
public class AggExpressionUtil {

    /* loaded from: input_file:org/apache/kylin/query/util/AggExpressionUtil$AggExpression.class */
    public static class AggExpression {
        private AggregateCall agg;
        private RexNode expression;
        private RelDataType type;
        private List<RexNode> conditions;
        private int[] bottomProjConditionsInput;
        private int[] bottomAggConditionsInput;
        private int[] topProjConditionsInput;
        private List<RexNode> valuesList;
        private int[] bottomAggValuesInput;
        private int[] topProjValuesInput;
        private int[] bottomProjInput;
        private int[] bottomAggInput;
        private int[] topProjInput;
        private int[] topAggInput;
        private boolean isSumCase = false;
        private boolean isCountDistinctCase = false;
        private boolean isCount = false;
        private boolean isSumConst = false;

        public AggExpression(AggregateCall aggregateCall) {
            this.agg = aggregateCall;
        }

        public AggregateCall getAggCall() {
            return this.agg;
        }

        public RelDataType getType() {
            return this.type;
        }

        public void setType(RelDataType relDataType) {
            this.type = relDataType;
        }

        public void setExpression(RexNode rexNode) {
            this.expression = rexNode;
        }

        public RexNode getExpression() {
            return this.expression;
        }

        public void setSumCase() {
            this.isSumCase = true;
        }

        public boolean isSumCase() {
            return this.isSumCase;
        }

        public boolean isCountDistinctCase() {
            return this.isCountDistinctCase;
        }

        public void setCountDistinctCase(boolean z) {
            this.isCountDistinctCase = z;
        }

        public void setConditionsList(List<RexNode> list) {
            this.conditions = list;
            this.bottomProjConditionsInput = RelOptUtil.InputFinder.bits(this.conditions, null).toArray();
            this.bottomAggConditionsInput = AggExpressionUtil.newArray(this.bottomProjConditionsInput.length);
            this.topProjConditionsInput = AggExpressionUtil.newArray(this.bottomProjConditionsInput.length);
        }

        public List<RexNode> getConditions() {
            return this.conditions;
        }

        public int[] getBottomProjConditionsInput() {
            return this.bottomProjConditionsInput;
        }

        public int[] getBottomAggConditionsInput() {
            return this.bottomAggConditionsInput;
        }

        public int[] getTopProjConditionsInput() {
            return this.topProjConditionsInput;
        }

        public void setValuesList(List<RexNode> list) {
            this.valuesList = list;
            this.bottomAggValuesInput = AggExpressionUtil.newArray(list.size());
            this.topProjValuesInput = AggExpressionUtil.newArray(list.size());
        }

        public List<RexNode> getValuesList() {
            return this.valuesList;
        }

        public int[] getBottomAggValuesInput() {
            return this.bottomAggValuesInput;
        }

        public int[] getTopProjValuesInput() {
            return this.topProjValuesInput;
        }

        public void setCount() {
            this.isCount = true;
        }

        public boolean isCount() {
            return this.isCount;
        }

        public int[] getBottomProjInput() {
            return this.bottomProjInput;
        }

        public int[] getBottomAggInput() {
            return this.bottomAggInput;
        }

        public int[] getTopProjInput() {
            return this.topProjInput;
        }

        public int[] getTopAggInput() {
            return this.topAggInput;
        }

        public void setSumConst() {
            this.isSumConst = true;
        }

        public boolean isSumConst() {
            return this.isSumConst;
        }
    }

    /* loaded from: input_file:org/apache/kylin/query/util/AggExpressionUtil$CollectRexVisitor.class */
    public static class CollectRexVisitor extends RexVisitorImpl<RexCall> {
        private Set<RexInputRef> rexInputRefs;
        private Set<RexLiteral> rexLiterals;

        CollectRexVisitor(boolean z) {
            super(z);
            this.rexInputRefs = Sets.newHashSet();
            this.rexLiterals = Sets.newHashSet();
        }

        @Override // org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitorImpl, org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitor
        public RexCall visitInputRef(RexInputRef rexInputRef) {
            this.rexInputRefs.add(rexInputRef);
            return null;
        }

        @Override // org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitorImpl, org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitor
        public RexCall visitLiteral(RexLiteral rexLiteral) {
            this.rexLiterals.add(rexLiteral);
            return null;
        }

        boolean isRexLiteral() {
            return CollectionUtils.isEmpty(this.rexInputRefs) && CollectionUtils.isNotEmpty(this.rexLiterals);
        }
    }

    /* loaded from: input_file:org/apache/kylin/query/util/AggExpressionUtil$GroupExpression.class */
    public static class GroupExpression {
        private RexNode expression;
        private boolean isLiteral = false;
        private int[] bottomProjInput;
        private int[] bottomAggInput;
        private int[] topProjInput;
        private int[] topAggInput;

        public RexNode getExpression() {
            return this.expression;
        }

        public void setExpression(RexNode rexNode) {
            this.expression = rexNode;
        }

        public boolean isLiteral() {
            return this.isLiteral;
        }

        public void setLiteral() {
            this.isLiteral = true;
        }

        public int[] getBottomProjInput() {
            return this.bottomProjInput;
        }

        public void setBottomProjInput(int[] iArr) {
            this.bottomProjInput = iArr;
        }

        public int[] getBottomAggInput() {
            return this.bottomAggInput;
        }

        public void setBottomAggInput(int[] iArr) {
            this.bottomAggInput = iArr;
        }

        public int[] getTopProjInput() {
            return this.topProjInput;
        }

        public void setTopProjInput(int[] iArr) {
            this.topProjInput = iArr;
        }

        public int[] getTopAggInput() {
            return this.topAggInput;
        }

        public void setTopAggInput(int[] iArr) {
            this.topAggInput = iArr;
        }
    }

    /* loaded from: input_file:org/apache/kylin/query/util/AggExpressionUtil$InputRefCapacity.class */
    public static class InputRefCapacity extends RexVisitorImpl<RexCall> {
        Set<RexInputRef> rexInputRefs;

        protected InputRefCapacity(boolean z) {
            super(z);
            this.rexInputRefs = Sets.newHashSet();
        }

        @Override // org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitorImpl, org.apache.kylin.job.shaded.org.apache.calcite.rex.RexVisitor
        public RexCall visitInputRef(RexInputRef rexInputRef) {
            this.rexInputRefs.add(rexInputRef);
            return null;
        }

        public List<RexInputRef> getRexInputRef() {
            return Lists.newArrayList(this.rexInputRefs);
        }
    }

    private AggExpressionUtil() {
        throw new IllegalStateException("Utility class");
    }

    public static boolean hasAggInput(RelNode relNode) {
        if (relNode == null || relNode.getInputs().isEmpty()) {
            return false;
        }
        RelNode input = relNode.getInput(0);
        if (input == null) {
            return false;
        }
        if (input instanceof HepRelVertex) {
            input = ((HepRelVertex) input).getCurrentRel();
        }
        if (input instanceof Aggregate) {
            return true;
        }
        if (!(input instanceof RelSubset)) {
            return hasAggInput(input);
        }
        RelNode original = ((RelSubset) input).getOriginal();
        if (original instanceof Aggregate) {
            return true;
        }
        return hasAggInput(original);
    }

    public static boolean supportAggregateFunction(AggregateCall aggregateCall) {
        if (aggregateCall.isDistinct()) {
            return false;
        }
        SqlKind kind = aggregateCall.getAggregation().getKind();
        return SqlKind.SUM == kind || SqlKind.SUM0 == kind || SqlKind.COUNT == kind || SqlKind.MAX == kind || SqlKind.MIN == kind;
    }

    public static boolean hasSumCaseWhen(AggregateCall aggregateCall, RexNode rexNode) {
        return isSum(aggregateCall.getAggregation().getKind()) && SqlKind.CASE == rexNode.getKind();
    }

    public static boolean hasCountDistinctCaseWhen(AggregateCall aggregateCall, RexNode rexNode) {
        return aggregateCall.getAggregation().getKind() == SqlKind.COUNT && aggregateCall.isDistinct() && rexNode.getKind() == SqlKind.CASE;
    }

    public static boolean isSum(SqlKind sqlKind) {
        return SqlKind.SUM == sqlKind || SqlKind.SUM0 == sqlKind;
    }

    public static List<AggExpression> collectSumExpressions(Aggregate aggregate, Project project) {
        ArrayList newArrayList = Lists.newArrayList();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            assertCondition(aggregateCall.getArgList().size() <= 1, "Only support aggregate with 0 or 1 argument");
            AggExpression aggExpression = new AggExpression(aggregateCall);
            newArrayList.add(aggExpression);
            if (SqlKind.COUNT == aggregateCall.getAggregation().getKind()) {
                aggExpression.setCount();
            }
            if (aggregateCall.getArgList().isEmpty()) {
                aggExpression.bottomProjInput = newArray(0);
                aggExpression.bottomAggInput = newArray(0);
                aggExpression.topProjInput = newArray(1);
                aggExpression.topAggInput = newArray(1);
            } else {
                RexNode rexNode = project.getChildExps().get(aggregateCall.getArgList().get(0).intValue());
                int[] array = RelOptUtil.InputFinder.bits(rexNode).toArray();
                aggExpression.setExpression(rexNode);
                if (hasSumCaseWhen(aggregateCall, rexNode)) {
                    aggExpression.setSumCase();
                    aggExpression.setConditionsList(extractCaseWhenConditions(rexNode));
                    aggExpression.setValuesList(extractCaseThenElseValues(rexNode));
                    aggExpression.bottomProjInput = array;
                    aggExpression.bottomAggInput = newArray(1);
                    aggExpression.topProjInput = newArray(1);
                    aggExpression.topAggInput = newArray(1);
                } else if (hasCountDistinctCaseWhen(aggregateCall, rexNode)) {
                    aggExpression.setCountDistinctCase(true);
                    aggExpression.setConditionsList(extractCaseWhenConditions(rexNode));
                    aggExpression.setValuesList(extractCaseThenElseValues(rexNode));
                    aggExpression.bottomProjInput = array;
                    aggExpression.bottomAggInput = newArray(1);
                    aggExpression.topProjInput = newArray(1);
                    aggExpression.topAggInput = newArray(1);
                } else if (isSum(aggregateCall.getAggregation().getKind()) && array.length == 0) {
                    aggExpression.setSumConst();
                    aggExpression.bottomProjInput = newArray(0);
                    aggExpression.bottomAggInput = newArray(1);
                    aggExpression.topProjInput = newArray(1);
                    aggExpression.topAggInput = newArray(1);
                } else {
                    aggExpression.bottomProjInput = array;
                    aggExpression.bottomAggInput = newArray(1);
                    aggExpression.topProjInput = newArray(1);
                    aggExpression.topAggInput = newArray(1);
                }
            }
        }
        return newArrayList;
    }

    private static List<RexNode> extractCaseWhenConditions(RexNode rexNode) {
        assertCondition(rexNode instanceof RexCall, rexNode + " is not a case-when expression");
        RexCall rexCall = (RexCall) rexNode;
        ArrayList newArrayList = Lists.newArrayList();
        int size = rexCall.getOperands().size();
        assertCondition(size > 2 && size % 2 == 1, "case-when operands mismatch");
        for (int i = 0; i < size - 1; i += 2) {
            newArrayList.add(rexCall.getOperands().get(i));
        }
        return newArrayList;
    }

    private static List<RexNode> extractCaseThenElseValues(RexNode rexNode) {
        assertCondition(rexNode instanceof RexCall, rexNode + " is not a case-when expression");
        RexCall rexCall = (RexCall) rexNode;
        ArrayList newArrayList = Lists.newArrayList();
        int size = rexCall.getOperands().size();
        assertCondition(size > 2 && size % 2 == 1, "case-when operands mismatch");
        for (int i = 1; i < size - 1; i += 2) {
            newArrayList.add(rexCall.getOperands().get(i));
        }
        newArrayList.add(((RexCall) rexNode).getOperands().get(size - 1));
        return newArrayList;
    }

    public static Pair<List<GroupExpression>, ImmutableList<ImmutableBitSet>> collectGroupExprAndGroup(Aggregate aggregate, Project project) {
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(aggregate.getGroupCount());
        HashMap newHashMap = Maps.newHashMap();
        Iterator<Integer> it2 = aggregate.getGroupSet().iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            RexNode rexNode = project.getChildExps().get(intValue);
            CollectRexVisitor collectRexVisitor = new CollectRexVisitor(true);
            rexNode.accept(collectRexVisitor);
            GroupExpression groupExpression = new GroupExpression();
            if (collectRexVisitor.isRexLiteral()) {
                groupExpression.setLiteral();
            }
            int[] array = RelOptUtil.InputFinder.bits(rexNode).toArray();
            groupExpression.expression = rexNode;
            groupExpression.bottomProjInput = array;
            groupExpression.bottomAggInput = newArray(array.length);
            groupExpression.topProjInput = newArray(array.length);
            groupExpression.topAggInput = newArray(1);
            newHashMap.put(Integer.valueOf(intValue), Integer.valueOf(newArrayListWithCapacity.size()));
            newArrayListWithCapacity.add(groupExpression);
        }
        return Pair.newPair(newArrayListWithCapacity, collectGroupSets(aggregate.getGroupSets(), newHashMap).orElse(null));
    }

    private static Optional<ImmutableList<ImmutableBitSet>> collectGroupSets(ImmutableList<ImmutableBitSet> immutableList, Map<Integer, Integer> map) {
        if (immutableList.size() <= 1) {
            return Optional.empty();
        }
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(immutableList.size());
        UnmodifiableIterator<ImmutableBitSet> it2 = immutableList.iterator();
        while (it2.hasNext()) {
            ImmutableBitSet next = it2.next();
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            Iterator<Integer> it3 = next.iterator();
            while (it3.hasNext()) {
                builder.set(map.get(Integer.valueOf(it3.next().intValue())).intValue());
            }
            newArrayListWithCapacity.add(builder.build());
        }
        return Optional.of(ImmutableList.copyOf((Collection) newArrayListWithCapacity));
    }

    public static void assertCondition(boolean z, String str) {
        if (!z) {
            throw new SumExprUnSupportException(str);
        }
    }

    public static int[] generateAdjustments(int[] iArr, int[] iArr2) {
        assertCondition(iArr.length == iArr2.length, "Failed to generate adjustments");
        int[] iArr3 = new int[Arrays.stream(iArr).max().orElse(0) + 1];
        for (int i = 0; i < iArr.length; i++) {
            int i2 = iArr[i];
            iArr3[i2] = iArr2[i] - i2;
        }
        return iArr3;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static int[] newArray(int i) {
        int[] iArr = new int[i];
        Arrays.fill(iArr, -1);
        return iArr;
    }
}
