/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io.aws2.sns;

import java.io.IOException;
import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.StreamSupport;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.aws2.sns.MockSnsAsyncClient;
import org.apache.beam.sdk.io.aws2.sns.MockSnsAsyncExceptionClient;
import org.apache.beam.sdk.io.aws2.sns.SnsAsyncClientProvider;
import org.apache.beam.sdk.io.aws2.sns.SnsIO;
import org.apache.beam.sdk.io.aws2.sns.SnsResponse;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import software.amazon.awssdk.services.sns.model.PublishRequest;

@RunWith(value=JUnit4.class)
public class SnsIOWriteTest
implements Serializable {
    private static final String TOPIC = "test";
    private static final int FAILURE_STATUS_CODE = 400;
    private static final int SUCCESS_STATUS_CODE = 200;
    @Rule
    public transient TestPipeline pipeline = TestPipeline.create();

    @Test
    public void shouldReturnResponseOnPublishSuccess() {
        String testMessage1 = "test1";
        String testMessage2 = "test2";
        String testMessage3 = "test3";
        PCollection result = (PCollection)((PCollection)this.pipeline.apply((PTransform)Create.of((Object)testMessage1, (Object[])new String[]{testMessage2, testMessage3}).withCoder((Coder)StringUtf8Coder.of()))).apply((PTransform)SnsIO.writeAsync().withCoder((Coder)StringUtf8Coder.of()).withPublishRequestFn(this.createPublishRequestFn()).withSnsClientProvider((SnsAsyncClientProvider & Serializable)() -> MockSnsAsyncClient.withStatusCode(200)));
        PAssert.that((PCollection)result).satisfies((SerializableFunction & Serializable)responses -> {
            ImmutableSet messagesInResponse = (ImmutableSet)StreamSupport.stream(responses.spliterator(), false).filter(response -> response.statusCode().getAsInt() == 200).map(SnsResponse::element).collect(ImmutableSet.toImmutableSet());
            HashSet originalMessages = Sets.newHashSet((Object[])new String[]{testMessage1, testMessage2, testMessage3});
            Sets.SetView difference = Sets.difference((Set)messagesInResponse, (Set)originalMessages);
            Assert.assertEquals((long)3L, (long)messagesInResponse.size());
            Assert.assertEquals((long)0L, (long)difference.size());
            return null;
        });
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void shouldReturnResponseOnPublishFailure() {
        String testMessage1 = "test1";
        String testMessage2 = "test2";
        PCollection result = (PCollection)((PCollection)this.pipeline.apply((PTransform)Create.of((Object)testMessage1, (Object[])new String[]{testMessage2}).withCoder((Coder)StringUtf8Coder.of()))).apply((PTransform)SnsIO.writeAsync().withCoder((Coder)StringUtf8Coder.of()).withPublishRequestFn(this.createPublishRequestFn()).withSnsClientProvider((SnsAsyncClientProvider & Serializable)() -> MockSnsAsyncClient.withStatusCode(400)));
        PAssert.that((PCollection)result).satisfies((SerializableFunction & Serializable)responses -> {
            ImmutableSet messagesInResponse = (ImmutableSet)StreamSupport.stream(responses.spliterator(), false).filter(response -> response.statusCode().getAsInt() != 200).map(SnsResponse::element).collect(ImmutableSet.toImmutableSet());
            HashSet originalMessages = Sets.newHashSet((Object[])new String[]{testMessage1, testMessage2});
            Sets.SetView difference = Sets.difference((Set)messagesInResponse, (Set)originalMessages);
            Assert.assertEquals((long)2L, (long)messagesInResponse.size());
            Assert.assertEquals((long)0L, (long)difference.size());
            return null;
        });
        this.pipeline.run().waitUntilFinish();
    }

    @Test
    public void shouldThrowIfThrowErrorOptionSet() {
        String testMessage1 = "test1";
        ((PCollection)this.pipeline.apply((PTransform)Create.of((Object)testMessage1, (Object[])new String[0]))).apply((PTransform)SnsIO.writeAsync().withCoder((Coder)StringUtf8Coder.of()).withPublishRequestFn(this.createPublishRequestFn()).withSnsClientProvider((SnsAsyncClientProvider & Serializable)() -> MockSnsAsyncClient.withStatusCode(400)));
        try {
            this.pipeline.run().waitUntilFinish();
        }
        catch (Pipeline.PipelineExecutionException e) {
            Assert.assertThrows(IOException.class, () -> e.getCause().getClass());
        }
    }

    @Test
    public void shouldThrowIfThrowErrorOptionSetOnInternalException() {
        String testMessage1 = "test1";
        ((PCollection)this.pipeline.apply((PTransform)Create.of((Object)testMessage1, (Object[])new String[0]))).apply((PTransform)SnsIO.writeAsync().withCoder((Coder)StringUtf8Coder.of()).withPublishRequestFn(this.createPublishRequestFn()).withSnsClientProvider(MockSnsAsyncExceptionClient::create));
        try {
            this.pipeline.run().waitUntilFinish();
        }
        catch (Pipeline.PipelineExecutionException e) {
            Assert.assertThrows(IOException.class, () -> e.getCause().getClass());
        }
    }

    private SerializableFunction<String, PublishRequest> createPublishRequestFn() {
        return (SerializableFunction & Serializable)input -> (PublishRequest)PublishRequest.builder().topicArn(TOPIC).message(input).build();
    }
}

