/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import io.prestosql.Session;
import io.prestosql.SystemSessionProperties;
import io.prestosql.matching.Capture;
import io.prestosql.matching.Captures;
import io.prestosql.matching.Pattern;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.SymbolMapper;
import io.prestosql.sql.planner.plan.Patterns;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.UnionNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;

public class PushTableWriteThroughUnion
implements Rule<TableWriterNode> {
    private static final Capture<UnionNode> CHILD = Capture.newCapture();
    private static final Pattern<TableWriterNode> PATTERN = Patterns.tableWriterNode().matching(tableWriter -> !tableWriter.getPartitioningScheme().isPresent()).with(Patterns.source().matching(Patterns.union().capturedAs(CHILD)));

    @Override
    public Pattern<TableWriterNode> getPattern() {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isPushTableWriteThroughUnion(session);
    }

    @Override
    public Rule.Result apply(TableWriterNode writerNode, Captures captures, Rule.Context context) {
        UnionNode unionNode = (UnionNode)captures.get(CHILD);
        ImmutableList.Builder rewrittenSources = ImmutableList.builder();
        ArrayList<Map<Symbol, Symbol>> sourceMappings = new ArrayList<Map<Symbol, Symbol>>();
        for (int source = 0; source < unionNode.getSources().size(); ++source) {
            rewrittenSources.add((Object)PushTableWriteThroughUnion.rewriteSource(writerNode, unionNode, source, sourceMappings, context));
        }
        ImmutableListMultimap.Builder unionMappings = ImmutableListMultimap.builder();
        sourceMappings.forEach(mappings -> mappings.forEach((arg_0, arg_1) -> ((ImmutableListMultimap.Builder)unionMappings).put(arg_0, arg_1)));
        return Rule.Result.ofPlanNode(new UnionNode(context.getIdAllocator().getNextId(), (List<PlanNode>)rewrittenSources.build(), (ListMultimap<Symbol, Symbol>)unionMappings.build(), (List<Symbol>)ImmutableList.copyOf((Collection)unionMappings.build().keySet())));
    }

    private static TableWriterNode rewriteSource(TableWriterNode writerNode, UnionNode unionNode, int source, List<Map<Symbol, Symbol>> sourceMappings, Rule.Context context) {
        Map<Symbol, Symbol> inputMappings = PushTableWriteThroughUnion.getInputSymbolMapping(unionNode, source);
        ImmutableMap.Builder mappings = ImmutableMap.builder();
        mappings.putAll(inputMappings);
        ImmutableMap.Builder outputMappings = ImmutableMap.builder();
        for (Symbol outputSymbol : writerNode.getOutputSymbols()) {
            if (inputMappings.containsKey(outputSymbol)) {
                outputMappings.put((Object)outputSymbol, (Object)inputMappings.get(outputSymbol));
                continue;
            }
            Symbol newSymbol = context.getSymbolAllocator().newSymbol(outputSymbol);
            outputMappings.put((Object)outputSymbol, (Object)newSymbol);
            mappings.put((Object)outputSymbol, (Object)newSymbol);
        }
        sourceMappings.add((Map<Symbol, Symbol>)outputMappings.build());
        SymbolMapper symbolMapper = new SymbolMapper((Map<Symbol, Symbol>)mappings.build());
        return symbolMapper.map(writerNode, unionNode.getSources().get(source), context.getIdAllocator().getNextId());
    }

    private static Map<Symbol, Symbol> getInputSymbolMapping(UnionNode node, int source) {
        return (Map)node.getSymbolMapping().keySet().stream().collect(ImmutableMap.toImmutableMap(key -> key, key -> (Symbol)node.getSymbolMapping().get(key).get(source)));
    }
}

