package org.apache.beam.runners.fnexecution.state;

import java.util.EnumMap;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.ModelCoders;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/runners/fnexecution/state/StateRequestHandlersTest.class */
public class StateRequestHandlersTest {
    @Test
    public void testDelegatingStateHandlerDelegates() throws Exception {
        StateRequestHandler stateRequestHandler = (StateRequestHandler) Mockito.mock(StateRequestHandler.class);
        StateRequestHandler stateRequestHandler2 = (StateRequestHandler) Mockito.mock(StateRequestHandler.class);
        EnumMap enumMap = new EnumMap(BeamFnApi.StateKey.TypeCase.class);
        enumMap.put((EnumMap) BeamFnApi.StateKey.TypeCase.TYPE_NOT_SET, (BeamFnApi.StateKey.TypeCase) stateRequestHandler);
        enumMap.put((EnumMap) BeamFnApi.StateKey.TypeCase.MULTIMAP_SIDE_INPUT, (BeamFnApi.StateKey.TypeCase) stateRequestHandler2);
        BeamFnApi.StateRequest defaultInstance = BeamFnApi.StateRequest.getDefaultInstance();
        BeamFnApi.StateRequest build = BeamFnApi.StateRequest.newBuilder().setStateKey(BeamFnApi.StateKey.newBuilder().setMultimapSideInput(BeamFnApi.StateKey.MultimapSideInput.getDefaultInstance())).build();
        StateRequestHandlers.delegateBasedUponType(enumMap).handle(defaultInstance);
        StateRequestHandlers.delegateBasedUponType(enumMap).handle(build);
        ((StateRequestHandler) Mockito.verify(stateRequestHandler)).handle(defaultInstance);
        ((StateRequestHandler) Mockito.verify(stateRequestHandler2)).handle(build);
        Mockito.verifyNoMoreInteractions(new Object[]{stateRequestHandler, stateRequestHandler2});
    }

    @Test
    public void testDelegatingStateHandlerThrowsWhenNotFound() throws Exception {
        StateRequestHandlers.delegateBasedUponType(new EnumMap(BeamFnApi.StateKey.TypeCase.class)).handle(BeamFnApi.StateRequest.getDefaultInstance());
    }

    @Test
    public void testUserStateCacheTokenGeneration() throws Exception {
        ProcessBundleDescriptors.ExecutableProcessBundleDescriptor fromExecutableStage = ProcessBundleDescriptors.fromExecutableStage("id", buildExecutableStage("state1", "state2"), Endpoints.ApiServiceDescriptor.getDefaultInstance());
        InMemoryBagUserStateFactory inMemoryBagUserStateFactory = new InMemoryBagUserStateFactory();
        MatcherAssert.assertThat(Integer.valueOf(inMemoryBagUserStateFactory.handlers.size()), Matchers.is(0));
        StateRequestHandler forBagUserStateHandlerFactory = StateRequestHandlers.forBagUserStateHandlerFactory(fromExecutableStage, inMemoryBagUserStateFactory);
        BeamFnApi.ProcessBundleRequest.CacheToken assertSingleCacheToken = assertSingleCacheToken(forBagUserStateHandlerFactory);
        sendGetRequest(forBagUserStateHandlerFactory, "state1");
        MatcherAssert.assertThat(Integer.valueOf(inMemoryBagUserStateFactory.handlers.size()), Matchers.is(1));
        MatcherAssert.assertThat(assertSingleCacheToken(forBagUserStateHandlerFactory), Matchers.is(assertSingleCacheToken));
        sendGetRequest(forBagUserStateHandlerFactory, "state2");
        MatcherAssert.assertThat(Integer.valueOf(inMemoryBagUserStateFactory.handlers.size()), Matchers.is(2));
        MatcherAssert.assertThat(assertSingleCacheToken(forBagUserStateHandlerFactory), Matchers.is(assertSingleCacheToken));
    }

    private static BeamFnApi.ProcessBundleRequest.CacheToken assertSingleCacheToken(StateRequestHandler stateRequestHandler) {
        Iterable cacheTokens = stateRequestHandler.getCacheTokens();
        MatcherAssert.assertThat(Integer.valueOf(Iterables.size(cacheTokens)), Matchers.is(1));
        BeamFnApi.ProcessBundleRequest.CacheToken cacheToken = (BeamFnApi.ProcessBundleRequest.CacheToken) Iterables.getOnlyElement(cacheTokens);
        MatcherAssert.assertThat(cacheToken.getToken(), Matchers.is(Matchers.notNullValue()));
        MatcherAssert.assertThat(cacheToken.getUserState(), Matchers.is(BeamFnApi.ProcessBundleRequest.CacheToken.UserState.getDefaultInstance()));
        return cacheToken;
    }

    private static void sendGetRequest(StateRequestHandler stateRequestHandler, String str) throws Exception {
        stateRequestHandler.handle(BeamFnApi.StateRequest.newBuilder().setGet(BeamFnApi.StateGetRequest.getDefaultInstance()).setStateKey(BeamFnApi.StateKey.newBuilder().setBagUserState(BeamFnApi.StateKey.BagUserState.newBuilder().setKey(ByteString.copyFromUtf8("key")).setWindow(ByteString.copyFrom(CoderUtils.encodeToByteArray(GlobalWindow.Coder.INSTANCE, GlobalWindow.INSTANCE))).setTransformId("transform").setUserStateId(str)).build()).build()).toCompletableFuture().get();
    }

    private static ExecutableStage buildExecutableStage(String... strArr) {
        RunnerApi.ExecutableStagePayload.Builder components = RunnerApi.ExecutableStagePayload.newBuilder().setInput("input").setComponents(RunnerApi.Components.newBuilder().putWindowingStrategies("window", RunnerApi.WindowingStrategy.newBuilder().setWindowCoderId("windowCoder").build()).putPcollections("input", RunnerApi.PCollection.newBuilder().setWindowingStrategyId("window").setCoderId("coder").build()).putCoders("windowCoder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN).build()).build()).putCoders("coder", RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("keyCoder").addComponentCoderIds("valueCoder").build()).putCoders("keyCoder", RunnerApi.Coder.getDefaultInstance()).putCoders("valueCoder", RunnerApi.Coder.getDefaultInstance()).putTransforms("transform", RunnerApi.PTransform.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:transform:pardo:v1").build()).putInputs("input", "input").build()).build());
        for (String str : strArr) {
            components.addUserStates(RunnerApi.ExecutableStagePayload.UserStateId.newBuilder().setTransformId("transform").setLocalName(str).build());
        }
        return ExecutableStage.fromPayload(components.build());
    }
}
