package org.apache.pinot.core.query.aggregation.function;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.theta.CompactSketch;
import org.apache.datasketches.theta.Intersection;
import org.apache.datasketches.theta.SetOperation;
import org.apache.datasketches.theta.SetOperationBuilder;
import org.apache.datasketches.theta.Sketch;
import org.apache.datasketches.theta.Union;
import org.apache.pinot.common.function.AggregationFunctionType;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.common.BlockValSet;
import org.apache.pinot.core.common.Predicate;
import org.apache.pinot.core.common.predicate.RangePredicate;
import org.apache.pinot.core.io.writer.impl.v1.VarByteChunkSingleValueWriter;
import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluatorProvider;
import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder;
import org.apache.pinot.core.query.aggregation.ThetaSketchParams;
import org.apache.pinot.core.query.aggregation.function.customobject.QuantileDigest;
import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
import org.apache.pinot.parsers.utils.ParserUtils;
import org.apache.pinot.pql.parsers.pql2.ast.FilterKind;
import org.apache.pinot.pql.parsers.pql2.ast.IdentifierAstNode;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.sql.parsers.CalciteSqlParser;

/* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction.class */
public class DistinctCountThetaSketchAggregationFunction implements AggregationFunction<Map<String, Sketch>, Integer> {
    private String _thetaSketchColumn;
    private TransformExpressionTree _thetaSketchIdentifier;
    private Set<String> _predicateStrings;
    private Expression _postAggregationExpression;
    private Set<PredicateInfo> _predicateInfoSet;
    private Map<Expression, String> _expressionMap;
    private ThetaSketchParams _thetaSketchParams;
    private List<TransformExpressionTree> _inputExpressions;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.pinot.core.query.aggregation.function.DistinctCountThetaSketchAggregationFunction$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType;
        static final /* synthetic */ int[] $SwitchMap$org$apache$pinot$pql$parsers$pql2$ast$FilterKind = new int[FilterKind.values().length];

        static {
            try {
                $SwitchMap$org$apache$pinot$pql$parsers$pql2$ast$FilterKind[FilterKind.AND.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$pinot$pql$parsers$pql2$ast$FilterKind[FilterKind.OR.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType = new int[FieldSpec.DataType.values().length];
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.INT.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.LONG.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.FLOAT.ordinal()] = 3;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.DOUBLE.ordinal()] = 4;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[FieldSpec.DataType.STRING.ordinal()] = 5;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/pinot/core/query/aggregation/function/DistinctCountThetaSketchAggregationFunction$PredicateInfo.class */
    public static class PredicateInfo {
        private final String _stringVal;
        private final TransformExpressionTree _expression;
        private final Predicate _predicate;
        private PredicateEvaluator _predicateEvaluator;

        private PredicateInfo(String str, TransformExpressionTree transformExpressionTree, Predicate predicate) {
            this._stringVal = str;
            this._expression = transformExpressionTree;
            this._predicate = predicate;
            this._predicateEvaluator = null;
        }

        public String getStringVal() {
            return this._stringVal;
        }

        public TransformExpressionTree getExpression() {
            return this._expression;
        }

        public Predicate getPredicate() {
            return this._predicate;
        }

        public PredicateEvaluator getPredicateEvaluator(FieldSpec.DataType dataType) {
            if (this._predicateEvaluator != null) {
                return this._predicateEvaluator;
            }
            this._predicateEvaluator = PredicateEvaluatorProvider.getPredicateEvaluator(this._predicate, null, dataType);
            return this._predicateEvaluator;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof PredicateInfo)) {
                return false;
            }
            PredicateInfo predicateInfo = (PredicateInfo) obj;
            return Objects.equals(this._stringVal, predicateInfo._stringVal) && Objects.equals(this._expression, predicateInfo._expression) && Objects.equals(this._predicate, predicateInfo._predicate);
        }

        public int hashCode() {
            return Objects.hash(this._stringVal, this._expression, this._predicate);
        }

        /* synthetic */ PredicateInfo(String str, TransformExpressionTree transformExpressionTree, Predicate predicate, AnonymousClass1 anonymousClass1) {
            this(str, transformExpressionTree, predicate);
        }
    }

    public DistinctCountThetaSketchAggregationFunction(List<String> list) throws SqlParseException {
        int size = list.size();
        Preconditions.checkArgument(size >= 3, "DistinctCountThetaSketch expects at least three arguments, got: ", size);
        init(list);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationFunctionType getType() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public String getColumnName() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH.getName() + "_" + this._thetaSketchColumn;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public String getResultColumnName() {
        return AggregationFunctionType.DISTINCTCOUNTTHETASKETCH.getName().toLowerCase() + RangePredicate.LOWER_EXCLUSIVE + this._thetaSketchColumn + RangePredicate.UPPER_EXCLUSIVE;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public List<TransformExpressionTree> getInputExpressions() {
        return this._inputExpressions;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void accept(AggregationFunctionVisitorBase aggregationFunctionVisitorBase) {
        aggregationFunctionVisitorBase.visit(this);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public AggregationResultHolder createAggregationResultHolder() {
        return new ObjectAggregationResultHolder();
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public GroupByResultHolder createGroupByResultHolder(int i, int i2) {
        return new ObjectGroupByResultHolder(i, i2);
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregate(int i, AggregationResultHolder aggregationResultHolder, Map<TransformExpressionTree, BlockValSet> map) {
        Map<String, Union> defaultResult = getDefaultResult(aggregationResultHolder, this._predicateStrings);
        Sketch[] deserializeSketches = deserializeSketches(map.get(this._thetaSketchIdentifier).getBytesValuesSV(), i);
        for (PredicateInfo predicateInfo : this._predicateInfoSet) {
            String stringVal = predicateInfo.getStringVal();
            BlockValSet blockValSet = map.get(predicateInfo.getExpression());
            FieldSpec.DataType valueType = blockValSet.getValueType();
            PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
            Union union = defaultResult.get(stringVal);
            switch (AnonymousClass1.$SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[valueType.ordinal()]) {
                case 1:
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (predicateEvaluator.applySV(intValuesSV[i2])) {
                            union.update(deserializeSketches[i2]);
                        }
                    }
                    break;
                case QuantileDigest.Flags.HAS_RIGHT /* 2 */:
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (predicateEvaluator.applySV(longValuesSV[i3])) {
                            union.update(deserializeSketches[i3]);
                        }
                    }
                    break;
                case 3:
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (predicateEvaluator.applySV(floatValuesSV[i4])) {
                            union.update(deserializeSketches[i4]);
                        }
                    }
                    break;
                case VarByteChunkSingleValueWriter.CHUNK_HEADER_ENTRY_ROW_OFFSET_SIZE /* 4 */:
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (predicateEvaluator.applySV(doubleValuesSV[i5])) {
                            union.update(deserializeSketches[i5]);
                        }
                    }
                    break;
                case 5:
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (predicateEvaluator.applySV(stringValuesSV[i6])) {
                            union.update(deserializeSketches[i6]);
                        }
                    }
                    break;
                default:
                    throw new IllegalStateException("Illegal data type for " + getType() + " aggregation function: " + valueType);
            }
        }
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public void aggregateGroupBySV(int i, int[] iArr, GroupByResultHolder groupByResultHolder, Map<TransformExpressionTree, BlockValSet> map) {
        Sketch[] deserializeSketches = deserializeSketches(map.get(this._thetaSketchIdentifier).getBytesValuesSV(), i);
        for (PredicateInfo predicateInfo : this._predicateInfoSet) {
            String stringVal = predicateInfo.getStringVal();
            BlockValSet blockValSet = map.get(predicateInfo.getExpression());
            FieldSpec.DataType valueType = blockValSet.getValueType();
            PredicateEvaluator predicateEvaluator = predicateInfo.getPredicateEvaluator(valueType);
            switch (AnonymousClass1.$SwitchMap$org$apache$pinot$spi$data$FieldSpec$DataType[valueType.ordinal()]) {
                case 1:
                    int[] intValuesSV = blockValSet.getIntValuesSV();
                    for (int i2 = 0; i2 < i; i2++) {
                        if (predicateEvaluator.applySV(intValuesSV[i2])) {
                            getDefaultResult(groupByResultHolder, iArr[i2], this._predicateStrings).get(stringVal).update(deserializeSketches[i2]);
                        }
                    }
                    break;
                case QuantileDigest.Flags.HAS_RIGHT /* 2 */:
                    long[] longValuesSV = blockValSet.getLongValuesSV();
                    for (int i3 = 0; i3 < i; i3++) {
                        if (predicateEvaluator.applySV(longValuesSV[i3])) {
                            getDefaultResult(groupByResultHolder, iArr[i3], this._predicateStrings).get(stringVal).update(deserializeSketches[i3]);
                        }
                    }
                    break;
                case 3:
                    float[] floatValuesSV = blockValSet.getFloatValuesSV();
                    for (int i4 = 0; i4 < i; i4++) {
                        if (predicateEvaluator.applySV(floatValuesSV[i4])) {
                            getDefaultResult(groupByResultHolder, iArr[i4], this._predicateStrings).get(stringVal).update(deserializeSketches[i4]);
                        }
                    }
                    break;
                case VarByteChunkSingleValueWriter.CHUNK_HEADER_ENTRY_ROW_OFFSET_SIZE /* 4 */:
                    double[] doubleValuesSV = blockValSet.getDoubleValuesSV();
                    for (int i5 = 0; i5 < i; i5++) {
                        if (predicateEvaluator.applySV(doubleValuesSV[i5])) {
                            getDefaultResult(groupByResultHolder, iArr[i5], this._predicateStrings).get(stringVal).update(deserializeSketches[i5]);
                        }
                    }
                    break;
                case 5:
                    String[] stringValuesSV = blockValSet.getStringValuesSV();
                    for (int i6 = 0; i6 < i; i6++) {
                        if (predicateEvaluator.applySV(stringValuesSV[i6])) {
                            getDefaultResult(groupByResultHolder, iArr[i6], this._predicateStrings).get(stringVal).update(deserializeSketches[i6]);
                        }
                    }
                    break;
                default:
                    throw new IllegalStateException("Illegal data type for " + getType() + " aggregation function: " + valueType);
            }
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x006e. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:70:0x0110  */
    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void aggregateGroupByMV(int r6, int[][] r7, org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder r8, java.util.Map<org.apache.pinot.common.request.transform.TransformExpressionTree, org.apache.pinot.core.common.BlockValSet> r9) {
        /*
            Method dump skipped, instructions count: 750
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.apache.pinot.core.query.aggregation.function.DistinctCountThetaSketchAggregationFunction.aggregateGroupByMV(int, int[][], org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder, java.util.Map):void");
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> extractAggregationResult(AggregationResultHolder aggregationResultHolder) {
        Map<String, Union> map = (Map) aggregationResultHolder.getResult();
        if (map == null) {
            map = getDefaultResult(aggregationResultHolder, this._predicateStrings);
        }
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return ((Union) entry.getValue()).getResult();
        }));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> extractGroupByResult(GroupByResultHolder groupByResultHolder, int i) {
        Map<String, Union> map = (Map) groupByResultHolder.getResult(i);
        if (map == null) {
            map = getDefaultResult(groupByResultHolder, i, this._predicateStrings);
        }
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return ((Union) entry.getValue()).getResult();
        }));
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Map<String, Sketch> merge(Map<String, Sketch> map, Map<String, Sketch> map2) {
        if (map == null) {
            return map2;
        }
        if (map2 == null) {
            return map;
        }
        for (Map.Entry<String, Sketch> entry : map.entrySet()) {
            String key = entry.getKey();
            Union buildUnion = getSetOperationBuilder().buildUnion();
            buildUnion.update(entry.getValue());
            buildUnion.update(map2.get(key));
            map.put(key, buildUnion.getResult());
        }
        return map;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public boolean isIntermediateResultComparable() {
        return false;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getIntermediateResultColumnType() {
        return DataSchema.ColumnDataType.OBJECT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public DataSchema.ColumnDataType getFinalResultColumnType() {
        return DataSchema.ColumnDataType.INT;
    }

    @Override // org.apache.pinot.core.query.aggregation.function.AggregationFunction
    public Integer extractFinalResult(Map<String, Sketch> map) {
        return Integer.valueOf((int) Math.round(evalPostAggregationExpression(this._postAggregationExpression, map).getEstimate()));
    }

    private Map<String, Union> getDefaultResult(AggregationResultHolder aggregationResultHolder, Set<String> set) {
        Map<String, Union> map = (Map) aggregationResultHolder.getResult();
        if (map == null) {
            map = new HashMap();
            aggregationResultHolder.setValue(map);
        }
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            map.putIfAbsent(it.next(), getSetOperationBuilder().buildUnion());
        }
        return map;
    }

    private Map<String, Union> getDefaultResult(GroupByResultHolder groupByResultHolder, int i, Set<String> set) {
        Map<String, Union> map = (Map) groupByResultHolder.getResult(i);
        if (map == null) {
            map = new HashMap();
            groupByResultHolder.setValueForKey(i, map);
        }
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            map.putIfAbsent(it.next(), getSetOperationBuilder().buildUnion());
        }
        return map;
    }

    private Sketch[] deserializeSketches(byte[][] bArr, int i) {
        Sketch[] sketchArr = new Sketch[i];
        for (int i2 = 0; i2 < i; i2++) {
            sketchArr[i2] = Sketch.wrap(Memory.wrap(bArr[i2]));
        }
        return sketchArr;
    }

    private void init(List<String> list) throws SqlParseException {
        int size = list.size();
        boolean z = size > 3;
        this._thetaSketchColumn = list.get(0);
        this._thetaSketchIdentifier = new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, this._thetaSketchColumn, (List) null);
        this._inputExpressions = new ArrayList();
        this._inputExpressions.add(this._thetaSketchIdentifier);
        this._thetaSketchParams = ThetaSketchParams.fromString(list.get(1));
        String str = list.get(size - 1);
        this._postAggregationExpression = CalciteSqlParser.compileToExpression(str);
        this._predicateInfoSet = new LinkedHashSet();
        this._predicateStrings = new LinkedHashSet(list.subList(2, size - 1));
        this._expressionMap = new HashMap();
        if (z) {
            for (String str2 : this._predicateStrings) {
                Expression compileToExpression = CalciteSqlParser.compileToExpression(str2);
                String filterColumn = ParserUtils.getFilterColumn(compileToExpression);
                this._predicateInfoSet.add(new PredicateInfo(str2, new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn, (List) null), Predicate.newPredicate(ParserUtils.getFilterType(compileToExpression), filterColumn, ParserUtils.getFilterValues(compileToExpression)), null));
                this._expressionMap.put(compileToExpression, str2);
                this._inputExpressions.add(new TransformExpressionTree(new IdentifierAstNode(filterColumn)));
            }
            return;
        }
        for (Expression expression : extractPredicatesFromString(str)) {
            String filterColumn2 = ParserUtils.getFilterColumn(expression);
            Predicate newPredicate = Predicate.newPredicate(ParserUtils.getFilterType(expression), filterColumn2, ParserUtils.getFilterValues(expression));
            TransformExpressionTree transformExpressionTree = new TransformExpressionTree(TransformExpressionTree.ExpressionType.IDENTIFIER, filterColumn2, (List) null);
            String standardizeExpression = ParserUtils.standardizeExpression(expression, false);
            this._predicateStrings.add(standardizeExpression);
            this._predicateInfoSet.add(new PredicateInfo(standardizeExpression, transformExpressionTree, newPredicate, null));
            this._expressionMap.put(expression, standardizeExpression);
            this._inputExpressions.add(transformExpressionTree);
        }
    }

    private Set<Expression> extractPredicatesFromString(String str) throws SqlParseException {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        this._postAggregationExpression = CalciteSqlParser.compileToExpression(str);
        extractPredicatesFromExpression(this._postAggregationExpression, linkedHashSet);
        return linkedHashSet;
    }

    private void extractPredicatesFromExpression(Expression expression, Set<Expression> set) {
        if (expression.getType().equals(ExpressionType.FUNCTION)) {
            Function functionCall = expression.getFunctionCall();
            FilterKind valueOf = FilterKind.valueOf(functionCall.getOperator());
            List operands = functionCall.getOperands();
            if (!valueOf.equals(FilterKind.AND) && !valueOf.equals(FilterKind.OR)) {
                set.add(expression);
                return;
            }
            Iterator it = operands.iterator();
            while (it.hasNext()) {
                extractPredicatesFromExpression((Expression) it.next(), set);
            }
        }
    }

    private Sketch evalPostAggregationExpression(Expression expression, Map<String, Sketch> map) {
        CompactSketch compactSketch;
        Function functionCall = expression.getFunctionCall();
        switch (AnonymousClass1.$SwitchMap$org$apache$pinot$pql$parsers$pql2$ast$FilterKind[FilterKind.valueOf(functionCall.getOperator()).ordinal()]) {
            case 1:
                Intersection buildIntersection = getSetOperationBuilder().buildIntersection();
                Iterator it = functionCall.getOperands().iterator();
                while (it.hasNext()) {
                    buildIntersection.update(evalPostAggregationExpression((Expression) it.next(), map));
                }
                compactSketch = buildIntersection.getResult();
                break;
            case QuantileDigest.Flags.HAS_RIGHT /* 2 */:
                Union buildUnion = getSetOperationBuilder().buildUnion();
                Iterator it2 = functionCall.getOperands().iterator();
                while (it2.hasNext()) {
                    buildUnion.update(evalPostAggregationExpression((Expression) it2.next(), map));
                }
                compactSketch = buildUnion.getResult();
                break;
            default:
                String str = this._expressionMap.get(expression);
                compactSketch = (Sketch) map.get(str);
                Preconditions.checkState(compactSketch != null, "Precomputed sketch for predicate not provided: " + str);
                break;
        }
        return compactSketch;
    }

    private SetOperationBuilder getSetOperationBuilder() {
        return this._thetaSketchParams == null ? SetOperation.builder() : SetOperation.builder().setNominalEntries(this._thetaSketchParams.getNominalEntries());
    }
}
