package org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.spark;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.TreeSet;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.lib.Dispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.TaskGraphWalker;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.BaseWork;
import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.OperatorDesc;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.SparkEdgeProperty;
import org.apache.hadoop.hive.ql.plan.SparkWork;
import org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.OperatorComparatorFactory;
import org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.physical.PhysicalContext;
import org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver.class */
public class CombineEquivalentWorkResolver implements PhysicalPlanResolver {
    protected static transient Logger LOG = LoggerFactory.getLogger(CombineEquivalentWorkResolver.class);

    /* loaded from: input_file:org/apache/hudi/org/apache/hadoop/hive/ql/optimizer/spark/CombineEquivalentWorkResolver$EquivalentWorkMatcher.class */
    class EquivalentWorkMatcher implements Dispatcher {
        private Comparator<BaseWork> baseWorkComparator = new Comparator<BaseWork>() { // from class: org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.spark.CombineEquivalentWorkResolver.EquivalentWorkMatcher.1
            @Override // java.util.Comparator
            public int compare(BaseWork baseWork, BaseWork baseWork2) {
                return baseWork.getName().compareTo(baseWork2.getName());
            }
        };

        EquivalentWorkMatcher() {
        }

        @Override // org.apache.hadoop.hive.ql.lib.Dispatcher
        public Object dispatch(Node node, Stack<Node> stack, Object... objArr) throws SemanticException {
            if (!(node instanceof SparkTask)) {
                return null;
            }
            SparkWork work = ((SparkTask) node).getWork();
            compareWorksRecursively(work.getRoots(), work);
            return null;
        }

        private void compareWorksRecursively(Set<BaseWork> set, SparkWork sparkWork) {
            Set<BaseWork> combineEquivalentWorks = combineEquivalentWorks(compareChildWorks(set, sparkWork), sparkWork);
            for (BaseWork baseWork : set) {
                if (!combineEquivalentWorks.contains(baseWork)) {
                    HashSet newHashSet = Sets.newHashSet();
                    newHashSet.addAll(sparkWork.getChildren(baseWork));
                    if (newHashSet.size() > 0) {
                        compareWorksRecursively(newHashSet, sparkWork);
                    }
                }
            }
        }

        private Set<Set<BaseWork>> compareChildWorks(Set<BaseWork> set, SparkWork sparkWork) {
            HashSet newHashSet = Sets.newHashSet();
            if (set.size() > 1) {
                for (BaseWork baseWork : set) {
                    boolean z = false;
                    Iterator it = newHashSet.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        Set<BaseWork> set2 = (Set) it.next();
                        if (belongToSet(set2, baseWork, sparkWork)) {
                            set2.add(baseWork);
                            z = true;
                            break;
                        }
                    }
                    if (!z) {
                        TreeSet newTreeSet = Sets.newTreeSet(this.baseWorkComparator);
                        newTreeSet.add(baseWork);
                        newHashSet.add(newTreeSet);
                    }
                }
            }
            return newHashSet;
        }

        private boolean belongToSet(Set<BaseWork> set, BaseWork baseWork, SparkWork sparkWork) {
            return set.isEmpty() || compareWork(set.iterator().next(), baseWork, sparkWork);
        }

        private Set<BaseWork> combineEquivalentWorks(Set<Set<BaseWork>> set, SparkWork sparkWork) {
            HashSet newHashSet = Sets.newHashSet();
            for (Set<BaseWork> set2 : set) {
                if (set2.size() > 1) {
                    Iterator<BaseWork> it = set2.iterator();
                    BaseWork next = it.next();
                    while (it.hasNext()) {
                        BaseWork next2 = it.next();
                        replaceWork(next2, next, sparkWork);
                        newHashSet.add(next2);
                    }
                }
            }
            return newHashSet;
        }

        private void replaceWork(BaseWork baseWork, BaseWork baseWork2, SparkWork sparkWork) {
            updateReference(baseWork, baseWork2, sparkWork);
            List<BaseWork> parents = sparkWork.getParents(baseWork);
            List<BaseWork> children = sparkWork.getChildren(baseWork);
            if (parents != null) {
                Iterator<BaseWork> it = parents.iterator();
                while (it.hasNext()) {
                    sparkWork.disconnect(it.next(), baseWork);
                }
            }
            if (children != null) {
                for (BaseWork baseWork3 : children) {
                    SparkEdgeProperty edgeProperty = sparkWork.getEdgeProperty(baseWork, baseWork3);
                    sparkWork.disconnect(baseWork, baseWork3);
                    sparkWork.connect(baseWork2, baseWork3, edgeProperty);
                }
            }
            sparkWork.remove(baseWork);
        }

        /* JADX WARN: Multi-variable type inference failed */
        private void updateReference(BaseWork baseWork, BaseWork baseWork2, SparkWork sparkWork) {
            String name = baseWork.getName();
            String name2 = baseWork2.getName();
            Iterator<BaseWork> it = sparkWork.getAllWork().iterator();
            while (it.hasNext()) {
                for (Operator<?> operator : it.next().getAllOperators()) {
                    if (operator instanceof MapJoinOperator) {
                        Map<Integer, String> parentToInput = ((MapJoinDesc) ((MapJoinOperator) operator).getConf()).getParentToInput();
                        for (Integer num : parentToInput.keySet()) {
                            if (parentToInput.get(num).equals(name)) {
                                parentToInput.put(num, name2);
                            }
                        }
                    }
                }
            }
        }

        private boolean compareWork(BaseWork baseWork, BaseWork baseWork2, SparkWork sparkWork) {
            if (!baseWork.getClass().getName().equals(baseWork2.getClass().getName()) || !hasSameParent(baseWork, baseWork2, sparkWork)) {
                return false;
            }
            if (sparkWork.getLeaves().contains(baseWork) && sparkWork.getLeaves().contains(baseWork2)) {
                return false;
            }
            if ((baseWork instanceof MapWork) && !compareMapWork((MapWork) baseWork, (MapWork) baseWork2)) {
                return false;
            }
            Set<Operator<? extends OperatorDesc>> allRootOperators = baseWork.getAllRootOperators();
            Set<Operator<? extends OperatorDesc>> allRootOperators2 = baseWork2.getAllRootOperators();
            if (allRootOperators.size() != allRootOperators2.size()) {
                return false;
            }
            Iterator<Operator<? extends OperatorDesc>> it = allRootOperators.iterator();
            Iterator<Operator<? extends OperatorDesc>> it2 = allRootOperators2.iterator();
            while (it.hasNext()) {
                boolean compareOperatorChain = compareOperatorChain(it.next(), it2.next());
                if (!compareOperatorChain) {
                    return compareOperatorChain;
                }
            }
            return true;
        }

        private boolean compareMapWork(MapWork mapWork, MapWork mapWork2) {
            LinkedHashMap<Path, PartitionDesc> pathToPartitionInfo = mapWork.getPathToPartitionInfo();
            LinkedHashMap<Path, PartitionDesc> pathToPartitionInfo2 = mapWork2.getPathToPartitionInfo();
            if (pathToPartitionInfo.size() != pathToPartitionInfo2.size()) {
                return false;
            }
            for (Map.Entry<Path, PartitionDesc> entry : pathToPartitionInfo.entrySet()) {
                if (!entry.getValue().equals(pathToPartitionInfo2.get(entry.getKey()))) {
                    return false;
                }
            }
            return true;
        }

        private boolean hasSameParent(BaseWork baseWork, BaseWork baseWork2, SparkWork sparkWork) {
            boolean z = true;
            List<BaseWork> parents = sparkWork.getParents(baseWork);
            List<BaseWork> parents2 = sparkWork.getParents(baseWork2);
            if (parents.size() != parents2.size()) {
                z = false;
            }
            Iterator<BaseWork> it = parents.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (!parents2.contains(it.next())) {
                    z = false;
                    break;
                }
            }
            return z;
        }

        private boolean compareOperatorChain(Operator<?> operator, Operator<?> operator2) {
            boolean compareCurrentOperator = compareCurrentOperator(operator, operator2);
            if (!compareCurrentOperator) {
                return compareCurrentOperator;
            }
            List<Operator<? extends OperatorDesc>> childOperators = operator.getChildOperators();
            List<Operator<? extends OperatorDesc>> childOperators2 = operator2.getChildOperators();
            if (childOperators == null && childOperators2 != null) {
                return false;
            }
            if (childOperators != null && childOperators2 == null) {
                return false;
            }
            if (childOperators == null || childOperators2 == null) {
                return true;
            }
            if (childOperators.size() != childOperators2.size()) {
                return false;
            }
            int size = childOperators.size();
            for (int i = 0; i < size; i++) {
                if (!compareOperatorChain(childOperators.get(i), childOperators2.get(i))) {
                    return false;
                }
            }
            return true;
        }

        private boolean compareCurrentOperator(Operator<?> operator, Operator<?> operator2) {
            if (operator.getClass().getName().equals(operator2.getClass().getName())) {
                return OperatorComparatorFactory.getOperatorComparator(operator.getClass()).equals(operator, operator2);
            }
            return false;
        }
    }

    @Override // org.apache.hudi.org.apache.hadoop.hive.ql.optimizer.physical.PhysicalPlanResolver
    public PhysicalContext resolve(PhysicalContext physicalContext) throws SemanticException {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(physicalContext.getRootTasks());
        new TaskGraphWalker(new EquivalentWorkMatcher()).startWalking(arrayList, Maps.newHashMap());
        return physicalContext;
    }
}
