package org.apache.iceberg.spark.source;

import java.io.File;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.avro.generic.GenericData;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.Files;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.avro.Avro;
import org.apache.iceberg.avro.AvroIterable;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.data.AvroDataTest;
import org.apache.iceberg.spark.data.RandomData;
import org.apache.iceberg.spark.data.SparkAvroReader;
import org.apache.iceberg.spark.data.TestHelpers;
import org.apache.iceberg.types.Types;
import org.apache.spark.SparkException;
import org.apache.spark.TaskContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/iceberg/spark/source/TestDataFrameWrites.class */
public class TestDataFrameWrites extends AvroDataTest {
    private final String format;
    private Map<String, String> tableProperties;
    private StructType sparkSchema = new StructType(new StructField[]{new StructField("optionalField", DataTypes.StringType, true, Metadata.empty()), new StructField("requiredField", DataTypes.StringType, false, Metadata.empty())});
    private Schema icebergSchema = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "optionalField", Types.StringType.get()), Types.NestedField.required(2, "requiredField", Types.StringType.get())});
    private List<String> data0 = Arrays.asList("{\"optionalField\": \"a1\", \"requiredField\": \"bid_001\"}", "{\"optionalField\": \"a2\", \"requiredField\": \"bid_002\"}");
    private List<String> data1 = Arrays.asList("{\"optionalField\": \"d1\", \"requiredField\": \"bid_101\"}", "{\"optionalField\": \"d2\", \"requiredField\": \"bid_102\"}", "{\"optionalField\": \"d3\", \"requiredField\": \"bid_103\"}", "{\"optionalField\": \"d4\", \"requiredField\": \"bid_104\"}");
    private static final Configuration CONF = new Configuration();
    private static SparkSession spark = null;
    private static JavaSparkContext sc = null;

    @Parameterized.Parameters(name = "format = {0}")
    public static Object[] parameters() {
        return new Object[]{"parquet", "avro", "orc"};
    }

    public TestDataFrameWrites(String str) {
        this.format = str;
    }

    @BeforeClass
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").getOrCreate();
        sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
    }

    @AfterClass
    public static void stopSpark() {
        SparkSession sparkSession = spark;
        spark = null;
        sc = null;
        sparkSession.stop();
    }

    @Override // org.apache.iceberg.spark.data.AvroDataTest
    protected void writeAndValidate(Schema schema) throws IOException {
        File createTableFolder = createTableFolder();
        writeAndValidateWithLocations(createTable(schema, createTableFolder), createTableFolder, new File(createTableFolder, "data"));
    }

    @Test
    public void testWriteWithCustomDataLocation() throws IOException {
        File createTableFolder = createTableFolder();
        File newFolder = this.temp.newFolder("test-table-property-data-dir");
        Table createTable = createTable(new Schema(SUPPORTED_PRIMITIVES.fields()), createTableFolder);
        createTable.updateProperties().set("write.data.path", newFolder.getAbsolutePath()).commit();
        writeAndValidateWithLocations(createTable, createTableFolder, newFolder);
    }

    private File createTableFolder() throws IOException {
        File file = new File(this.temp.newFolder("parquet"), "test");
        Assert.assertTrue("Mkdir should succeed", file.mkdirs());
        return file;
    }

    private Table createTable(Schema schema, File file) {
        return new HadoopTables(CONF).create(schema, PartitionSpec.unpartitioned(), file.toString());
    }

    private void writeAndValidateWithLocations(Table table, File file, File file2) throws IOException {
        Schema schema = table.schema();
        table.updateProperties().set("write.format.default", this.format).commit();
        Iterable<GenericData.Record> generate = RandomData.generate(schema, 100, 0L);
        writeData(generate, schema, file.toString());
        table.refresh();
        List<Row> readTable = readTable(file.toString());
        Iterator<GenericData.Record> it = generate.iterator();
        Iterator<Row> it2 = readTable.iterator();
        while (it.hasNext() && it2.hasNext()) {
            TestHelpers.assertEqualsSafe(schema.asStruct(), it.next(), it2.next());
        }
        Assert.assertEquals("Both iterators should be exhausted", Boolean.valueOf(it.hasNext()), Boolean.valueOf(it2.hasNext()));
        table.currentSnapshot().addedFiles().forEach(dataFile -> {
            Assert.assertTrue(String.format("File should have the parent directory %s, but has: %s.", file2.getAbsolutePath(), dataFile.path()), URI.create(dataFile.path().toString()).getPath().startsWith(file2.getAbsolutePath()));
        });
    }

    private List<Row> readTable(String str) {
        return spark.read().format("iceberg").load(str).collectAsList();
    }

    private void writeData(Iterable<GenericData.Record> iterable, Schema schema, String str) throws IOException {
        createDataset(iterable, schema).write().format("iceberg").mode("append").save(str);
    }

    private void writeDataWithFailOnPartition(Iterable<GenericData.Record> iterable, Schema schema, String str) throws IOException, SparkException {
        int nextInt = new Random().nextInt(10);
        Dataset mapPartitions = createDataset(iterable, schema).repartition(10).mapPartitions(it -> {
            int partitionId = TaskContext.getPartitionId();
            if (partitionId == nextInt) {
                throw new SparkException(String.format("Intended exception in partition %d !", Integer.valueOf(partitionId)));
            }
            return it;
        }, RowEncoder.apply(SparkSchemaUtil.convert(schema)));
        mapPartitions.sqlContext().createDataFrame(mapPartitions.rdd(), SparkSchemaUtil.convert(schema)).write().format("iceberg").mode("append").save(str);
    }

    private Dataset<Row> createDataset(Iterable<GenericData.Record> iterable, Schema schema) throws IOException {
        ArrayList newArrayList;
        AvroIterable build;
        Throwable th;
        File newFile = this.temp.newFile();
        Assert.assertTrue("Delete should succeed", newFile.delete());
        FileAppender build2 = Avro.write(Files.localOutput(newFile)).schema(schema).named("test").build();
        Throwable th2 = null;
        try {
            try {
                Iterator<GenericData.Record> it = iterable.iterator();
                while (it.hasNext()) {
                    build2.add(it.next());
                }
                if (build2 != null) {
                    if (0 != 0) {
                        try {
                            build2.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        build2.close();
                    }
                }
                newArrayList = Lists.newArrayList();
                build = Avro.read(Files.localInput(newFile)).createReaderFunc(SparkAvroReader::new).project(schema).build();
                th = null;
            } finally {
            }
            try {
                try {
                    Iterator<GenericData.Record> it2 = iterable.iterator();
                    CloseableIterator it3 = build.iterator();
                    while (it2.hasNext() && it3.hasNext()) {
                        InternalRow internalRow = (InternalRow) it3.next();
                        TestHelpers.assertEqualsUnsafe(schema.asStruct(), it2.next(), internalRow);
                        newArrayList.add(internalRow);
                    }
                    Assert.assertEquals("Both iterators should be exhausted", Boolean.valueOf(it2.hasNext()), Boolean.valueOf(it3.hasNext()));
                    if (build != null) {
                        if (0 != 0) {
                            try {
                                build.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            build.close();
                        }
                    }
                    return spark.internalCreateDataFrame(JavaRDD.toRDD(sc.parallelize(newArrayList)), SparkSchemaUtil.convert(schema), false);
                } finally {
                }
            } catch (Throwable th5) {
                if (build != null) {
                    if (th != null) {
                        try {
                            build.close();
                        } catch (Throwable th6) {
                            th.addSuppressed(th6);
                        }
                    } else {
                        build.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (build2 != null) {
                if (th2 != null) {
                    try {
                        build2.close();
                    } catch (Throwable th8) {
                        th2.addSuppressed(th8);
                    }
                } else {
                    build2.close();
                }
            }
            throw th7;
        }
    }

    @Test
    public void testNullableWithWriteOption() throws IOException {
        Assume.assumeTrue("Spark 3.0 rejects writing nulls to a required column", spark.version().startsWith("2"));
        File file = new File(this.temp.newFolder("parquet"), "test");
        String format = String.format("%s/nullable_poc/sourceFolder/", file.toString());
        String format2 = String.format("%s/nullable_poc/targetFolder/", file.toString());
        this.tableProperties = ImmutableMap.of("write.data.path", format2);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(this.data1)).write().parquet(format);
        new HadoopTables(spark.sessionState().newHadoopConf()).create(this.icebergSchema, PartitionSpec.builderFor(this.icebergSchema).identity("requiredField").build(), this.tableProperties, format2);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(this.data0)).write().format("iceberg").mode(SaveMode.Append).save(format2);
        spark.read().schema(SparkSchemaUtil.convert(this.icebergSchema)).parquet(format).write().format("iceberg").option("check-nullability", false).mode(SaveMode.Append).save(format2);
        Assert.assertEquals("Should contain 6 rows", 6L, spark.read().format("iceberg").load(format2).collectAsList().size());
    }

    @Test
    public void testNullableWithSparkSqlOption() throws IOException {
        Assume.assumeTrue("Spark 3.0 rejects writing nulls to a required column", spark.version().startsWith("2"));
        File file = new File(this.temp.newFolder("parquet"), "test");
        String format = String.format("%s/nullable_poc/sourceFolder/", file.toString());
        String format2 = String.format("%s/nullable_poc/targetFolder/", file.toString());
        this.tableProperties = ImmutableMap.of("write.data.path", format2);
        spark.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(this.data1)).write().parquet(format);
        SparkSession orCreate = SparkSession.builder().master("local[2]").appName("NullableTest").config("spark.sql.iceberg.check-nullability", false).getOrCreate();
        new HadoopTables(orCreate.sessionState().newHadoopConf()).create(this.icebergSchema, PartitionSpec.builderFor(this.icebergSchema).identity("requiredField").build(), this.tableProperties, format2);
        orCreate.read().schema(this.sparkSchema).json(JavaSparkContext.fromSparkContext(spark.sparkContext()).parallelize(this.data0)).write().format("iceberg").mode(SaveMode.Append).save(format2);
        orCreate.read().schema(SparkSchemaUtil.convert(this.icebergSchema)).parquet(format).write().format("iceberg").mode(SaveMode.Append).save(format2);
        Assert.assertEquals("Should contain 6 rows", 6L, orCreate.read().format("iceberg").load(format2).collectAsList().size());
    }

    @Test
    public void testFaultToleranceOnWrite() throws IOException {
        File createTableFolder = createTableFolder();
        Schema schema = new Schema(SUPPORTED_PRIMITIVES.fields());
        Table createTable = createTable(schema, createTableFolder);
        writeData(RandomData.generate(schema, 100, 0L), schema, createTableFolder.toString());
        createTable.refresh();
        Snapshot currentSnapshot = createTable.currentSnapshot();
        List<Row> readTable = readTable(createTableFolder.toString());
        try {
            writeDataWithFailOnPartition(RandomData.generate(schema, 100, 0L), schema, createTableFolder.toString());
            Assert.fail("The query should fail");
        } catch (SparkException e) {
        }
        createTable.refresh();
        Snapshot currentSnapshot2 = createTable.currentSnapshot();
        List<Row> readTable2 = readTable(createTableFolder.toString());
        Assert.assertEquals(currentSnapshot2, currentSnapshot);
        Assert.assertEquals(readTable2, readTable);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1688429725:
                if (implMethodName.equals("lambda$writeDataWithFailOnPartition$d875144$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/spark/api/java/function/MapPartitionsFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("call") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/util/Iterator;)Ljava/util/Iterator;") && serializedLambda.getImplClass().equals("org/apache/iceberg/spark/source/TestDataFrameWrites") && serializedLambda.getImplMethodSignature().equals("(ILjava/util/Iterator;)Ljava/util/Iterator;")) {
                    int intValue = ((Integer) serializedLambda.getCapturedArg(0)).intValue();
                    return it -> {
                        int partitionId = TaskContext.getPartitionId();
                        if (partitionId == intValue) {
                            throw new SparkException(String.format("Intended exception in partition %d !", Integer.valueOf(partitionId)));
                        }
                        return it;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
