package io.kyligence.kap.query.optrule;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.kylin.guava30.shaded.common.collect.ImmutableList;
import org.apache.kylin.guava30.shaded.common.collect.Lists;
import org.apache.kylin.guava30.shaded.common.collect.Maps;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.RelNode;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.Aggregate;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.AggregateCall;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.core.RelFactories;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataType;
import org.apache.kylin.job.shaded.org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexLiteral;
import org.apache.kylin.job.shaded.org.apache.calcite.rex.RexNode;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeName;
import org.apache.kylin.job.shaded.org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilder;
import org.apache.kylin.job.shaded.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.kylin.job.shaded.org.apache.calcite.util.CompositeList;
import org.apache.kylin.job.shaded.org.apache.calcite.util.ImmutableIntList;

/* loaded from: input_file:io/kyligence/kap/query/optrule/CorrReduceFunctionRule.class */
public class CorrReduceFunctionRule extends RelOptRule {
    public static final CorrReduceFunctionRule INSTANCE = new CorrReduceFunctionRule(operand(Aggregate.class, any()), RelFactories.LOGICAL_BUILDER, "CorrReduceFunctionRule");
    private static final RelDataType DOUBLE_TYPE = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createSqlType(SqlTypeName.DOUBLE);

    public CorrReduceFunctionRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory, String str) {
        super(relOptRuleOperand, relBuilderFactory, str);
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        return containsCorrCall(((Aggregate) relOptRuleCall.rels[0]).getAggCallList());
    }

    @Override // org.apache.kylin.job.shaded.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        reduceAggs(relOptRuleCall, (Aggregate) relOptRuleCall.rels[0]);
    }

    private void reduceAggs(RelOptRuleCall relOptRuleCall, Aggregate aggregate) {
        List<AggregateCall> aggCallList = aggregate.getAggCallList();
        int groupCount = aggregate.getGroupCount();
        int indicatorCount = aggregate.getIndicatorCount();
        ArrayList newArrayList = Lists.newArrayList();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        for (int i = 0; i < groupCount + indicatorCount; i++) {
            newArrayList.add(rexBuilder.makeInputRef(getFieldType(aggregate, i), i));
        }
        RelBuilder builder = relOptRuleCall.builder();
        builder.push(aggregate.getInput());
        ArrayList arrayList = new ArrayList(builder.fields());
        ArrayList newArrayList2 = Lists.newArrayList();
        HashMap newHashMap = Maps.newHashMap();
        Iterator<AggregateCall> it2 = aggCallList.iterator();
        while (it2.hasNext()) {
            newArrayList.add(reduceAgg(aggregate, it2.next(), newArrayList2, newHashMap, arrayList));
        }
        int size = arrayList.size() - builder.peek().getRowType().getFieldCount();
        if (size > 0) {
            builder.project(arrayList, CompositeList.of((List) builder.peek().getRowType().getFieldNames(), Collections.nCopies(size, null)));
        }
        builder.aggregate(builder.groupKey(aggregate.getGroupSet(), aggregate.indicator, aggregate.getGroupSets()), (List<AggregateCall>) newArrayList2);
        builder.project(newArrayList, aggregate.getRowType().getFieldNames());
        relOptRuleCall.transformTo(builder.build());
    }

    private boolean containsCorrCall(List<AggregateCall> list) {
        Iterator<AggregateCall> it2 = list.iterator();
        while (it2.hasNext()) {
            if ("CORR".equals(it2.next().getAggregation().getName())) {
                return true;
            }
        }
        return false;
    }

    private RexNode reduceAgg(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        if ("CORR".equals(aggregateCall.getAggregation().getName())) {
            return reduceCORR(aggregate, aggregateCall, list, map, list2);
        }
        return aggregate.getCluster().getRexBuilder().addAggCall(aggregateCall, aggregate.getGroupCount(), aggregate.indicator, list, map, SqlTypeUtil.projectTypes(aggregate.getInput().getRowType(), aggregateCall.getArgList()));
    }

    private RexNode reduceCORR(Aggregate aggregate, AggregateCall aggregateCall, List<AggregateCall> list, Map<AggregateCall, RexNode> map, List<RexNode> list2) {
        int groupCount = aggregate.getGroupCount();
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        Function function = rexNode -> {
            return rexBuilder.makeCast(DOUBLE_TYPE, rexNode);
        };
        if (aggregateCall.getArgList() == null || aggregateCall.getArgList().size() < 2) {
            throw new IllegalArgumentException("CORR must have 2 argument parameters");
        }
        int intValue = aggregateCall.getArgList().get(0).intValue();
        int intValue2 = aggregateCall.getArgList().get(1).intValue();
        RelDataType fieldType = getFieldType(aggregate.getInput(), intValue);
        RexNode rexNode2 = list2.get(intValue);
        RexNode rexNode3 = list2.get(intValue2);
        RexNode rexNode4 = (RexNode) function.apply(rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), ImmutableIntList.of(intValue), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(list2.get(intValue).getType())));
        RexNode rexNode5 = (RexNode) function.apply(rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), ImmutableIntList.of(intValue2), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(list2.get(intValue2).getType())));
        RexNode makeCall = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode4, rexNode4);
        RexNode makeCall2 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode4, rexNode5);
        RexNode makeCall3 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode5, rexNode5);
        RexNode rexNode6 = (RexNode) function.apply(buildMultiplyRexNode(aggregate, aggregateCall, list2, list, map, rexNode2, rexNode2));
        RexNode rexNode7 = (RexNode) function.apply(buildMultiplyRexNode(aggregate, aggregateCall, list2, list, map, rexNode3, rexNode3));
        RexNode rexNode8 = (RexNode) function.apply(buildMultiplyRexNode(aggregate, aggregateCall, list2, list, map, rexNode2, rexNode3));
        RexNode addAggCall = rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.COUNT, aggregateCall.isDistinct(), Lists.newArrayList(), aggregateCall.filterArg, aggregate.getGroupCount(), aggregate.getInput(), null, null), groupCount, aggregate.indicator, list, map, ImmutableList.of(fieldType));
        RexNode makeCall4 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, addAggCall, rexNode8), makeCall2);
        RexNode makeCall5 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode6, addAggCall), makeCall);
        RexNode makeCall6 = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexNode7, addAggCall), makeCall3);
        RexLiteral makeExactLiteral = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        RexNode makeCall7 = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, rexBuilder.makeCall(SqlStdOperatorTable.POWER, makeCall5, makeExactLiteral), rexBuilder.makeCall(SqlStdOperatorTable.POWER, makeCall6, makeExactLiteral));
        RexNode makeCall8 = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, (RexNode) function.apply(makeCall4), (RexNode) function.apply(makeCall7));
        LinkedList linkedList = new LinkedList();
        linkedList.add(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, addAggCall, rexBuilder.makeZeroLiteral(makeCall7.getType())));
        linkedList.add(rexBuilder.makeNullLiteral(DOUBLE_TYPE));
        linkedList.add(makeCall8);
        return rexBuilder.makeCast(aggregateCall.getType(), rexBuilder.makeCall(DOUBLE_TYPE, SqlStdOperatorTable.CASE, linkedList));
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        return relNode.getRowType().getFieldList().get(i).getType();
    }

    private static <T> int lookupOrAdd(List<T> list, T t) {
        int indexOf = list.indexOf(t);
        if (indexOf == -1) {
            indexOf = list.size();
            list.add(t);
        }
        return indexOf;
    }

    private RexNode buildMultiplyRexNode(Aggregate aggregate, AggregateCall aggregateCall, List<RexNode> list, List<AggregateCall> list2, Map<AggregateCall, RexNode> map, RexNode rexNode, RexNode rexNode2) {
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        int lookupOrAdd = lookupOrAdd(list, rexBuilder.makeCall(resultType(), SqlStdOperatorTable.MULTIPLY, Lists.newArrayList(rexNode, rexNode2)));
        return rexBuilder.addAggCall(AggregateCall.create(SqlStdOperatorTable.SUM, aggregateCall.isDistinct(), ImmutableIntList.of(lookupOrAdd), aggregateCall.filterArg, SqlStdOperatorTable.SUM.inferReturnType(new Aggregate.AggCallBinding(aggregate.getCluster().getTypeFactory(), SqlStdOperatorTable.SUM, ImmutableList.of(list.get(lookupOrAdd).getType()), aggregate.getGroupCount(), aggregateCall.filterArg >= 0)), null), aggregate.getGroupCount(), aggregate.indicator, list2, map, ImmutableList.of(list.get(lookupOrAdd).getType()));
    }

    private RelDataType resultType() {
        return DOUBLE_TYPE;
    }
}
