/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.util;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.commons.collections.CollectionUtils;
import org.apache.kylin.common.util.Pair;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.query.exception.SumExprUnSupportException;

public class AggExpressionUtil {
    private AggExpressionUtil() {
        throw new IllegalStateException("Utility class");
    }

    public static boolean hasAggInput(RelNode current) {
        if (current == null) {
            return false;
        }
        if (current.getInputs().isEmpty()) {
            return false;
        }
        RelNode input = current.getInput(0);
        if (input == null) {
            return false;
        }
        if (input instanceof HepRelVertex) {
            input = ((HepRelVertex)input).getCurrentRel();
        }
        if (input instanceof Aggregate) {
            return true;
        }
        if (!(input instanceof RelSubset)) {
            return AggExpressionUtil.hasAggInput(input);
        }
        if ((input = ((RelSubset)input).getOriginal()) instanceof Aggregate) {
            return true;
        }
        return AggExpressionUtil.hasAggInput(input);
    }

    public static boolean supportAggregateFunction(AggregateCall call) {
        if (call.isDistinct()) {
            return false;
        }
        SqlKind kind = call.getAggregation().getKind();
        return SqlKind.SUM == kind || SqlKind.SUM0 == kind || SqlKind.COUNT == kind || SqlKind.MAX == kind || SqlKind.MIN == kind;
    }

    public static boolean hasSumCaseWhen(AggregateCall call, RexNode expression) {
        if (AggExpressionUtil.isSum(call.getAggregation().getKind())) {
            return SqlKind.CASE == expression.getKind();
        }
        return false;
    }

    public static boolean hasCountDistinctCaseWhen(AggregateCall call, RexNode expression) {
        return call.getAggregation().getKind() == SqlKind.COUNT && call.isDistinct() && expression.getKind() == SqlKind.CASE;
    }

    public static boolean isSum(SqlKind kind) {
        return SqlKind.SUM == kind || SqlKind.SUM0 == kind;
    }

    public static List<AggExpression> collectSumExpressions(Aggregate oldAgg, Project oldProject) {
        ArrayList aggExpressions = Lists.newArrayList();
        for (AggregateCall call : oldAgg.getAggCallList()) {
            List<RexNode> valuesList;
            List<RexNode> conditions;
            AggExpressionUtil.assertCondition(call.getArgList().size() <= 1, "Only support aggregate with 0 or 1 argument");
            AggExpression aggExpression = new AggExpression(call);
            aggExpressions.add(aggExpression);
            if (SqlKind.COUNT == call.getAggregation().getKind()) {
                aggExpression.setCount();
            }
            if (call.getArgList().isEmpty()) {
                AggExpression.access$102(aggExpression, AggExpressionUtil.newArray(0));
                AggExpression.access$202(aggExpression, AggExpressionUtil.newArray(0));
                AggExpression.access$302(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$402(aggExpression, AggExpressionUtil.newArray(1));
                continue;
            }
            int input = (Integer)call.getArgList().get(0);
            RexNode expression = (RexNode)oldProject.getChildExps().get(input);
            int[] sourceInput = RelOptUtil.InputFinder.bits((RexNode)expression).toArray();
            aggExpression.setExpression(expression);
            if (AggExpressionUtil.hasSumCaseWhen(call, expression)) {
                aggExpression.setSumCase();
                conditions = AggExpressionUtil.extractCaseWhenConditions(expression);
                aggExpression.setConditionsList(conditions);
                valuesList = AggExpressionUtil.extractCaseThenElseValues(expression);
                aggExpression.setValuesList(valuesList);
                AggExpression.access$102(aggExpression, sourceInput);
                AggExpression.access$202(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$302(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$402(aggExpression, AggExpressionUtil.newArray(1));
                continue;
            }
            if (AggExpressionUtil.hasCountDistinctCaseWhen(call, expression)) {
                aggExpression.setCountDistinctCase(true);
                conditions = AggExpressionUtil.extractCaseWhenConditions(expression);
                aggExpression.setConditionsList(conditions);
                valuesList = AggExpressionUtil.extractCaseThenElseValues(expression);
                aggExpression.setValuesList(valuesList);
                AggExpression.access$102(aggExpression, sourceInput);
                AggExpression.access$202(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$302(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$402(aggExpression, AggExpressionUtil.newArray(1));
                continue;
            }
            if (AggExpressionUtil.isSum(call.getAggregation().getKind()) && sourceInput.length == 0) {
                aggExpression.setSumConst();
                AggExpression.access$102(aggExpression, AggExpressionUtil.newArray(0));
                AggExpression.access$202(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$302(aggExpression, AggExpressionUtil.newArray(1));
                AggExpression.access$402(aggExpression, AggExpressionUtil.newArray(1));
                continue;
            }
            AggExpression.access$102(aggExpression, sourceInput);
            AggExpression.access$202(aggExpression, AggExpressionUtil.newArray(1));
            AggExpression.access$302(aggExpression, AggExpressionUtil.newArray(1));
            AggExpression.access$402(aggExpression, AggExpressionUtil.newArray(1));
        }
        return aggExpressions;
    }

    private static List<RexNode> extractCaseWhenConditions(RexNode caseWhenExpr) {
        AggExpressionUtil.assertCondition(caseWhenExpr instanceof RexCall, caseWhenExpr + " is not a case-when expression");
        RexCall caseWhenCall = (RexCall)caseWhenExpr;
        ArrayList conditions = Lists.newArrayList();
        int operandsCnt = caseWhenCall.getOperands().size();
        AggExpressionUtil.assertCondition(operandsCnt > 2 && operandsCnt % 2 == 1, "case-when operands mismatch");
        for (int i = 0; i < operandsCnt - 1; i += 2) {
            conditions.add(caseWhenCall.getOperands().get(i));
        }
        return conditions;
    }

    private static List<RexNode> extractCaseThenElseValues(RexNode caseWhenExpr) {
        AggExpressionUtil.assertCondition(caseWhenExpr instanceof RexCall, caseWhenExpr + " is not a case-when expression");
        RexCall caseWhenCall = (RexCall)caseWhenExpr;
        ArrayList values = Lists.newArrayList();
        int operandsCnt = caseWhenCall.getOperands().size();
        AggExpressionUtil.assertCondition(operandsCnt > 2 && operandsCnt % 2 == 1, "case-when operands mismatch");
        for (int i = 1; i < operandsCnt - 1; i += 2) {
            RexNode thenRexNode = (RexNode)caseWhenCall.getOperands().get(i);
            values.add(thenRexNode);
        }
        RexNode elseRexNode = (RexNode)((RexCall)caseWhenExpr).getOperands().get(operandsCnt - 1);
        values.add(elseRexNode);
        return values;
    }

    public static Pair<List<GroupExpression>, ImmutableList<ImmutableBitSet>> collectGroupExprAndGroup(Aggregate oldAgg, Project oldProject) {
        ArrayList groupExpressions = Lists.newArrayListWithCapacity((int)oldAgg.getGroupCount());
        HashMap old2new = Maps.newHashMap();
        Iterator iterator = oldAgg.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int groupBy = (Integer)iterator.next();
            RexNode projectExpr = (RexNode)oldProject.getChildExps().get(groupBy);
            CollectRexVisitor visitor = new CollectRexVisitor(true);
            projectExpr.accept((RexVisitor)visitor);
            GroupExpression groupExpr = new GroupExpression();
            if (visitor.isRexLiteral()) {
                groupExpr.setLiteral();
            }
            int[] sourceInput = RelOptUtil.InputFinder.bits((RexNode)projectExpr).toArray();
            groupExpr.expression = projectExpr;
            GroupExpression.access$602(groupExpr, sourceInput);
            GroupExpression.access$702(groupExpr, AggExpressionUtil.newArray(sourceInput.length));
            GroupExpression.access$802(groupExpr, AggExpressionUtil.newArray(sourceInput.length));
            GroupExpression.access$902(groupExpr, AggExpressionUtil.newArray(1));
            old2new.put(groupBy, groupExpressions.size());
            groupExpressions.add(groupExpr);
        }
        return Pair.newPair((Object)groupExpressions, AggExpressionUtil.collectGroupSets((ImmutableList<ImmutableBitSet>)oldAgg.getGroupSets(), old2new).orElse(null));
    }

    private static Optional<ImmutableList<ImmutableBitSet>> collectGroupSets(ImmutableList<ImmutableBitSet> oldGroupSets, Map<Integer, Integer> old2new) {
        if (oldGroupSets.size() <= 1) {
            return Optional.empty();
        }
        ArrayList groupSets = Lists.newArrayListWithCapacity((int)oldGroupSets.size());
        for (ImmutableBitSet set : oldGroupSets) {
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            Iterator iterator = set.iterator();
            while (iterator.hasNext()) {
                int oldIndex = (Integer)iterator.next();
                builder.set(old2new.get(oldIndex).intValue());
            }
            groupSets.add(builder.build());
        }
        return Optional.of(ImmutableList.copyOf((Collection)groupSets));
    }

    public static void assertCondition(boolean condition, String errorMsg) {
        if (!condition) {
            throw new SumExprUnSupportException(errorMsg);
        }
    }

    public static int[] generateAdjustments(int[] src, int[] dst) {
        AggExpressionUtil.assertCondition(src.length == dst.length, "Failed to generate adjustments");
        int maxRange = Arrays.stream(src).max().orElse(0);
        int[] adjustments = new int[maxRange + 1];
        for (int i = 0; i < src.length; ++i) {
            int srcIndex = src[i];
            int dstIndex = dst[i];
            adjustments[srcIndex] = dstIndex - srcIndex;
        }
        return adjustments;
    }

    private static int[] newArray(int size) {
        int[] arr = new int[size];
        Arrays.fill(arr, -1);
        return arr;
    }

    public static class InputRefCapacity
    extends RexVisitorImpl<RexCall> {
        Set<RexInputRef> rexInputRefs = Sets.newHashSet();

        protected InputRefCapacity(boolean deep) {
            super(deep);
        }

        public RexCall visitInputRef(RexInputRef inputRef) {
            this.rexInputRefs.add(inputRef);
            return null;
        }

        public List<RexInputRef> getRexInputRef() {
            return Lists.newArrayList(this.rexInputRefs);
        }
    }

    public static class CollectRexVisitor
    extends RexVisitorImpl<RexCall> {
        private Set<RexInputRef> rexInputRefs = Sets.newHashSet();
        private Set<RexLiteral> rexLiterals = Sets.newHashSet();

        CollectRexVisitor(boolean deep) {
            super(deep);
        }

        public RexCall visitInputRef(RexInputRef inputRef) {
            this.rexInputRefs.add(inputRef);
            return null;
        }

        public RexCall visitLiteral(RexLiteral literal) {
            this.rexLiterals.add(literal);
            return null;
        }

        boolean isRexLiteral() {
            return CollectionUtils.isEmpty(this.rexInputRefs) && CollectionUtils.isNotEmpty(this.rexLiterals);
        }
    }

    public static class GroupExpression {
        private RexNode expression;
        private boolean isLiteral = false;
        private int[] bottomProjInput;
        private int[] bottomAggInput;
        private int[] topProjInput;
        private int[] topAggInput;

        public RexNode getExpression() {
            return this.expression;
        }

        public void setExpression(RexNode expression) {
            this.expression = expression;
        }

        public boolean isLiteral() {
            return this.isLiteral;
        }

        public void setLiteral() {
            this.isLiteral = true;
        }

        public int[] getBottomProjInput() {
            return this.bottomProjInput;
        }

        public void setBottomProjInput(int[] bottomProjInput) {
            this.bottomProjInput = bottomProjInput;
        }

        public int[] getBottomAggInput() {
            return this.bottomAggInput;
        }

        public void setBottomAggInput(int[] bottomAggInput) {
            this.bottomAggInput = bottomAggInput;
        }

        public int[] getTopProjInput() {
            return this.topProjInput;
        }

        public void setTopProjInput(int[] topProjInput) {
            this.topProjInput = topProjInput;
        }

        public int[] getTopAggInput() {
            return this.topAggInput;
        }

        public void setTopAggInput(int[] topAggInput) {
            this.topAggInput = topAggInput;
        }

        static /* synthetic */ int[] access$602(GroupExpression x0, int[] x1) {
            x0.bottomProjInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$702(GroupExpression x0, int[] x1) {
            x0.bottomAggInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$802(GroupExpression x0, int[] x1) {
            x0.topProjInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$902(GroupExpression x0, int[] x1) {
            x0.topAggInput = x1;
            return x1;
        }
    }

    public static class AggExpression {
        private AggregateCall agg;
        private RexNode expression;
        private RelDataType type;
        private boolean isSumCase = false;
        private boolean isCountDistinctCase = false;
        private List<RexNode> conditions;
        private int[] bottomProjConditionsInput;
        private int[] bottomAggConditionsInput;
        private int[] topProjConditionsInput;
        private List<RexNode> valuesList;
        private int[] bottomAggValuesInput;
        private int[] topProjValuesInput;
        private boolean isCount = false;
        private int[] bottomProjInput;
        private int[] bottomAggInput;
        private int[] topProjInput;
        private int[] topAggInput;
        private boolean isSumConst = false;

        public AggExpression(AggregateCall agg) {
            this.agg = agg;
        }

        public AggregateCall getAggCall() {
            return this.agg;
        }

        public RelDataType getType() {
            return this.type;
        }

        public void setType(RelDataType type) {
            this.type = type;
        }

        public void setExpression(RexNode expression) {
            this.expression = expression;
        }

        public RexNode getExpression() {
            return this.expression;
        }

        public void setSumCase() {
            this.isSumCase = true;
        }

        public boolean isSumCase() {
            return this.isSumCase;
        }

        public boolean isCountDistinctCase() {
            return this.isCountDistinctCase;
        }

        public void setCountDistinctCase(boolean countDistinctCase) {
            this.isCountDistinctCase = countDistinctCase;
        }

        public void setConditionsList(List<RexNode> conditionsList) {
            this.conditions = conditionsList;
            this.bottomProjConditionsInput = RelOptUtil.InputFinder.bits(this.conditions, null).toArray();
            this.bottomAggConditionsInput = AggExpressionUtil.newArray(this.bottomProjConditionsInput.length);
            this.topProjConditionsInput = AggExpressionUtil.newArray(this.bottomProjConditionsInput.length);
        }

        public List<RexNode> getConditions() {
            return this.conditions;
        }

        public int[] getBottomProjConditionsInput() {
            return this.bottomProjConditionsInput;
        }

        public int[] getBottomAggConditionsInput() {
            return this.bottomAggConditionsInput;
        }

        public int[] getTopProjConditionsInput() {
            return this.topProjConditionsInput;
        }

        public void setValuesList(List<RexNode> valuesList) {
            this.valuesList = valuesList;
            this.bottomAggValuesInput = AggExpressionUtil.newArray(valuesList.size());
            this.topProjValuesInput = AggExpressionUtil.newArray(valuesList.size());
        }

        public List<RexNode> getValuesList() {
            return this.valuesList;
        }

        public int[] getBottomAggValuesInput() {
            return this.bottomAggValuesInput;
        }

        public int[] getTopProjValuesInput() {
            return this.topProjValuesInput;
        }

        public void setCount() {
            this.isCount = true;
        }

        public boolean isCount() {
            return this.isCount;
        }

        public int[] getBottomProjInput() {
            return this.bottomProjInput;
        }

        public int[] getBottomAggInput() {
            return this.bottomAggInput;
        }

        public int[] getTopProjInput() {
            return this.topProjInput;
        }

        public int[] getTopAggInput() {
            return this.topAggInput;
        }

        public void setSumConst() {
            this.isSumConst = true;
        }

        public boolean isSumConst() {
            return this.isSumConst;
        }

        static /* synthetic */ int[] access$102(AggExpression x0, int[] x1) {
            x0.bottomProjInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$202(AggExpression x0, int[] x1) {
            x0.bottomAggInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$302(AggExpression x0, int[] x1) {
            x0.topProjInput = x1;
            return x1;
        }

        static /* synthetic */ int[] access$402(AggExpression x0, int[] x1) {
            x0.topAggInput = x1;
            return x1;
        }
    }
}

