package com.linkedin.coral.spark;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.hive.hive2rel.functions.GenericProjectFunction;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;
import com.linkedin.coral.hive.hive2rel.functions.VersionedSqlUserDefinedFunction;
import com.linkedin.coral.spark.containers.SparkRelInfo;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.exceptions.UnsupportedUDFException;
import com.linkedin.coral.spark.utils.RelDataTypeToSparkDataTypeStringConverter;
import java.math.BigDecimal;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelShuttleImpl;
import org.apache.calcite.rel.core.TableFunctionScan;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalExchange;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalMatch;
import org.apache.calcite.rel.logical.LogicalMinus;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSort;
import org.apache.calcite.rel.logical.LogicalUnion;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.schema.Function;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlOperandTypeInference;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/linkedin/coral/spark/IRRelToSparkRelTransformer.class */
class IRRelToSparkRelTransformer {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/coral/spark/IRRelToSparkRelTransformer$SparkRexConverter.class */
    public static class SparkRexConverter extends RexShuttle {
        private final RexBuilder rexBuilder;
        private List<SparkUDFInfo> sparkUDFInfos;
        private static final Logger LOG = LoggerFactory.getLogger(SparkRexConverter.class);

        SparkRexConverter(RexBuilder rexBuilder, List<SparkUDFInfo> list) {
            this.sparkUDFInfos = list;
            this.rexBuilder = rexBuilder;
        }

        /* renamed from: visitCall, reason: merged with bridge method [inline-methods] */
        public RexNode m2visitCall(RexCall rexCall) {
            if (rexCall == null) {
                return null;
            }
            RexCall rexCall2 = (RexCall) super.visitCall(rexCall);
            return convertToZeroBasedArrayIndex(rexCall2).orElseGet(() -> {
                return convertToNamedStruct(rexCall2).orElseGet(() -> {
                    return convertFuzzyUnionGenericProject(rexCall2).orElseGet(() -> {
                        return convertDaliUDF(rexCall2).orElseGet(() -> {
                            return convertBuiltInUDF(rexCall2).orElseGet(() -> {
                                return fallbackToHiveUdf(rexCall2).orElseGet(() -> {
                                    return removeExtractUnionFunction(rexCall2).orElse(rexCall2);
                                });
                            });
                        });
                    });
                });
            });
        }

        private Optional<RexNode> convertDaliUDF(RexCall rexCall) {
            Optional<SparkUDFInfo> lookup = TransportableUDFMap.lookup(rexCall.getOperator().getName());
            if (!lookup.isPresent()) {
                return Optional.empty();
            }
            String viewDependentFunctionName = rexCall.getOperator().getViewDependentFunctionName();
            SparkUDFInfo sparkUDFInfo = lookup.get();
            this.sparkUDFInfos.add(new SparkUDFInfo(sparkUDFInfo.getClassName(), viewDependentFunctionName, sparkUDFInfo.getArtifactoryUrls(), sparkUDFInfo.getUdfType()));
            return Optional.of(this.rexBuilder.makeCall(createUDF(viewDependentFunctionName, rexCall.getOperator().getReturnTypeInference()), rexCall.getOperands()));
        }

        private Optional<RexNode> convertBuiltInUDF(RexCall rexCall) {
            return BuiltinUDFMap.lookup(rexCall.getOperator().getName()).map(str -> {
                return this.rexBuilder.makeCall(createUDF(str, rexCall.getOperator().getReturnTypeInference()), rexCall.getOperands());
            });
        }

        private Optional<RexNode> fallbackToHiveUdf(RexCall rexCall) {
            VersionedSqlUserDefinedFunction operator = rexCall.getOperator();
            String name = operator.getName();
            Optional empty = Optional.empty();
            if (name.indexOf(46) >= 0) {
                if (UnsupportedHiveUDFsInSpark.contains(name).booleanValue()) {
                    throw new UnsupportedUDFException(name);
                }
                VersionedSqlUserDefinedFunction versionedSqlUserDefinedFunction = operator;
                String viewDependentFunctionName = versionedSqlUserDefinedFunction.getViewDependentFunctionName();
                List ivyDependencies = versionedSqlUserDefinedFunction.getIvyDependencies();
                empty = Optional.of(new SparkUDFInfo(name, viewDependentFunctionName, (List) ivyDependencies.stream().map(URI::create).collect(Collectors.toList()), SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF));
                LOG.info("Function: {} is not a Builtin UDF or Transportable UDF.  We fall back to its Hive function with ivy dependency: {}", name, String.join(",", ivyDependencies));
            }
            List<SparkUDFInfo> list = this.sparkUDFInfos;
            list.getClass();
            empty.ifPresent((v1) -> {
                r1.add(v1);
            });
            return empty.map(sparkUDFInfo -> {
                return this.rexBuilder.makeCall(createUDF(sparkUDFInfo.getFunctionName(), rexCall.getOperator().getReturnTypeInference()), rexCall.getOperands());
            });
        }

        private Optional<RexNode> convertToZeroBasedArrayIndex(RexCall rexCall) {
            if (rexCall.getOperator().equals(SqlStdOperatorTable.ITEM)) {
                RexNode rexNode = (RexNode) rexCall.getOperands().get(0);
                RexLiteral rexLiteral = (RexNode) rexCall.getOperands().get(1);
                if (rexNode.getType() instanceof ArraySqlType) {
                    if (rexLiteral.isA(SqlKind.LITERAL) && rexLiteral.getType().getSqlTypeName().equals(SqlTypeName.INTEGER)) {
                        return Optional.of(this.rexBuilder.makeCall(rexCall.op, new RexNode[]{rexNode, this.rexBuilder.makeExactLiteral(new BigDecimal(((Integer) rexLiteral.getValueAs(Integer.class)).intValue() - 1), rexLiteral.getType())}));
                    }
                    return Optional.of(this.rexBuilder.makeCall(rexCall.op, new RexNode[]{rexNode, this.rexBuilder.makeCall(SqlStdOperatorTable.MINUS, new RexNode[]{rexLiteral, this.rexBuilder.makeExactLiteral(BigDecimal.ONE)})}));
                }
            }
            return Optional.empty();
        }

        private Optional<RexNode> convertToNamedStruct(RexCall rexCall) {
            if (rexCall.getOperator().equals(SqlStdOperatorTable.CAST)) {
                RexCall rexCall2 = (RexNode) rexCall.getOperands().get(0);
                if ((rexCall2 instanceof RexCall) && rexCall2.getOperator().equals(SqlStdOperatorTable.ROW)) {
                    RelRecordType type = rexCall.getType();
                    List operands = rexCall2.getOperands();
                    ArrayList arrayList = new ArrayList(type.getFieldCount() * 2);
                    for (int i = 0; i < type.getFieldCount(); i++) {
                        RelDataTypeField relDataTypeField = (RelDataTypeField) type.getFieldList().get(i);
                        arrayList.add(this.rexBuilder.makeLiteral((String) relDataTypeField.getKey()));
                        arrayList.add(this.rexBuilder.makeCast(relDataTypeField.getType(), (RexNode) operands.get(i)));
                    }
                    return Optional.of(this.rexBuilder.makeCall(rexCall.getType(), new HiveNamedStructFunction(), arrayList));
                }
            }
            return Optional.empty();
        }

        private Optional<RexNode> convertFuzzyUnionGenericProject(RexCall rexCall) {
            if (!(rexCall.getOperator() instanceof GenericProjectFunction)) {
                return Optional.empty();
            }
            RelDataType type = rexCall.getType();
            String convertRelDataType = RelDataTypeToSparkDataTypeStringConverter.convertRelDataType(type);
            ArrayList arrayList = new ArrayList();
            arrayList.add(rexCall.getOperands().get(0));
            arrayList.add(this.rexBuilder.makeLiteral(convertRelDataType));
            return Optional.of(this.rexBuilder.makeCall(type, new GenericProjectFunction(type), arrayList));
        }

        private Optional<RexNode> removeExtractUnionFunction(RexCall rexCall) {
            if (rexCall.getOperator().getName().equalsIgnoreCase("extract_union")) {
                if (rexCall.getOperands().size() == 1) {
                    return Optional.of(rexCall.getOperands().get(0));
                }
                if (rexCall.getOperands().size() == 2) {
                    return Optional.of(this.rexBuilder.makeFieldAccess((RexNode) rexCall.getOperands().get(0), ((Integer) ((RexLiteral) rexCall.getOperands().get(1)).getValueAs(Integer.class)).intValue()));
                }
            }
            return Optional.empty();
        }

        private static SqlOperator createUDF(String str, SqlReturnTypeInference sqlReturnTypeInference) {
            return new SqlUserDefinedFunction(new SqlIdentifier(ImmutableList.of(str), SqlParserPos.ZERO), sqlReturnTypeInference, (SqlOperandTypeInference) null, (SqlOperandTypeChecker) null, (List) null, (Function) null);
        }
    }

    private IRRelToSparkRelTransformer() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static SparkRelInfo transform(RelNode relNode) {
        final ArrayList arrayList = new ArrayList();
        return new SparkRelInfo(relNode.accept(new RelShuttleImpl() { // from class: com.linkedin.coral.spark.IRRelToSparkRelTransformer.1
            public RelNode visit(LogicalProject logicalProject) {
                return super.visit(logicalProject).accept(getSparkRexConverter(logicalProject));
            }

            public RelNode visit(LogicalFilter logicalFilter) {
                return super.visit(logicalFilter).accept(getSparkRexConverter(logicalFilter));
            }

            public RelNode visit(LogicalAggregate logicalAggregate) {
                return super.visit(logicalAggregate).accept(getSparkRexConverter(logicalAggregate));
            }

            public RelNode visit(LogicalMatch logicalMatch) {
                return super.visit(logicalMatch).accept(getSparkRexConverter(logicalMatch));
            }

            public RelNode visit(TableScan tableScan) {
                return super.visit(tableScan).accept(getSparkRexConverter(tableScan));
            }

            public RelNode visit(TableFunctionScan tableFunctionScan) {
                return super.visit(tableFunctionScan).accept(getSparkRexConverter(tableFunctionScan));
            }

            public RelNode visit(LogicalValues logicalValues) {
                return super.visit(logicalValues).accept(getSparkRexConverter(logicalValues));
            }

            public RelNode visit(LogicalJoin logicalJoin) {
                return super.visit(logicalJoin).accept(getSparkRexConverter(logicalJoin));
            }

            public RelNode visit(LogicalCorrelate logicalCorrelate) {
                return super.visit(logicalCorrelate).accept(getSparkRexConverter(logicalCorrelate));
            }

            public RelNode visit(LogicalUnion logicalUnion) {
                return super.visit(logicalUnion).accept(getSparkRexConverter(logicalUnion));
            }

            public RelNode visit(LogicalIntersect logicalIntersect) {
                return super.visit(logicalIntersect).accept(getSparkRexConverter(logicalIntersect));
            }

            public RelNode visit(LogicalMinus logicalMinus) {
                return super.visit(logicalMinus).accept(getSparkRexConverter(logicalMinus));
            }

            public RelNode visit(LogicalSort logicalSort) {
                return super.visit(logicalSort).accept(getSparkRexConverter(logicalSort));
            }

            public RelNode visit(LogicalExchange logicalExchange) {
                return super.visit(logicalExchange).accept(getSparkRexConverter(logicalExchange));
            }

            public RelNode visit(RelNode relNode2) {
                return super.visit(relNode2).accept(getSparkRexConverter(relNode2));
            }

            private SparkRexConverter getSparkRexConverter(RelNode relNode2) {
                return new SparkRexConverter(relNode2.getCluster().getRexBuilder(), arrayList);
            }
        }), arrayList);
    }
}
