package org.apache.beam.sdk.io.aws2.kinesis;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.StaticSupplier;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.assertj.core.api.Assertions;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentMatcher;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClient;
import software.amazon.awssdk.services.cloudwatch.CloudWatchClientBuilder;
import software.amazon.awssdk.services.kinesis.KinesisClient;
import software.amazon.awssdk.services.kinesis.KinesisClientBuilder;
import software.amazon.awssdk.services.kinesis.model.GetRecordsRequest;
import software.amazon.awssdk.services.kinesis.model.GetRecordsResponse;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorRequest;
import software.amazon.awssdk.services.kinesis.model.GetShardIteratorResponse;
import software.amazon.awssdk.services.kinesis.model.LimitExceededException;
import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
import software.amazon.awssdk.services.kinesis.model.ListShardsResponse;
import software.amazon.awssdk.services.kinesis.model.Record;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.kinesis.common.InitialPositionInStream;

@RunWith(MockitoJUnitRunner.class)
/* loaded from: input_file:org/apache/beam/sdk/io/aws2/kinesis/KinesisIOReadTest.class */
public class KinesisIOReadTest {
    private static final String KEY = "key";
    private static final String SECRET = "secret";
    private static final int SHARDS = 3;
    private static final int SHARD_EVENTS = 100;

    @Rule
    public final transient TestPipeline p = TestPipeline.create();

    @Mock
    public KinesisClient client;

    /* loaded from: input_file:org/apache/beam/sdk/io/aws2/kinesis/KinesisIOReadTest$Provider.class */
    static class Provider extends StaticSupplier<KinesisClient, Provider> implements AWSClientsProvider {
        Provider() {
        }

        static AWSClientsProvider of(KinesisClient kinesisClient) {
            return new Provider().withObject(kinesisClient);
        }

        public KinesisClient getKinesisClient() {
            return get();
        }

        public CloudWatchClient getCloudWatchClient() {
            return (CloudWatchClient) Mockito.mock(CloudWatchClient.class);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/io/aws2/kinesis/KinesisIOReadTest$ToRecord.class */
    public static class ToRecord extends DoFn<KinesisRecord, Record> {
        ToRecord() {
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.Element KinesisRecord kinesisRecord, DoFn.OutputReceiver<Record> outputReceiver) {
            outputReceiver.output(KinesisIOReadTest.record(kinesisRecord.getApproximateArrivalTimestamp(), kinesisRecord.getDataAsBytes(), kinesisRecord.getSequenceNumber()));
        }
    }

    @Before
    public void configureClientBuilderFactory() {
        MockClientBuilderFactory.set(this.p, KinesisClientBuilder.class, this.client);
        MockClientBuilderFactory.set(this.p, CloudWatchClientBuilder.class, (CloudWatchClient) Mockito.mock(CloudWatchClient.class));
    }

    @Test
    public void testReadFromShards() {
        List<List<Record>> testRecords = testRecords(SHARDS, SHARD_EVENTS);
        mockShards(SHARDS);
        mockShardIterators(testRecords);
        mockRecords(testRecords, 10);
        readFromShards(Function.identity(), Iterables.concat(testRecords));
    }

    @Test
    public void testReadFromShardsWithLegacyProvider() {
        List<List<Record>> testRecords = testRecords(SHARDS, SHARD_EVENTS);
        mockShards(SHARDS);
        mockShardIterators(testRecords);
        mockRecords(testRecords, 10);
        MockClientBuilderFactory.set(this.p, KinesisClientBuilder.class, null);
        readFromShards(read -> {
            return read.withAWSClientsProvider(Provider.of(this.client));
        }, Iterables.concat(testRecords));
    }

    @Test(expected = Pipeline.PipelineExecutionException.class)
    public void testReadWithLimitExceeded() {
        Mockito.when(this.client.listShards((ListShardsRequest) ArgumentMatchers.any(ListShardsRequest.class))).thenThrow(new Throwable[]{(Throwable) LimitExceededException.builder().message("ListShards rate limit exceeded").build()});
        readFromShards(Function.identity(), ImmutableList.of());
    }

    private void readFromShards(Function<KinesisIO.Read, KinesisIO.Read> function, Iterable<Record> iterable) {
        PAssert.that(this.p.apply(function.apply(KinesisIO.read().withStreamName("stream").withInitialPositionInStream(InitialPositionInStream.TRIM_HORIZON).withArrivalTimeWatermarkPolicy().withMaxNumRecords(300L))).apply(ParDo.of(new ToRecord()))).containsInAnyOrder(iterable);
        this.p.run();
    }

    @Test
    public void testBuildWithBasicCredentials() {
        Region region = Region.US_EAST_1;
        Assertions.assertThat(KinesisIO.read().withAWSClientsProvider(KEY, SECRET, region).getClientConfiguration()).isEqualTo(ClientConfiguration.create(StaticCredentialsProvider.create(AwsBasicCredentials.create(KEY, SECRET)), region, (URI) null));
    }

    @Test
    public void testBuildWithCredentialsProvider() {
        Region region = Region.US_EAST_1;
        DefaultCredentialsProvider create = DefaultCredentialsProvider.create();
        Assertions.assertThat(KinesisIO.read().withAWSClientsProvider(create, region).getClientConfiguration()).isEqualTo(ClientConfiguration.create(create, region, (URI) null));
    }

    @Test
    public void testBuildWithBasicCredentialsAndCustomEndpoint() {
        Region region = Region.US_WEST_1;
        Assertions.assertThat(KinesisIO.read().withAWSClientsProvider(KEY, SECRET, region, "localhost:9999").getClientConfiguration()).isEqualTo(ClientConfiguration.create(StaticCredentialsProvider.create(AwsBasicCredentials.create(KEY, SECRET)), region, URI.create("localhost:9999")));
    }

    @Test
    public void testBuildWithCredentialsProviderAndCustomEndpoint() {
        Region region = Region.US_WEST_1;
        DefaultCredentialsProvider create = DefaultCredentialsProvider.create();
        Assertions.assertThat(KinesisIO.read().withAWSClientsProvider(create, region, "localhost:9999").getClientConfiguration()).isEqualTo(ClientConfiguration.create(create, region, URI.create("localhost:9999")));
    }

    private static ArgumentMatcher<GetShardIteratorRequest> hasShardId(int i) {
        return getShardIteratorRequest -> {
            return getShardIteratorRequest != null && getShardIteratorRequest.shardId().equals(new StringBuilder().append("").append(i).toString());
        };
    }

    private static ArgumentMatcher<GetRecordsRequest> hasShardIterator(String str) {
        return getRecordsRequest -> {
            return getRecordsRequest != null && getRecordsRequest.shardIterator().equals(str);
        };
    }

    private void mockShardIterators(List<List<Record>> list) {
        for (int i = 0; i < list.size(); i++) {
            Mockito.when(this.client.getShardIterator((GetShardIteratorRequest) ArgumentMatchers.argThat(hasShardId(i)))).thenReturn((GetShardIteratorResponse) GetShardIteratorResponse.builder().shardIterator(i + ":0").build());
        }
    }

    private void mockRecords(List<List<Record>> list, int i) {
        BiFunction biFunction = (list2, str) -> {
            return GetRecordsResponse.builder().millisBehindLatest(0L).records(list2).nextShardIterator(str);
        };
        for (int i2 = 0; i2 < list.size(); i2++) {
            List<Record> list3 = list.get(i2);
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < list3.size()) {
                    int max = Math.max(i4 + i, list3.size());
                    Mockito.when(this.client.getRecords((GetRecordsRequest) ArgumentMatchers.argThat(hasShardIterator(i2 + ":" + i4)))).thenReturn((GetRecordsResponse) ((GetRecordsResponse.Builder) biFunction.apply(list3.subList(i4, max), max == list3.size() ? "done" : i2 + ":" + max)).build());
                    i3 = i4 + i;
                }
            }
        }
        Mockito.when(this.client.getRecords((GetRecordsRequest) ArgumentMatchers.argThat(hasShardIterator("done")))).thenReturn((GetRecordsResponse) ((GetRecordsResponse.Builder) biFunction.apply(ImmutableList.of(), "done")).build());
    }

    private void mockShards(int i) {
        Mockito.when(this.client.listShards((ListShardsRequest) ArgumentMatchers.any(ListShardsRequest.class))).thenReturn((ListShardsResponse) ListShardsResponse.builder().shards((List) IntStream.range(0, i).mapToObj(i2 -> {
            return (Shard) Shard.builder().shardId(Integer.toString(i2)).build();
        }).collect(Collectors.toList())).build());
    }

    private List<List<Record>> testRecords(int i, int i2) {
        Instant instant = DateTime.now().toInstant();
        return (List) IntStream.range(0, i).boxed().map(num -> {
            return (List) IntStream.range(0, i2).mapToObj(i3 -> {
                return record(instant, num.intValue(), i3);
            }).collect(Collectors.toList());
        }).collect(Collectors.toList());
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Record record(Instant instant, int i, int i2) {
        String num = Integer.toString((i * SHARD_EVENTS) + i2);
        return record(instant.plus(Duration.standardSeconds(i2)), num.getBytes(StandardCharsets.UTF_8), num);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Record record(Instant instant, byte[] bArr, String str) {
        return (Record) Record.builder().approximateArrivalTimestamp(TimeUtil.toJava(instant)).data(SdkBytes.fromByteArray(bArr)).sequenceNumber(str).partitionKey("").build();
    }
}
