package org.apache.beam.fn.harness.state;

import com.google.auto.value.AutoValue;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.cache.LoadingCache;

/* loaded from: input_file:org/apache/beam/fn/harness/state/CachingBeamFnStateClient.class */
public class CachingBeamFnStateClient implements BeamFnStateClient {
    private final BeamFnStateClient beamFnStateClient;
    private final LoadingCache<BeamFnApi.StateKey, Map<StateCacheKey, BeamFnApi.StateGetResponse>> stateCache;
    private final Map<BeamFnApi.ProcessBundleRequest.CacheToken.SideInput, ByteString> sideInputCacheTokens = new HashMap();
    private final ByteString userStateToken;

    @AutoValue
    /* loaded from: input_file:org/apache/beam/fn/harness/state/CachingBeamFnStateClient$StateCacheKey.class */
    public static abstract class StateCacheKey {
        public abstract ByteString getCacheToken();

        public abstract ByteString getContinuationToken();

        static StateCacheKey create(ByteString byteString, ByteString byteString2) {
            return new AutoValue_CachingBeamFnStateClient_StateCacheKey(byteString, byteString2);
        }
    }

    public CachingBeamFnStateClient(BeamFnStateClient beamFnStateClient, LoadingCache<BeamFnApi.StateKey, Map<StateCacheKey, BeamFnApi.StateGetResponse>> loadingCache, List<BeamFnApi.ProcessBundleRequest.CacheToken> list) {
        this.beamFnStateClient = beamFnStateClient;
        this.stateCache = loadingCache;
        ByteString byteString = ByteString.EMPTY;
        for (BeamFnApi.ProcessBundleRequest.CacheToken cacheToken : list) {
            if (cacheToken.hasUserState()) {
                byteString = cacheToken.getToken();
            } else if (cacheToken.hasSideInput()) {
                this.sideInputCacheTokens.put(cacheToken.getSideInput(), cacheToken.getToken());
            }
        }
        this.userStateToken = byteString;
    }

    @Override // org.apache.beam.fn.harness.state.BeamFnStateClient
    public void handle(BeamFnApi.StateRequest.Builder builder, CompletableFuture<BeamFnApi.StateResponse> completableFuture) {
        BeamFnApi.StateKey stateKey = builder.getStateKey();
        ByteString cacheToken = getCacheToken(stateKey);
        if (ByteString.EMPTY.equals(cacheToken)) {
            this.beamFnStateClient.handle(builder, completableFuture);
            return;
        }
        switch (builder.getRequestCase()) {
            case GET:
                StateCacheKey create = StateCacheKey.create(cacheToken, builder.getGet().getContinuationToken());
                BeamFnApi.StateGetResponse stateGetResponse = this.stateCache.getUnchecked(stateKey).get(create);
                if (stateGetResponse != null) {
                    completableFuture.complete(BeamFnApi.StateResponse.newBuilder().setId(builder.getId()).setGet(stateGetResponse).build());
                    return;
                } else {
                    completableFuture.thenAccept(stateResponse -> {
                        this.stateCache.getUnchecked(stateKey).put(create, stateResponse.getGet());
                    });
                    this.beamFnStateClient.handle(builder, completableFuture);
                    return;
                }
            case APPEND:
                this.beamFnStateClient.handle(builder, completableFuture);
                this.stateCache.getUnchecked(stateKey).entrySet().removeIf(entry -> {
                    return ((BeamFnApi.StateGetResponse) entry.getValue()).getContinuationToken().equals(ByteString.EMPTY);
                });
                return;
            case CLEAR:
                this.beamFnStateClient.handle(builder, completableFuture);
                HashMap hashMap = new HashMap();
                hashMap.put(StateCacheKey.create(cacheToken, ByteString.EMPTY), BeamFnApi.StateGetResponse.getDefaultInstance());
                this.stateCache.put(stateKey, hashMap);
                return;
            default:
                throw new IllegalStateException(String.format("Unknown request type %s", builder.getRequestCase()));
        }
    }

    private ByteString getCacheToken(BeamFnApi.StateKey stateKey) {
        if (stateKey.hasBagUserState()) {
            return this.userStateToken;
        }
        if (stateKey.hasRunner()) {
            return ByteString.EMPTY;
        }
        BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.Builder newBuilder = BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder();
        if (stateKey.hasIterableSideInput()) {
            BeamFnApi.StateKey.IterableSideInput iterableSideInput = stateKey.getIterableSideInput();
            newBuilder.setTransformId(iterableSideInput.getTransformId()).setSideInputId(iterableSideInput.getSideInputId());
        } else if (stateKey.hasMultimapSideInput()) {
            BeamFnApi.StateKey.MultimapSideInput multimapSideInput = stateKey.getMultimapSideInput();
            newBuilder.setTransformId(multimapSideInput.getTransformId()).setSideInputId(multimapSideInput.getSideInputId());
        } else if (stateKey.hasMultimapKeysSideInput()) {
            BeamFnApi.StateKey.MultimapKeysSideInput multimapKeysSideInput = stateKey.getMultimapKeysSideInput();
            newBuilder.setTransformId(multimapKeysSideInput.getTransformId()).setSideInputId(multimapKeysSideInput.getSideInputId());
        }
        return this.sideInputCacheTokens.getOrDefault(newBuilder.build(), ByteString.EMPTY);
    }
}
