package org.apache.hadoop.hive.ql.optimizer.spark;

import io.prestosql.hive.$internal.org.slf4j.Logger;
import io.prestosql.hive.$internal.org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.common.ObjectPair;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants;
import org.apache.hadoop.hive.ql.exec.FileSinkOperator;
import org.apache.hadoop.hive.ql.exec.LimitOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorUtils;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.TerminalOperator;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSession;
import org.apache.hadoop.hive.ql.exec.spark.session.SparkSessionManagerImpl;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.spark.GenSparkUtils;
import org.apache.hadoop.hive.ql.parse.spark.OptimizeSparkProcContext;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.FileSinkDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.stats.StatsUtils;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/spark/SetSparkReducerParallelism.class */
public class SetSparkReducerParallelism implements NodeProcessor {
    private static final Logger LOG = LoggerFactory.getLogger(SetSparkReducerParallelism.class.getName());
    private static final String SPARK_DYNAMIC_ALLOCATION_ENABLED = "spark.dynamicAllocation.enabled";
    private ObjectPair<Long, Integer> sparkMemoryAndCores = null;
    private final boolean useOpStats;

    public SetSparkReducerParallelism(HiveConf hiveConf) {
        this.useOpStats = hiveConf.getBoolVar(HiveConf.ConfVars.SPARK_USE_OP_STATS);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.hadoop.hive.ql.lib.NodeProcessor
    public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
        OptimizeSparkProcContext optimizeSparkProcContext = (OptimizeSparkProcContext) nodeProcessorCtx;
        ReduceSinkOperator reduceSinkOperator = (ReduceSinkOperator) node;
        ReduceSinkDesc reduceSinkDesc = (ReduceSinkDesc) reduceSinkOperator.getConf();
        Set set = null;
        int intVar = optimizeSparkProcContext.getConf().getIntVar(HiveConf.ConfVars.MAXREDUCERS);
        int intVar2 = optimizeSparkProcContext.getConf().getIntVar(HiveConf.ConfVars.HADOOPNUMREDUCERS);
        if (!this.useOpStats) {
            set = OperatorUtils.findOperatorsUpstream(reduceSinkOperator, ReduceSinkOperator.class);
            set.remove(reduceSinkOperator);
            if (!optimizeSparkProcContext.getVisitedReduceSinks().containsAll(set)) {
                LOG.debug("Skipping sink " + reduceSinkOperator + " for now as we haven't seen all its parents.");
                return false;
            }
        }
        if (optimizeSparkProcContext.getVisitedReduceSinks().contains(reduceSinkOperator)) {
            LOG.debug("Already processed reduce sink: " + reduceSinkOperator.getName());
            return true;
        }
        optimizeSparkProcContext.getVisitedReduceSinks().add(reduceSinkOperator);
        if (!needSetParallelism(reduceSinkOperator, optimizeSparkProcContext.getConf())) {
            LOG.info("Number of reducers for sink " + reduceSinkOperator + " was already determined to be: " + reduceSinkDesc.getNumReducers());
        } else if (intVar2 > 0) {
            LOG.info("Parallelism for reduce sink " + reduceSinkOperator + " set by user to " + intVar2);
            reduceSinkDesc.setNumReducers(intVar2);
        } else {
            FileSinkOperator fileSinkOperator = (FileSinkOperator) GenSparkUtils.getChildOperator(reduceSinkOperator, FileSinkOperator.class);
            if (fileSinkOperator != null) {
                String property = ((FileSinkDesc) fileSinkOperator.getConf()).getTableInfo().getProperties().getProperty(hive_metastoreConstants.BUCKET_COUNT);
                int parseInt = property == null ? 0 : Integer.parseInt(property);
                if (parseInt > 0) {
                    LOG.info("Set parallelism for reduce sink " + reduceSinkOperator + " to: " + parseInt + " (buckets)");
                    reduceSinkDesc.setNumReducers(parseInt);
                    return false;
                }
            }
            if (this.useOpStats || set.isEmpty()) {
                long j = 0;
                if (this.useOpStats) {
                    for (Operator<? extends OperatorDesc> operator : reduceSinkOperator.getChildOperators().get(0).getParentOperators()) {
                        if (operator.getStatistics() != null) {
                            j = StatsUtils.safeAdd(j, operator.getStatistics().getDataSize());
                            if (LOG.isDebugEnabled()) {
                                LOG.debug("Sibling " + operator + " has stats: " + operator.getStatistics());
                            }
                        } else {
                            LOG.warn("No stats available from: " + operator);
                        }
                    }
                } else {
                    Iterator<Operator<? extends OperatorDesc>> it = reduceSinkOperator.getChildOperators().get(0).getParentOperators().iterator();
                    while (it.hasNext()) {
                        for (TableScanOperator tableScanOperator : OperatorUtils.findOperatorsUpstream(it.next(), TableScanOperator.class)) {
                            if (tableScanOperator.getStatistics() != null) {
                                j = StatsUtils.safeAdd(j, tableScanOperator.getStatistics().getDataSize());
                                if (LOG.isDebugEnabled()) {
                                    LOG.debug("Table source " + tableScanOperator + " has stats: " + tableScanOperator.getStatistics());
                                }
                            } else {
                                LOG.warn("No stats available from table source: " + tableScanOperator);
                            }
                        }
                    }
                    LOG.debug("Gathered stats for sink " + reduceSinkOperator + ". Total size is " + j + " bytes.");
                }
                long longVar = optimizeSparkProcContext.getConf().getLongVar(HiveConf.ConfVars.BYTESPERREDUCER) / 2;
                int estimateReducers = Utilities.estimateReducers(j, longVar, intVar, false);
                getSparkMemoryAndCores(optimizeSparkProcContext);
                if (this.sparkMemoryAndCores != null && this.sparkMemoryAndCores.getFirst().longValue() > 0 && this.sparkMemoryAndCores.getSecond().intValue() > 0) {
                    if (this.sparkMemoryAndCores.getFirst().longValue() / longVar < 0.5d) {
                        LOG.warn("Average load of a reducer is much larger than its available memory. Consider decreasing hive.exec.reducers.bytes.per.reducer");
                    }
                    estimateReducers = Math.max(estimateReducers, this.sparkMemoryAndCores.getSecond().intValue());
                }
                int min = Math.min(estimateReducers, intVar);
                LOG.info("Set parallelism for reduce sink " + reduceSinkOperator + " to: " + min + " (calculated)");
                reduceSinkDesc.setNumReducers(min);
            } else {
                int i = 0;
                Iterator it2 = set.iterator();
                while (it2.hasNext()) {
                    i = Math.max(i, ((ReduceSinkDesc) ((ReduceSinkOperator) it2.next()).getConf()).getNumReducers());
                }
                reduceSinkDesc.setNumReducers(i);
                LOG.debug("Set parallelism for sink " + reduceSinkOperator + " to " + i + " based on its parents");
            }
            Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> transform = ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(reduceSinkDesc.getKeyCols());
            Collection<ExprNodeDesc.ExprNodeDescEqualityWrapper> transform2 = ExprNodeDesc.ExprNodeDescEqualityWrapper.transform(reduceSinkDesc.getPartitionCols());
            if (transform != null && transform.equals(transform2)) {
                reduceSinkDesc.setReducerTraits(EnumSet.of(ReduceSinkDesc.ReducerTraits.UNIFORM));
            }
        }
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean needSetParallelism(ReduceSinkOperator reduceSinkOperator, HiveConf hiveConf) {
        List<Operator<? extends OperatorDesc>> childOperators;
        ReduceSinkDesc reduceSinkDesc = (ReduceSinkDesc) reduceSinkOperator.getConf();
        if (reduceSinkDesc.getNumReducers() <= 0) {
            return true;
        }
        if (reduceSinkDesc.getNumReducers() != 1 || !reduceSinkDesc.hasOrderBy() || !hiveConf.getBoolVar(HiveConf.ConfVars.HIVESAMPLINGFORORDERBY) || reduceSinkDesc.isDeduplicated()) {
            return false;
        }
        Stack stack = new Stack();
        List<Operator<? extends OperatorDesc>> childOperators2 = reduceSinkOperator.getChildOperators();
        if (childOperators2 != null) {
            Iterator<Operator<? extends OperatorDesc>> it = childOperators2.iterator();
            while (it.hasNext()) {
                stack.push(it.next());
            }
        }
        while (stack.size() != 0) {
            Operator operator = (Operator) stack.pop();
            if (operator instanceof LimitOperator) {
                return false;
            }
            if (!(operator instanceof TerminalOperator) && (childOperators = operator.getChildOperators()) != null) {
                Iterator<Operator<? extends OperatorDesc>> it2 = childOperators.iterator();
                while (it2.hasNext()) {
                    stack.push(it2.next());
                }
            }
        }
        return true;
    }

    private void getSparkMemoryAndCores(OptimizeSparkProcContext optimizeSparkProcContext) throws SemanticException {
        if (this.sparkMemoryAndCores != null) {
            return;
        }
        if (optimizeSparkProcContext.getConf().getBoolean(SPARK_DYNAMIC_ALLOCATION_ENABLED, false)) {
            this.sparkMemoryAndCores = null;
            return;
        }
        SparkSessionManagerImpl sparkSessionManagerImpl = null;
        SparkSession sparkSession = null;
        try {
            try {
                try {
                    sparkSessionManagerImpl = SparkSessionManagerImpl.getInstance();
                    sparkSession = SparkUtilities.getSparkSession(optimizeSparkProcContext.getConf(), sparkSessionManagerImpl);
                    this.sparkMemoryAndCores = sparkSession.getMemoryAndCores();
                    if (sparkSession == null || sparkSessionManagerImpl == null) {
                        return;
                    }
                    try {
                        sparkSessionManagerImpl.returnSession(sparkSession);
                    } catch (HiveException e) {
                        LOG.error("Failed to return the session to SessionManager: " + e, (Throwable) e);
                    }
                } catch (Throwable th) {
                    if (sparkSession != null && sparkSessionManagerImpl != null) {
                        try {
                            sparkSessionManagerImpl.returnSession(sparkSession);
                        } catch (HiveException e2) {
                            LOG.error("Failed to return the session to SessionManager: " + e2, (Throwable) e2);
                        }
                    }
                    throw th;
                }
            } catch (Exception e3) {
                LOG.warn("Failed to get spark memory/core info", (Throwable) e3);
                if (sparkSession == null || sparkSessionManagerImpl == null) {
                    return;
                }
                try {
                    sparkSessionManagerImpl.returnSession(sparkSession);
                } catch (HiveException e4) {
                    LOG.error("Failed to return the session to SessionManager: " + e4, (Throwable) e4);
                }
            }
        } catch (HiveException e5) {
            throw new SemanticException("Failed to get a spark session: " + e5);
        }
    }
}
