package org.apache.beam.runners.direct;

import java.util.concurrent.ThreadLocalRandom;
import org.apache.beam.runners.direct.repackaged.com.google.common.base.Preconditions;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.Partition;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PDone;

/* loaded from: input_file:org/apache/beam/runners/direct/ShardControlledWrite.class */
abstract class ShardControlledWrite<InputT> extends ForwardingPTransform<PCollection<InputT>, PDone> {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/direct/ShardControlledWrite$RandomSeedPartitionFn.class */
    public static class RandomSeedPartitionFn<T> implements Partition.PartitionFn<T> {
        int nextPartition;

        private RandomSeedPartitionFn() {
            this.nextPartition = -1;
        }

        public int partitionFor(T t, int i) {
            if (this.nextPartition < 0) {
                this.nextPartition = ThreadLocalRandom.current().nextInt(i);
            }
            this.nextPartition++;
            this.nextPartition %= i;
            return this.nextPartition;
        }
    }

    @Override // org.apache.beam.runners.direct.ForwardingPTransform
    public PDone apply(PCollection<InputT> pCollection) {
        int numShards = getNumShards();
        Preconditions.checkArgument(numShards >= 1, "%s should only be applied if the output has a controlled number of shards (> 1); got %s", getClass().getSimpleName(), Integer.valueOf(getNumShards()));
        PCollectionList apply = pCollection.apply(new StringBuilder(30).append("PartitionInto").append(numShards).append("Shards").toString(), Partition.of(getNumShards(), new RandomSeedPartitionFn()));
        for (int i = 0; i < apply.size(); i++) {
            PCollection pCollection2 = apply.get(i);
            PTransform<? super PCollection<InputT>, PDone> singleShardTransform = getSingleShardTransform(i);
            pCollection2.apply(String.format("%s(Shard:%s)", singleShardTransform.getName(), Integer.valueOf(i)), singleShardTransform);
        }
        return PDone.in(pCollection.getPipeline());
    }

    abstract int getNumShards();

    abstract PTransform<? super PCollection<InputT>, PDone> getSingleShardTransform(int i);
}
