package org.apache.beam.fn.harness.state;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.fn.harness.Cache;
import org.apache.beam.fn.harness.Caches;
import org.apache.beam.fn.harness.state.StateFetchingIterators;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.fn.stream.PrefetchableIterator;
import org.apache.beam.sdk.transforms.reflect.ByteBuddyDoFnInvokerFactory;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(Enclosed.class)
/* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIteratorsTest.class */
public class StateFetchingIteratorsTest {

    @RunWith(JUnit4.class)
    /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIteratorsTest$CachingStateIterableTest.class */
    public static class CachingStateIterableTest {
        @Test
        public void testEmpty() throws Exception {
            testFetchAndClear(4, new int[0]);
        }

        @Test
        public void testNonEmpty() throws Exception {
            testFetchAndClear(4, 0);
        }

        @Test
        public void testMultipleElementsPerChunk() throws Exception {
            testFetchAndClear(8, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testSingleElementPerChunk() throws Exception {
            testFetchAndClear(4, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testChunkSmallerThenElementSize() throws Exception {
            testFetchAndClear(3, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testChunkLargerThenElementSize() throws Exception {
            testFetchAndClear(5, 0, 1, 2, 3, 4, 5);
        }

        @Test
        public void testAppend() throws Exception {
            int[] iArr = {0, 1, 2, 3, 4, 5};
            StateFetchingIterators.CachingStateIterable<Integer> create = create(5, iArr);
            create.append(Ints.asList(42, 43));
            verifyFetch(create.iterator(), iArr);
            StateFetchingIterators.CachingStateIterable<Integer> create2 = create(5, iArr);
            create2.append(Ints.asList(42, 43));
            Boolean.valueOf(create2.iterator().hasNext());
            verifyFetch(create2.iterator(), iArr);
            StateFetchingIterators.CachingStateIterable<Integer> create3 = create(5, iArr);
            verifyFetch(create3.iterator(), iArr);
            create3.append(Ints.asList(42, 43));
            verifyFetch(create3.iterator(), 0, 1, 2, 3, 4, 5, 42, 43);
        }

        @Test
        public void testRemove() throws Exception {
            int[] iArr = {0, 1, 2, 3, 4, 5};
            HashSet hashSet = new HashSet();
            hashSet.add(BigEndianIntegerCoder.of().structuralValue(2));
            hashSet.add(BigEndianIntegerCoder.of().structuralValue(4));
            StateFetchingIterators.CachingStateIterable<Integer> create = create(5, iArr);
            create.remove(hashSet);
            verifyFetch(create.iterator(), iArr);
            StateFetchingIterators.CachingStateIterable<Integer> create2 = create(5, iArr);
            create2.remove(hashSet);
            Boolean.valueOf(create2.iterator().hasNext());
            verifyFetch(create2.iterator(), iArr);
            StateFetchingIterators.CachingStateIterable<Integer> create3 = create(5, iArr);
            verifyFetch(create3.iterator(), iArr);
            create3.remove(hashSet);
            verifyFetch(create3.iterator(), 0, 1, 3, 5);
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Test
        public void testCacheEvictionOrphansIteratorAndAllowsForIteratorToRejoin() throws Exception {
            int[] iArr = {0, 1, 2, 3, 4, 5};
            BeamFnApi.StateRequest build = BeamFnApi.StateRequest.newBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setBagUserState(BeamFnApi.StateKey.BagUserState.newBuilder().setTransformId("transformId").setUserStateId("stateId").setKey(ByteString.copyFromUtf8(ByteBuddyDoFnInvokerFactory.KEY_PARAMETER_METHOD)).setWindow(ByteString.copyFromUtf8(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD)))).setGet(BeamFnApi.StateGetRequest.getDefaultInstance()).build();
            FakeBeamFnStateClient fakeBeamFnStateClient = new FakeBeamFnStateClient(BigEndianIntegerCoder.of(), ImmutableMap.of(build.getStateKey(), Ints.asList(iArr)), 4);
            Cache eternal = Caches.eternal();
            StateFetchingIterators.CachingStateIterable cachingStateIterable = new StateFetchingIterators.CachingStateIterable(eternal, fakeBeamFnStateClient, build, BigEndianIntegerCoder.of());
            verifyFetch(cachingStateIterable.iterator(), iArr);
            MatcherAssert.assertThat((StateFetchingIterators.CachingStateIterable.Blocks) eternal.peek(StateFetchingIterators.IterableCacheKey.INSTANCE), (Matcher<? super StateFetchingIterators.CachingStateIterable.Blocks>) Matchers.is(Matchers.instanceOf(StateFetchingIterators.CachingStateIterable.BlocksPrefix.class)));
            int callCount = fakeBeamFnStateClient.getCallCount();
            PrefetchableIterator it = cachingStateIterable.iterator();
            Assert.assertEquals(0L, ((Integer) it.next()).intValue());
            Assert.assertEquals(callCount, fakeBeamFnStateClient.getCallCount());
            int callCount2 = fakeBeamFnStateClient.getCallCount();
            eternal.remove(StateFetchingIterators.IterableCacheKey.INSTANCE);
            Assert.assertEquals(1L, ((Integer) it.next()).intValue());
            Assert.assertEquals(callCount2 + 2, fakeBeamFnStateClient.getCallCount());
            Assert.assertEquals(2L, ((Integer) it.next()).intValue());
            Assert.assertEquals(r13 + 1, fakeBeamFnStateClient.getCallCount());
            Assert.assertNull(eternal.peek(StateFetchingIterators.IterableCacheKey.INSTANCE));
            int callCount3 = fakeBeamFnStateClient.getCallCount();
            PrefetchableIterator it2 = cachingStateIterable.iterator();
            Assert.assertEquals(0L, ((Integer) it2.next()).intValue());
            MatcherAssert.assertThat((StateFetchingIterators.CachingStateIterable.Blocks) eternal.peek(StateFetchingIterators.IterableCacheKey.INSTANCE), (Matcher<? super StateFetchingIterators.CachingStateIterable.Blocks>) Matchers.is(Matchers.instanceOf(StateFetchingIterators.CachingStateIterable.BlocksPrefix.class)));
            Assert.assertTrue(callCount3 < fakeBeamFnStateClient.getCallCount());
            Assert.assertEquals(1L, ((Integer) it2.next()).intValue());
            Assert.assertEquals(2L, ((Integer) it2.next()).intValue());
            Assert.assertEquals(3L, ((Integer) it2.next()).intValue());
            Assert.assertEquals(4L, ((Integer) it2.next()).intValue());
            int callCount4 = fakeBeamFnStateClient.getCallCount();
            Assert.assertEquals(3L, ((Integer) it.next()).intValue());
            Assert.assertEquals(callCount4, fakeBeamFnStateClient.getCallCount());
        }

        @Test
        public void testBlocksPrefixShrinkage() throws Exception {
            List asList = Arrays.asList(StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("A"), (ByteString) null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("B"), (ByteString) null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("C"), (ByteString) null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("D"), (ByteString) null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("E"), (ByteString) null), StateFetchingIterators.CachingStateIterable.Block.fromValues(Arrays.asList("F"), (ByteString) null));
            StateFetchingIterators.CachingStateIterable.BlocksPrefix shrink = new StateFetchingIterators.CachingStateIterable.BlocksPrefix(asList).shrink();
            StateFetchingIterators.CachingStateIterable.BlocksPrefix shrink2 = shrink.shrink();
            MatcherAssert.assertThat(shrink.getBlocks(), (Matcher<? super List>) Matchers.contains((StateFetchingIterators.CachingStateIterable.Block) asList.get(0), (StateFetchingIterators.CachingStateIterable.Block) asList.get(1), (StateFetchingIterators.CachingStateIterable.Block) asList.get(2)));
            MatcherAssert.assertThat(shrink2.getBlocks(), (Matcher<? super List>) Matchers.contains((StateFetchingIterators.CachingStateIterable.Block) asList.get(0)));
            Assert.assertNull(shrink2.shrink());
        }

        @Test
        public void testBlocksWeight() throws Exception {
            List asList = Arrays.asList(StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("A"), 10L), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("B"), 4611686018427387903L), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("C"), 4611686018427387903L), StateFetchingIterators.CachingStateIterable.Block.mutatedBlock(Arrays.asList("D"), 5L));
            Assert.assertEquals(4611686018427387913L, new StateFetchingIterators.CachingStateIterable.BlocksPrefix(asList.subList(0, 2)).getWeight());
            Assert.assertEquals(Long.MAX_VALUE, new StateFetchingIterators.CachingStateIterable.BlocksPrefix(asList).getWeight());
        }

        private StateFetchingIterators.CachingStateIterable<Integer> create(int i, int... iArr) {
            BeamFnApi.StateRequest build = BeamFnApi.StateRequest.newBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setBagUserState(BeamFnApi.StateKey.BagUserState.newBuilder().setTransformId("transformId").setUserStateId("stateId").setKey(ByteString.copyFromUtf8(ByteBuddyDoFnInvokerFactory.KEY_PARAMETER_METHOD)).setWindow(ByteString.copyFromUtf8(ByteBuddyDoFnInvokerFactory.WINDOW_PARAMETER_METHOD)))).setGet(BeamFnApi.StateGetRequest.getDefaultInstance()).build();
            StateFetchingIterators.CachingStateIterable<Integer> cachingStateIterable = new StateFetchingIterators.CachingStateIterable<>(Caches.eternal(), new FakeBeamFnStateClient(BigEndianIntegerCoder.of(), ImmutableMap.of(build.getStateKey(), Ints.asList(iArr)), i), build, BigEndianIntegerCoder.of());
            cachingStateIterable.iterator();
            Assert.assertEquals(0L, r0.getCallCount());
            return cachingStateIterable;
        }

        private void testFetchAndClear(int i, int... iArr) throws Exception {
            PrefetchableIterator<Integer> it = create(i, iArr).iterator();
            Assert.assertFalse(it.isReady());
            verifyFetch(it, iArr);
            verifyClear(i, iArr);
        }

        private void verifyFetch(PrefetchableIterator<Integer> prefetchableIterator, int... iArr) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < iArr.length; i++) {
                Assert.assertTrue(prefetchableIterator.hasNext());
                arrayList.add(prefetchableIterator.next());
            }
            Assert.assertFalse(prefetchableIterator.hasNext());
            Assert.assertTrue(prefetchableIterator.isReady());
            Assert.assertEquals(Ints.asList(iArr), arrayList);
        }

        private void verifyClear(int i, int... iArr) throws Exception {
            StateFetchingIterators.CachingStateIterable<Integer> create = create(i, iArr);
            create.clearAndAppend(Ints.asList(42, 43));
            Assert.assertTrue(create.iterator().isReady());
            verifyFetch(create.iterator(), 42, 43);
            StateFetchingIterators.CachingStateIterable<Integer> create2 = create(i, iArr);
            Boolean.valueOf(create2.iterator().hasNext());
            create2.clearAndAppend(Ints.asList(42, 43));
            Assert.assertTrue(create2.iterator().isReady());
            verifyFetch(create2.iterator(), 42, 43);
            StateFetchingIterators.CachingStateIterable<Integer> create3 = create(i, iArr);
            verifyFetch(create3.iterator(), iArr);
            create3.clearAndAppend(Ints.asList(42, 43));
            Assert.assertTrue(create3.iterator().isReady());
            verifyFetch(create3.iterator(), 42, 43);
        }
    }

    @RunWith(JUnit4.class)
    /* loaded from: input_file:org/apache/beam/fn/harness/state/StateFetchingIteratorsTest$LazyBlockingStateFetchingIteratorTest.class */
    public static class LazyBlockingStateFetchingIteratorTest {
        @Test
        public void testEmpty() throws Exception {
            testFetch(ByteString.EMPTY);
        }

        @Test
        public void testNonEmpty() throws Exception {
            testFetch(ByteString.copyFromUtf8("A"));
        }

        @Test
        public void testWithLastByteStringBeingEmpty() throws Exception {
            testFetch(ByteString.copyFromUtf8("A"), ByteString.EMPTY);
        }

        @Test
        public void testMulti() throws Exception {
            testFetch(ByteString.copyFromUtf8("BC"), ByteString.copyFromUtf8("DEF"));
        }

        @Test
        public void testMultiWithEmptyByteStrings() throws Exception {
            testFetch(ByteString.EMPTY, ByteString.copyFromUtf8("BC"), ByteString.EMPTY, ByteString.EMPTY, ByteString.copyFromUtf8("DEF"), ByteString.EMPTY);
        }

        @Test
        public void testPrefetchIgnoredWhenExistingPrefetchOngoing() throws Exception {
            final AtomicInteger atomicInteger = new AtomicInteger();
            StateFetchingIterators.LazyBlockingStateFetchingIterator lazyBlockingStateFetchingIterator = new StateFetchingIterators.LazyBlockingStateFetchingIterator(new BeamFnStateClient() { // from class: org.apache.beam.fn.harness.state.StateFetchingIteratorsTest.LazyBlockingStateFetchingIteratorTest.1
                public CompletableFuture<BeamFnApi.StateResponse> handle(BeamFnApi.StateRequest.Builder builder) {
                    atomicInteger.incrementAndGet();
                    return new CompletableFuture<>();
                }
            }, BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(0L, atomicInteger.get());
            lazyBlockingStateFetchingIterator.prefetch();
            Assert.assertEquals(1L, atomicInteger.get());
            lazyBlockingStateFetchingIterator.prefetch();
            Assert.assertEquals(1L, atomicInteger.get());
        }

        @Test
        public void testSeekToContinuationToken() throws Exception {
            StateFetchingIterators.LazyBlockingStateFetchingIterator lazyBlockingStateFetchingIterator = new StateFetchingIterators.LazyBlockingStateFetchingIterator(new BeamFnStateClient() { // from class: org.apache.beam.fn.harness.state.StateFetchingIteratorsTest.LazyBlockingStateFetchingIteratorTest.2
                public CompletableFuture<BeamFnApi.StateResponse> handle(BeamFnApi.StateRequest.Builder builder) {
                    int i = 0;
                    if (!ByteString.EMPTY.equals(builder.getGet().getContinuationToken())) {
                        i = Integer.parseInt(builder.getGet().getContinuationToken().toStringUtf8());
                    }
                    return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setGet(BeamFnApi.StateGetResponse.newBuilder().setData(ByteString.copyFromUtf8("value" + i)).setContinuationToken(ByteString.copyFromUtf8(Integer.toString(i + 1)))).build());
                }
            }, BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(ByteString.copyFromUtf8("value0"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value1"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value2"), lazyBlockingStateFetchingIterator.next());
            lazyBlockingStateFetchingIterator.seekToContinuationToken(ByteString.EMPTY);
            Assert.assertEquals(ByteString.copyFromUtf8("value0"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value1"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value2"), lazyBlockingStateFetchingIterator.next());
            lazyBlockingStateFetchingIterator.seekToContinuationToken(ByteString.copyFromUtf8("42"));
            Assert.assertEquals(ByteString.copyFromUtf8("value42"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value43"), lazyBlockingStateFetchingIterator.next());
            Assert.assertEquals(ByteString.copyFromUtf8("value44"), lazyBlockingStateFetchingIterator.next());
        }

        private void testFetch(ByteString... byteStringArr) {
            StateFetchingIterators.LazyBlockingStateFetchingIterator lazyBlockingStateFetchingIterator = new StateFetchingIterators.LazyBlockingStateFetchingIterator(StateFetchingIteratorsTest.fakeStateClient(new AtomicInteger(), byteStringArr), BeamFnApi.StateRequest.getDefaultInstance());
            Assert.assertEquals(0L, r0.get());
            Assert.assertFalse(lazyBlockingStateFetchingIterator.isReady());
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < byteStringArr.length; i++) {
                if (i % 2 == 0) {
                    lazyBlockingStateFetchingIterator.prefetch();
                    Assert.assertEquals(i + 1, r0.get());
                    Assert.assertTrue(lazyBlockingStateFetchingIterator.isReady());
                }
                Assert.assertTrue(lazyBlockingStateFetchingIterator.hasNext());
                arrayList.add((ByteString) lazyBlockingStateFetchingIterator.next());
            }
            Assert.assertFalse(lazyBlockingStateFetchingIterator.hasNext());
            Assert.assertTrue(lazyBlockingStateFetchingIterator.isReady());
            Assert.assertEquals(Arrays.asList(byteStringArr), arrayList);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BeamFnStateClient fakeStateClient(AtomicInteger atomicInteger, ByteString... byteStringArr) {
        return builder -> {
            atomicInteger.incrementAndGet();
            if (byteStringArr.length == 0) {
                return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(builder.getId()).setGet(BeamFnApi.StateGetResponse.newBuilder()).build());
            }
            ByteString continuationToken = builder.getGet().getContinuationToken();
            int i = 0;
            if (!ByteString.EMPTY.equals(continuationToken)) {
                i = Integer.parseInt(continuationToken.toStringUtf8());
            }
            ByteString byteString = ByteString.EMPTY;
            if (i != byteStringArr.length - 1) {
                byteString = ByteString.copyFromUtf8(Integer.toString(i + 1));
            }
            return CompletableFuture.completedFuture(BeamFnApi.StateResponse.newBuilder().setId(builder.getId()).setGet(BeamFnApi.StateGetResponse.newBuilder().setData(byteStringArr[i]).setContinuationToken(byteString)).build());
        };
    }
}
