package org.apache.beam.sdk.schemas.transforms;

import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.bytebuddy.utility.JavaConstant;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.SchemaUtils;
import org.apache.beam.sdk.schemas.transforms.AutoValue_CoGroup_By;
import org.apache.beam.sdk.schemas.utils.RowSelector;
import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;

/* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup.class */
public class CoGroup {
    private static final List<Row> NULL_LIST = Lists.newArrayList();

    @AutoValue
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$By.class */
    public static abstract class By implements Serializable {

        /* JADX INFO: Access modifiers changed from: package-private */
        @AutoValue.Builder
        /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$By$Builder.class */
        public static abstract class Builder {
            abstract Builder setFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor);

            abstract Builder setOptionalParticipation(boolean z);

            abstract Builder setSideInput(boolean z);

            abstract By build();
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract FieldAccessDescriptor getFieldAccessDescriptor();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract boolean getOptionalParticipation();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract boolean getSideInput();

        abstract Builder toBuilder();

        public static By fieldNames(String... strArr) {
            return fieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(strArr));
        }

        public static By fieldIds(Integer... numArr) {
            return fieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(numArr));
        }

        public static By fieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
            return new AutoValue_CoGroup_By.Builder().setFieldAccessDescriptor(fieldAccessDescriptor).setOptionalParticipation(false).setSideInput(false).build();
        }

        public By withOptionalParticipation() {
            return toBuilder().setOptionalParticipation(true).build();
        }

        public By withSideInput() {
            return toBuilder().setSideInput(true).build();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$ConvertCoGbkResult.class */
    public static class ConvertCoGbkResult extends DoFn<KV<Row, CoGbkResult>, Row> {
        private final JoinInformation joinInformation;
        private final JoinArguments joinArgs;
        private final Schema outputSchema;
        private ConvertType convertType;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$ConvertCoGbkResult$ConvertType.class */
        public enum ConvertType {
            UNEXPANDED,
            EXPANDED
        }

        public ConvertCoGbkResult(JoinInformation joinInformation, JoinArguments joinArguments, ConvertType convertType, Schema schema) {
            this.joinInformation = joinInformation;
            this.joinArgs = joinArguments;
            this.outputSchema = schema;
            this.convertType = convertType;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element KV<Row, CoGbkResult> kv, DoFn<KV<Row, CoGbkResult>, Row>.ProcessContext processContext, DoFn.OutputReceiver<Row> outputReceiver) {
            Result from = Result.from(this.joinInformation, this.joinArgs, kv.getKey(), this.outputSchema, kv.getValue(), processContext);
            if (this.convertType == ConvertType.UNEXPANDED) {
                from.outputUnexpandedRow(this.outputSchema, outputReceiver);
            } else {
                from.outputExpandedRows(outputReceiver);
            }
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$ExpandCrossProduct.class */
    public static class ExpandCrossProduct extends PTransform<PCollectionTuple, PCollection<Row>> {
        private final JoinArguments joinArgs;

        ExpandCrossProduct(JoinArguments joinArguments) {
            this.joinArgs = joinArguments;
        }

        public ExpandCrossProduct join(String str, By by) {
            if (this.joinArgs.allInputsJoinArgs != null) {
                throw new IllegalStateException("Cannot set both a global and per-tag fields.");
            }
            return new ExpandCrossProduct(this.joinArgs.with(str, by));
        }

        @Override // org.apache.beam.sdk.transforms.PTransform
        public PCollection<Row> expand(PCollectionTuple pCollectionTuple) {
            CoGroup.verify(pCollectionTuple, this.joinArgs);
            JoinArguments joinArguments = this.joinArgs;
            Objects.requireNonNull(joinArguments);
            Function function = str -> {
                return joinArguments.getFieldAccessDescriptor(str);
            };
            JoinArguments joinArguments2 = this.joinArgs;
            Objects.requireNonNull(joinArguments2);
            JoinInformation from = JoinInformation.from(pCollectionTuple, function, str2 -> {
                return Boolean.valueOf(joinArguments2.getSideInputSource(str2));
            });
            Result.verifyExpandedArgs(from, this.joinArgs);
            Schema expandedOutputSchema = Result.getExpandedOutputSchema(from, this.joinArgs);
            Collection values = from.sideInputs.values();
            return (from.keyedPCollectionTuple.getKeyedCollections().size() > 1 ? (PCollection) ((PCollection) from.keyedPCollectionTuple.apply("CoGroupByKey", CoGroupByKey.create())).apply(ParDo.of(new ConvertCoGbkResult(from, this.joinArgs, ConvertCoGbkResult.ConvertType.EXPANDED, expandedOutputSchema)).withSideInputs(values)) : (PCollection) ((KeyedPCollectionTuple.TaggedKeyedPCollection) Iterables.getOnlyElement(from.keyedPCollectionTuple.getKeyedCollections())).getCollection().apply(ParDo.of(new ExpandRowResult(from, this.joinArgs, expandedOutputSchema)).withSideInputs(values))).setRowSchema(expandedOutputSchema);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$ExpandRowResult.class */
    public static class ExpandRowResult extends DoFn<KV<Row, Row>, Row> {
        private final JoinInformation joinInformation;
        private final JoinArguments joinArgs;
        private final Schema outputSchema;

        public ExpandRowResult(JoinInformation joinInformation, JoinArguments joinArguments, Schema schema) {
            this.joinInformation = joinInformation;
            this.joinArgs = joinArguments;
            this.outputSchema = schema;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element KV<Row, Row> kv, DoFn<KV<Row, Row>, Row>.ProcessContext processContext, DoFn.OutputReceiver<Row> outputReceiver) {
            Result.from(this.joinInformation, this.joinArgs, kv.getKey(), this.outputSchema, kv.getValue(), processContext).outputExpandedRows(outputReceiver);
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$Impl.class */
    public static class Impl extends PTransform<PCollectionTuple, PCollection<Row>> {
        private final JoinArguments joinArgs;
        private final String keyFieldName;

        private Impl() {
            this(new JoinArguments((Map<String, By>) Collections.emptyMap()));
        }

        private Impl(JoinArguments joinArguments) {
            this(joinArguments, "key");
        }

        private Impl(JoinArguments joinArguments, String str) {
            this.joinArgs = joinArguments;
            this.keyFieldName = str;
        }

        public Impl withKeyField(String str) {
            return new Impl(this.joinArgs, str);
        }

        public Impl join(String str, By by) {
            if (this.joinArgs.allInputsJoinArgs != null) {
                throw new IllegalStateException("Cannot set both a global and per-tag fields.");
            }
            return new Impl(this.joinArgs.with(str, by), this.keyFieldName);
        }

        public ExpandCrossProduct crossProductJoin() {
            return new ExpandCrossProduct(this.joinArgs);
        }

        @Override // org.apache.beam.sdk.transforms.PTransform
        public PCollection<Row> expand(PCollectionTuple pCollectionTuple) {
            CoGroup.verify(pCollectionTuple, this.joinArgs);
            JoinArguments joinArguments = this.joinArgs;
            Objects.requireNonNull(joinArguments);
            Function function = str -> {
                return joinArguments.getFieldAccessDescriptor(str);
            };
            JoinArguments joinArguments2 = this.joinArgs;
            Objects.requireNonNull(joinArguments2);
            JoinInformation from = JoinInformation.from(pCollectionTuple, function, str2 -> {
                return Boolean.valueOf(joinArguments2.getSideInputSource(str2));
            });
            Collection values = from.sideInputs.values();
            Schema unexandedOutputSchema = Result.getUnexandedOutputSchema(this.keyFieldName, from);
            return ((PCollection) ((PCollection) from.keyedPCollectionTuple.apply("CoGroupByKey", CoGroupByKey.create())).apply(ParDo.of(new ConvertCoGbkResult(from, this.joinArgs, ConvertCoGbkResult.ConvertType.UNEXPANDED, unexandedOutputSchema)).withSideInputs(values))).setRowSchema(unexandedOutputSchema);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$JoinArguments.class */
    public static class JoinArguments implements Serializable {
        private final By allInputsJoinArgs;
        private final Map<String, By> joinArgsMap;

        JoinArguments(By by) {
            this.allInputsJoinArgs = by;
            this.joinArgsMap = Collections.emptyMap();
        }

        JoinArguments(Map<String, By> map) {
            this.allInputsJoinArgs = null;
            this.joinArgsMap = map;
        }

        JoinArguments with(String str, By by) {
            return new JoinArguments(new ImmutableMap.Builder().putAll(this.joinArgsMap).put(str, by).build());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public FieldAccessDescriptor getFieldAccessDescriptor(String str) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getFieldAccessDescriptor() : this.joinArgsMap.get(str).getFieldAccessDescriptor();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean getOptionalParticipation(String str) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getOptionalParticipation() : this.joinArgsMap.get(str).getOptionalParticipation();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean getSideInputSource(String str) {
            return this.allInputsJoinArgs != null ? this.allInputsJoinArgs.getSideInput() : this.joinArgsMap.get(str).getSideInput();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$JoinInformation.class */
    public static class JoinInformation implements Serializable {
        private final transient KeyedPCollectionTuple<Row> keyedPCollectionTuple;
        private final Map<String, PCollectionView<Map<Row, Iterable<Row>>>> sideInputs;
        private final Schema keySchema;
        private final Map<String, Schema> componentSchemas;
        private final List<String> sortedTags;
        private final Map<Integer, String> tagToKeyedTag;

        private JoinInformation(KeyedPCollectionTuple<Row> keyedPCollectionTuple, Map<String, PCollectionView<Map<Row, Iterable<Row>>>> map, Schema schema, Map<String, Schema> map2, List<String> list, Map<Integer, String> map3) {
            this.keyedPCollectionTuple = keyedPCollectionTuple;
            this.sideInputs = map;
            this.keySchema = schema;
            this.componentSchemas = map2;
            this.sortedTags = list;
            this.tagToKeyedTag = map3;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static JoinInformation from(PCollectionTuple pCollectionTuple, Function<String, FieldAccessDescriptor> function, Function<String, Boolean> function2) {
            KeyedPCollectionTuple empty = KeyedPCollectionTuple.empty(pCollectionTuple.getPipeline());
            List list = (List) pCollectionTuple.getAll().keySet().stream().map((v0) -> {
                return v0.getId();
            }).sorted().collect(Collectors.toList());
            TreeMap newTreeMap = Maps.newTreeMap();
            HashMap newHashMap = Maps.newHashMap();
            HashMap newHashMap2 = Maps.newHashMap();
            Schema schema = null;
            for (Map.Entry<TupleTag<?>, PCollection<?>> entry : pCollectionTuple.getAll().entrySet()) {
                String id = entry.getKey().getId();
                Schema schema2 = entry.getValue().getSchema();
                newTreeMap.put(id, schema2);
                FieldAccessDescriptor apply = function.apply(id);
                if (apply == null) {
                    throw new IllegalStateException("No fields were set for input " + id);
                }
                Schema outputSchema = SelectHelpers.getOutputSchema(schema2, apply.resolve(schema2));
                schema = schema == null ? outputSchema : SchemaUtils.mergeWideningNullable(schema, outputSchema);
            }
            for (Map.Entry<TupleTag<?>, PCollection<?>> entry2 : pCollectionTuple.getAll().entrySet()) {
                String id2 = entry2.getKey().getId();
                int indexOf = list.indexOf(id2);
                PCollection<?> value = entry2.getValue();
                Schema schema3 = value.getSchema();
                FieldAccessDescriptor resolve = function.apply(id2).resolve(schema3);
                String str = id2 + JavaConstant.Dynamic.DEFAULT_NAME + new TupleTag();
                newHashMap2.put(Integer.valueOf(indexOf), str);
                PCollection<KV<Row, Row>> extractKey = extractKey(value, schema3, schema, resolve, id2);
                if (function2.apply(id2).booleanValue()) {
                    newHashMap.put(str, (PCollectionView) extractKey.apply("computeSideInputView" + id2, View.asMultimap()));
                } else {
                    empty = empty.and(str, extractKey);
                }
            }
            return new JoinInformation(empty, newHashMap, schema, newTreeMap, list, newHashMap2);
        }

        private static <T> PCollection<KV<Row, Row>> extractKey(PCollection<T> pCollection, final Schema schema, Schema schema2, final FieldAccessDescriptor fieldAccessDescriptor, String str) {
            return ((PCollection) pCollection.apply("extractKey" + str, ParDo.of(new DoFn<T, KV<Row, Row>>() { // from class: org.apache.beam.sdk.schemas.transforms.CoGroup.JoinInformation.1
                private RowSelector rowSelector;

                {
                    this.rowSelector = new SelectHelpers.RowSelectorContainer(Schema.this, fieldAccessDescriptor, true);
                }

                @DoFn.ProcessElement
                public void process(@DoFn.Element Row row, DoFn.OutputReceiver<KV<Row, Row>> outputReceiver) {
                    outputReceiver.output(KV.of(this.rowSelector.select(row), row));
                }
            }))).setCoder(KvCoder.of(SchemaCoder.of(schema2), SchemaCoder.of(schema)));
        }
    }

    @AutoValue
    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$Result.class */
    public static abstract class Result {
        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract Row getKey();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract List<Iterable<Row>> getIterables();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract List<String> getTags();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract JoinArguments getJoinArguments();

        /* JADX INFO: Access modifiers changed from: package-private */
        public abstract Schema getOutputSchema();

        static Result from(JoinInformation joinInformation, JoinArguments joinArguments, Row row, Schema schema, CoGbkResult coGbkResult, DoFn<?, Row>.ProcessContext processContext) {
            Objects.requireNonNull(coGbkResult);
            return from(joinInformation, joinArguments, row, schema, (Function<String, Iterable<Row>>) coGbkResult::getAll, processContext);
        }

        static Result from(JoinInformation joinInformation, JoinArguments joinArguments, Row row, Schema schema, Row row2, DoFn<?, Row>.ProcessContext processContext) {
            return from(joinInformation, joinArguments, row, schema, (Function<String, Iterable<Row>>) str -> {
                return Lists.newArrayList(row2);
            }, processContext);
        }

        private static Result from(JoinInformation joinInformation, JoinArguments joinArguments, Row row, Schema schema, Function<String, Iterable<Row>> function, DoFn<?, Row>.ProcessContext processContext) {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(joinInformation.sortedTags.size());
            ArrayList newArrayListWithCapacity2 = Lists.newArrayListWithCapacity(joinInformation.sortedTags.size());
            for (int i = 0; i < joinInformation.sortedTags.size(); i++) {
                String str = (String) joinInformation.tagToKeyedTag.get(Integer.valueOf(i));
                PCollectionView<T> pCollectionView = (PCollectionView) joinInformation.sideInputs.get(str);
                Iterable<Row> apply = pCollectionView != 0 ? (Iterable) ((Map) processContext.sideInput(pCollectionView)).get(row) : function.apply(str);
                if (apply == null) {
                    apply = Collections::emptyIterator;
                }
                newArrayListWithCapacity.add(apply);
                newArrayListWithCapacity2.add((String) joinInformation.sortedTags.get(i));
            }
            return new AutoValue_CoGroup_Result(row, newArrayListWithCapacity, newArrayListWithCapacity2, joinArguments, schema);
        }

        static Schema getUnexandedOutputSchema(String str, JoinInformation joinInformation) {
            Schema.Builder addRowField = Schema.builder().addRowField(str, joinInformation.keySchema);
            for (Map.Entry entry : joinInformation.componentSchemas.entrySet()) {
                addRowField.addIterableField((String) entry.getKey(), Schema.FieldType.row((Schema) entry.getValue()));
            }
            return addRowField.build();
        }

        void outputUnexpandedRow(Schema schema, DoFn.OutputReceiver<Row> outputReceiver) {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(getIterables().size() + 1);
            newArrayListWithCapacity.add(getKey());
            newArrayListWithCapacity.addAll(getIterables());
            outputReceiver.output(Row.withSchema(schema).attachValues(newArrayListWithCapacity));
        }

        static void verifyExpandedArgs(JoinInformation joinInformation, JoinArguments joinArguments) {
            boolean z = false;
            boolean z2 = true;
            for (int i = 0; i < joinInformation.sortedTags.size(); i++) {
                if (joinInformation.sideInputs.get((String) joinInformation.tagToKeyedTag.get(Integer.valueOf(i))) != null) {
                    z = true;
                } else if (!joinArguments.getOptionalParticipation((String) joinInformation.sortedTags.get(i))) {
                    z2 = false;
                }
            }
            Preconditions.checkArgument((z && z2) ? false : true, "Cannot perform join when all main inputs are optional and there is a side input.  consider removing the side input.");
        }

        static Schema getExpandedOutputSchema(JoinInformation joinInformation, JoinArguments joinArguments) {
            Schema.Builder builder = Schema.builder();
            for (Map.Entry entry : joinInformation.componentSchemas.entrySet()) {
                Schema.FieldType row = Schema.FieldType.row((Schema) entry.getValue());
                if (joinArguments.getOptionalParticipation((String) entry.getKey())) {
                    row = row.withNullable(true);
                }
                builder.addField((String) entry.getKey(), row);
            }
            return builder.build();
        }

        void outputExpandedRows(DoFn.OutputReceiver<Row> outputReceiver) {
            crossProduct(0, Lists.newArrayListWithCapacity(getIterables().size()), extractIterables(), outputReceiver);
        }

        private List<Iterable<Row>> extractIterables() {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(getIterables().size());
            for (int i = 0; i < getIterables().size(); i++) {
                Iterable<Row> iterable = getIterables().get(i);
                String str = getTags().get(i);
                if (!iterable.iterator().hasNext() && getJoinArguments().getOptionalParticipation(str)) {
                    iterable = () -> {
                        return CoGroup.NULL_LIST.iterator();
                    };
                }
                newArrayListWithCapacity.add(iterable);
            }
            return newArrayListWithCapacity;
        }

        private void crossProduct(int i, List<Row> list, List<Iterable<Row>> list2, DoFn.OutputReceiver<Row> outputReceiver) {
            if (i >= list2.size()) {
                return;
            }
            Iterator<Row> it = list2.get(i).iterator();
            while (it.hasNext()) {
                crossProductHelper(i, list, it.next(), list2, outputReceiver);
            }
        }

        private void crossProductHelper(int i, List<Row> list, Row row, List<Iterable<Row>> list2, DoFn.OutputReceiver<Row> outputReceiver) {
            boolean z = i == list2.size() - 1;
            list.add(row);
            if (z) {
                outputReceiver.output(Row.withSchema(getOutputSchema()).attachValues(Lists.newArrayList(list)));
            } else {
                crossProduct(i + 1, list, list2, outputReceiver);
            }
            list.remove(list.size() - 1);
        }
    }

    public static Impl join(By by) {
        return new Impl(new JoinArguments(by));
    }

    public static Impl join(String str, By by) {
        return new Impl(new JoinArguments(ImmutableMap.of(str, by)));
    }

    static void verify(PCollectionTuple pCollectionTuple, JoinArguments joinArguments) {
        if (joinArguments.allInputsJoinArgs == null) {
            Set set = (Set) pCollectionTuple.getAll().keySet().stream().map((v0) -> {
                return v0.getId();
            }).collect(Collectors.toSet());
            Set keySet = joinArguments.joinArgsMap.keySet();
            if (!set.equals(keySet)) {
                throw new IllegalArgumentException("The input PCollectionTuple has tags: " + set + " and the join was specified for tags " + keySet + ". These do not match.");
            }
        }
    }

    static {
        NULL_LIST.add(null);
    }
}
