/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.statefun.flink.state.processor.union;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.UnionOperator;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.statefun.flink.state.processor.BootstrapDataRouterProvider;
import org.apache.flink.statefun.flink.state.processor.union.BootstrapDataset;
import org.apache.flink.statefun.flink.state.processor.union.TaggedBootstrapData;
import org.apache.flink.statefun.flink.state.processor.union.TaggedBootstrapDataTypeInfo;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.Router;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

public final class BootstrapDatasetUnion {
    public static DataSet<TaggedBootstrapData> apply(List<BootstrapDataset<?>> bootstrapDatasets) {
        Objects.requireNonNull(bootstrapDatasets);
        Preconditions.checkArgument((bootstrapDatasets.size() > 0 ? 1 : 0) != 0);
        ArrayList<DataSet<TaggedBootstrapData>> unionBootstrapDataset = new ArrayList<DataSet<TaggedBootstrapData>>(bootstrapDatasets.size());
        TypeInformation<TaggedBootstrapData> unionTypeInfo = BootstrapDatasetUnion.createUnionTypeInfo(bootstrapDatasets);
        int unionIndex = 0;
        for (BootstrapDataset<?> bootstrapDataset : bootstrapDatasets) {
            unionBootstrapDataset.add(BootstrapDatasetUnion.toTaggedFlinkDataSet(bootstrapDataset, unionIndex, unionTypeInfo));
            ++unionIndex;
        }
        return BootstrapDatasetUnion.unionTaggedBootstrapDataSets(unionBootstrapDataset);
    }

    private static TypeInformation<TaggedBootstrapData> createUnionTypeInfo(List<BootstrapDataset<?>> bootstrapDatasets) {
        List<TypeInformation<?>> payloadTypeInfos = bootstrapDatasets.stream().map(bootstrapDataset -> bootstrapDataset.getDataSet().getType()).collect(Collectors.toList());
        return new TaggedBootstrapDataTypeInfo(payloadTypeInfos);
    }

    private static <T> DataSet<TaggedBootstrapData> toTaggedFlinkDataSet(BootstrapDataset<T> bootstrapDataset, int unionIndex, TypeInformation<TaggedBootstrapData> unionTypeInfo) {
        return bootstrapDataset.getDataSet().flatMap(new BootstrapRouterFlatMap<T>(bootstrapDataset.getRouterProvider(), unionIndex)).returns(unionTypeInfo);
    }

    private static DataSet<TaggedBootstrapData> unionTaggedBootstrapDataSets(List<DataSet<TaggedBootstrapData>> taggedBootstrapDatasets) {
        UnionOperator result = null;
        for (UnionOperator unionOperator : taggedBootstrapDatasets) {
            if (result != null) {
                result = result.union(unionOperator);
                continue;
            }
            result = unionOperator;
        }
        return result;
    }

    private static class TaggingBootstrapDataCollector<T>
    implements Router.Downstream<T> {
        private final Collector<TaggedBootstrapData> out;
        private final int unionIndex;

        TaggingBootstrapDataCollector(Collector<TaggedBootstrapData> out, int unionIndex) {
            this.out = Objects.requireNonNull(out);
            this.unionIndex = unionIndex;
        }

        public void forward(FunctionType functionType, String id, T message) {
            this.out.collect((Object)new TaggedBootstrapData(new Address(functionType, id), message, this.unionIndex));
        }

        public void forward(Address to, T message) {
            this.out.collect((Object)new TaggedBootstrapData(to, message, this.unionIndex));
        }
    }

    private static class BootstrapRouterFlatMap<T>
    extends RichFlatMapFunction<T, TaggedBootstrapData> {
        private static final long serialVersionUID = 1L;
        private final BootstrapDataRouterProvider<T> routerProvider;
        private final int unionIndex;
        private transient Router<T> router;

        BootstrapRouterFlatMap(BootstrapDataRouterProvider<T> routerProvider, int unionIndex) {
            this.routerProvider = Objects.requireNonNull(routerProvider);
            Preconditions.checkArgument((unionIndex >= 0 ? 1 : 0) != 0);
            this.unionIndex = unionIndex;
        }

        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            this.router = this.routerProvider.provide();
        }

        public void flatMap(T data, Collector<TaggedBootstrapData> collector) throws Exception {
            this.router.route(data, new TaggingBootstrapDataCollector(collector, this.unionIndex));
        }
    }
}

