/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io.aws2.sqs;

import java.io.Serializable;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
import org.apache.beam.sdk.io.aws2.sqs.SqsIO;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.apache.commons.lang3.RandomUtils;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.joda.time.Duration;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.verification.VerificationMode;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.SqsAsyncClientBuilder;
import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.SendMessageRequest;

@RunWith(value=MockitoJUnitRunner.class)
public class SqsIOWriteBatchesTest {
    private static final SqsIO.WriteBatches.EntryMapperFn.Builder<String> SET_MESSAGE_BODY = SendMessageBatchRequestEntry.Builder::messageBody;
    private static final SendMessageBatchResponse SUCCESS = (SendMessageBatchResponse)SendMessageBatchResponse.builder().build();
    @Rule
    public TestPipeline p = TestPipeline.create();
    @Mock
    public SqsAsyncClient sqs;
    @Rule
    public ExpectedLogs logs = ExpectedLogs.none(AsyncBatchWriteHandler.class);

    @Before
    public void configureClientBuilderFactory() {
        MockClientBuilderFactory.set(this.p, SqsAsyncClientBuilder.class, this.sqs);
    }

    @Test
    public void testSchemaEntryMapper() throws Exception {
        SchemaRegistry registry = this.p.getSchemaRegistry();
        ImmutableMap attributes = ImmutableMap.of((Object)"key", (Object)((MessageAttributeValue)MessageAttributeValue.builder().stringValue("value").build()));
        ImmutableMap systemAttributes = ImmutableMap.of((Object)"key", (Object)((MessageSystemAttributeValue)MessageSystemAttributeValue.builder().binaryValue(SdkBytes.fromString((String)"bytes", (Charset)StandardCharsets.UTF_8)).build()));
        SendMessageRequest input = (SendMessageRequest)SendMessageRequest.builder().messageBody("body").delaySeconds(Integer.valueOf(3)).messageAttributes((Map)attributes).messageSystemAttributesWithStrings((Map)systemAttributes).build();
        SqsIO.WriteBatches.SchemaEntryMapper mapper = new SqsIO.WriteBatches.SchemaEntryMapper(registry.getSchema(SendMessageRequest.class), registry.getSchema(SendMessageBatchRequestEntry.class), registry.getToRowFunction(SendMessageRequest.class), registry.getFromRowFunction(SendMessageBatchRequestEntry.class));
        Assertions.assertThat((Object)((SendMessageBatchRequestEntry)mapper.apply((Object)"1", (Object)input))).isEqualTo(SendMessageBatchRequestEntry.builder().id("1").messageBody("body").delaySeconds(Integer.valueOf(3)).messageAttributes((Map)attributes).messageSystemAttributesWithStrings((Map)systemAttributes).build());
    }

    @Test
    public void testWrite() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        SendMessageRequest.Builder msgBuilder = SendMessageRequest.builder().queueUrl("queue");
        Set messages = IntStream.range(0, 100).mapToObj(i -> (SendMessageRequest)msgBuilder.messageBody("test" + i).build()).collect(Collectors.toSet());
        ((PCollection)this.p.apply((PTransform)Create.of(messages))).apply((PTransform)SqsIO.write());
        this.p.run().waitUntilFinish();
        ArgumentCaptor captor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs, (VerificationMode)Mockito.times((int)100))).sendMessageBatch((SendMessageBatchRequest)captor.capture());
        for (SendMessageBatchRequest req : captor.getAllValues()) {
            Assertions.assertThat((String)req.queueUrl()).isEqualTo((Object)"queue");
            Assertions.assertThat((List)req.entries()).hasSize(1);
            for (SendMessageBatchRequestEntry entry : req.entries()) {
                Assert.assertTrue((boolean)messages.remove(msgBuilder.messageBody(entry.messageBody()).build()));
            }
        }
        Assert.assertTrue((boolean)messages.isEmpty());
    }

    @Test
    public void testWriteBatches() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)23, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
        this.p.run().waitUntilFinish();
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(0, 10)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(10, 20)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(20, 23)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).close();
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{this.sqs});
    }

    @Test
    public void testWriteBatchesFailure() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS), (Object[])new CompletableFuture[]{CompletableFuture.supplyAsync(() -> (SendMessageBatchResponse)Preconditions.checkNotNull(null, (Object)"sendMessageBatch failed")), CompletableFuture.completedFuture(SUCCESS)});
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)23, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> this.p.run().waitUntilFinish()).isInstanceOf(Pipeline.PipelineExecutionException.class)).hasMessageContaining("sendMessageBatch failed");
    }

    @Test
    public void testWriteBatchesPartialSuccess() {
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 10));
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(this.partialSuccessResponse(entries[2].id(), entries[3].id())), (Object[])new CompletableFuture[]{CompletableFuture.completedFuture(this.partialSuccessResponse(entries[3].id())), CompletableFuture.completedFuture(SUCCESS)});
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)23, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
        this.p.run().waitUntilFinish();
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[2], entries[3]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[3]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(10, 20)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(20, 23)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).close();
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{this.sqs});
        this.logs.verifyInfo("retry after partial failure: code REASON for 2 record(s)");
        this.logs.verifyInfo("retry after partial failure: code REASON for 1 record(s)");
    }

    @Test
    public void testWriteCustomBatches() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)8, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).withBatchSize(3).to("queue"));
        this.p.run().waitUntilFinish();
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(0, 3)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(3, 6)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", IntStream.range(6, 8)));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).close();
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{this.sqs});
    }

    @Test
    public void testWriteBatchesWithTimeout() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)5, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SqsIOWriteBatchesTest.withDelay(Duration.millis((long)100L), SET_MESSAGE_BODY)).withBatchTimeout(Duration.millis((long)150L)).to("queue"));
        this.p.run().waitUntilFinish();
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 5));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[0], entries[1], entries[2]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[3], entries[4]));
    }

    @Test
    public void testWriteBatchesWithStrictTimeout() {
        Mockito.when((Object)this.sqs.sendMessageBatch((SendMessageBatchRequest)ArgumentMatchers.any(SendMessageBatchRequest.class))).thenReturn(CompletableFuture.completedFuture((SendMessageBatchResponse)SendMessageBatchResponse.builder().build()));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)5, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SqsIOWriteBatchesTest.withDelay(Duration.millis((long)100L), SET_MESSAGE_BODY)).withBatchTimeout(Duration.millis((long)150L), true).to("queue"));
        this.p.run().waitUntilFinish();
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 5));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[0], entries[1]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[2], entries[3]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("queue", entries[4]));
    }

    @Test
    public void testWriteBatchesToDynamic() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(Duration.millis((long)1L)).build();
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)10, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).withClientConfiguration(ClientConfiguration.builder().retry(retry).build()).withBatchSize(3).to((SqsIO.WriteBatches.DynamicDestination & Serializable)msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
        this.p.run().waitUntilFinish();
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 9), IntStream.range(9, 10));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[0], entries[2], entries[4]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("uneven", entries[1], entries[3], entries[5]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[6], entries[8]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("uneven", entries[7], entries[9]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).close();
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{this.sqs});
    }

    @Test
    public void testWriteBatchesToDynamicWithTimeout() {
        Mockito.when((Object)this.sqs.sendMessageBatch(this.anyRequest())).thenReturn(CompletableFuture.completedFuture(SUCCESS));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)5, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SqsIOWriteBatchesTest.withDelay(Duration.millis((long)100L), SET_MESSAGE_BODY)).withBatchTimeout(Duration.millis((long)150L)).to((SqsIO.WriteBatches.DynamicDestination & Serializable)msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
        this.p.run().waitUntilFinish();
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 5));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[0], entries[2]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("uneven", entries[1], entries[3]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[4]));
    }

    @Test
    public void testWriteBatchesToDynamicWithStrictTimeout() {
        Mockito.when((Object)this.sqs.sendMessageBatch((SendMessageBatchRequest)ArgumentMatchers.any(SendMessageBatchRequest.class))).thenReturn(CompletableFuture.completedFuture((SendMessageBatchResponse)SendMessageBatchResponse.builder().build()));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)5, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SqsIOWriteBatchesTest.withDelay(Duration.millis((long)100L), SET_MESSAGE_BODY)).withBatchTimeout(Duration.millis((long)150L), true).to((SqsIO.WriteBatches.DynamicDestination & Serializable)msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
        this.p.run().waitUntilFinish();
        SendMessageBatchRequestEntry[] entries = this.entries(IntStream.range(0, 5));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[0]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("uneven", entries[1]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[2]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("uneven", entries[3]));
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs)).sendMessageBatch(this.request("even", entries[4]));
    }

    @Test
    public void testWriteBatchesToDynamicWithStrictTimeoutAtHighVolume() {
        Mockito.when((Object)this.sqs.sendMessageBatch((SendMessageBatchRequest)ArgumentMatchers.any(SendMessageBatchRequest.class))).thenReturn(CompletableFuture.completedFuture((SendMessageBatchResponse)SendMessageBatchResponse.builder().build()));
        SqsIO.WriteBatches.DynamicDestination & Serializable dynamicDestination = (SqsIO.WriteBatches.DynamicDestination & Serializable)msg -> String.valueOf(RandomUtils.nextInt((int)0, (int)((int)(1.0 + Math.sqrt(Integer.valueOf(msg).intValue())))));
        ((PCollection)((PCollection)this.p.apply((PTransform)Create.of((Object)100000, (Object[])new Integer[0]))).apply((PTransform)ParDo.of((DoFn)new CreateMessages()))).apply((PTransform)SqsIO.writeBatches().withEntryMapper(SET_MESSAGE_BODY).withBatchTimeout(Duration.millis((long)10L), true).to((SqsIO.WriteBatches.DynamicDestination)dynamicDestination));
        this.p.run().waitUntilFinish();
        ArgumentCaptor reqCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
        ((SqsAsyncClient)Mockito.verify((Object)this.sqs, (VerificationMode)Mockito.atLeastOnce())).sendMessageBatch((SendMessageBatchRequest)reqCaptor.capture());
        HashSet<String> capturedMessages = new HashSet<String>();
        for (SendMessageBatchRequest req : reqCaptor.getAllValues()) {
            for (SendMessageBatchRequestEntry entry : req.entries()) {
                Assert.assertTrue((String)"duplicate message", (boolean)capturedMessages.add(entry.messageBody()));
            }
        }
        Assert.assertEquals((String)"Invalid message count", (long)100000L, (long)capturedMessages.size());
    }

    private SendMessageBatchRequest anyRequest() {
        return (SendMessageBatchRequest)ArgumentMatchers.any();
    }

    private SendMessageBatchRequest request(String queue, SendMessageBatchRequestEntry ... entries) {
        return (SendMessageBatchRequest)SendMessageBatchRequest.builder().queueUrl(queue).entries(Arrays.asList(entries)).build();
    }

    private SendMessageBatchRequest request(String queue, IntStream msgs) {
        return this.request(queue, this.entries(msgs));
    }

    private SendMessageBatchRequestEntry[] entries(IntStream ... msgStreams) {
        return (SendMessageBatchRequestEntry[])Arrays.stream(msgStreams).flatMap(msgs -> Streams.mapWithIndex((IntStream)msgs, this::entry)).toArray(SendMessageBatchRequestEntry[]::new);
    }

    private SendMessageBatchRequestEntry entry(int msg, long id) {
        return (SendMessageBatchRequestEntry)SendMessageBatchRequestEntry.builder().id(Long.toString(id)).messageBody(Integer.toString(msg)).build();
    }

    private SendMessageBatchResponse partialSuccessResponse(String ... failedIds) {
        Stream<BatchResultErrorEntry> errors = Arrays.stream(failedIds).map(arg_0 -> ((BatchResultErrorEntry.Builder)BatchResultErrorEntry.builder()).id(arg_0)).map(b -> (BatchResultErrorEntry)b.code("REASON").build());
        return (SendMessageBatchResponse)SendMessageBatchResponse.builder().failed((Collection)errors.collect(Collectors.toList())).build();
    }

    private static <T> SqsIO.WriteBatches.EntryMapperFn.Builder<T> withDelay(Duration delay, SqsIO.WriteBatches.EntryMapperFn.Builder<T> builder) {
        return (SqsIO.WriteBatches.EntryMapperFn.Builder & Serializable)(t1, t2) -> {
            builder.accept(t1, t2);
            try {
                Thread.sleep(delay.getMillis());
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
        };
    }

    private static class CreateMessages
    extends DoFn<Integer, String> {
        private CreateMessages() {
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.Element Integer count, DoFn.OutputReceiver<String> out) {
            for (int i = 0; i < count; ++i) {
                out.output((Object)Integer.toString(i));
            }
        }
    }
}

