package org.apache.paimon.flink.sink;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer;
import org.apache.flink.api.java.tuple.Tuple5;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.utils.SerializationUtils;

/* loaded from: input_file:org/apache/paimon/flink/sink/StoreSinkWriteState.class */
public class StoreSinkWriteState {
    private final StateValueFilter stateValueFilter;
    private final ListState<Tuple5<String, String, byte[], Integer, byte[]>> listState;
    private final Map<String, Map<String, List<StateValue>>> map = new HashMap();

    /* loaded from: input_file:org/apache/paimon/flink/sink/StoreSinkWriteState$StateValue.class */
    public static class StateValue {
        private final BinaryRow partition;
        private final int bucket;
        private final byte[] value;

        public StateValue(BinaryRow binaryRow, int i, byte[] bArr) {
            this.partition = binaryRow;
            this.bucket = i;
            this.value = bArr;
        }

        public BinaryRow partition() {
            return this.partition;
        }

        public int bucket() {
            return this.bucket;
        }

        public byte[] value() {
            return this.value;
        }
    }

    /* loaded from: input_file:org/apache/paimon/flink/sink/StoreSinkWriteState$StateValueFilter.class */
    public interface StateValueFilter {
        boolean filter(String str, BinaryRow binaryRow, int i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StoreSinkWriteState(StateInitializationContext stateInitializationContext, StateValueFilter stateValueFilter) throws Exception {
        this.stateValueFilter = stateValueFilter;
        this.listState = stateInitializationContext.getOperatorStateStore().getUnionListState(new ListStateDescriptor("paimon_store_sink_write_state", new TupleSerializer(Tuple5.class, new TypeSerializer[]{StringSerializer.INSTANCE, StringSerializer.INSTANCE, BytePrimitiveArraySerializer.INSTANCE, IntSerializer.INSTANCE, BytePrimitiveArraySerializer.INSTANCE})));
        for (Tuple5 tuple5 : (Iterable) this.listState.get()) {
            BinaryRow deserializeBinaryRow = SerializationUtils.deserializeBinaryRow((byte[]) tuple5.f2);
            if (stateValueFilter.filter((String) tuple5.f0, deserializeBinaryRow, ((Integer) tuple5.f3).intValue())) {
                ((List) ((Map) this.map.computeIfAbsent(tuple5.f0, str -> {
                    return new HashMap();
                })).computeIfAbsent(tuple5.f1, str2 -> {
                    return new ArrayList();
                })).add(new StateValue(deserializeBinaryRow, ((Integer) tuple5.f3).intValue(), (byte[]) tuple5.f4));
            }
        }
    }

    public StateValueFilter stateValueFilter() {
        return this.stateValueFilter;
    }

    @Nullable
    public List<StateValue> get(String str, String str2) {
        Map<String, List<StateValue>> map = this.map.get(str);
        if (map == null) {
            return null;
        }
        return map.get(str2);
    }

    public void put(String str, String str2, List<StateValue> list) {
        this.map.computeIfAbsent(str, str3 -> {
            return new HashMap();
        }).put(str2, list);
    }

    public void snapshotState() throws Exception {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, Map<String, List<StateValue>>> entry : this.map.entrySet()) {
            for (Map.Entry<String, List<StateValue>> entry2 : entry.getValue().entrySet()) {
                for (StateValue stateValue : entry2.getValue()) {
                    arrayList.add(Tuple5.of(entry.getKey(), entry2.getKey(), SerializationUtils.serializeBinaryRow(stateValue.partition()), Integer.valueOf(stateValue.bucket()), stateValue.value()));
                }
            }
        }
        this.listState.update(arrayList);
    }
}
