package org.apache.iceberg.spark.source;

import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.execution.streaming.MemoryStream;
import org.apache.spark.sql.streaming.DataStreamWriter;
import org.apache.spark.sql.streaming.StreamingQuery;
import org.apache.spark.sql.streaming.StreamingQueryException;
import org.assertj.core.api.Assertions;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import scala.Option;
import scala.collection.JavaConverters;

/* loaded from: input_file:org/apache/iceberg/spark/source/TestStructuredStreaming.class */
public class TestStructuredStreaming {
    private static final Configuration CONF = new Configuration();
    private static final Schema SCHEMA = new Schema(new Types.NestedField[]{Types.NestedField.optional(1, "id", Types.IntegerType.get()), Types.NestedField.optional(2, "data", Types.StringType.get())});
    private static SparkSession spark = null;

    @Rule
    public TemporaryFolder temp = new TemporaryFolder();

    @BeforeClass
    public static void startSpark() {
        spark = SparkSession.builder().master("local[2]").config("spark.sql.shuffle.partitions", 4L).getOrCreate();
    }

    @AfterClass
    public static void stopSpark() {
        SparkSession sparkSession = spark;
        spark = null;
        sparkSession.stop();
    }

    @Test
    public void testStreamingWriteAppendMode() throws Exception {
        File newFolder = this.temp.newFolder("parquet");
        File file = new File(newFolder, "test-table");
        File file2 = new File(newFolder, "checkpoint");
        Table create = new HadoopTables(CONF).create(SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("data").build(), file.toString());
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, "1"), new SimpleRecord(2, "2"), new SimpleRecord(3, "3"), new SimpleRecord(4, "4")});
        MemoryStream newMemoryStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT());
        DataStreamWriter option = newMemoryStream.toDF().selectExpr(new String[]{"value AS id", "CAST (value AS STRING) AS data"}).writeStream().outputMode("append").format("iceberg").option("checkpointLocation", file2.toString()).option("path", file.toString());
        try {
            StreamingQuery start = option.start();
            send(Lists.newArrayList(new Integer[]{1, 2}), newMemoryStream);
            start.processAllAvailable();
            send(Lists.newArrayList(new Integer[]{3, 4}), newMemoryStream);
            start.processAllAvailable();
            start.stop();
            Assert.assertTrue("The commit file must be deleted", new File(file2 + "/commits/1").delete());
            option.start().processAllAvailable();
            List collectAsList = spark.read().format("iceberg").load(file.toString()).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList();
            Assert.assertEquals("Number of rows should match", newArrayList.size(), collectAsList.size());
            Assert.assertEquals("Result rows should match", newArrayList, collectAsList);
            Assert.assertEquals("Number of snapshots should match", 2L, Iterables.size(create.snapshots()));
            for (StreamingQuery streamingQuery : spark.streams().active()) {
                streamingQuery.stop();
            }
        } catch (Throwable th) {
            for (StreamingQuery streamingQuery2 : spark.streams().active()) {
                streamingQuery2.stop();
            }
            throw th;
        }
    }

    @Test
    public void testStreamingWriteCompleteMode() throws Exception {
        File newFolder = this.temp.newFolder("parquet");
        File file = new File(newFolder, "test-table");
        File file2 = new File(newFolder, "checkpoint");
        Table create = new HadoopTables(CONF).create(SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("data").build(), file.toString());
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(2, "1"), new SimpleRecord(3, "2"), new SimpleRecord(1, "3")});
        MemoryStream newMemoryStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT());
        DataStreamWriter option = newMemoryStream.toDF().groupBy("value", new String[0]).count().selectExpr(new String[]{"CAST(count AS INT) AS id", "CAST (value AS STRING) AS data"}).writeStream().outputMode("complete").format("iceberg").option("checkpointLocation", file2.toString()).option("path", file.toString());
        try {
            StreamingQuery start = option.start();
            send(Lists.newArrayList(new Integer[]{1, 2}), newMemoryStream);
            start.processAllAvailable();
            send(Lists.newArrayList(new Integer[]{1, 2, 2, 3}), newMemoryStream);
            start.processAllAvailable();
            start.stop();
            Assert.assertTrue("The commit file must be deleted", new File(file2 + "/commits/1").delete());
            option.start().processAllAvailable();
            List collectAsList = spark.read().format("iceberg").load(file.toString()).orderBy("data", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList();
            Assert.assertEquals("Number of rows should match", newArrayList.size(), collectAsList.size());
            Assert.assertEquals("Result rows should match", newArrayList, collectAsList);
            Assert.assertEquals("Number of snapshots should match", 2L, Iterables.size(create.snapshots()));
            for (StreamingQuery streamingQuery : spark.streams().active()) {
                streamingQuery.stop();
            }
        } catch (Throwable th) {
            for (StreamingQuery streamingQuery2 : spark.streams().active()) {
                streamingQuery2.stop();
            }
            throw th;
        }
    }

    @Test
    public void testStreamingWriteCompleteModeWithProjection() throws Exception {
        File newFolder = this.temp.newFolder("parquet");
        File file = new File(newFolder, "test-table");
        File file2 = new File(newFolder, "checkpoint");
        Table create = new HadoopTables(CONF).create(SCHEMA, PartitionSpec.unpartitioned(), file.toString());
        ArrayList newArrayList = Lists.newArrayList(new SimpleRecord[]{new SimpleRecord(1, null), new SimpleRecord(2, null), new SimpleRecord(3, null)});
        MemoryStream newMemoryStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT());
        DataStreamWriter option = newMemoryStream.toDF().groupBy("value", new String[0]).count().selectExpr(new String[]{"CAST(count AS INT) AS id"}).writeStream().outputMode("complete").format("iceberg").option("checkpointLocation", file2.toString()).option("path", file.toString());
        try {
            StreamingQuery start = option.start();
            send(Lists.newArrayList(new Integer[]{1, 2}), newMemoryStream);
            start.processAllAvailable();
            send(Lists.newArrayList(new Integer[]{1, 2, 2, 3}), newMemoryStream);
            start.processAllAvailable();
            start.stop();
            Assert.assertTrue("The commit file must be deleted", new File(file2 + "/commits/1").delete());
            option.start().processAllAvailable();
            List collectAsList = spark.read().format("iceberg").load(file.toString()).orderBy("id", new String[0]).as(Encoders.bean(SimpleRecord.class)).collectAsList();
            Assert.assertEquals("Number of rows should match", newArrayList.size(), collectAsList.size());
            Assert.assertEquals("Result rows should match", newArrayList, collectAsList);
            Assert.assertEquals("Number of snapshots should match", 2L, Iterables.size(create.snapshots()));
            for (StreamingQuery streamingQuery : spark.streams().active()) {
                streamingQuery.stop();
            }
        } catch (Throwable th) {
            for (StreamingQuery streamingQuery2 : spark.streams().active()) {
                streamingQuery2.stop();
            }
            throw th;
        }
    }

    @Test
    public void testStreamingWriteUpdateMode() throws Exception {
        File newFolder = this.temp.newFolder("parquet");
        File file = new File(newFolder, "test-table");
        File file2 = new File(newFolder, "checkpoint");
        new HadoopTables(CONF).create(SCHEMA, PartitionSpec.builderFor(SCHEMA).identity("data").build(), file.toString());
        MemoryStream newMemoryStream = newMemoryStream(1, spark.sqlContext(), Encoders.INT());
        try {
            StreamingQuery start = newMemoryStream.toDF().selectExpr(new String[]{"value AS id", "CAST (value AS STRING) AS data"}).writeStream().outputMode("update").format("iceberg").option("checkpointLocation", file2.toString()).option("path", file.toString()).start();
            send(Lists.newArrayList(new Integer[]{1, 2}), newMemoryStream);
            Objects.requireNonNull(start);
            Assertions.assertThatThrownBy(start::processAllAvailable).isInstanceOf(StreamingQueryException.class).hasMessageContaining("does not support Update mode");
            for (StreamingQuery streamingQuery : spark.streams().active()) {
                streamingQuery.stop();
            }
        } catch (Throwable th) {
            for (StreamingQuery streamingQuery2 : spark.streams().active()) {
                streamingQuery2.stop();
            }
            throw th;
        }
    }

    private <T> MemoryStream<T> newMemoryStream(int i, SQLContext sQLContext, Encoder<T> encoder) {
        return new MemoryStream<>(i, sQLContext, Option.empty(), encoder);
    }

    private <T> void send(List<T> list, MemoryStream<T> memoryStream) {
        memoryStream.addData(JavaConverters.asScalaBuffer(list));
    }
}
