package org.apache.iceberg.spark.extensions;

import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import org.apache.iceberg.DistributionMode;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runners.Parameterized;

/* loaded from: input_file:org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownInRowLevelOperations.class */
public class TestSystemFunctionPushDownInRowLevelOperations extends SparkExtensionsTestBase {
    private static final String CHANGES_TABLE_NAME = "changes";

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}")
    public static Object[][] parameters() {
        return new Object[]{new Object[]{SparkCatalogConfig.HIVE.catalogName(), SparkCatalogConfig.HIVE.implementation(), SparkCatalogConfig.HIVE.properties()}};
    }

    public TestSystemFunctionPushDownInRowLevelOperations(String str, String str2, Map<String, String> map) {
        super(str, str2, map);
    }

    @Before
    public void beforeEach() {
        sql("USE %s", new Object[]{this.catalogName});
    }

    @After
    public void removeTables() {
        sql("DROP TABLE IF EXISTS %s PURGE", new Object[]{this.tableName});
        sql("DROP TABLE IF EXISTS %s PURGE", new Object[]{tableName(CHANGES_TABLE_NAME)});
    }

    @Test
    public void testCopyOnWriteDeleteBucketTransformInPredicate() {
        initTable("bucket(4, dep)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
    }

    @Test
    public void testMergeOnReadDeleteBucketTransformInPredicate() {
        initTable("bucket(4, dep)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.bucket(4, dep) IN (2, 3)");
    }

    @Test
    public void testCopyOnWriteDeleteBucketTransformEqPredicate() {
        initTable("bucket(4, dep)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.bucket(4, dep) = 2");
    }

    @Test
    public void testMergeOnReadDeleteBucketTransformEqPredicate() {
        initTable("bucket(4, dep)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.bucket(4, dep) = 2");
    }

    @Test
    public void testCopyOnWriteDeleteYearsTransform() {
        initTable("years(ts)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.years(ts) > 30");
    }

    @Test
    public void testMergeOnReadDeleteYearsTransform() {
        initTable("years(ts)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.years(ts) <= 30");
    }

    @Test
    public void testCopyOnWriteDeleteMonthsTransform() {
        initTable("months(ts)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.months(ts) <= 250");
    }

    @Test
    public void testMergeOnReadDeleteMonthsTransform() {
        initTable("months(ts)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.months(ts) > 250");
    }

    @Test
    public void testCopyOnWriteDeleteDaysTransform() {
        initTable("days(ts)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.days(ts) <= date('2000-01-03 00:00:00')");
    }

    @Test
    public void testMergeOnReadDeleteDaysTransform() {
        initTable("days(ts)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.days(ts) > date('2000-01-03 00:00:00')");
    }

    @Test
    public void testCopyOnWriteDeleteHoursTransform() {
        initTable("hours(ts)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.hours(ts) <= 100000");
    }

    @Test
    public void testMergeOnReadDeleteHoursTransform() {
        initTable("hours(ts)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.hours(ts) > 100000");
    }

    @Test
    public void testCopyOnWriteDeleteTruncateTransform() {
        initTable("truncate(1, dep)");
        checkDelete(RowLevelOperationMode.COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
    }

    @Test
    public void testMergeOnReadDeleteTruncateTransform() {
        initTable("truncate(1, dep)");
        checkDelete(RowLevelOperationMode.MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
    }

    @Test
    public void testCopyOnWriteUpdateBucketTransform() {
        initTable("bucket(4, dep)");
        checkUpdate(RowLevelOperationMode.COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
    }

    @Test
    public void testMergeOnReadUpdateBucketTransform() {
        initTable("bucket(4, dep)");
        checkUpdate(RowLevelOperationMode.MERGE_ON_READ, "system.bucket(4, dep) = 2");
    }

    @Test
    public void testCopyOnWriteUpdateYearsTransform() {
        initTable("years(ts)");
        checkUpdate(RowLevelOperationMode.COPY_ON_WRITE, "system.years(ts) > 30");
    }

    @Test
    public void testMergeOnReadUpdateYearsTransform() {
        initTable("years(ts)");
        checkUpdate(RowLevelOperationMode.MERGE_ON_READ, "system.years(ts) <= 30");
    }

    @Test
    public void testCopyOnWriteMergeBucketTransform() {
        initTable("bucket(4, dep)");
        checkMerge(RowLevelOperationMode.COPY_ON_WRITE, "system.bucket(4, dep) IN (2, 3)");
    }

    @Test
    public void testMergeOnReadMergeBucketTransform() {
        initTable("bucket(4, dep)");
        checkMerge(RowLevelOperationMode.MERGE_ON_READ, "system.bucket(4, dep) = 2");
    }

    @Test
    public void testCopyOnWriteMergeYearsTransform() {
        initTable("years(ts)");
        checkMerge(RowLevelOperationMode.COPY_ON_WRITE, "system.years(ts) > 30");
    }

    @Test
    public void testMergeOnReadMergeYearsTransform() {
        initTable("years(ts)");
        checkMerge(RowLevelOperationMode.MERGE_ON_READ, "system.years(ts) <= 30");
    }

    @Test
    public void testCopyOnWriteMergeTruncateTransform() {
        initTable("truncate(1, dep)");
        checkMerge(RowLevelOperationMode.COPY_ON_WRITE, "system.truncate(1, dep) = 'i'");
    }

    @Test
    public void testMergeOnReadMergeTruncateTransform() {
        initTable("truncate(1, dep)");
        checkMerge(RowLevelOperationMode.MERGE_ON_READ, "system.truncate(1, dep) = 'i'");
    }

    private void checkDelete(RowLevelOperationMode rowLevelOperationMode, String str) {
        withUnavailableLocations(findIrrelevantFileLocations(str), () -> {
            sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", new Object[]{this.tableName, "write.delete.mode", rowLevelOperationMode.modeName(), "write.delete.distribution-mode", DistributionMode.NONE.modeName()});
            spark.table(this.tableName).where(str).limit(2).select("id", new String[0]).coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
            Assertions.assertThat(executeAndCollectFunctionCalls("DELETE FROM %s t WHERE %s AND t.id IN (SELECT id FROM %s)", this.tableName, str, tableName(CHANGES_TABLE_NAME))).hasSize(rowLevelOperationMode == RowLevelOperationMode.COPY_ON_WRITE ? 1 : 0);
            assertEquals("Should have no matching rows", ImmutableList.of(), sql("SELECT * FROM %s WHERE %s AND id IN (SELECT * FROM %s)", new Object[]{this.tableName, str, tableName(CHANGES_TABLE_NAME)}));
        });
    }

    private void checkUpdate(RowLevelOperationMode rowLevelOperationMode, String str) {
        withUnavailableLocations(findIrrelevantFileLocations(str), () -> {
            sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", new Object[]{this.tableName, "write.update.mode", rowLevelOperationMode.modeName(), "write.update.distribution-mode", DistributionMode.NONE.modeName()});
            spark.table(this.tableName).where(str).limit(2).select("id", new String[0]).coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
            Assertions.assertThat(executeAndCollectFunctionCalls("UPDATE %s t SET t.salary = -1 WHERE %s AND t.id IN (SELECT id FROM %s)", this.tableName, str, tableName(CHANGES_TABLE_NAME))).hasSize(rowLevelOperationMode == RowLevelOperationMode.COPY_ON_WRITE ? 2 : 0);
            assertEquals("Should have correct updates", sql("SELECT id FROM %s", new Object[]{tableName(CHANGES_TABLE_NAME)}), sql("SELECT id FROM %s WHERE %s AND salary = -1", new Object[]{this.tableName, str}));
        });
    }

    private void checkMerge(RowLevelOperationMode rowLevelOperationMode, String str) {
        withUnavailableLocations(findIrrelevantFileLocations(str), () -> {
            sql("ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s', '%s' '%s')", new Object[]{this.tableName, "write.merge.mode", rowLevelOperationMode.modeName(), "write.merge.distribution-mode", DistributionMode.NONE.modeName()});
            spark.table(this.tableName).where(str).limit(2).selectExpr(new String[]{"id + 1 as id"}).coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
            Assertions.assertThat(executeAndCollectFunctionCalls("MERGE INTO %s t USING %s s ON t.id == s.id AND %s WHEN MATCHED THEN   UPDATE SET salary = -1 WHEN NOT MATCHED AND s.id = 2 THEN   INSERT (id, salary, dep, ts) VALUES (100, -1, 'hr', null)", this.tableName, tableName(CHANGES_TABLE_NAME), str)).isEmpty();
            assertEquals("Should have correct updates", sql("SELECT id FROM %s", new Object[]{tableName(CHANGES_TABLE_NAME)}), sql("SELECT id FROM %s WHERE %s AND salary = -1", new Object[]{this.tableName, str}));
        });
    }

    private List<Expression> executeAndCollectFunctionCalls(String str, Object... objArr) {
        V2TableWriteExec commandPhysicalPlan = executeAndKeepPlan(str, objArr).commandPhysicalPlan();
        System.out.println("!!! WRITE PLAN !!!");
        System.out.println(commandPhysicalPlan.toString());
        return SparkPlanUtil.collectExprs(commandPhysicalPlan.query(), (Predicate<Expression>) expression -> {
            return (expression instanceof StaticInvoke) || (expression instanceof ApplyFunctionExpression);
        });
    }

    private List<String> findIrrelevantFileLocations(String str) {
        return spark.table(this.tableName).where("NOT " + str).select(MetadataColumns.FILE_PATH.name(), new String[0]).distinct().as(Encoders.STRING()).collectAsList();
    }

    private void initTable(String str) {
        sql("CREATE TABLE %s (id BIGINT, salary INT, dep STRING, ts TIMESTAMP)USING iceberg PARTITIONED BY (%s) TBLPROPERTIES ('%s' 'true')", new Object[]{this.tableName, str, "write.spark.fanout.enabled"});
        append(this.tableName, new String[]{"{ \"id\": 1, \"salary\": 100, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", "{ \"id\": 2, \"salary\": 200, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", "{ \"id\": 3, \"salary\": 300, \"dep\": \"hr\", \"ts\": \"1975-01-01 06:00:00\" }", "{ \"id\": 4, \"salary\": 400, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", "{ \"id\": 5, \"salary\": 500, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }", "{ \"id\": 6, \"salary\": 600, \"dep\": \"it\", \"ts\": \"2020-01-01 10:00:00\" }"});
    }
}
