package org.apache.beam.runners.flink;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
import org.apache.beam.runners.core.construction.UnconsumedReads;
import org.apache.beam.runners.core.construction.WriteFilesTranslation;
import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.ShardedKeyCoder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.FileBasedSink;
import org.apache.beam.sdk.io.ShardingFunction;
import org.apache.beam.sdk.io.WriteFiles;
import org.apache.beam.sdk.io.WriteFilesResult;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformMatcher;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.ShardedKey;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v20_0.com.google.common.cache.Cache;
import org.apache.beam.vendor.guava.v20_0.com.google.common.cache.CacheBuilder;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.class */
public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) FlinkStreamingPipelineTranslator.class);
    private final FlinkStreamingTranslationContext streamingContext;
    private int depth = 0;

    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator$FlinkAutoBalancedShardKeyShardingFunction.class */
    static class FlinkAutoBalancedShardKeyShardingFunction<UserT, DestinationT> implements ShardingFunction<UserT, DestinationT> {

        @VisibleForTesting
        static final int CACHE_MAX_SIZE = 100;
        private static final long CACHE_EXPIRE_SECONDS = 600;
        private final int parallelism;
        private final int maxParallelism;
        private final Coder<DestinationT> destinationCoder;
        private transient Cache<Integer, Map<Integer, ShardedKey<Integer>>> cache;
        private final ShardedKeyCoder<Integer> shardedKeyCoder = ShardedKeyCoder.of(VarIntCoder.of());
        private int shardNumber = -1;

        @VisibleForTesting
        Map<Integer, Map<Integer, ShardedKey<Integer>>> getCache() {
            if (this.cache == null) {
                return null;
            }
            return this.cache.asMap();
        }

        @VisibleForTesting
        int getMaxParallelism() {
            return this.maxParallelism;
        }

        FlinkAutoBalancedShardKeyShardingFunction(int i, int i2, Coder<DestinationT> coder) {
            this.parallelism = i;
            this.maxParallelism = i2 > 0 ? i2 : KeyGroupRangeAssignment.computeDefaultMaxParallelism(i);
            this.destinationCoder = coder;
        }

        @Override // org.apache.beam.sdk.io.ShardingFunction
        public ShardedKey<Integer> assignShardKey(DestinationT destinationt, UserT usert, int i) throws Exception {
            if (this.shardNumber == -1) {
                this.shardNumber = ThreadLocalRandom.current().nextInt(i);
            } else {
                this.shardNumber = (this.shardNumber + 1) % i;
            }
            int hashCode = Arrays.hashCode(CoderUtils.encodeToByteArray(this.destinationCoder, destinationt));
            if (this.cache == null) {
                this.cache = CacheBuilder.newBuilder().maximumSize(100L).expireAfterAccess(CACHE_EXPIRE_SECONDS, TimeUnit.SECONDS).build();
            }
            if (this.cache.getIfPresent(Integer.valueOf(hashCode)) == null) {
                this.cache.put(Integer.valueOf(hashCode), generateShardedKeys(hashCode, i));
            }
            return this.cache.getIfPresent(Integer.valueOf(hashCode)).get(Integer.valueOf(this.shardNumber));
        }

        private Map<Integer, ShardedKey<Integer>> generateShardedKeys(int i, int i2) {
            ShardedKey of;
            HashMap hashMap = new HashMap();
            for (int i3 = 0; i3 < i2; i3++) {
                int i4 = -1;
                do {
                    int i5 = i4;
                    i4++;
                    if (i5 == Integer.MAX_VALUE) {
                        throw new RuntimeException("Failed to find sharded key in [ 2147483647 ] iterations");
                    }
                    of = ShardedKey.of(Integer.valueOf(Objects.hash(Integer.valueOf(i), Integer.valueOf(i4))), i3);
                } while (KeyGroupRangeAssignment.assignKeyToParallelOperator(FlinkKeyUtils.encodeKey(of, this.shardedKeyCoder), this.maxParallelism, this.parallelism) != i3 % this.parallelism);
                hashMap.put(Integer.valueOf(i3), of);
            }
            return hashMap;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator$StreamTransformTranslator.class */
    public static abstract class StreamTransformTranslator<T extends PTransform> {
        abstract void translateNode(T t, FlinkStreamingTranslationContext flinkStreamingTranslationContext);

        boolean canTranslate(T t, FlinkStreamingTranslationContext flinkStreamingTranslationContext) {
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator$StreamingShardedWriteFactory.class */
    public static class StreamingShardedWriteFactory<UserT, DestinationT, OutputT> implements PTransformOverrideFactory<PCollection<UserT>, WriteFilesResult<DestinationT>, WriteFiles<UserT, DestinationT, OutputT>> {
        FlinkPipelineOptions options;

        /* JADX INFO: Access modifiers changed from: package-private */
        public static PTransformMatcher writeFilesNeedsOverrides() {
            return appliedPTransform -> {
                if (!PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals(PTransformTranslation.urnForTransformOrNull((PTransform<?, ?>) appliedPTransform.getTransform()))) {
                    return false;
                }
                try {
                    FlinkPipelineOptions flinkPipelineOptions = (FlinkPipelineOptions) appliedPTransform.getPipeline().getOptions().as(FlinkPipelineOptions.class);
                    ShardingFunction<UserT, DestinationT> shardingFunction = ((WriteFiles) appliedPTransform.getTransform()).getShardingFunction();
                    if (!WriteFilesTranslation.isRunnerDeterminedSharding(appliedPTransform)) {
                        if (!flinkPipelineOptions.isAutoBalanceWriteFilesShardingEnabled().booleanValue() || shardingFunction != null) {
                            return false;
                        }
                    }
                    return true;
                } catch (IOException e) {
                    throw new RuntimeException(String.format("Transform with URN %s failed to parse: %s", PTransformTranslation.WRITE_FILES_TRANSFORM_URN, appliedPTransform.getTransform()), e);
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public StreamingShardedWriteFactory(PipelineOptions pipelineOptions) {
            this.options = (FlinkPipelineOptions) pipelineOptions.as(FlinkPipelineOptions.class);
        }

        @Override // org.apache.beam.sdk.runners.PTransformOverrideFactory
        public PTransformOverrideFactory.PTransformReplacement<PCollection<UserT>, WriteFilesResult<DestinationT>> getReplacementTransform(AppliedPTransform<PCollection<UserT>, WriteFilesResult<DestinationT>, WriteFiles<UserT, DestinationT, OutputT>> appliedPTransform) {
            Integer parallelism = this.options.getParallelism();
            Preconditions.checkArgument(parallelism.intValue() > 0, "Parallelism of a job should be greater than 0. Currently set: %s", parallelism);
            int intValue = parallelism.intValue() * 2;
            try {
                List<PCollectionView<?>> dynamicDestinationSideInputs = WriteFilesTranslation.getDynamicDestinationSideInputs(appliedPTransform);
                FileBasedSink sink = WriteFilesTranslation.getSink(appliedPTransform);
                WriteFiles<UserT, DestinationT, OutputT> withSideInputs = WriteFiles.to(sink).withSideInputs(dynamicDestinationSideInputs);
                if (WriteFilesTranslation.isWindowedWrites(appliedPTransform)) {
                    withSideInputs = withSideInputs.withWindowedWrites();
                }
                if (WriteFilesTranslation.isRunnerDeterminedSharding(appliedPTransform)) {
                    withSideInputs = withSideInputs.withNumShards(intValue);
                } else {
                    if (appliedPTransform.getTransform().getNumShardsProvider() != null) {
                        withSideInputs = withSideInputs.withNumShards(appliedPTransform.getTransform().getNumShardsProvider());
                    }
                    if (appliedPTransform.getTransform().getComputeNumShards() != null) {
                        withSideInputs = withSideInputs.withSharding(appliedPTransform.getTransform().getComputeNumShards());
                    }
                }
                if (this.options.isAutoBalanceWriteFilesShardingEnabled().booleanValue()) {
                    withSideInputs = withSideInputs.withShardingFunction(new FlinkAutoBalancedShardKeyShardingFunction(parallelism.intValue(), this.options.getMaxParallelism().intValue(), sink.getDynamicDestinations().getDestinationCoder()));
                }
                return PTransformOverrideFactory.PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(appliedPTransform), withSideInputs);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public Map<PValue, PTransformOverrideFactory.ReplacementOutput> mapOutputs(Map<TupleTag<?>, PValue> map, WriteFilesResult<DestinationT> writeFilesResult) {
            return ReplacementOutputs.tagged(map, writeFilesResult);
        }

        @Override // org.apache.beam.sdk.runners.PTransformOverrideFactory
        public /* bridge */ /* synthetic */ Map mapOutputs(Map map, POutput pOutput) {
            return mapOutputs((Map<TupleTag<?>, PValue>) map, (WriteFilesResult) pOutput);
        }
    }

    public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment streamExecutionEnvironment, PipelineOptions pipelineOptions) {
        this.streamingContext = new FlinkStreamingTranslationContext(streamExecutionEnvironment, pipelineOptions);
    }

    @Override // org.apache.beam.runners.flink.FlinkPipelineTranslator
    public void translate(Pipeline pipeline) {
        UnconsumedReads.ensureAllReadsConsumed(pipeline);
        super.translate(pipeline);
    }

    @Override // org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults, org.apache.beam.sdk.Pipeline.PipelineVisitor
    public Pipeline.PipelineVisitor.CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
        StreamTransformTranslator<?> translator;
        LOG.info("{} enterCompositeTransform- {}", genSpaces(this.depth), node.getFullName());
        this.depth++;
        PTransform<?, ?> transform = node.getTransform();
        if (transform == null || (translator = FlinkStreamingTransformTranslators.getTranslator(transform)) == null || !applyCanTranslate(transform, node, translator)) {
            return Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
        }
        applyStreamingTransform(transform, node, translator);
        LOG.info("{} translated- {}", genSpaces(this.depth), node.getFullName());
        return Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
    }

    @Override // org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults, org.apache.beam.sdk.Pipeline.PipelineVisitor
    public void leaveCompositeTransform(TransformHierarchy.Node node) {
        this.depth--;
        LOG.info("{} leaveCompositeTransform- {}", genSpaces(this.depth), node.getFullName());
    }

    @Override // org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults, org.apache.beam.sdk.Pipeline.PipelineVisitor
    public void visitPrimitiveTransform(TransformHierarchy.Node node) {
        LOG.info("{} visitPrimitiveTransform- {}", genSpaces(this.depth), node.getFullName());
        PTransform<?, ?> transform = node.getTransform();
        StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform);
        if (translator != null && applyCanTranslate(transform, node, translator)) {
            applyStreamingTransform(transform, node, translator);
        } else {
            String urnForTransform = PTransformTranslation.urnForTransform(transform);
            LOG.info(urnForTransform);
            throw new UnsupportedOperationException("The transform " + urnForTransform + " is currently not supported.");
        }
    }

    @Override // org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults, org.apache.beam.sdk.Pipeline.PipelineVisitor
    public void visitValue(PValue pValue, TransformHierarchy.Node node) {
    }

    private <T extends PTransform<?, ?>> void applyStreamingTransform(PTransform<?, ?> pTransform, TransformHierarchy.Node node, StreamTransformTranslator<?> streamTransformTranslator) {
        this.streamingContext.setCurrentTransform(node.toAppliedPTransform(getPipeline()));
        streamTransformTranslator.translateNode(pTransform, this.streamingContext);
    }

    private <T extends PTransform<?, ?>> boolean applyCanTranslate(PTransform<?, ?> pTransform, TransformHierarchy.Node node, StreamTransformTranslator<?> streamTransformTranslator) {
        this.streamingContext.setCurrentTransform(node.toAppliedPTransform(getPipeline()));
        return streamTransformTranslator.canTranslate(pTransform, this.streamingContext);
    }
}
