package org.apache.beam.sdk.schemas.transforms;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.collect.ImmutableMap;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.collect.Lists;
import org.apache.beam.repackaged.beam_sdks_java_core.com.google.common.collect.Maps;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.join.CoGbkResult;
import org.apache.beam.sdk.transforms.join.CoGroupByKey;
import org.apache.beam.sdk.transforms.join.KeyedPCollectionTuple;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;

/* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup.class */
public class CoGroup {

    /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$Inner.class */
    public static class Inner extends PTransform<PCollectionTuple, PCollection<KV<Row, Row>>> {

        @Nullable
        private final FieldAccessDescriptor allInputsFieldAccessDescriptor;
        private final Map<TupleTag<?>, FieldAccessDescriptor> fieldAccessDescriptorMap;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/beam/sdk/schemas/transforms/CoGroup$Inner$ConvertToRow.class */
        public static class ConvertToRow extends DoFn<KV<Row, CoGbkResult>, KV<Row, Row>> {
            List<TupleTag<Row>> sortedTags;
            Map<String, SerializableFunction<Object, Row>> toRows;
            Schema joinedSchema;

            public ConvertToRow(List<TupleTag<Row>> list, Map<String, SerializableFunction<Object, Row>> map, Schema schema) {
                this.toRows = Maps.newHashMap();
                this.sortedTags = list;
                this.toRows = map;
                this.joinedSchema = schema;
            }

            @DoFn.ProcessElement
            public void process(@DoFn.Element KV<Row, CoGbkResult> kv, DoFn.OutputReceiver<KV<Row, Row>> outputReceiver) {
                Row key = kv.getKey();
                CoGbkResult value = kv.getValue();
                ArrayList newArrayListWithExpectedSize = Lists.newArrayListWithExpectedSize(this.sortedTags.size());
                for (TupleTag<Row> tupleTag : this.sortedTags) {
                    SerializableFunction<Object, Row> serializableFunction = this.toRows.get(tupleTag.getId());
                    ArrayList newArrayList = Lists.newArrayList();
                    Iterator it = value.getAll(tupleTag).iterator();
                    while (it.hasNext()) {
                        newArrayList.add(serializableFunction.apply(it.next()));
                    }
                    newArrayListWithExpectedSize.add(newArrayList);
                }
                outputReceiver.output(KV.of(key, Row.withSchema(this.joinedSchema).addValues(newArrayListWithExpectedSize).build()));
            }
        }

        private Inner() {
            this((Map<TupleTag<?>, FieldAccessDescriptor>) Collections.emptyMap());
        }

        private Inner(Map<TupleTag<?>, FieldAccessDescriptor> map) {
            this.allInputsFieldAccessDescriptor = null;
            this.fieldAccessDescriptorMap = map;
        }

        private Inner(FieldAccessDescriptor fieldAccessDescriptor) {
            this.allInputsFieldAccessDescriptor = fieldAccessDescriptor;
            this.fieldAccessDescriptorMap = Collections.emptyMap();
        }

        public Inner byFieldNames(TupleTag<?> tupleTag, String... strArr) {
            return byFieldAccessDescriptor(tupleTag, FieldAccessDescriptor.withFieldNames(strArr));
        }

        public Inner byFieldIds(TupleTag<?> tupleTag, Integer... numArr) {
            return byFieldAccessDescriptor(tupleTag, FieldAccessDescriptor.withFieldIds(numArr));
        }

        public Inner byFieldAccessDescriptor(TupleTag<?> tupleTag, FieldAccessDescriptor fieldAccessDescriptor) {
            if (this.allInputsFieldAccessDescriptor != null) {
                throw new IllegalStateException("Cannot set both a global and per-tag fields.");
            }
            return new Inner(new ImmutableMap.Builder().putAll(this.fieldAccessDescriptorMap).put(tupleTag, fieldAccessDescriptor).build());
        }

        @Nullable
        private FieldAccessDescriptor getFieldAccessDescriptor(TupleTag<?> tupleTag) {
            return this.allInputsFieldAccessDescriptor != null ? this.allInputsFieldAccessDescriptor : this.fieldAccessDescriptorMap.get(tupleTag);
        }

        @Override // org.apache.beam.sdk.transforms.PTransform
        public PCollection<KV<Row, Row>> expand(PCollectionTuple pCollectionTuple) {
            KeyedPCollectionTuple empty = KeyedPCollectionTuple.empty(pCollectionTuple.getPipeline());
            List list = (List) pCollectionTuple.getAll().keySet().stream().sorted(Comparator.comparing((v0) -> {
                return v0.getId();
            })).map(tupleTag -> {
                return new TupleTag(tupleTag.getId() + "_ROW");
            }).collect(Collectors.toList());
            TreeMap newTreeMap = Maps.newTreeMap();
            HashMap newHashMap = Maps.newHashMap();
            Schema schema = null;
            for (Map.Entry<TupleTag<?>, PCollection<?>> entry : pCollectionTuple.getAll().entrySet()) {
                TupleTag<?> key = entry.getKey();
                PCollection<?> value = entry.getValue();
                Schema schema2 = value.getSchema();
                newTreeMap.put(key.getId(), schema2);
                TupleTag tupleTag2 = new TupleTag(key.getId() + "_ROW");
                newHashMap.put(tupleTag2.getId(), value.getToRowFunction());
                FieldAccessDescriptor fieldAccessDescriptor = getFieldAccessDescriptor(key);
                if (fieldAccessDescriptor == null) {
                    throw new IllegalStateException("No fields were set for input " + key);
                }
                FieldAccessDescriptor resolve = fieldAccessDescriptor.withOrderByFieldInsertionOrder().resolve(schema2);
                Schema outputSchema = Select.getOutputSchema(schema2, resolve);
                if (schema == null) {
                    schema = outputSchema;
                } else if (!outputSchema.typesEqual(schema)) {
                    throw new IllegalStateException("All keys must have the same schema");
                }
                empty = empty.and(tupleTag2, extractKey(value, schema2, schema, resolve, key.getId()));
            }
            Schema.Builder builder = Schema.builder();
            for (Map.Entry entry2 : newTreeMap.entrySet()) {
                builder.addArrayField((String) entry2.getKey(), Schema.FieldType.row((Schema) entry2.getValue()));
            }
            Schema build = builder.build();
            return ((PCollection) ((PCollection) empty.apply("CoGroupByKey", CoGroupByKey.create())).apply("ConvertToRow", ParDo.of(new ConvertToRow(list, newHashMap, build)))).setCoder(KvCoder.of(SchemaCoder.of(schema), SchemaCoder.of(build)));
        }

        private static <T> PCollection<KV<Row, Row>> extractKey(PCollection<T> pCollection, final Schema schema, final Schema schema2, final FieldAccessDescriptor fieldAccessDescriptor, String str) {
            return ((PCollection) pCollection.apply("extractKey" + str, ParDo.of(new DoFn<T, KV<Row, Row>>() { // from class: org.apache.beam.sdk.schemas.transforms.CoGroup.Inner.1
                @DoFn.ProcessElement
                public void process(@DoFn.Element Row row, DoFn.OutputReceiver<KV<Row, Row>> outputReceiver) {
                    outputReceiver.output(KV.of(Select.selectRow(row, FieldAccessDescriptor.this, schema, schema2), row));
                }
            }))).setCoder(KvCoder.of(SchemaCoder.of(schema2), SchemaCoder.of(schema)));
        }
    }

    public static Inner byFieldNames(String... strArr) {
        return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldNames(strArr));
    }

    public static Inner byFieldIds(Integer... numArr) {
        return byFieldAccessDescriptor(FieldAccessDescriptor.withFieldIds(numArr));
    }

    public static Inner byFieldAccessDescriptor(FieldAccessDescriptor fieldAccessDescriptor) {
        return new Inner(fieldAccessDescriptor);
    }

    public static Inner byFieldNames(TupleTag<?> tupleTag, String... strArr) {
        return byFieldAccessDescriptor(tupleTag, FieldAccessDescriptor.withFieldNames(strArr));
    }

    public static Inner byFieldIds(TupleTag<?> tupleTag, Integer... numArr) {
        return byFieldAccessDescriptor(tupleTag, FieldAccessDescriptor.withFieldIds(numArr));
    }

    public static Inner byFieldAccessDescriptor(TupleTag<?> tupleTag, FieldAccessDescriptor fieldAccessDescriptor) {
        return new Inner().byFieldAccessDescriptor(tupleTag, fieldAccessDescriptor);
    }
}
