package com.linkedin.coral.spark.transformers;

import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.hive.hive2rel.functions.VersionedSqlUserDefinedFunction;
import com.linkedin.coral.spark.containers.SparkUDFInfo;
import com.linkedin.coral.spark.exceptions.UnsupportedUDFException;
import java.net.URI;
import java.util.Collections;
import java.util.Set;
import org.apache.calcite.sql.SqlCall;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/linkedin/coral/spark/transformers/TransportUDFTransformer.class */
public class TransportUDFTransformer extends SqlCallTransformer {
    private final String hiveUDFClassName;
    private final String sparkUDFClassName;
    private final String artifactoryUrlSpark211;
    private final String artifactoryUrlSpark212;
    private final Set<SparkUDFInfo> sparkUDFInfos;
    private ScalaVersion scalaVersion;
    private static final Logger LOG = LoggerFactory.getLogger(TransportUDFTransformer.class);
    public static final String DALI_UDFS_IVY_URL_SPARK_2_11 = "ivy://com.linkedin.standard-udfs-dali-udfs:standard-udfs-dali-udfs:2.0.3?classifier=spark_2.11";
    public static final String DALI_UDFS_IVY_URL_SPARK_2_12 = "ivy://com.linkedin.standard-udfs-dali-udfs:standard-udfs-dali-udfs:2.0.3?classifier=spark_2.12";

    /* loaded from: input_file:com/linkedin/coral/spark/transformers/TransportUDFTransformer$ScalaVersion.class */
    public enum ScalaVersion {
        SCALA_2_11,
        SCALA_2_12
    }

    public TransportUDFTransformer(String str, String str2, String str3, String str4, Set<SparkUDFInfo> set) {
        this.hiveUDFClassName = str;
        this.sparkUDFClassName = str2;
        this.artifactoryUrlSpark211 = str3;
        this.artifactoryUrlSpark212 = str4;
        this.sparkUDFInfos = set;
    }

    protected boolean condition(SqlCall sqlCall) {
        this.scalaVersion = getScalaVersionOfSpark();
        if (!(sqlCall.getOperator() instanceof VersionedSqlUserDefinedFunction) || !this.hiveUDFClassName.equalsIgnoreCase(sqlCall.getOperator().getName())) {
            return false;
        }
        if (this.scalaVersion == ScalaVersion.SCALA_2_11 && this.artifactoryUrlSpark211 != null) {
            return true;
        }
        if (this.scalaVersion != ScalaVersion.SCALA_2_12 || this.artifactoryUrlSpark212 == null) {
            throw new UnsupportedUDFException(String.format("Transport UDF for class '%s' is not supported for scala %s, please contact the UDF owner for upgrade", this.hiveUDFClassName, this.scalaVersion.toString()));
        }
        return true;
    }

    protected SqlCall transform(SqlCall sqlCall) {
        VersionedSqlUserDefinedFunction operator = sqlCall.getOperator();
        String viewDependentFunctionName = operator.getViewDependentFunctionName();
        this.sparkUDFInfos.add(new SparkUDFInfo(this.sparkUDFClassName, viewDependentFunctionName, Collections.singletonList(URI.create(this.scalaVersion == ScalaVersion.SCALA_2_11 ? this.artifactoryUrlSpark211 : this.artifactoryUrlSpark212)), SparkUDFInfo.UDFTYPE.TRANSPORTABLE_UDF));
        return createSqlOperator(viewDependentFunctionName, operator.getReturnTypeInference()).createCall(sqlCall.getParserPosition(), sqlCall.getOperandList());
    }

    public ScalaVersion getScalaVersionOfSpark() {
        try {
            String version = SparkSession.active().version();
            if (version.matches("2\\.[\\d\\.]*")) {
                return ScalaVersion.SCALA_2_11;
            }
            if (version.matches("3\\.[\\d\\.]*")) {
                return ScalaVersion.SCALA_2_12;
            }
            throw new IllegalStateException(String.format("Unsupported Spark Version %s", version));
        } catch (IllegalStateException | NoClassDefFoundError e) {
            LOG.warn("Couldn't determine Spark version, falling back to scala_2.11: {}", e.getMessage());
            return ScalaVersion.SCALA_2_11;
        }
    }
}
