package org.apache.flink.statefun.flink.core.reqreply;

import com.google.protobuf.MoreByteStrings;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.flink.statefun.flink.core.types.remote.RemoteValueTypeMismatchException;
import org.apache.flink.statefun.sdk.TypeName;
import org.apache.flink.statefun.sdk.annotations.Persisted;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.Expiration;
import org.apache.flink.statefun.sdk.state.PersistedStateRegistry;
import org.apache.flink.statefun.sdk.state.RemotePersistedValue;

/* loaded from: input_file:org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.class */
public final class PersistedRemoteFunctionValues {
    private static final TypeName UNSET_STATE_TYPE = TypeName.parseFrom("io.statefun.types/unset");

    @Persisted
    private final PersistedStateRegistry stateRegistry = new PersistedStateRegistry();
    private final Map<String, RemotePersistedValue> managedStates = new HashMap();

    /* loaded from: input_file:org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues$RemoteFunctionStateException.class */
    public static class RemoteFunctionStateException extends RuntimeException {
        private static final long serialVersionUID = 1;

        private RemoteFunctionStateException(String str, Throwable th) {
            super("An error occurred for state [" + str + "].", th);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void attachStateValues(ToFunction.InvocationBatchRequest.Builder builder) {
        for (Map.Entry<String, RemotePersistedValue> entry : this.managedStates.entrySet()) {
            ToFunction.PersistedValue.Builder stateName = ToFunction.PersistedValue.newBuilder().setStateName(entry.getKey());
            RemotePersistedValue value = entry.getValue();
            byte[] bArr = value.get();
            if (bArr != null) {
                stateName.setStateValue(TypedValue.newBuilder().setTypename(value.type().canonicalTypenameString()).setHasValue(true).setValue(MoreByteStrings.wrap(bArr)).m1280build());
            }
            builder.addState(stateName);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateStateValues(List<FromFunction.PersistedValueMutation> list) {
        for (FromFunction.PersistedValueMutation persistedValueMutation : list) {
            String stateName = persistedValueMutation.getStateName();
            switch (persistedValueMutation.getMutationType()) {
                case DELETE:
                    getStateHandleOrThrow(stateName).clear();
                    break;
                case MODIFY:
                    RemotePersistedValue stateHandleOrThrow = getStateHandleOrThrow(stateName);
                    TypedValue stateValue = persistedValueMutation.getStateValue();
                    validateType(stateHandleOrThrow, stateValue.getTypename());
                    stateHandleOrThrow.set(stateValue.getValue().toByteArray());
                    break;
                case UNRECOGNIZED:
                    throw new IllegalStateException("Received an UNRECOGNIZED PersistedValueMutation type. This may be caused by a mismatch or incompatibility with the remote function SDK version and the Stateful Functions version.");
                default:
                    throw new IllegalStateException("Unexpected value: " + persistedValueMutation.getMutationType());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void registerStates(List<FromFunction.PersistedValueSpec> list) {
        list.forEach(this::createAndRegisterValueStateIfAbsent);
    }

    private void createAndRegisterValueStateIfAbsent(FromFunction.PersistedValueSpec persistedValueSpec) {
        RemotePersistedValue remotePersistedValue = this.managedStates.get(persistedValueSpec.getStateName());
        if (remotePersistedValue == null) {
            registerValueState(persistedValueSpec);
        } else {
            validateType(remotePersistedValue, persistedValueSpec.getTypeTypename());
        }
    }

    private void registerValueState(FromFunction.PersistedValueSpec persistedValueSpec) {
        String stateName = persistedValueSpec.getStateName();
        RemotePersistedValue of = RemotePersistedValue.of(stateName, sdkStateType(persistedValueSpec.getTypeTypename()), sdkTtlExpiration(persistedValueSpec.getExpirationSpec()));
        this.managedStates.put(stateName, of);
        try {
            this.stateRegistry.registerRemoteValue(of);
        } catch (RemoteValueTypeMismatchException e) {
            throw new RemoteFunctionStateException(stateName, e);
        }
    }

    private void validateType(RemotePersistedValue remotePersistedValue, String str) {
        TypeName sdkStateType = sdkStateType(str);
        if (!sdkStateType.equals(remotePersistedValue.type())) {
            throw new RemoteFunctionStateException(remotePersistedValue.name(), new RemoteValueTypeMismatchException(remotePersistedValue.type(), sdkStateType));
        }
    }

    private static TypeName sdkStateType(String str) {
        return str.isEmpty() ? UNSET_STATE_TYPE : TypeName.parseFrom(str);
    }

    private static Expiration sdkTtlExpiration(FromFunction.ExpirationSpec expirationSpec) {
        long expireAfterMillis = expirationSpec.getExpireAfterMillis();
        switch (expirationSpec.getMode()) {
            case AFTER_INVOKE:
                return Expiration.expireAfterReadingOrWriting(Duration.ofMillis(expireAfterMillis));
            case AFTER_WRITE:
                return Expiration.expireAfterWriting(Duration.ofMillis(expireAfterMillis));
            case NONE:
            default:
                return Expiration.none();
        }
    }

    private RemotePersistedValue getStateHandleOrThrow(String str) {
        RemotePersistedValue remotePersistedValue = this.managedStates.get(str);
        if (remotePersistedValue == null) {
            throw new IllegalStateException("Accessing a non-existing function state: " + str + ". This can happen if you forgot to declare this state using the language SDKs.");
        }
        return remotePersistedValue;
    }
}
