package org.apache.beam.sdk.io.snowflake.test.unit.write;

import java.io.IOException;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.snowflake.SnowflakeIO;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeBasicDataSource;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeDatabase;
import org.apache.beam.sdk.io.snowflake.test.FakeSnowflakeStreamingServiceImpl;
import org.apache.beam.sdk.io.snowflake.test.TestSnowflakePipelineOptions;
import org.apache.beam.sdk.io.snowflake.test.TestUtils;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.TimestampedValue;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/snowflake/test/unit/write/StreamingWriteTest.class */
public class StreamingWriteTest {
    private static final String FAKE_TABLE = "TEST_TABLE";
    private static final String STAGING_BUCKET_NAME = "BUCKET/";
    private static final String STORAGE_INTEGRATION_NAME = "STORAGE_INTEGRATION";
    private static final String SNOW_PIPE = "Snowpipe";

    @Rule
    public final transient TestPipeline pipeline = TestPipeline.create();

    @Rule
    public ExpectedException exceptionRule = ExpectedException.none();
    private static SnowflakeIO.DataSourceConfiguration dataSourceConfiguration;
    private static SnowflakeService snowflakeService;
    private static TestSnowflakePipelineOptions options;
    private static List<Long> testData;
    private static final Instant START_TIME = new Instant(0);
    private static final List<String> SENTENCES = Arrays.asList("Snowflake window 1 1", "Snowflake window 1 2", "Snowflake window 1 3", "Snowflake window 1 4", "Snowflake window 2 1", "Snowflake window 2 2");
    private static final List<String> FIRST_WIN_WORDS = SENTENCES.subList(0, 4);
    private static final List<String> SECOND_WIN_WORDS = SENTENCES.subList(4, 6);
    private static final Duration WINDOW_DURATION = Duration.standardMinutes(1);

    @BeforeClass
    public static void setup() {
        snowflakeService = new FakeSnowflakeStreamingServiceImpl();
        PipelineOptionsFactory.register(TestSnowflakePipelineOptions.class);
        options = (TestSnowflakePipelineOptions) TestPipeline.testingPipelineOptions().as(TestSnowflakePipelineOptions.class);
        options.setUsername("username");
        options.setServerName("NULL.snowflakecomputing.com");
        testData = (List) LongStream.range(0L, 100L).boxed().collect(Collectors.toList());
        FakeSnowflakeDatabase.createTable(FAKE_TABLE);
        dataSourceConfiguration = SnowflakeIO.DataSourceConfiguration.create(new FakeSnowflakeBasicDataSource()).withServerName(options.getServerName()).withSchema("PUBLIC").withDatabase("DATABASE").withWarehouse("WAREHOUSE");
    }

    @After
    public void tearDown() {
        TestUtils.removeTempDir("BUCKET/");
    }

    @Test
    public void streamWriteWithOAuthFails() {
        dataSourceConfiguration = SnowflakeIO.DataSourceConfiguration.create().withOAuth("token").withServerName(options.getServerName()).withSchema("PUBLIC").withDatabase("DATABASE").withWarehouse("WAREHOUSE");
        this.exceptionRule.expectMessage("KeyPair is required for authentication");
        this.pipeline.apply(Create.of(testData)).apply(SnowflakeIO.write().withDataSourceConfiguration(dataSourceConfiguration).to(FAKE_TABLE).withStagingBucketName("BUCKET/").withStorageIntegrationName(STORAGE_INTEGRATION_NAME).withSnowPipe(SNOW_PIPE).withUserDataMapper(TestUtils.getLongCsvMapper()).withSnowflakeService(snowflakeService));
        this.pipeline.run(options);
    }

    @Test
    public void streamWriteWithUserPasswordFails() {
        dataSourceConfiguration = SnowflakeIO.DataSourceConfiguration.create().withUsernamePasswordAuth(options.getUsername(), "password").withServerName(options.getServerName()).withSchema("PUBLIC").withDatabase("DATABASE").withWarehouse("WAREHOUSE");
        this.exceptionRule.expectMessage("KeyPair is required for authentication");
        this.pipeline.apply(Create.of(testData)).apply(SnowflakeIO.write().withDataSourceConfiguration(dataSourceConfiguration).to(FAKE_TABLE).withStagingBucketName("BUCKET/").withStorageIntegrationName(STORAGE_INTEGRATION_NAME).withSnowPipe(SNOW_PIPE).withUserDataMapper(TestUtils.getLongCsvMapper()).withSnowflakeService(snowflakeService));
        this.pipeline.run(options);
    }

    @Test
    public void streamWriteWithKey() throws SnowflakeSQLException {
        TestStream advanceWatermarkToInfinity = TestStream.create(StringUtf8Coder.of()).advanceWatermarkTo(START_TIME).addElements(event(FIRST_WIN_WORDS.get(0), 2L), new TimestampedValue[0]).advanceWatermarkTo(START_TIME.plus(Duration.standardSeconds(27L))).addElements(event(FIRST_WIN_WORDS.get(1), 25L), new TimestampedValue[]{event(FIRST_WIN_WORDS.get(2), 18L), event(FIRST_WIN_WORDS.get(3), 26L)}).advanceWatermarkTo(START_TIME.plus(Duration.standardSeconds(65L))).addElements(event(SECOND_WIN_WORDS.get(0), 67L), new TimestampedValue[]{event(SECOND_WIN_WORDS.get(1), 68L)}).advanceWatermarkToInfinity();
        dataSourceConfiguration = SnowflakeIO.DataSourceConfiguration.create().withKeyPairPathAuth(options.getUsername(), TestUtils.getValidPrivateKeyPath(getClass()), TestUtils.getPrivateKeyPassphrase()).withServerName(options.getServerName()).withSchema("PUBLIC").withDatabase("DATABASE").withWarehouse("WAREHOUSE");
        this.pipeline.apply(advanceWatermarkToInfinity).apply(Window.into(FixedWindows.of(WINDOW_DURATION))).apply(SnowflakeIO.write().withDataSourceConfiguration(dataSourceConfiguration).withStagingBucketName("BUCKET/").withStorageIntegrationName(STORAGE_INTEGRATION_NAME).withSnowPipe(SNOW_PIPE).withFlushRowLimit(4).withFlushTimeLimit(WINDOW_DURATION).withUserDataMapper(TestUtils.getStringCsvMapper()).withSnowflakeService(snowflakeService));
        this.pipeline.run(options).waitUntilFinish();
        List<String> parseResults = parseResults(FakeSnowflakeDatabase.getElements(String.format(FAKE_TABLE, new Object[0])), "'");
        Map<String, List<String>> mapOfFilesAndResults = getMapOfFilesAndResults();
        List<String> parseResults2 = parseResults(mapOfFilesAndResults.get("0"), "'");
        MatcherAssert.assertThat(Integer.valueOf(mapOfFilesAndResults.size()), CoreMatchers.equalTo(2));
        MatcherAssert.assertThat(parseResults2, CoreMatchers.equalTo(FIRST_WIN_WORDS));
        MatcherAssert.assertThat(parseResults, CoreMatchers.equalTo(SENTENCES));
    }

    @Test
    public void streamWriteWithDoubleQuotation() throws SnowflakeSQLException {
        TestStream advanceWatermarkToInfinity = TestStream.create(StringUtf8Coder.of()).advanceWatermarkTo(START_TIME).addElements(event(FIRST_WIN_WORDS.get(0), 2L), new TimestampedValue[0]).advanceWatermarkTo(START_TIME.plus(Duration.standardSeconds(27L))).addElements(event(FIRST_WIN_WORDS.get(1), 25L), new TimestampedValue[]{event(FIRST_WIN_WORDS.get(2), 18L), event(FIRST_WIN_WORDS.get(3), 26L)}).advanceWatermarkTo(START_TIME.plus(Duration.standardSeconds(65L))).addElements(event(SECOND_WIN_WORDS.get(0), 67L), new TimestampedValue[]{event(SECOND_WIN_WORDS.get(1), 68L)}).advanceWatermarkToInfinity();
        dataSourceConfiguration = SnowflakeIO.DataSourceConfiguration.create().withKeyPairPathAuth(options.getUsername(), TestUtils.getValidPrivateKeyPath(getClass()), TestUtils.getPrivateKeyPassphrase()).withServerName(options.getServerName()).withSchema("PUBLIC").withDatabase("DATABASE").withWarehouse("WAREHOUSE");
        this.pipeline.apply(advanceWatermarkToInfinity).apply(Window.into(FixedWindows.of(WINDOW_DURATION))).apply(SnowflakeIO.write().withDataSourceConfiguration(dataSourceConfiguration).withStagingBucketName("BUCKET/").withStorageIntegrationName(STORAGE_INTEGRATION_NAME).withSnowPipe(SNOW_PIPE).withFlushRowLimit(4).withQuotationMark("\"").withFlushTimeLimit(WINDOW_DURATION).withUserDataMapper(TestUtils.getStringCsvMapper()).withSnowflakeService(snowflakeService));
        this.pipeline.run(options).waitUntilFinish();
        List<String> parseResults = parseResults(FakeSnowflakeDatabase.getElements(String.format(FAKE_TABLE, new Object[0])), "\"");
        Map<String, List<String>> mapOfFilesAndResults = getMapOfFilesAndResults();
        List<String> parseResults2 = parseResults(mapOfFilesAndResults.get("0"), "\"");
        MatcherAssert.assertThat(Integer.valueOf(mapOfFilesAndResults.size()), CoreMatchers.equalTo(2));
        MatcherAssert.assertThat(parseResults2, CoreMatchers.equalTo(FIRST_WIN_WORDS));
        MatcherAssert.assertThat(parseResults, CoreMatchers.equalTo(SENTENCES));
    }

    private List<String> parseResults(List<String> list, String str) {
        return (List) list.stream().map(str2 -> {
            return str2.replaceAll(str, "");
        }).collect(Collectors.toList());
    }

    private Map<String, List<String>> getMapOfFilesAndResults() {
        return new HashMap(getFiles(Paths.get("BUCKET/", new String[0])));
    }

    private Map<String, List<String>> getFiles(Path path) {
        HashMap hashMap = new HashMap();
        try {
            DirectoryStream<Path> newDirectoryStream = Files.newDirectoryStream(path, "*.gz");
            Throwable th = null;
            try {
                try {
                    newDirectoryStream.forEach(path2 -> {
                        hashMap.put(path2.getFileName().toString().split("-", -1)[1], TestUtils.readGZIPFile(path2.toString()));
                    });
                    if (newDirectoryStream != null) {
                        if (0 != 0) {
                            try {
                                newDirectoryStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            newDirectoryStream.close();
                        }
                    }
                    return hashMap;
                } finally {
                }
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to retrieve files", e);
        }
    }

    private TimestampedValue<String> event(String str, Long l) {
        return TimestampedValue.of(str, START_TIME.plus(new Duration(l)));
    }
}
