package org.apache.beam.sdk.extensions.ml;

import com.google.privacy.dlp.v2.Table;
import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Experimental
/* loaded from: input_file:org/apache/beam/sdk/extensions/ml/BatchRequestForDLP.class */
class BatchRequestForDLP extends DoFn<KV<String, Table.Row>, KV<String, Iterable<Table.Row>>> {
    public static final Logger LOG = LoggerFactory.getLogger(BatchRequestForDLP.class);
    private final Integer batchSizeBytes;
    private final Counter numberOfRowsBagged = Metrics.counter(BatchRequestForDLP.class, "numberOfRowsBagged");

    @DoFn.StateId("elementsBag")
    private final StateSpec<BagState<KV<String, Table.Row>>> elementsBag = StateSpecs.bag();

    @DoFn.TimerId("eventTimer")
    private final TimerSpec eventTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);

    public BatchRequestForDLP(Integer num) {
        this.batchSizeBytes = num;
    }

    @DoFn.ProcessElement
    public void process(@DoFn.Element KV<String, Table.Row> kv, @DoFn.StateId("elementsBag") BagState<KV<String, Table.Row>> bagState, @DoFn.TimerId("eventTimer") Timer timer, BoundedWindow boundedWindow) {
        bagState.add(kv);
        timer.set(boundedWindow.maxTimestamp());
    }

    @DoFn.OnTimer("eventTimer")
    public void onTimer(@DoFn.StateId("elementsBag") BagState<KV<String, Table.Row>> bagState, DoFn.OutputReceiver<KV<String, Iterable<Table.Row>>> outputReceiver) {
        if (bagState.read().iterator().hasNext()) {
            String str = (String) ((KV) bagState.read().iterator().next()).getKey();
            AtomicInteger atomicInteger = new AtomicInteger();
            ArrayList arrayList = new ArrayList();
            bagState.read().forEach(kv -> {
                if (atomicInteger.intValue() + ((Table.Row) kv.getValue()).getSerializedSize() > this.batchSizeBytes.intValue()) {
                    LOG.debug("Clear buffer of {} bytes, Key {}", Integer.valueOf(atomicInteger.intValue()), kv.getKey());
                    this.numberOfRowsBagged.inc(arrayList.size());
                    outputReceiver.output(KV.of((String) kv.getKey(), arrayList));
                    arrayList.clear();
                    atomicInteger.set(0);
                }
                arrayList.add((Table.Row) kv.getValue());
                atomicInteger.getAndAdd(((Table.Row) kv.getValue()).getSerializedSize());
            });
            if (arrayList.isEmpty()) {
                return;
            }
            LOG.debug("Outputting remaining {} rows.", Integer.valueOf(arrayList.size()));
            this.numberOfRowsBagged.inc(arrayList.size());
            outputReceiver.output(KV.of(str, arrayList));
        }
    }
}
