package org.apache.iceberg.spark.extensions;

import java.math.BigDecimal;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.Iterator;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.catalyst.parser.extensions.IcebergParseException;
import org.apache.spark.sql.catalyst.plans.logical.CallArgument;
import org.apache.spark.sql.catalyst.plans.logical.CallStatement;
import org.apache.spark.sql.catalyst.plans.logical.NamedArgument;
import org.apache.spark.sql.catalyst.plans.logical.PositionalArgument;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import scala.collection.JavaConverters;

/* loaded from: input_file:org/apache/iceberg/spark/extensions/TestCallStatementParser.class */
public class TestCallStatementParser {
    private static SparkSession spark = null;
    private static ParserInterface parser = null;

    @BeforeAll
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName()).config("spark.extra.prop", "value").getOrCreate();
        parser = spark.sessionState().sqlParser();
    }

    @AfterAll
    public static void stopSpark() {
        SparkSession sparkSession = spark;
        spark = null;
        parser = null;
        sparkSession.stop();
    }

    @Test
    public void testCallWithPositionalArgs() throws ParseException {
        CallStatement callStatement = (CallStatement) parser.parsePlan("CALL c.n.func(1, '2', 3L, true, 1.0D, 9.0e1, 900e-1BD)");
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"c", "n", "func"});
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(7);
        checkArg(callStatement, 0, 1, DataTypes.IntegerType);
        checkArg(callStatement, 1, "2", DataTypes.StringType);
        checkArg(callStatement, 2, 3L, DataTypes.LongType);
        checkArg(callStatement, 3, true, DataTypes.BooleanType);
        checkArg(callStatement, 4, Double.valueOf(1.0d), DataTypes.DoubleType);
        checkArg(callStatement, 5, Double.valueOf(90.0d), DataTypes.DoubleType);
        checkArg(callStatement, 6, new BigDecimal("900e-1"), DataTypes.createDecimalType(3, 1));
    }

    @Test
    public void testCallWithNamedArgs() throws ParseException {
        CallStatement callStatement = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, c2 => '2', c3 => true)");
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"cat", "system", "func"});
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(3);
        checkArg(callStatement, 0, "c1", 1, DataTypes.IntegerType);
        checkArg(callStatement, 1, "c2", "2", DataTypes.StringType);
        checkArg(callStatement, 2, "c3", true, DataTypes.BooleanType);
    }

    @Test
    public void testCallWithMixedArgs() throws ParseException {
        CallStatement callStatement = (CallStatement) parser.parsePlan("CALL cat.system.func(c1 => 1, '2')");
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"cat", "system", "func"});
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(2);
        checkArg(callStatement, 0, "c1", 1, DataTypes.IntegerType);
        checkArg(callStatement, 1, "2", DataTypes.StringType);
    }

    @Test
    public void testCallWithTimestampArg() throws ParseException {
        CallStatement callStatement = (CallStatement) parser.parsePlan("CALL cat.system.func(TIMESTAMP '2017-02-03T10:37:30.00Z')");
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"cat", "system", "func"});
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(1);
        checkArg(callStatement, 0, Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), DataTypes.TimestampType);
    }

    @Test
    public void testCallWithVarSubstitution() throws ParseException {
        CallStatement callStatement = (CallStatement) parser.parsePlan("CALL cat.system.func('${spark.extra.prop}')");
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"cat", "system", "func"});
        Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(1);
        checkArg(callStatement, 0, "value", DataTypes.StringType);
    }

    @Test
    public void testCallParseError() {
        Assertions.assertThatThrownBy(() -> {
            parser.parsePlan("CALL cat.system radish kebab");
        }).isInstanceOf(IcebergParseException.class).hasMessageContaining("missing '(' at 'radish'");
    }

    @Test
    public void testCallStripsComments() throws ParseException {
        Iterator it = Lists.newArrayList(new String[]{"/* bracketed comment */  CALL cat.system.func('${spark.extra.prop}')", "/**/  CALL cat.system.func('${spark.extra.prop}')", "-- single line comment \n CALL cat.system.func('${spark.extra.prop}')", "-- multiple \n-- single line \n-- comments \n CALL cat.system.func('${spark.extra.prop}')", "/* select * from multiline_comment \n where x like '%sql%'; */ CALL cat.system.func('${spark.extra.prop}')", "/* {\"app\": \"dbt\", \"dbt_version\": \"1.0.1\", \"profile_name\": \"profile1\", \"target_name\": \"dev\", \"node_id\": \"model.profile1.stg_users\"} \n*/ CALL cat.system.func('${spark.extra.prop}')", "/* Some multi-line comment \n*/ CALL /* inline comment */ cat.system.func('${spark.extra.prop}') -- ending comment", "CALL -- a line ending comment\ncat.system.func('${spark.extra.prop}')"}).iterator();
        while (it.hasNext()) {
            CallStatement callStatement = (CallStatement) parser.parsePlan((String) it.next());
            Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.name())).containsExactly(new String[]{"cat", "system", "func"});
            Assertions.assertThat(JavaConverters.seqAsJavaList(callStatement.args())).hasSize(1);
            checkArg(callStatement, 0, "value", DataTypes.StringType);
        }
    }

    private void checkArg(CallStatement callStatement, int i, Object obj, DataType dataType) {
        checkArg(callStatement, i, null, obj, dataType);
    }

    private void checkArg(CallStatement callStatement, int i, String str, Object obj, DataType dataType) {
        if (str != null) {
            Assertions.assertThat(((NamedArgument) checkCast(callStatement.args().apply(i), NamedArgument.class)).name()).isEqualTo(str);
        } else {
            checkCast((CallArgument) callStatement.args().apply(i), PositionalArgument.class);
        }
        Literal sparkLiteral = toSparkLiteral(obj, dataType);
        Expression expr = ((CallArgument) callStatement.args().apply(i)).expr();
        Assertions.assertThat(expr.dataType()).as("Arg types must match", new Object[0]).isEqualTo(sparkLiteral.dataType());
        Assertions.assertThat(expr).as("Arg must match", new Object[0]).isEqualTo(sparkLiteral);
    }

    private Literal toSparkLiteral(Object obj, DataType dataType) {
        return Literal$.MODULE$.create(obj, dataType);
    }

    private <T> T checkCast(Object obj, Class<T> cls) {
        Assertions.assertThat(obj).isInstanceOf(cls);
        return cls.cast(obj);
    }
}
