/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.FilterDesc;
import org.apache.hadoop.hive.ql.plan.JoinDesc;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.ReduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.SelectDesc;
import org.apache.hadoop.hive.ql.plan.TableScanDesc;
import org.apache.hadoop.hive.ql.plan.UnionDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;

public class SkewJoinOptimizer
implements Transform {
    private static final Log LOG = LogFactory.getLog((String)SkewJoinOptimizer.class.getName());

    @Override
    public ParseContext transform(ParseContext pctx) throws SemanticException {
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp("R1", "TS%.*RS%JOIN%"), this.getSkewJoinProc(pctx));
        SkewJoinOptProcCtx skewJoinOptProcCtx = new SkewJoinOptProcCtx(pctx);
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(null, opRules, skewJoinOptProcCtx);
        DefaultGraphWalker ogw = new DefaultGraphWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(pctx.getTopOps().values());
        ogw.startWalking(topNodes, null);
        return pctx;
    }

    private NodeProcessor getSkewJoinProc(ParseContext parseContext) {
        return new SkewJoinProc(parseContext);
    }

    public static class SkewJoinOptProcCtx
    implements NodeProcessorCtx {
        private ParseContext pGraphContext;
        private Set<JoinOperator> doneJoins;
        private Map<TableScanOperator, TableScanOperator> cloneTSOpMap;

        public SkewJoinOptProcCtx(ParseContext pctx) {
            this.pGraphContext = pctx;
            this.doneJoins = new HashSet<JoinOperator>();
            this.cloneTSOpMap = new HashMap<TableScanOperator, TableScanOperator>();
        }

        public ParseContext getpGraphContext() {
            return this.pGraphContext;
        }

        public void setPGraphContext(ParseContext graphContext) {
            this.pGraphContext = graphContext;
        }

        public Set<JoinOperator> getDoneJoins() {
            return this.doneJoins;
        }

        public void setDoneJoins(Set<JoinOperator> doneJoins) {
            this.doneJoins = doneJoins;
        }

        public Map<TableScanOperator, TableScanOperator> getCloneTSOpMap() {
            return this.cloneTSOpMap;
        }

        public void setCloneTSOpMap(Map<TableScanOperator, TableScanOperator> cloneTSOpMap) {
            this.cloneTSOpMap = cloneTSOpMap;
        }
    }

    public static class SkewJoinProc
    implements NodeProcessor {
        private ParseContext parseContext;

        public SkewJoinProc(ParseContext parseContext) {
            this.parseContext = parseContext;
        }

        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            Object currOpClone;
            ArrayList<TableScanOperator> tableScanOpsForJoin;
            SkewJoinOptProcCtx ctx = (SkewJoinOptProcCtx)procCtx;
            this.parseContext = ctx.getpGraphContext();
            JoinOperator joinOp = (JoinOperator)nd;
            if (ctx.getDoneJoins().contains(joinOp)) {
                return null;
            }
            ctx.getDoneJoins().add(joinOp);
            Operator currOp = joinOp;
            boolean processSelect = false;
            if (joinOp.getChildOperators().size() == 1 && joinOp.getChildOperators().get(0) instanceof SelectOperator) {
                currOp = joinOp.getChildOperators().get(0);
                processSelect = true;
            }
            if (!this.getTableScanOpsForJoin(joinOp, tableScanOpsForJoin = new ArrayList<TableScanOperator>())) {
                return null;
            }
            if (tableScanOpsForJoin == null || tableScanOpsForJoin.isEmpty()) {
                return null;
            }
            Map<List<ExprNodeDesc>, List<List<String>>> skewedValues = this.getSkewedValues(joinOp, tableScanOpsForJoin);
            if (skewedValues == null || skewedValues.size() == 0) {
                return null;
            }
            try {
                currOpClone = currOp.clone();
                this.insertRowResolvers(currOp, (Operator<? extends OperatorDesc>)currOpClone, ctx);
            }
            catch (CloneNotSupportedException e) {
                LOG.debug((Object)"Operator tree could not be cloned");
                return null;
            }
            JoinOperator joinOpClone = processSelect ? (JoinOperator)((Operator)currOpClone).getParentOperators().get(0) : (JoinOperator)currOpClone;
            ((JoinDesc)joinOpClone.getConf()).cloneQBJoinTreeProps((JoinDesc)joinOp.getConf());
            this.parseContext.getJoinOps().add(joinOpClone);
            ArrayList<TableScanOperator> tableScanCloneOpsForJoin = new ArrayList<TableScanOperator>();
            if (!this.getTableScanOpsForJoin(joinOpClone, tableScanCloneOpsForJoin)) {
                LOG.debug((Object)"Operator tree not properly cloned!");
                return null;
            }
            this.insertSkewFilter(tableScanOpsForJoin, skewedValues, true);
            this.insertSkewFilter(tableScanCloneOpsForJoin, skewedValues, false);
            Map<String, Operator<? extends OperatorDesc>> topOps = this.getTopOps(joinOpClone);
            HashMap<String, Operator<? extends OperatorDesc>> origTopOps = this.parseContext.getTopOps();
            for (Map.Entry<String, Operator<? extends OperatorDesc>> topOp : topOps.entrySet()) {
                TableScanOperator tso = (TableScanOperator)topOp.getValue();
                String tabAlias = ((TableScanDesc)tso.getConf()).getAlias();
                int initCnt = 1;
                String newAlias = "subquery" + initCnt + ":" + tabAlias;
                while (origTopOps.containsKey(newAlias)) {
                    newAlias = "subquery" + ++initCnt + ":" + tabAlias;
                }
                this.parseContext.getTopOps().put(newAlias, tso);
                SkewJoinProc.setUpAlias(joinOp, joinOpClone, tabAlias, newAlias, tso);
            }
            List<Operator<? extends OperatorDesc>> finalOps = currOp.getChildOperators();
            currOp.setChildOperators(null);
            ((Operator)currOpClone).setChildOperators(null);
            ArrayList<Operator<? extends OperatorDesc>> oplist = new ArrayList<Operator<? extends OperatorDesc>>();
            oplist.add(currOp);
            oplist.add((Operator<? extends OperatorDesc>)currOpClone);
            Operator<UnionDesc> unionOp = OperatorFactory.getAndMakeChild(new UnionDesc(), new RowSchema(currOp.getSchema().getSignature()), oplist);
            ArrayList<Operator<? extends OperatorDesc>> unionList = new ArrayList<Operator<? extends OperatorDesc>>();
            unionList.add(unionOp);
            Operator<SelectDesc> selectUnionOp = OperatorFactory.getAndMakeChild(new SelectDesc(true), new RowSchema(unionOp.getSchema().getSignature()), unionList);
            selectUnionOp.setChildOperators(finalOps);
            for (Operator<? extends OperatorDesc> finalOp : finalOps) {
                finalOp.replaceParent(currOp, selectUnionOp);
            }
            return null;
        }

        private boolean getTableScanOpsForJoin(JoinOperator op, List<TableScanOperator> tsOps) {
            for (Operator<OperatorDesc> parent : op.getParentOperators()) {
                if (this.getTableScanOps(parent, tsOps)) continue;
                return false;
            }
            return true;
        }

        private boolean getTableScanOps(Operator<? extends OperatorDesc> op, List<TableScanOperator> tsOps) {
            for (Operator<OperatorDesc> parent : op.getParentOperators()) {
                if (!parent.supportSkewJoinOptimization()) {
                    return false;
                }
                if (parent instanceof TableScanOperator) {
                    tsOps.add((TableScanOperator)parent);
                    continue;
                }
                if (this.getTableScanOps(parent, tsOps)) continue;
                return false;
            }
            return true;
        }

        private Map<List<ExprNodeDesc>, List<List<String>>> getSkewedValues(Operator<? extends OperatorDesc> op, List<TableScanOperator> tableScanOpsForJoin) {
            HashMap<List<ExprNodeDesc>, List<List<String>>> skewDataReturn = new HashMap<List<ExprNodeDesc>, List<List<String>>>();
            HashMap skewData = new HashMap();
            for (Operator<OperatorDesc> operator : op.getParentOperators()) {
                ReduceSinkDesc rsDesc = (ReduceSinkDesc)((ReduceSinkOperator)operator).getConf();
                if (rsDesc.getKeyCols() == null) continue;
                Table table = null;
                List<String> skewedColumns = null;
                List<List<String>> skewedValueList = null;
                ArrayList<ExprNodeDesc.ExprNodeDescEqualityWrapper> joinKeysSkewedCols = new ArrayList<ExprNodeDesc.ExprNodeDescEqualityWrapper>();
                ArrayList<Integer> positionSkewedKeys = new ArrayList<Integer>();
                for (ExprNodeDesc keyColDesc : rsDesc.getKeyCols()) {
                    int pos;
                    ExprNodeColumnDesc keyCol = null;
                    if (!(keyColDesc instanceof ExprNodeColumnDesc)) continue;
                    keyCol = (ExprNodeColumnDesc)keyColDesc;
                    if (table == null) {
                        table = this.getTable(this.parseContext, operator, tableScanOpsForJoin);
                        skewedColumns = table == null ? null : table.getSkewedColNames();
                        if (skewedColumns == null || skewedColumns.isEmpty()) continue;
                        List<List<String>> list = skewedValueList = table == null ? null : table.getSkewedColValues();
                    }
                    if ((pos = skewedColumns.indexOf(keyCol.getColumn())) < 0 || positionSkewedKeys.contains(pos)) continue;
                    positionSkewedKeys.add(pos);
                    ExprNodeColumnDesc keyColClone = (ExprNodeColumnDesc)keyCol.clone();
                    keyColClone.setTabAlias(null);
                    joinKeysSkewedCols.add(new ExprNodeDesc.ExprNodeDescEqualityWrapper(keyColClone));
                }
                if (skewedColumns == null || skewedColumns.isEmpty() || joinKeysSkewedCols.isEmpty()) continue;
                List<List<String>> skewedJoinValues = skewedColumns.size() == positionSkewedKeys.size() ? skewedValueList : this.getSkewedJoinValues(skewedValueList, positionSkewedKeys);
                ArrayList<List<String>> oldSkewedJoinValues = (ArrayList<List<String>>)skewData.get(joinKeysSkewedCols);
                if (oldSkewedJoinValues == null) {
                    oldSkewedJoinValues = new ArrayList<List<String>>();
                }
                for (List<String> skewValue : skewedJoinValues) {
                    if (oldSkewedJoinValues.contains(skewValue)) continue;
                    oldSkewedJoinValues.add(skewValue);
                }
                skewData.put(joinKeysSkewedCols, oldSkewedJoinValues);
            }
            for (Map.Entry entry : skewData.entrySet()) {
                ArrayList<ExprNodeDesc> skewedKeyJoinCols = new ArrayList<ExprNodeDesc>();
                for (ExprNodeDesc.ExprNodeDescEqualityWrapper key : (List)entry.getKey()) {
                    skewedKeyJoinCols.add(key.getExprNodeDesc());
                }
                skewDataReturn.put((List<ExprNodeDesc>)skewedKeyJoinCols, (List<List<String>>)entry.getValue());
            }
            return skewDataReturn;
        }

        private Table getTable(ParseContext parseContext, Operator<? extends OperatorDesc> op, List<TableScanOperator> tableScanOpsForJoin) {
            TableScanOperator tsOp;
            while (!(op instanceof TableScanOperator) || !tableScanOpsForJoin.contains(tsOp = (TableScanOperator)op)) {
                if (op.getParentOperators() == null || op.getParentOperators().isEmpty() || op.getParentOperators().size() > 1) {
                    return null;
                }
                op = op.getParentOperators().get(0);
            }
            return ((TableScanDesc)tsOp.getConf()).getTableMetadata();
        }

        private List<List<String>> getSkewedJoinValues(List<List<String>> skewedValueList, List<Integer> positionSkewedKeys) {
            ArrayList<List<String>> skewedJoinValues = new ArrayList<List<String>>();
            for (List<String> skewedValuesAllColumns : skewedValueList) {
                ArrayList<String> skewedValuesSpecifiedColumns = new ArrayList<String>();
                for (int pos : positionSkewedKeys) {
                    skewedValuesSpecifiedColumns.add(skewedValuesAllColumns.get(pos));
                }
                skewedJoinValues.add(skewedValuesSpecifiedColumns);
            }
            return skewedJoinValues;
        }

        private void insertSkewFilter(List<TableScanOperator> tableScanOpsForJoin, Map<List<ExprNodeDesc>, List<List<String>>> skewedValuesList, boolean skewed) {
            ExprNodeDesc filterExpr = this.constructFilterExpr(skewedValuesList, skewed);
            for (TableScanOperator tableScanOp : tableScanOpsForJoin) {
                this.insertFilterOnTop(tableScanOp, filterExpr);
            }
        }

        private void insertFilterOnTop(TableScanOperator tableScanOp, ExprNodeDesc filterExpr) {
            Operator<OperatorDesc> currChild = tableScanOp.getChildOperators().get(0);
            tableScanOp.setChildOperators(null);
            currChild.setParentOperators(null);
            Operator<FilterDesc> filter = OperatorFactory.getAndMakeChild(new FilterDesc(filterExpr, false), new RowSchema(tableScanOp.getSchema().getSignature()), tableScanOp);
            OperatorFactory.makeChild(filter, currChild);
        }

        private ExprNodeDesc constructFilterExpr(Map<List<ExprNodeDesc>, List<List<String>>> skewedValuesMap, boolean skewed) {
            ExprNodeGenericFuncDesc finalExprNodeDesc;
            block8: {
                finalExprNodeDesc = null;
                try {
                    for (Map.Entry<List<ExprNodeDesc>, List<List<String>>> mapEntry : skewedValuesMap.entrySet()) {
                        List<ExprNodeDesc> keyCols = mapEntry.getKey();
                        List<List<String>> skewedValuesList = mapEntry.getValue();
                        for (List<String> skewedValues : skewedValuesList) {
                            int keyPos = 0;
                            ExprNodeGenericFuncDesc currExprNodeDesc = null;
                            for (String skewedValue : skewedValues) {
                                ArrayList<ExprNodeDesc> children = new ArrayList<ExprNodeDesc>();
                                ExprNodeColumnDesc keyCol = (ExprNodeColumnDesc)keyCols.get(keyPos).clone();
                                ++keyPos;
                                children.add(keyCol);
                                children.add(this.createConstDesc(skewedValue, keyCol));
                                ExprNodeGenericFuncDesc expr = null;
                                expr = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPEqual(), children);
                                if (currExprNodeDesc == null) {
                                    currExprNodeDesc = expr;
                                    continue;
                                }
                                ArrayList<ExprNodeDesc> childrenAND = new ArrayList<ExprNodeDesc>();
                                childrenAND.add(currExprNodeDesc);
                                childrenAND.add(expr);
                                currExprNodeDesc = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPAnd(), childrenAND);
                            }
                            if (finalExprNodeDesc == null) {
                                finalExprNodeDesc = currExprNodeDesc;
                                continue;
                            }
                            ArrayList<ExprNodeDesc> childrenOR = new ArrayList<ExprNodeDesc>();
                            childrenOR.add(finalExprNodeDesc);
                            childrenOR.add(currExprNodeDesc);
                            finalExprNodeDesc = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPOr(), childrenOR);
                        }
                    }
                    if (!skewed) {
                        ArrayList<ExprNodeDesc> childrenNOT = new ArrayList<ExprNodeDesc>();
                        childrenNOT.add(finalExprNodeDesc);
                        finalExprNodeDesc = ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNot(), childrenNOT);
                    }
                }
                catch (UDFArgumentException e) {
                    if ($assertionsDisabled) break block8;
                    throw new AssertionError();
                }
            }
            return finalExprNodeDesc;
        }

        private ExprNodeConstantDesc createConstDesc(String skewedValue, ExprNodeColumnDesc keyCol) {
            ObjectInspector inputOI = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(TypeInfoFactory.stringTypeInfo);
            ObjectInspector outputOI = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(keyCol.getTypeInfo());
            ObjectInspectorConverters.Converter converter = ObjectInspectorConverters.getConverter(inputOI, outputOI);
            Object skewedValueObject = converter.convert(skewedValue);
            return new ExprNodeConstantDesc(keyCol.getTypeInfo(), skewedValueObject);
        }

        private Map<String, Operator<? extends OperatorDesc>> getTopOps(Operator<? extends OperatorDesc> op) {
            LinkedHashMap<String, Operator<? extends OperatorDesc>> topOps = new LinkedHashMap<String, Operator<? extends OperatorDesc>>();
            if (op.getParentOperators() == null || op.getParentOperators().size() == 0) {
                topOps.put(((TableScanDesc)((TableScanOperator)op).getConf()).getAlias(), op);
            } else {
                for (Operator<OperatorDesc> parent : op.getParentOperators()) {
                    if (parent == null) continue;
                    topOps.putAll(this.getTopOps(parent));
                }
            }
            return topOps;
        }

        private void insertRowResolvers(Operator<? extends OperatorDesc> op, Operator<? extends OperatorDesc> opClone, SkewJoinOptProcCtx ctx) {
            if (op instanceof TableScanOperator) {
                ctx.getCloneTSOpMap().put((TableScanOperator)opClone, (TableScanOperator)op);
            }
            List<Operator<OperatorDesc>> parents = op.getParentOperators();
            List<Operator<OperatorDesc>> parentClones = opClone.getParentOperators();
            if (parents != null && !parents.isEmpty() && parentClones != null && !parentClones.isEmpty()) {
                for (int pos = 0; pos < parents.size(); ++pos) {
                    this.insertRowResolvers(parents.get(pos), parentClones.get(pos), ctx);
                }
            }
        }

        private static void setUpAlias(JoinOperator origin, JoinOperator cloned, String origAlias, String newAlias, Operator<? extends OperatorDesc> topOp) {
            ((JoinDesc)cloned.getConf()).getAliasToOpInfo().remove(origAlias);
            ((JoinDesc)cloned.getConf()).getAliasToOpInfo().put(newAlias, topOp);
            if (((JoinDesc)origin.getConf()).getLeftAlias().equals(origAlias)) {
                ((JoinDesc)cloned.getConf()).setLeftAlias(null);
                ((JoinDesc)cloned.getConf()).setLeftAlias(newAlias);
            }
            SkewJoinProc.replaceAlias(((JoinDesc)origin.getConf()).getLeftAliases(), ((JoinDesc)cloned.getConf()).getLeftAliases(), origAlias, newAlias);
            SkewJoinProc.replaceAlias(((JoinDesc)origin.getConf()).getRightAliases(), ((JoinDesc)cloned.getConf()).getRightAliases(), origAlias, newAlias);
            SkewJoinProc.replaceAlias(((JoinDesc)origin.getConf()).getBaseSrc(), ((JoinDesc)cloned.getConf()).getBaseSrc(), origAlias, newAlias);
            SkewJoinProc.replaceAlias(((JoinDesc)origin.getConf()).getMapAliases(), ((JoinDesc)cloned.getConf()).getMapAliases(), origAlias, newAlias);
            SkewJoinProc.replaceAlias(((JoinDesc)origin.getConf()).getStreamAliases(), ((JoinDesc)cloned.getConf()).getStreamAliases(), origAlias, newAlias);
        }

        private static void replaceAlias(String[] origin, String[] cloned, String alias, String newAlias) {
            if (origin == null || cloned == null || origin.length != cloned.length) {
                return;
            }
            for (int i = 0; i < origin.length; ++i) {
                if (!origin[i].equals(alias)) continue;
                cloned[i] = newAlias;
            }
        }

        private static void replaceAlias(List<String> origin, List<String> cloned, String alias, String newAlias) {
            if (origin == null || cloned == null || origin.size() != cloned.size()) {
                return;
            }
            for (int i = 0; i < origin.size(); ++i) {
                if (!origin.get(i).equals(alias)) continue;
                cloned.set(i, newAlias);
            }
        }
    }
}

