/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;
import io.prestosql.Session;
import io.prestosql.execution.warnings.WarningCollector;
import io.prestosql.metadata.Metadata;
import io.prestosql.spi.connector.ColumnHandle;
import io.prestosql.sql.planner.OrderingScheme;
import io.prestosql.sql.planner.PartitioningScheme;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolAllocator;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.TypeAnalyzer;
import io.prestosql.sql.planner.TypeProvider;
import io.prestosql.sql.planner.iterative.rule.PruneTableScanColumns;
import io.prestosql.sql.planner.optimizations.PlanOptimizer;
import io.prestosql.sql.planner.optimizations.QueryCardinalityUtil;
import io.prestosql.sql.planner.plan.AggregationNode;
import io.prestosql.sql.planner.plan.ApplyNode;
import io.prestosql.sql.planner.plan.AssignUniqueId;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.CorrelatedJoinNode;
import io.prestosql.sql.planner.plan.DeleteNode;
import io.prestosql.sql.planner.plan.DistinctLimitNode;
import io.prestosql.sql.planner.plan.ExceptNode;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.ExplainAnalyzeNode;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.GroupIdNode;
import io.prestosql.sql.planner.plan.IndexJoinNode;
import io.prestosql.sql.planner.plan.IndexSourceNode;
import io.prestosql.sql.planner.plan.IntersectNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.LimitNode;
import io.prestosql.sql.planner.plan.MarkDistinctNode;
import io.prestosql.sql.planner.plan.OffsetNode;
import io.prestosql.sql.planner.plan.OutputNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.planner.plan.RowNumberNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SetOperationNode;
import io.prestosql.sql.planner.plan.SimplePlanRewriter;
import io.prestosql.sql.planner.plan.SortNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.StatisticAggregations;
import io.prestosql.sql.planner.plan.StatisticsWriterNode;
import io.prestosql.sql.planner.plan.TableFinishNode;
import io.prestosql.sql.planner.plan.TableScanNode;
import io.prestosql.sql.planner.plan.TableWriterNode;
import io.prestosql.sql.planner.plan.TopNNode;
import io.prestosql.sql.planner.plan.TopNRowNumberNode;
import io.prestosql.sql.planner.plan.UnionNode;
import io.prestosql.sql.planner.plan.UnnestNode;
import io.prestosql.sql.planner.plan.ValuesNode;
import io.prestosql.sql.planner.plan.WindowNode;
import io.prestosql.sql.tree.BooleanLiteral;
import io.prestosql.sql.tree.Expression;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

public class PruneUnreferencedOutputs
implements PlanOptimizer {
    private final Metadata metadata;
    private final TypeAnalyzer typeAnalyzer;

    public PruneUnreferencedOutputs(Metadata metadata, TypeAnalyzer typeAnalyzer) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) {
        Objects.requireNonNull(plan, "plan is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(types, "types is null");
        Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
        Objects.requireNonNull(idAllocator, "idAllocator is null");
        return SimplePlanRewriter.rewriteWith(new Rewriter(this.metadata, types, this.typeAnalyzer, symbolAllocator, session), plan, ImmutableSet.of());
    }

    private static class Rewriter
    extends SimplePlanRewriter<Set<Symbol>> {
        private final Metadata metadata;
        private final TypeProvider types;
        private final TypeAnalyzer typeAnalyzer;
        private final SymbolAllocator symbolAllocator;
        private final Session session;

        public Rewriter(Metadata metadata, TypeProvider types, TypeAnalyzer typeAnalyzer, SymbolAllocator symbolAllocator, Session session) {
            this.metadata = metadata;
            this.types = types;
            this.typeAnalyzer = typeAnalyzer;
            this.symbolAllocator = symbolAllocator;
            this.session = session;
        }

        @Override
        public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            return context.defaultRewrite(node, (Set<Symbol>)ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            HashSet expectedOutputSymbols = Sets.newHashSet((Iterable)context.get());
            node.getPartitioningScheme().getHashColumn().ifPresent(expectedOutputSymbols::add);
            expectedOutputSymbols.addAll(node.getPartitioningScheme().getPartitioning().getColumns());
            node.getOrderingScheme().ifPresent(orderingScheme -> expectedOutputSymbols.addAll(orderingScheme.getOrderBy()));
            ArrayList<List<Symbol>> inputsBySource = new ArrayList<List<Symbol>>(node.getInputs().size());
            for (int i = 0; i < node.getInputs().size(); ++i) {
                inputsBySource.add(new ArrayList());
            }
            ArrayList<Symbol> newOutputSymbols = new ArrayList<Symbol>(node.getOutputSymbols().size());
            for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                Symbol outputSymbol = node.getOutputSymbols().get(i);
                if (!expectedOutputSymbols.contains(outputSymbol)) continue;
                newOutputSymbols.add(outputSymbol);
                for (int source = 0; source < node.getInputs().size(); ++source) {
                    ((List)inputsBySource.get(source)).add(node.getInputs().get(source).get(i));
                }
            }
            PartitioningScheme partitioningScheme = new PartitioningScheme(node.getPartitioningScheme().getPartitioning(), newOutputSymbols, node.getPartitioningScheme().getHashColumn(), node.getPartitioningScheme().isReplicateNullsAndAny(), node.getPartitioningScheme().getBucketToPartition());
            ImmutableList.Builder rewrittenSources = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); ++i) {
                ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)inputsBySource.get(i));
                rewrittenSources.add((Object)context.rewrite(node.getSources().get(i), (Set<Symbol>)expectedInputs.build()));
            }
            return new ExchangeNode(node.getId(), node.getType(), node.getScope(), partitioningScheme, (List<PlanNode>)rewrittenSources.build(), inputsBySource, node.getOrderingScheme());
        }

        @Override
        public PlanNode visitJoin(JoinNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet expectedFilterInputs = new HashSet();
            if (node.getFilter().isPresent()) {
                expectedFilterInputs = ImmutableSet.builder().addAll(SymbolsExtractor.extractUnique(node.getFilter().get())).addAll((Iterable)context.get()).build();
            }
            ImmutableSet.Builder leftInputsBuilder = ImmutableSet.builder();
            leftInputsBuilder.addAll((Iterable)context.get()).addAll(node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft).iterator());
            if (node.getLeftHashSymbol().isPresent()) {
                leftInputsBuilder.add((Object)node.getLeftHashSymbol().get());
            }
            leftInputsBuilder.addAll((Iterable)expectedFilterInputs);
            ImmutableSet leftInputs = leftInputsBuilder.build();
            ImmutableSet.Builder rightInputsBuilder = ImmutableSet.builder();
            rightInputsBuilder.addAll((Iterable)context.get()).addAll(Iterables.transform(node.getCriteria(), JoinNode.EquiJoinClause::getRight));
            if (node.getRightHashSymbol().isPresent()) {
                rightInputsBuilder.add((Object)node.getRightHashSymbol().get());
            }
            rightInputsBuilder.addAll((Iterable)expectedFilterInputs);
            ImmutableSet rightInputs = rightInputsBuilder.build();
            PlanNode left = context.rewrite(node.getLeft(), (Set<Symbol>)leftInputs);
            PlanNode right = context.rewrite(node.getRight(), (Set<Symbol>)rightInputs);
            List leftOutputSymbols = (List)node.getLeftOutputSymbols().stream().filter(context.get()::contains).distinct().collect(ImmutableList.toImmutableList());
            List rightOutputSymbols = (List)node.getRightOutputSymbols().stream().filter(context.get()::contains).distinct().collect(ImmutableList.toImmutableList());
            return new JoinNode(node.getId(), node.getType(), left, right, node.getCriteria(), leftOutputSymbols, rightOutputSymbols, node.getFilter(), node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType(), node.isSpillable(), node.getDynamicFilters(), node.getReorderJoinStatsAndCost());
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder sourceInputsBuilder = ImmutableSet.builder();
            sourceInputsBuilder.addAll((Iterable)context.get()).add((Object)node.getSourceJoinSymbol());
            if (node.getSourceHashSymbol().isPresent()) {
                sourceInputsBuilder.add((Object)node.getSourceHashSymbol().get());
            }
            ImmutableSet sourceInputs = sourceInputsBuilder.build();
            ImmutableSet.Builder filteringSourceInputBuilder = ImmutableSet.builder();
            filteringSourceInputBuilder.add((Object)node.getFilteringSourceJoinSymbol());
            if (node.getFilteringSourceHashSymbol().isPresent()) {
                filteringSourceInputBuilder.add((Object)node.getFilteringSourceHashSymbol().get());
            }
            ImmutableSet filteringSourceInputs = filteringSourceInputBuilder.build();
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)sourceInputs);
            PlanNode filteringSource = context.rewrite(node.getFilteringSource(), (Set<Symbol>)filteringSourceInputs);
            return new SemiJoinNode(node.getId(), source, filteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput(), node.getSourceHashSymbol(), node.getFilteringSourceHashSymbol(), node.getDistributionType());
        }

        @Override
        public PlanNode visitSpatialJoin(SpatialJoinNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet requiredInputs = ImmutableSet.builder().addAll(SymbolsExtractor.extractUnique(node.getFilter())).addAll((Iterable)context.get()).build();
            ImmutableSet.Builder leftInputs = ImmutableSet.builder();
            node.getLeftPartitionSymbol().map(arg_0 -> ((ImmutableSet.Builder)leftInputs).add(arg_0));
            ImmutableSet.Builder rightInputs = ImmutableSet.builder();
            node.getRightPartitionSymbol().map(arg_0 -> ((ImmutableSet.Builder)rightInputs).add(arg_0));
            PlanNode left = context.rewrite(node.getLeft(), (Set<Symbol>)leftInputs.addAll((Iterable)requiredInputs).build());
            PlanNode right = context.rewrite(node.getRight(), (Set<Symbol>)rightInputs.addAll((Iterable)requiredInputs).build());
            List outputSymbols = (List)node.getOutputSymbols().stream().filter(context.get()::contains).distinct().collect(ImmutableList.toImmutableList());
            return new SpatialJoinNode(node.getId(), node.getType(), left, right, outputSymbols, node.getFilter(), node.getLeftPartitionSymbol(), node.getRightPartitionSymbol(), node.getKdbTree());
        }

        @Override
        public PlanNode visitIndexJoin(IndexJoinNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder probeInputsBuilder = ImmutableSet.builder();
            probeInputsBuilder.addAll((Iterable)context.get()).addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getProbe));
            if (node.getProbeHashSymbol().isPresent()) {
                probeInputsBuilder.add((Object)node.getProbeHashSymbol().get());
            }
            ImmutableSet probeInputs = probeInputsBuilder.build();
            ImmutableSet.Builder indexInputBuilder = ImmutableSet.builder();
            indexInputBuilder.addAll((Iterable)context.get()).addAll(Iterables.transform(node.getCriteria(), IndexJoinNode.EquiJoinClause::getIndex));
            if (node.getIndexHashSymbol().isPresent()) {
                indexInputBuilder.add((Object)node.getIndexHashSymbol().get());
            }
            ImmutableSet indexInputs = indexInputBuilder.build();
            PlanNode probeSource = context.rewrite(node.getProbeSource(), (Set<Symbol>)probeInputs);
            PlanNode indexSource = context.rewrite(node.getIndexSource(), (Set<Symbol>)indexInputs);
            return new IndexJoinNode(node.getId(), node.getType(), probeSource, indexSource, node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol());
        }

        @Override
        public PlanNode visitIndexSource(IndexSourceNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            List newOutputSymbols = (List)node.getOutputSymbols().stream().filter(context.get()::contains).collect(ImmutableList.toImmutableList());
            Set newLookupSymbols = (Set)node.getLookupSymbols().stream().filter(context.get()::contains).collect(ImmutableSet.toImmutableSet());
            Map<Symbol, ColumnHandle> newAssignments = newOutputSymbols.stream().collect(Collectors.toMap(Function.identity(), node.getAssignments()::get));
            return new IndexSourceNode(node.getId(), node.getIndexHandle(), node.getTableHandle(), newLookupSymbols, newOutputSymbols, newAssignments, node.getCurrentConstraint());
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll(node.getGroupingKeys());
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add((Object)node.getHashSymbol().get());
            }
            ImmutableMap.Builder aggregations = ImmutableMap.builder();
            for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                Symbol symbol = entry.getKey();
                if (!context.get().contains(symbol)) continue;
                AggregationNode.Aggregation aggregation = entry.getValue();
                expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation));
                aggregations.put((Object)symbol, (Object)aggregation);
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new AggregationNode(node.getId(), source, (Map<Symbol, AggregationNode.Aggregation>)aggregations.build(), node.getGroupingSets(), (List<Symbol>)ImmutableList.of(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)context.get()).addAll(node.getPartitionBy());
            node.getOrderingScheme().ifPresent(orderingScheme -> orderingScheme.getOrderBy().forEach(arg_0 -> ((ImmutableSet.Builder)expectedInputs).add(arg_0)));
            for (WindowNode.Frame frame : node.getFrames()) {
                if (frame.getStartValue().isPresent()) {
                    expectedInputs.add((Object)frame.getStartValue().get());
                }
                if (!frame.getEndValue().isPresent()) continue;
                expectedInputs.add((Object)frame.getEndValue().get());
            }
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add((Object)node.getHashSymbol().get());
            }
            ImmutableMap.Builder functionsBuilder = ImmutableMap.builder();
            for (Map.Entry<Symbol, WindowNode.Function> entry : node.getWindowFunctions().entrySet()) {
                Symbol symbol = entry.getKey();
                WindowNode.Function function = entry.getValue();
                if (!context.get().contains(symbol)) continue;
                expectedInputs.addAll(SymbolsExtractor.extractUnique(function));
                functionsBuilder.put((Object)symbol, (Object)entry.getValue());
            }
            PlanNode planNode = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            ImmutableMap immutableMap = functionsBuilder.build();
            if (immutableMap.size() == 0) {
                return planNode;
            }
            return new WindowNode(node.getId(), planNode, node.getSpecification(), (Map<Symbol, WindowNode.Function>)immutableMap, node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix());
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            return PruneTableScanColumns.pruneColumns(this.metadata, this.typeAnalyzer, this.types, this.session, node, context.get()).orElse(node);
        }

        @Override
        public PlanNode visitFilter(FilterNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet expectedInputs = ImmutableSet.builder().addAll(SymbolsExtractor.extractUnique(node.getPredicate())).addAll((Iterable)context.get()).build();
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs);
            return new FilterNode(node.getId(), source, node.getPredicate());
        }

        @Override
        public PlanNode visitGroupId(GroupIdNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder();
            List<Symbol> newAggregationArguments = node.getAggregationArguments().stream().filter(context.get()::contains).collect(Collectors.toList());
            expectedInputs.addAll(newAggregationArguments);
            ImmutableList.Builder newGroupingSets = ImmutableList.builder();
            HashMap<Symbol, Symbol> newGroupingMapping = new HashMap<Symbol, Symbol>();
            for (List<Symbol> groupingSet : node.getGroupingSets()) {
                ImmutableList.Builder newGroupingSet = ImmutableList.builder();
                for (Symbol output : groupingSet) {
                    if (!context.get().contains(output)) continue;
                    newGroupingSet.add((Object)output);
                    newGroupingMapping.putIfAbsent(output, node.getGroupingColumns().get(output));
                    expectedInputs.add((Object)node.getGroupingColumns().get(output));
                }
                newGroupingSets.add((Object)newGroupingSet.build());
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new GroupIdNode(node.getId(), source, (List<List<Symbol>>)newGroupingSets.build(), newGroupingMapping, newAggregationArguments, node.getGroupIdSymbol());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            if (!context.get().contains(node.getMarkerSymbol())) {
                return context.rewrite(node.getSource(), context.get());
            }
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll(node.getDistinctSymbols()).addAll((Iterable)context.get().stream().filter(symbol -> !symbol.equals(node.getMarkerSymbol())).collect(ImmutableList.toImmutableList()));
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add((Object)node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new MarkDistinctNode(node.getId(), source, node.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol());
        }

        @Override
        public PlanNode visitUnnest(UnnestNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            List replicateSymbols = (List)node.getReplicateSymbols().stream().filter(context.get()::contains).collect(ImmutableList.toImmutableList());
            Optional<Symbol> ordinalitySymbol = node.getOrdinalitySymbol();
            if (ordinalitySymbol.isPresent() && !context.get().contains(ordinalitySymbol.get())) {
                ordinalitySymbol = Optional.empty();
            }
            Map<Symbol, List<Symbol>> unnestSymbols = node.getUnnestSymbols();
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)replicateSymbols).addAll(unnestSymbols.keySet());
            ImmutableSet.Builder unnestedSymbols = ImmutableSet.builder();
            for (List<Symbol> symbols : unnestSymbols.values()) {
                unnestedSymbols.addAll(symbols);
            }
            Sets.SetView expectedFilterSymbols = Sets.difference(SymbolsExtractor.extractUnique(node.getFilter().orElse((Expression)BooleanLiteral.TRUE_LITERAL)), (Set)unnestedSymbols.build());
            expectedInputs.addAll((Iterable)expectedFilterSymbols);
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol, node.getJoinType(), node.getFilter());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder();
            Assignments.Builder builder = Assignments.builder();
            node.getAssignments().forEach((symbol, expression) -> {
                if (((Set)context.get()).contains(symbol)) {
                    expectedInputs.addAll(SymbolsExtractor.extractUnique(expression));
                    builder.put((Symbol)symbol, (Expression)expression);
                }
            });
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new ProjectNode(node.getId(), source, builder.build());
        }

        @Override
        public PlanNode visitOutput(OutputNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet expectedInputs = ImmutableSet.copyOf(node.getOutputSymbols());
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs);
            return new OutputNode(node.getId(), source, node.getColumnNames(), node.getOutputSymbols());
        }

        @Override
        public PlanNode visitOffset(OffsetNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)context.get());
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new OffsetNode(node.getId(), source, node.getCount());
        }

        @Override
        public PlanNode visitLimit(LimitNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)context.get()).addAll((Iterable)node.getTiesResolvingScheme().map(OrderingScheme::getOrderBy).orElse((List)ImmutableList.of()));
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new LimitNode(node.getId(), source, node.getCount(), node.isPartial());
        }

        @Override
        public PlanNode visitDistinctLimit(DistinctLimitNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet expectedInputs = node.getHashSymbol().isPresent() ? ImmutableSet.copyOf((Iterable)Iterables.concat(node.getDistinctSymbols(), (Iterable)ImmutableList.of((Object)node.getHashSymbol().get()))) : ImmutableSet.copyOf(node.getDistinctSymbols());
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs);
            return new DistinctLimitNode(node.getId(), source, node.getLimit(), node.isPartial(), node.getDistinctSymbols(), node.getHashSymbol());
        }

        @Override
        public PlanNode visitTopN(TopNNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)context.get()).addAll(node.getOrderingScheme().getOrderBy());
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new TopNNode(node.getId(), source, node.getCount(), node.getOrderingScheme(), node.getStep());
        }

        @Override
        public PlanNode visitRowNumber(RowNumberNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder inputsBuilder = ImmutableSet.builder();
            ImmutableSet.Builder expectedInputs = inputsBuilder.addAll((Iterable)context.get()).addAll(node.getPartitionBy());
            if (node.getHashSymbol().isPresent()) {
                inputsBuilder.add((Object)node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new RowNumberNode(node.getId(), source, node.getPartitionBy(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.getHashSymbol());
        }

        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll((Iterable)context.get()).addAll(node.getPartitionBy()).addAll(node.getOrderingScheme().getOrderBy());
            if (node.getHashSymbol().isPresent()) {
                expectedInputs.add((Object)node.getHashSymbol().get());
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new TopNRowNumberNode(node.getId(), source, node.getSpecification(), node.getRowNumberSymbol(), node.getMaxRowCountPerPartition(), node.isPartial(), node.getHashSymbol());
        }

        @Override
        public PlanNode visitSort(SortNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet expectedInputs = ImmutableSet.copyOf((Iterable)Iterables.concat((Iterable)context.get(), node.getOrderingScheme().getOrderBy()));
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs);
            return new SortNode(node.getId(), source, node.getOrderingScheme(), node.isPartial());
        }

        @Override
        public PlanNode visitTableWriter(TableWriterNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableSet.Builder expectedInputs = ImmutableSet.builder().addAll(node.getColumns());
            if (node.getPartitioningScheme().isPresent()) {
                PartitioningScheme partitioningScheme = node.getPartitioningScheme().get();
                partitioningScheme.getPartitioning().getColumns().forEach(arg_0 -> ((ImmutableSet.Builder)expectedInputs).add(arg_0));
                partitioningScheme.getHashColumn().ifPresent(arg_0 -> ((ImmutableSet.Builder)expectedInputs).add(arg_0));
            }
            if (node.getStatisticsAggregation().isPresent()) {
                StatisticAggregations aggregations = node.getStatisticsAggregation().get();
                expectedInputs.addAll(aggregations.getGroupingSymbols());
                aggregations.getAggregations().values().forEach(aggregation -> expectedInputs.addAll(SymbolsExtractor.extractUnique(aggregation)));
            }
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)expectedInputs.build());
            return new TableWriterNode(node.getId(), source, node.getTarget(), node.getRowCountSymbol(), node.getFragmentSymbol(), node.getColumns(), node.getColumnNames(), node.getPartitioningScheme(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        @Override
        public PlanNode visitStatisticsWriterNode(StatisticsWriterNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
            return new StatisticsWriterNode(node.getId(), source, node.getTarget(), node.getRowCountSymbol(), node.isRowCountEnabled(), node.getDescriptor());
        }

        @Override
        public PlanNode visitTableFinish(TableFinishNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)ImmutableSet.copyOf(node.getSource().getOutputSymbols()));
            return new TableFinishNode(node.getId(), source, node.getTarget(), node.getRowCountSymbol(), node.getStatisticsAggregation(), node.getStatisticsAggregationDescriptor());
        }

        @Override
        public PlanNode visitDelete(DeleteNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            PlanNode source = context.rewrite(node.getSource(), (Set<Symbol>)ImmutableSet.of((Object)node.getRowId()));
            return new DeleteNode(node.getId(), source, node.getTarget(), node.getRowId(), node.getOutputSymbols());
        }

        @Override
        public PlanNode visitUnion(UnionNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableListMultimap.Builder prunedMappingBuilder = ImmutableListMultimap.builder();
            for (Symbol symbol : node.getOutputSymbols()) {
                if (!context.get().contains(symbol)) continue;
                prunedMappingBuilder.putAll((Object)symbol, (Iterable)node.getSymbolMapping().get((Object)symbol));
            }
            ImmutableListMultimap prunedSymbolMapping = prunedMappingBuilder.build();
            ImmutableList.Builder rewrittenSources = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); ++i) {
                ImmutableSet.Builder expectedSourceSymbols = ImmutableSet.builder();
                for (Collection symbols : prunedSymbolMapping.asMap().values()) {
                    expectedSourceSymbols.add(Iterables.get((Iterable)symbols, (int)i));
                }
                rewrittenSources.add((Object)context.rewrite(node.getSources().get(i), (Set<Symbol>)expectedSourceSymbols.build()));
            }
            return new UnionNode(node.getId(), (List<PlanNode>)rewrittenSources.build(), (ListMultimap<Symbol, Symbol>)prunedSymbolMapping, (List<Symbol>)ImmutableList.copyOf((Collection)prunedSymbolMapping.keySet()));
        }

        @Override
        public PlanNode visitIntersect(IntersectNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            return this.rewriteSetOperationChildren(node, context);
        }

        @Override
        public PlanNode visitExcept(ExceptNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            return this.rewriteSetOperationChildren(node, context);
        }

        private PlanNode rewriteSetOperationChildren(SetOperationNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableList.Builder rewrittenSources = ImmutableList.builder();
            for (int i = 0; i < node.getSources().size(); ++i) {
                rewrittenSources.add((Object)context.rewrite(node.getSources().get(i), (Set<Symbol>)ImmutableSet.copyOf(node.sourceOutputLayout(i))));
            }
            return node.replaceChildren((List<PlanNode>)rewrittenSources.build());
        }

        @Override
        public PlanNode visitValues(ValuesNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            ImmutableList.Builder rewrittenOutputSymbolsBuilder = ImmutableList.builder();
            ImmutableList.Builder rowBuildersBuilder = ImmutableList.builder();
            for (int i = 0; i < node.getRows().size(); ++i) {
                rowBuildersBuilder.add((Object)ImmutableList.builder());
            }
            ImmutableList rowBuilders = rowBuildersBuilder.build();
            for (int i = 0; i < node.getOutputSymbols().size(); ++i) {
                Symbol outputSymbol = node.getOutputSymbols().get(i);
                if (!context.get().contains(outputSymbol)) continue;
                rewrittenOutputSymbolsBuilder.add((Object)outputSymbol);
                for (int j = 0; j < node.getRows().size(); ++j) {
                    ((ImmutableList.Builder)rowBuilders.get(j)).add((Object)node.getRows().get(j).get(i));
                }
            }
            List rewrittenRows = (List)rowBuilders.stream().map(ImmutableList.Builder::build).collect(ImmutableList.toImmutableList());
            return new ValuesNode(node.getId(), (List<Symbol>)rewrittenOutputSymbolsBuilder.build(), rewrittenRows);
        }

        @Override
        public PlanNode visitApply(ApplyNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            if (Sets.intersection(node.getSubqueryAssignments().getSymbols(), context.get()).isEmpty()) {
                return context.rewrite(node.getInput(), context.get());
            }
            ImmutableSet.Builder subqueryAssignmentsSymbolsBuilder = ImmutableSet.builder();
            Assignments.Builder subqueryAssignments = Assignments.builder();
            for (Map.Entry<Symbol, Expression> entry : node.getSubqueryAssignments().getMap().entrySet()) {
                Symbol output = entry.getKey();
                Expression expression = entry.getValue();
                if (!context.get().contains(output)) continue;
                subqueryAssignmentsSymbolsBuilder.addAll(SymbolsExtractor.extractUnique(expression));
                subqueryAssignments.put(output, expression);
            }
            ImmutableSet subqueryAssignmentsSymbols = subqueryAssignmentsSymbolsBuilder.build();
            PlanNode subquery = context.rewrite(node.getSubquery(), (Set<Symbol>)subqueryAssignmentsSymbols);
            Set<Symbol> subquerySymbols = SymbolsExtractor.extractUnique(subquery);
            List newCorrelation = (List)node.getCorrelation().stream().filter(subquerySymbols::contains).collect(ImmutableList.toImmutableList());
            ImmutableSet inputContext = ImmutableSet.builder().addAll((Iterable)context.get()).addAll((Iterable)newCorrelation).addAll((Iterable)subqueryAssignmentsSymbols).build();
            PlanNode input = context.rewrite(node.getInput(), (Set<Symbol>)inputContext);
            return new ApplyNode(node.getId(), input, subquery, subqueryAssignments.build(), newCorrelation, node.getOriginSubquery());
        }

        @Override
        public PlanNode visitAssignUniqueId(AssignUniqueId node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            if (!context.get().contains(node.getIdColumn())) {
                return context.rewrite(node.getSource(), context.get());
            }
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitCorrelatedJoin(CorrelatedJoinNode node, SimplePlanRewriter.RewriteContext<Set<Symbol>> context) {
            Set<Symbol> expectedFilterSymbols = SymbolsExtractor.extractUnique(node.getFilter());
            ImmutableSet expectedFilterAndContextSymbols = ImmutableSet.builder().addAll(expectedFilterSymbols).addAll((Iterable)context.get()).build();
            PlanNode subquery = context.rewrite(node.getSubquery(), (Set<Symbol>)expectedFilterAndContextSymbols);
            if (Sets.intersection((Set)ImmutableSet.copyOf(subquery.getOutputSymbols()), context.get()).isEmpty()) {
                if (node.getType() == CorrelatedJoinNode.Type.INNER && QueryCardinalityUtil.isScalar(subquery) && node.getFilter().equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                    return context.rewrite(node.getInput(), context.get());
                }
                if (node.getType() == CorrelatedJoinNode.Type.LEFT && QueryCardinalityUtil.isAtMostScalar(subquery)) {
                    return context.rewrite(node.getInput(), context.get());
                }
            }
            Set<Symbol> subquerySymbols = SymbolsExtractor.extractUnique(subquery);
            List newCorrelation = (List)node.getCorrelation().stream().filter(subquerySymbols::contains).collect(ImmutableList.toImmutableList());
            ImmutableSet expectedCorrelationAndContextSymbols = ImmutableSet.builder().addAll((Iterable)newCorrelation).addAll((Iterable)context.get()).build();
            ImmutableSet inputContext = ImmutableSet.builder().addAll((Iterable)expectedCorrelationAndContextSymbols).addAll(expectedFilterSymbols).build();
            PlanNode input = context.rewrite(node.getInput(), (Set<Symbol>)inputContext);
            if (Sets.intersection((Set)ImmutableSet.copyOf(input.getOutputSymbols()), (Set)expectedCorrelationAndContextSymbols).isEmpty()) {
                if (node.getType() == CorrelatedJoinNode.Type.INNER && QueryCardinalityUtil.isScalar(input) && node.getFilter().equals((Object)BooleanLiteral.TRUE_LITERAL)) {
                    return subquery;
                }
                if (node.getType() == CorrelatedJoinNode.Type.RIGHT && QueryCardinalityUtil.isAtMostScalar(input)) {
                    return subquery;
                }
            }
            return new CorrelatedJoinNode(node.getId(), input, subquery, newCorrelation, node.getType(), node.getFilter(), node.getOriginSubquery());
        }
    }
}

