package org.apache.crunch.lib.join;

import java.io.Serializable;
import java.util.Random;
import org.apache.crunch.DoFn;
import org.apache.crunch.Emitter;
import org.apache.crunch.MapFn;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.types.PTableType;
import org.apache.crunch.types.PTypeFamily;

/* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy.class */
public class ShardedJoinStrategy<K, U, V> implements JoinStrategy<K, U, V> {
    private JoinStrategy<Pair<K, Integer>, U, V> wrappedJoinStrategy;
    private ShardingStrategy<K> shardingStrategy;

    /* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy$ConstantShardingStrategy.class */
    private static class ConstantShardingStrategy<K> implements ShardingStrategy<K> {
        private int numShards;

        public ConstantShardingStrategy(int i) {
            this.numShards = i;
        }

        @Override // org.apache.crunch.lib.join.ShardedJoinStrategy.ShardingStrategy
        public int getNumShards(K k) {
            return this.numShards;
        }
    }

    /* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy$PreShardLeftSideFn.class */
    private static class PreShardLeftSideFn<K, U> extends DoFn<Pair<K, U>, Pair<Pair<K, Integer>, U>> {
        private ShardingStrategy<K> shardingStrategy;

        public PreShardLeftSideFn(ShardingStrategy<K> shardingStrategy) {
            this.shardingStrategy = shardingStrategy;
        }

        @Override // org.apache.crunch.DoFn
        public void process(Pair<K, U> pair, Emitter<Pair<Pair<K, Integer>, U>> emitter) {
            K first = pair.first();
            int numShards = this.shardingStrategy.getNumShards(first);
            if (numShards < 1) {
                throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + first);
            }
            for (int i = 0; i < numShards; i++) {
                emitter.emit(Pair.of(Pair.of(first, Integer.valueOf(i)), pair.second()));
            }
        }
    }

    /* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy$PreShardRightSideFn.class */
    private static class PreShardRightSideFn<K, V> extends MapFn<Pair<K, V>, Pair<Pair<K, Integer>, V>> {
        private ShardingStrategy<K> shardingStrategy;
        private transient Random random;

        public PreShardRightSideFn(ShardingStrategy<K> shardingStrategy) {
            this.shardingStrategy = shardingStrategy;
        }

        @Override // org.apache.crunch.DoFn
        public void initialize() {
            this.random = new Random(getTaskAttemptID().getTaskID().getId());
        }

        @Override // org.apache.crunch.MapFn
        public Pair<Pair<K, Integer>, V> map(Pair<K, V> pair) {
            K first = pair.first();
            V second = pair.second();
            int numShards = this.shardingStrategy.getNumShards(first);
            if (numShards < 1) {
                throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + first);
            }
            return Pair.of(Pair.of(first, Integer.valueOf(this.random.nextInt(numShards))), second);
        }
    }

    /* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy$ShardingStrategy.class */
    public interface ShardingStrategy<K> extends Serializable {
        int getNumShards(K k);
    }

    /* loaded from: input_file:lib/crunch-core-0.8.2.jar:org/apache/crunch/lib/join/ShardedJoinStrategy$UnshardFn.class */
    private static class UnshardFn<K, U, V> extends MapFn<Pair<Pair<K, Integer>, Pair<U, V>>, Pair<K, Pair<U, V>>> {
        private UnshardFn() {
        }

        @Override // org.apache.crunch.MapFn
        public Pair<K, Pair<U, V>> map(Pair<Pair<K, Integer>, Pair<U, V>> pair) {
            return Pair.of(pair.first().first(), pair.second());
        }
    }

    public ShardedJoinStrategy(int i) {
        this(new ConstantShardingStrategy(i));
    }

    public ShardedJoinStrategy(ShardingStrategy<K> shardingStrategy) {
        this.wrappedJoinStrategy = new DefaultJoinStrategy();
        this.shardingStrategy = shardingStrategy;
    }

    @Override // org.apache.crunch.lib.join.JoinStrategy
    public PTable<K, Pair<U, V>> join(PTable<K, U> pTable, PTable<K, V> pTable2, JoinType joinType) {
        if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.LEFT_OUTER_JOIN) {
            throw new UnsupportedOperationException("Join type " + joinType + " not supported by ShardedJoinStrategy");
        }
        PTypeFamily typeFamily = pTable.getTypeFamily();
        return this.wrappedJoinStrategy.join(pTable.parallelDo("Pre-shard left", (DoFn<S, Pair<K, U>>) new PreShardLeftSideFn(this.shardingStrategy), typeFamily.tableOf(typeFamily.pairs(pTable.getKeyType(), typeFamily.ints()), pTable.getValueType())), pTable2.parallelDo("Pre-shard right", (DoFn) new PreShardRightSideFn(this.shardingStrategy), (PTableType) typeFamily.tableOf(typeFamily.pairs(pTable2.getKeyType(), typeFamily.ints()), pTable2.getValueType())), joinType).parallelDo("Unshard", (DoFn<S, Pair<Pair<K, Integer>, Pair<U, V>>>) new UnshardFn(), (PTableType<Pair<K, Integer>, Pair<U, V>>) typeFamily.tableOf(pTable.getKeyType(), typeFamily.pairs(pTable.getValueType(), pTable2.getValueType())));
    }
}
