package org.apache.beam.runners.spark.structuredstreaming.translation.batch;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.expressions.Aggregator;
import org.joda.time.Instant;
import scala.Tuple2;

/* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.class */
class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow> extends Aggregator<WindowedValue<KV<K, InputT>>, Iterable<WindowedValue<AccumT>>, Iterable<WindowedValue<OutputT>>> {
    private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
    private WindowingStrategy<InputT, W> windowingStrategy;
    private TimestampCombiner timestampCombiner;
    private Coder<AccumT> accumulatorCoder;
    private IterableCoder<WindowedValue<AccumT>> bufferEncoder;
    private IterableCoder<WindowedValue<OutputT>> outputCoder;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner$MergeContextImpl.class */
    public class MergeContextImpl extends WindowFn<InputT, W>.MergeContext {
        private Set<W> windows;
        private Map<W, W> windowToMergeResult;

        /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
        MergeContextImpl(WindowFn<InputT, W> windowFn, Set<W> set, Map<W, W> map) {
            super(windowFn);
            Objects.requireNonNull(windowFn);
            this.windows = set;
            this.windowToMergeResult = map;
        }

        public Collection<W> windows() {
            return this.windows;
        }

        public void merge(Collection<W> collection, W w) throws Exception {
            Iterator<W> it = collection.iterator();
            while (it.hasNext()) {
                this.windowToMergeResult.put(it.next(), w);
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public AggregatorCombiner(Combine.CombineFn<InputT, AccumT, OutputT> combineFn, WindowingStrategy<?, ?> windowingStrategy, Coder<AccumT> coder, Coder<OutputT> coder2) {
        this.combineFn = combineFn;
        this.windowingStrategy = windowingStrategy;
        this.timestampCombiner = windowingStrategy.getTimestampCombiner();
        this.accumulatorCoder = coder;
        this.bufferEncoder = IterableCoder.of(WindowedValue.FullWindowedValueCoder.of(coder, windowingStrategy.getWindowFn().windowCoder()));
        this.outputCoder = IterableCoder.of(WindowedValue.FullWindowedValueCoder.of(coder2, windowingStrategy.getWindowFn().windowCoder()));
    }

    /* renamed from: zero, reason: merged with bridge method [inline-methods] */
    public Iterable<WindowedValue<AccumT>> m65zero() {
        return new ArrayList();
    }

    private Iterable<WindowedValue<AccumT>> createAccumulator(WindowedValue<KV<K, InputT>> windowedValue) {
        return Lists.newArrayList(new WindowedValue[]{WindowedValue.of(this.combineFn.addInput(this.combineFn.createAccumulator(), ((KV) windowedValue.getValue()).getValue()), windowedValue.getTimestamp(), windowedValue.getWindows(), windowedValue.getPane())});
    }

    public Iterable<WindowedValue<AccumT>> reduce(Iterable<WindowedValue<AccumT>> iterable, WindowedValue<KV<K, InputT>> windowedValue) {
        return merge((Iterable) iterable, (Iterable) createAccumulator(windowedValue));
    }

    public Iterable<WindowedValue<AccumT>> merge(Iterable<WindowedValue<AccumT>> iterable, Iterable<WindowedValue<AccumT>> iterable2) {
        Object decodeFromByteArray;
        Iterable<WindowedValue<AccumT>> concat = Iterables.concat(iterable, iterable2);
        try {
            Map<W, W> mergeWindows = mergeWindows(this.windowingStrategy, collectAccumulatorsWindows(concat));
            HashMap hashMap = new HashMap();
            for (WindowedValue<AccumT> windowedValue : concat) {
                byte[] bArr = null;
                if (windowedValue.getWindows().size() > 1) {
                    try {
                        bArr = CoderUtils.encodeToByteArray(this.accumulatorCoder, windowedValue.getValue());
                    } catch (CoderException e) {
                        throw new RuntimeException(String.format("Unable to encode accumulator %s with coder %s.", windowedValue.getValue(), this.accumulatorCoder), e);
                    }
                }
                for (BoundedWindow boundedWindow : windowedValue.getWindows()) {
                    W w = mergeWindows.get(boundedWindow);
                    BoundedWindow boundedWindow2 = w == null ? boundedWindow : w;
                    if (bArr != null) {
                        try {
                            decodeFromByteArray = CoderUtils.decodeFromByteArray(this.accumulatorCoder, bArr);
                        } catch (CoderException e2) {
                            throw new RuntimeException(String.format("Unable to encode accumulator %s with coder %s.", windowedValue.getValue(), this.accumulatorCoder), e2);
                        }
                    } else {
                        decodeFromByteArray = windowedValue.getValue();
                    }
                    Tuple2 tuple2 = new Tuple2(decodeFromByteArray, this.timestampCombiner.assign(boundedWindow2, windowedValue.getTimestamp()));
                    if (hashMap.get(boundedWindow2) == null) {
                        hashMap.put(boundedWindow2, Lists.newArrayList(new Tuple2[]{tuple2}));
                    } else {
                        ((List) hashMap.get(boundedWindow2)).add(tuple2);
                    }
                }
            }
            ArrayList arrayList = new ArrayList();
            for (Map.Entry entry : hashMap.entrySet()) {
                BoundedWindow boundedWindow3 = (BoundedWindow) entry.getKey();
                List list = (List) entry.getValue();
                arrayList.add(WindowedValue.of(this.combineFn.mergeAccumulators(Iterables.concat(Collections.singleton(this.combineFn.createAccumulator()), (Iterable) list.stream().map(tuple22 -> {
                    return tuple22._1();
                }).collect(Collectors.toList()))), this.timestampCombiner.combine((Iterable) list.stream().map(tuple23 -> {
                    return (Instant) tuple23._2();
                }).collect(Collectors.toList())), boundedWindow3, PaneInfo.NO_FIRING));
            }
            return arrayList;
        } catch (Exception e3) {
            throw new RuntimeException("Unable to merge accumulators windows", e3);
        }
    }

    public Iterable<WindowedValue<OutputT>> finish(Iterable<WindowedValue<AccumT>> iterable) {
        ArrayList arrayList = new ArrayList();
        for (WindowedValue<AccumT> windowedValue : iterable) {
            arrayList.add(windowedValue.withValue(this.combineFn.extractOutput(windowedValue.getValue())));
        }
        return arrayList;
    }

    public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
        return EncoderHelpers.fromBeamCoder(this.bufferEncoder);
    }

    public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() {
        return EncoderHelpers.fromBeamCoder(this.outputCoder);
    }

    private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> iterable) {
        HashSet hashSet = new HashSet();
        Iterator<WindowedValue<AccumT>> it = iterable.iterator();
        while (it.hasNext()) {
            Iterator it2 = it.next().getWindows().iterator();
            while (it2.hasNext()) {
                hashSet.add((BoundedWindow) it2.next());
            }
        }
        return hashSet;
    }

    private Map<W, W> mergeWindows(WindowingStrategy<InputT, W> windowingStrategy, Set<W> set) throws Exception {
        WindowFn windowFn = windowingStrategy.getWindowFn();
        if (!windowingStrategy.needsMerge()) {
            return Collections.emptyMap();
        }
        HashMap hashMap = new HashMap();
        windowFn.mergeWindows(new MergeContextImpl(windowFn, set, hashMap));
        return hashMap;
    }
}
