package org.apache.flink.table.planner.plan.rules.logical;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.utils.PythonUtil;

/* loaded from: input_file:flink-table-store-codegen.jar:org/apache/flink/table/planner/plan/rules/logical/PythonMapMergeRule.class */
public class PythonMapMergeRule extends RelOptRule {
    public static final PythonMapMergeRule INSTANCE = new PythonMapMergeRule();

    private PythonMapMergeRule() {
        super(operand(FlinkLogicalCalc.class, operand(FlinkLogicalCalc.class, operand(FlinkLogicalCalc.class, none()), new RelOptRuleOperand[0]), new RelOptRuleOperand[0]), "PythonMapMergeRule");
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        FlinkLogicalCalc flinkLogicalCalc = (FlinkLogicalCalc) relOptRuleCall.rel(0);
        FlinkLogicalCalc flinkLogicalCalc2 = (FlinkLogicalCalc) relOptRuleCall.rel(1);
        FlinkLogicalCalc flinkLogicalCalc3 = (FlinkLogicalCalc) relOptRuleCall.rel(2);
        RexProgram program = flinkLogicalCalc.getProgram();
        Stream<RexLocalRef> stream = program.getProjectList().stream();
        program.getClass();
        List list = (List) stream.map(program::expandLocalRef).collect(Collectors.toList());
        if (list.size() != 1 || !PythonUtil.isPythonCall((RexNode) list.get(0), null) || !PythonUtil.takesRowAsInput((RexCall) list.get(0))) {
            return false;
        }
        RexProgram program2 = flinkLogicalCalc3.getProgram();
        Stream<RexLocalRef> stream2 = program2.getProjectList().stream();
        program2.getClass();
        List list2 = (List) stream2.map(program2::expandLocalRef).collect(Collectors.toList());
        if (list2.size() != 1 || !PythonUtil.isPythonCall((RexNode) list2.get(0), null) || (PythonUtil.isPythonCall((RexNode) list.get(0), PythonFunctionKind.GENERAL) ^ PythonUtil.isPythonCall((RexNode) list2.get(0), PythonFunctionKind.GENERAL))) {
            return false;
        }
        RexProgram program3 = flinkLogicalCalc2.getProgram();
        if (program.getCondition() != null || program3.getCondition() != null || program2.getCondition() != null) {
            return false;
        }
        Stream<RexLocalRef> stream3 = program3.getProjectList().stream();
        program3.getClass();
        List<RexNode> list3 = (List) stream3.map(program3::expandLocalRef).collect(Collectors.toList());
        return isFlattenCalc(list3, program3.getInputRowType().getFieldList().get(0).getValue().getFieldList().size()) && isTopCalcTakesWholeMiddleCalcAsInputs((RexCall) list.get(0), list3.size());
    }

    private boolean isTopCalcTakesWholeMiddleCalcAsInputs(RexCall rexCall, int i) {
        List<RexNode> operands = rexCall.getOperands();
        if (operands.size() != i) {
            return false;
        }
        for (int i2 = 0; i2 < operands.size(); i2++) {
            RexNode rexNode = operands.get(i2);
            if (!(rexNode instanceof RexInputRef) || ((RexInputRef) rexNode).getIndex() != i2) {
                return false;
            }
        }
        return true;
    }

    private boolean isFlattenCalc(List<RexNode> list, int i) {
        if (i != list.size()) {
            return false;
        }
        for (int i2 = 0; i2 < i; i2++) {
            RexNode rexNode = list.get(i2);
            if (!(rexNode instanceof RexFieldAccess)) {
                return false;
            }
            RexFieldAccess rexFieldAccess = (RexFieldAccess) rexNode;
            if (rexFieldAccess.getField().getIndex() != i2) {
                return false;
            }
            RexNode referenceExpr = rexFieldAccess.getReferenceExpr();
            if (!(referenceExpr instanceof RexInputRef) || ((RexInputRef) referenceExpr).getIndex() != 0) {
                return false;
            }
        }
        return true;
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        FlinkLogicalCalc flinkLogicalCalc = (FlinkLogicalCalc) relOptRuleCall.rel(0);
        FlinkLogicalCalc flinkLogicalCalc2 = (FlinkLogicalCalc) relOptRuleCall.rel(1);
        FlinkLogicalCalc flinkLogicalCalc3 = (FlinkLogicalCalc) relOptRuleCall.rel(2);
        RexProgram program = flinkLogicalCalc.getProgram();
        Stream<RexLocalRef> stream = program.getProjectList().stream();
        program.getClass();
        RexCall rexCall = (RexCall) ((List) stream.map(program::expandLocalRef).map(rexNode -> {
            return (RexCall) rexNode;
        }).collect(Collectors.toList())).get(0);
        FlinkLogicalCalc flinkLogicalCalc4 = new FlinkLogicalCalc(flinkLogicalCalc2.getCluster(), flinkLogicalCalc2.getTraitSet(), flinkLogicalCalc3, RexProgram.create(flinkLogicalCalc3.getRowType(), (List<? extends RexNode>) Collections.singletonList(rexCall.clone(rexCall.getType(), Collections.singletonList(RexInputRef.of(0, flinkLogicalCalc3.getRowType())))), (RexNode) null, (List<String>) Collections.singletonList("f0"), relOptRuleCall.builder().getRexBuilder()));
        relOptRuleCall.transformTo(flinkLogicalCalc4.copy(flinkLogicalCalc4.getTraitSet(), flinkLogicalCalc3.getInput(), RexProgramBuilder.mergePrograms(flinkLogicalCalc4.getProgram(), flinkLogicalCalc3.getProgram(), relOptRuleCall.builder().getRexBuilder())));
    }
}
