package org.apache.beam.sdk.transforms;

import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.io.range.OffsetRangeTracker;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/transforms/GroupIntoBatches.class */
public class GroupIntoBatches<K, InputT> extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> {
    private final long batchSize;

    @Nullable
    private final Duration maxBufferingDuration;

    /* JADX INFO: Access modifiers changed from: private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/sdk/transforms/GroupIntoBatches$GroupIntoBatchesDoFn.class */
    public static class GroupIntoBatchesDoFn<K, InputT> extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> {
        private static final Logger LOG = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class);
        private static final String END_OF_WINDOW_ID = "endOFWindow";
        private static final String END_OF_BUFFERING_ID = "endOfBuffering";
        private static final String BATCH_ID = "batch";
        private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch";
        private static final String KEY_ID = "key";
        private final long batchSize;
        private final Duration allowedLateness;
        private final Duration maxBufferingDuration;

        @DoFn.StateId(BATCH_ID)
        private final StateSpec<BagState<InputT>> batchSpec;

        @DoFn.StateId("key")
        private final StateSpec<ValueState<K>> keySpec;
        private final long prefetchFrequency;

        @DoFn.TimerId(END_OF_WINDOW_ID)
        private final TimerSpec windowTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @DoFn.TimerId(END_OF_BUFFERING_ID)
        private final TimerSpec bufferingTimer = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);

        @DoFn.StateId(NUM_ELEMENTS_IN_BATCH_ID)
        private final StateSpec<CombiningState<Long, long[], Long>> numElementsInBatchSpec = StateSpecs.combining(new Combine.BinaryCombineLongFn() { // from class: org.apache.beam.sdk.transforms.GroupIntoBatches.GroupIntoBatchesDoFn.1
            @Override // org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn
            public long identity() {
                return 0L;
            }

            @Override // org.apache.beam.sdk.transforms.Combine.BinaryCombineLongFn
            public long apply(long j, long j2) {
                return j + j2;
            }
        });

        GroupIntoBatchesDoFn(long j, Duration duration, Duration duration2, Coder<K> coder, Coder<InputT> coder2) {
            this.batchSize = j;
            this.allowedLateness = duration;
            this.maxBufferingDuration = duration2;
            this.batchSpec = StateSpecs.bag(coder2);
            this.keySpec = StateSpecs.value(coder);
            this.prefetchFrequency = j / 5 <= 1 ? OffsetRangeTracker.OFFSET_INFINITY : j / 5;
        }

        @DoFn.ProcessElement
        public void processElement(@DoFn.TimerId("endOFWindow") Timer timer, @DoFn.TimerId("endOfBuffering") Timer timer2, @DoFn.StateId("batch") BagState<InputT> bagState, @DoFn.StateId("numElementsInBatch") CombiningState<Long, long[], Long> combiningState, @DoFn.StateId("key") ValueState<K> valueState, @DoFn.Element KV<K, InputT> kv, BoundedWindow boundedWindow, DoFn.OutputReceiver<KV<K, Iterable<InputT>>> outputReceiver) {
            Instant plus = boundedWindow.maxTimestamp().plus(this.allowedLateness);
            LOG.debug("*** SET TIMER *** to point in time {} for window {}", plus, boundedWindow);
            timer.set(plus);
            valueState.write(kv.getKey());
            LOG.debug("*** BATCH *** Add element for window {} ", boundedWindow);
            bagState.add(kv.getValue());
            combiningState.add(1L);
            long longValue = combiningState.read().longValue();
            if (longValue == 1 && this.maxBufferingDuration != null) {
                timer2.offset(this.maxBufferingDuration).setRelative();
            }
            if (longValue % this.prefetchFrequency == 0) {
                bagState.readLater();
            }
            if (longValue >= this.batchSize) {
                LOG.debug("*** END OF BATCH *** for window {}", boundedWindow.toString());
                flushBatch(outputReceiver, valueState, bagState, combiningState, timer2);
            }
        }

        @DoFn.OnTimer(END_OF_BUFFERING_ID)
        public void onBufferingTimer(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> outputReceiver, @DoFn.Timestamp Instant instant, @DoFn.StateId("key") ValueState<K> valueState, @DoFn.StateId("batch") BagState<InputT> bagState, @DoFn.StateId("numElementsInBatch") CombiningState<Long, long[], Long> combiningState, @DoFn.TimerId("endOfBuffering") Timer timer) {
            LOG.debug("*** END OF BUFFERING *** for timer timestamp {} with buffering duration {}", instant, this.maxBufferingDuration);
            flushBatch(outputReceiver, valueState, bagState, combiningState, null);
        }

        @DoFn.OnTimer(END_OF_WINDOW_ID)
        public void onWindowTimer(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> outputReceiver, @DoFn.Timestamp Instant instant, @DoFn.StateId("key") ValueState<K> valueState, @DoFn.StateId("batch") BagState<InputT> bagState, @DoFn.StateId("numElementsInBatch") CombiningState<Long, long[], Long> combiningState, @DoFn.TimerId("endOfBuffering") Timer timer, BoundedWindow boundedWindow) {
            LOG.debug("*** END OF WINDOW *** for timer timestamp {} in windows {}", instant, boundedWindow.toString());
            flushBatch(outputReceiver, valueState, bagState, combiningState, timer);
        }

        private void flushBatch(DoFn.OutputReceiver<KV<K, Iterable<InputT>>> outputReceiver, ValueState<K> valueState, BagState<InputT> bagState, CombiningState<Long, long[], Long> combiningState, @Nullable Timer timer) {
            Iterable<InputT> read = bagState.read();
            if (!Iterables.isEmpty(read)) {
                outputReceiver.output(KV.of(valueState.read(), read));
            }
            bagState.clear();
            LOG.debug("*** BATCH *** clear");
            combiningState.clear();
            if (timer == null || this.maxBufferingDuration == null) {
                return;
            }
            timer.offset(this.maxBufferingDuration).setRelative();
        }
    }

    private GroupIntoBatches(long j, @Nullable Duration duration) {
        this.batchSize = j;
        this.maxBufferingDuration = duration;
    }

    public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long j) {
        return new GroupIntoBatches<>(j, null);
    }

    public long getBatchSize() {
        return this.batchSize;
    }

    public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration duration) {
        Preconditions.checkArgument(duration.isLongerThan(Duration.ZERO), "max buffering duration should be a positive value");
        return new GroupIntoBatches<>(this.batchSize, duration);
    }

    @Override // org.apache.beam.sdk.transforms.PTransform
    public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> pCollection) {
        Duration allowedLateness = pCollection.getWindowingStrategy().getAllowedLateness();
        Preconditions.checkArgument(pCollection.getCoder() instanceof KvCoder, "coder specified in the input PCollection is not a KvCoder");
        KvCoder kvCoder = (KvCoder) pCollection.getCoder();
        return (PCollection) pCollection.apply(ParDo.of(new GroupIntoBatchesDoFn(this.batchSize, allowedLateness, this.maxBufferingDuration, kvCoder.getCoderArguments().get(0), kvCoder.getCoderArguments().get(1))));
    }
}
