package org.apache.iceberg.spark.source;

import java.io.File;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.net.URI;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RawLocalFileSystem;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.expressions.Literal;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Sets;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.unsafe.types.UTF8String;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/iceberg/spark/source/TestPartitionPruning.class */
public class TestPartitionPruning {
    private final String format;
    private final boolean vectorized;

    @Rule
    public TemporaryFolder temp = new TemporaryFolder();
    private PartitionSpec spec = PartitionSpec.builderFor(LOG_SCHEMA).identity("date").identity("level").bucket("id", 3).truncate("message", 5).hour("timestamp").build();
    private static final Configuration CONF = new Configuration();
    private static final HadoopTables TABLES = new HadoopTables(CONF);
    private static SparkSession spark = null;
    private static JavaSparkContext sparkContext = null;
    private static Transform<Object, Integer> bucketTransform = Transforms.bucket(Types.IntegerType.get(), 3);
    private static Transform<Object, Object> truncateTransform = Transforms.truncate(Types.StringType.get(), 5);
    private static Transform<Object, Integer> hourTransform = Transforms.hour(Types.TimestampType.withoutZone());
    private static final Schema LOG_SCHEMA = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.IntegerType.get()), Types.NestedField.optional(2, "date", Types.StringType.get()), Types.NestedField.optional(3, "level", Types.StringType.get()), Types.NestedField.optional(4, "message", Types.StringType.get()), Types.NestedField.optional(5, "timestamp", Types.TimestampType.withZone())});
    private static final List<LogMessage> LOGS = ImmutableList.of(LogMessage.debug("2020-02-02", "debug event 1", getInstant("2020-02-02T00:00:00")), LogMessage.info("2020-02-02", "info event 1", getInstant("2020-02-02T01:00:00")), LogMessage.debug("2020-02-02", "debug event 2", getInstant("2020-02-02T02:00:00")), LogMessage.info("2020-02-03", "info event 2", getInstant("2020-02-03T00:00:00")), LogMessage.debug("2020-02-03", "debug event 3", getInstant("2020-02-03T01:00:00")), LogMessage.info("2020-02-03", "info event 3", getInstant("2020-02-03T02:00:00")), LogMessage.error("2020-02-03", "error event 1", getInstant("2020-02-03T03:00:00")), LogMessage.debug("2020-02-04", "debug event 4", getInstant("2020-02-04T01:00:00")), LogMessage.warn("2020-02-04", "warn event 1", getInstant("2020-02-04T02:00:00")), LogMessage.debug("2020-02-04", "debug event 5", getInstant("2020-02-04T03:00:00")));

    /* loaded from: input_file:org/apache/iceberg/spark/source/TestPartitionPruning$CountOpenLocalFileSystem.class */
    public static class CountOpenLocalFileSystem extends RawLocalFileSystem {
        public static String scheme = String.format("TestIdentityPartitionData%dfs", Integer.valueOf(new Random().nextInt()));
        public static ConcurrentHashMap<String, Long> pathToNumOpenCalled = new ConcurrentHashMap<>();

        public static String convertPath(String str) {
            return scheme + "://" + str;
        }

        public static String convertPath(File file) {
            return convertPath(file.getAbsolutePath());
        }

        public static String stripScheme(String str) {
            if (!str.startsWith(scheme + ":")) {
                throw new IllegalArgumentException("Received unexpected path: " + str);
            }
            int length = scheme.length() + 1;
            while (str.charAt(length) == '/') {
                length++;
            }
            return str.substring(length - 1);
        }

        public static void resetRecordsInPathPrefix(String str) {
            pathToNumOpenCalled.keySet().stream().filter(str2 -> {
                return str2.startsWith(str);
            }).forEach(str3 -> {
                pathToNumOpenCalled.remove(str3);
            });
        }

        public URI getUri() {
            return URI.create(scheme + ":///");
        }

        public String getScheme() {
            return scheme;
        }

        public FSDataInputStream open(Path path, int i) throws IOException {
            pathToNumOpenCalled.compute(path.toUri().getPath(), (str, l) -> {
                if (l == null) {
                    return 1L;
                }
                return Long.valueOf(l.longValue() + 1);
            });
            return super.open(path, i);
        }
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @Parameterized.Parameters(name = "format = {0}, vectorized = {1}")
    public static Object[][] parameters() {
        return new Object[]{new Object[]{"parquet", false}, new Object[]{"parquet", true}, new Object[]{"avro", false}, new Object[]{"orc", false}, new Object[]{"orc", true}};
    }

    public TestPartitionPruning(String str, boolean z) {
        this.format = str;
        this.vectorized = z;
    }

    @BeforeClass
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").getOrCreate();
        sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());
        String format = String.format("fs.%s.impl", CountOpenLocalFileSystem.scheme);
        CONF.set(format, CountOpenLocalFileSystem.class.getName());
        spark.conf().set(format, CountOpenLocalFileSystem.class.getName());
        spark.conf().set("spark.sql.session.timeZone", "UTC");
        spark.udf().register("bucket3", num -> {
            return (Integer) bucketTransform.apply(num);
        }, DataTypes.IntegerType);
        spark.udf().register("truncate5", str -> {
            return truncateTransform.apply(str);
        }, DataTypes.StringType);
        spark.udf().register("hour", timestamp -> {
            return (Integer) hourTransform.apply(Long.valueOf(DateTimeUtils.fromJavaTimestamp(timestamp)));
        }, DataTypes.IntegerType);
    }

    @AfterClass
    public static void stopSpark() {
        SparkSession sparkSession = spark;
        spark = null;
        sparkSession.stop();
    }

    private static Instant getInstant(String str) {
        return Instant.ofEpochMilli(TimeUnit.MICROSECONDS.toMillis(((Long) Literal.of(str).to(Types.TimestampType.withoutZone()).value()).longValue()));
    }

    @Test
    public void testPartitionPruningIdentityString() {
        runTest("date >= '2020-02-03' AND level = 'DEBUG'", row -> {
            return row.getString(0).compareTo("2020-02-03") >= 0 && row.getString(1).equals("DEBUG");
        });
    }

    @Test
    public void testPartitionPruningBucketingInteger() {
        int[] iArr = {LOGS.get(3).getId(), LOGS.get(7).getId()};
        runTest("id in " + ((String) Arrays.stream(iArr).mapToObj(String::valueOf).collect(Collectors.joining(",", "(", ")"))), row -> {
            int i = row.getInt(2);
            IntStream stream = Arrays.stream(iArr);
            Transform<Object, Integer> transform = bucketTransform;
            Objects.requireNonNull(transform);
            return ((Set) stream.map((v1) -> {
                return r1.apply(v1);
            }).boxed().collect(Collectors.toSet())).contains(Integer.valueOf(i));
        });
    }

    @Test
    public void testPartitionPruningTruncatedString() {
        runTest("message like 'info event%'", row -> {
            return row.getString(3).equals("info ");
        });
    }

    @Test
    public void testPartitionPruningTruncatedStringComparingValueShorterThanPartitionValue() {
        runTest("message like 'inf%'", row -> {
            return row.getString(3).startsWith("inf");
        });
    }

    @Test
    public void testPartitionPruningHourlyPartition() {
        runTest(spark.version().startsWith("2") ? "timestamp >= to_timestamp('2020-02-03T01:00:00')" : "timestamp >= '2020-02-03T01:00:00'", row -> {
            return row.getInt(4) >= ((Integer) hourTransform.apply(Long.valueOf(TimeUnit.MILLISECONDS.toMicros(getInstant("2020-02-03T01:00:00").toEpochMilli())))).intValue();
        });
    }

    private void runTest(String str, Predicate<Row> predicate) {
        File createTempDir = createTempDir();
        Assert.assertTrue("Temp folder should exist", createTempDir.exists());
        Table createTable = createTable(createTempDir);
        Dataset<Row> createTestDataset = createTestDataset();
        saveTestDatasetToTable(createTestDataset, createTable);
        List collectAsList = createTestDataset.select("id", new String[]{"date", "level", "message", "timestamp"}).filter(str).orderBy("id", new String[0]).collectAsList();
        Assert.assertFalse("Expected rows should be not empty", collectAsList.isEmpty());
        CountOpenLocalFileSystem.resetRecordsInPathPrefix(createTempDir.getAbsolutePath());
        List collectAsList2 = spark.read().format("iceberg").option("vectorization-enabled", String.valueOf(this.vectorized)).load(createTable.location()).select("id", new String[]{"date", "level", "message", "timestamp"}).filter(str).orderBy("id", new String[0]).collectAsList();
        Assert.assertFalse("Actual rows should not be empty", collectAsList2.isEmpty());
        Assert.assertEquals("Rows should match", collectAsList, collectAsList2);
        assertAccessOnDataFiles(createTempDir, createTable, predicate);
    }

    private File createTempDir() {
        try {
            return this.temp.newFolder();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private Table createTable(File file) {
        String convertPath = CountOpenLocalFileSystem.convertPath(file);
        return TABLES.create(LOG_SCHEMA, this.spec, ImmutableMap.of("write.format.default", this.format), convertPath);
    }

    private Dataset<Row> createTestDataset() {
        return spark.internalCreateDataFrame(JavaRDD.toRDD(sparkContext.parallelize((List) LOGS.stream().map(logMessage -> {
            return new GenericInternalRow(new Object[]{Integer.valueOf(logMessage.getId()), UTF8String.fromString(logMessage.getDate()), UTF8String.fromString(logMessage.getLevel()), UTF8String.fromString(logMessage.getMessage()), Long.valueOf(TimeUnit.MILLISECONDS.toMicros(logMessage.getTimestamp().toEpochMilli()))});
        }).collect(Collectors.toList()))), SparkSchemaUtil.convert(LOG_SCHEMA), false).selectExpr(new String[]{"id", "date", "level", "message", "timestamp"}).selectExpr(new String[]{"id", "date", "level", "message", "timestamp", "bucket3(id) AS bucket_id", "truncate5(message) AS truncated_message", "hour(timestamp) AS ts_hour"});
    }

    private void saveTestDatasetToTable(Dataset<Row> dataset, Table table) {
        dataset.orderBy("date", new String[]{"level", "bucket_id", "truncated_message", "ts_hour"}).select("id", new String[]{"date", "level", "message", "timestamp"}).write().format("iceberg").mode("append").save(table.location());
    }

    private void assertAccessOnDataFiles(File file, Table table, Predicate<Row> predicate) {
        Set set = (Set) CountOpenLocalFileSystem.pathToNumOpenCalled.keySet().stream().filter(str -> {
            return str.startsWith(file.getAbsolutePath());
        }).collect(Collectors.toSet());
        List<Row> collectAsList = spark.read().format("iceberg").load(table.location() + "#files").collectAsList();
        Set<String> extractFilePathsMatchingConditionOnPartition = extractFilePathsMatchingConditionOnPartition(collectAsList, predicate);
        Set<String> extractFilePathsNotIn = extractFilePathsNotIn(collectAsList, extractFilePathsMatchingConditionOnPartition);
        Assert.assertTrue(Sets.intersection(extractFilePathsMatchingConditionOnPartition, extractFilePathsNotIn).isEmpty());
        Assert.assertFalse("The query should prune some data files.", extractFilePathsNotIn.isEmpty());
        Assert.assertFalse("Some of data files in partition range should be read. Read files in query: " + set + " / data files in partition range: " + extractFilePathsMatchingConditionOnPartition, Sets.intersection(extractFilePathsMatchingConditionOnPartition, set).isEmpty());
        Assert.assertTrue("Data files outside of partition range should not be read. Read files in query: " + set + " / data files outside of partition range: " + extractFilePathsNotIn, Sets.intersection(extractFilePathsNotIn, set).isEmpty());
    }

    private Set<String> extractFilePathsMatchingConditionOnPartition(List<Row> list, Predicate<Row> predicate) {
        return (Set) list.stream().filter(row -> {
            return predicate.test(row.getStruct(4));
        }).map(row2 -> {
            return CountOpenLocalFileSystem.stripScheme(row2.getString(1));
        }).collect(Collectors.toSet());
    }

    private Set<String> extractFilePathsNotIn(List<Row> list, Set<String> set) {
        return Sets.newHashSet(Sets.symmetricDifference((Set) list.stream().map(row -> {
            return CountOpenLocalFileSystem.stripScheme(row.getString(1));
        }).collect(Collectors.toSet()), set));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 526434542:
                if (implMethodName.equals("lambda$startSpark$45ca9450$1")) {
                    z = true;
                    break;
                }
                break;
            case 526434543:
                if (implMethodName.equals("lambda$startSpark$45ca9450$2")) {
                    z = 2;
                    break;
                }
                break;
            case 526434544:
                if (implMethodName.equals("lambda$startSpark$45ca9450$3")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/sql/api/java/UDF1") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/iceberg/spark/source/TestPartitionPruning") && serializedLambda.getImplMethodSignature().equals("(Ljava/sql/Timestamp;)Ljava/lang/Object;")) {
                    return timestamp -> {
                        return (Integer) hourTransform.apply(Long.valueOf(DateTimeUtils.fromJavaTimestamp(timestamp)));
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/sql/api/java/UDF1") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/iceberg/spark/source/TestPartitionPruning") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Integer;)Ljava/lang/Object;")) {
                    return num -> {
                        return (Integer) bucketTransform.apply(num);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/sql/api/java/UDF1") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/iceberg/spark/source/TestPartitionPruning") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/String;)Ljava/lang/Object;")) {
                    return str -> {
                        return truncateTransform.apply(str);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
