package org.apache.nemo.compiler.optimizer.pass.runtime;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.nemo.common.HashRange;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.PartitionSetProperty;
import org.apache.nemo.common.ir.edge.executionproperty.PartitionerProperty;
import org.apache.nemo.common.ir.executionproperty.EdgeExecutionProperty;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceAntiAffinityProperty;
import org.apache.nemo.common.partitioner.HashPartitioner;
import org.apache.nemo.common.partitioner.Partitioner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/runtime/SkewRunTimePass.class */
public final class SkewRunTimePass extends RunTimePass<Map<Object, Long>> {
    private static final Logger LOG = LoggerFactory.getLogger(SkewRunTimePass.class.getName());
    private static final int DEFAULT_NUM_SKEWED_TASKS = 1;
    private final int numSkewedKeys;

    public SkewRunTimePass() {
        this(DEFAULT_NUM_SKEWED_TASKS);
    }

    public SkewRunTimePass(int i) {
        this.numSkewedKeys = i;
    }

    @Override // java.util.function.BiFunction
    public IRDAG apply(IRDAG irdag, Message<Map<Object, Long>> message) {
        Set<IREdge> examinedEdges = message.getExaminedEdges();
        LOG.info("Examined edges {}", examinedEdges.stream().map((v0) -> {
            return v0.getId();
        }).collect(Collectors.toList()));
        IREdge next = examinedEdges.iterator().next();
        Pair pair = (Pair) next.getPropertyValue(PartitionerProperty.class).orElseThrow(IllegalStateException::new);
        int intValue = ((Integer) next.getDst().getPropertyValue(ParallelismProperty.class).orElseThrow(IllegalStateException::new)).intValue();
        Pair<PartitionSetProperty, ResourceAntiAffinityProperty> analyzeMessage = analyzeMessage(message.getMessageValue(), (HashPartitioner) Partitioner.getPartitioner(next.getExecutionProperties(), next.getDst().getExecutionProperties()), ((Integer) pair.right()).intValue() == 0 ? intValue : ((Integer) pair.right()).intValue(), intValue);
        LOG.info("Result of analysis: {}", analyzeMessage);
        examinedEdges.forEach(iREdge -> {
            iREdge.setPropertyPermanently((EdgeExecutionProperty) analyzeMessage.left());
            iREdge.getDst().setPropertyPermanently((VertexExecutionProperty) analyzeMessage.right());
            irdag.getDescendants(iREdge.getDst().getId()).forEach(iRVertex -> {
                iRVertex.setProperty((VertexExecutionProperty) analyzeMessage.right());
            });
        });
        return irdag;
    }

    Pair<PartitionSetProperty, ResourceAntiAffinityProperty> analyzeMessage(Map<Object, Long> map, HashPartitioner hashPartitioner, int i, int i2) {
        HashMap hashMap = new HashMap();
        int i3 = i - DEFAULT_NUM_SKEWED_TASKS;
        for (Map.Entry<Object, Long> entry : map.entrySet()) {
            hashMap.compute(Integer.valueOf(hashPartitioner.partition(entry.getKey()).intValue()), (num, l) -> {
                return Long.valueOf(l == null ? ((Long) entry.getValue()).longValue() : l.longValue() + ((Long) entry.getValue()).longValue());
            });
        }
        ArrayList arrayList = new ArrayList(i3 + DEFAULT_NUM_SKEWED_TASKS);
        for (int i4 = 0; i4 <= i3; i4 += DEFAULT_NUM_SKEWED_TASKS) {
            arrayList.add(Long.valueOf(((Long) hashMap.getOrDefault(Integer.valueOf(i4), 0L)).longValue()));
        }
        List<Long> topNLargeKeySizes = getTopNLargeKeySizes(arrayList);
        LOG.info("Top {} sizes: {}", Integer.valueOf(this.numSkewedKeys), topNLargeKeySizes);
        Long valueOf = Long.valueOf(Long.valueOf(arrayList.stream().mapToLong(l2 -> {
            return l2.longValue();
        }).sum()).longValue() / i2);
        int i5 = 0;
        int i6 = DEFAULT_NUM_SKEWED_TASKS;
        Long l3 = arrayList.get(0);
        Long l4 = 0L;
        ArrayList arrayList2 = new ArrayList();
        HashSet hashSet = new HashSet();
        for (int i7 = 0; i7 < i2; i7 += DEFAULT_NUM_SKEWED_TASKS) {
            if (i7 < i2 - DEFAULT_NUM_SKEWED_TASKS) {
                Long valueOf2 = Long.valueOf(valueOf.longValue() * (i7 + DEFAULT_NUM_SKEWED_TASKS));
                while (l3.longValue() < valueOf2.longValue()) {
                    l3 = Long.valueOf(l3.longValue() + arrayList.get(i6).longValue());
                    i6 += DEFAULT_NUM_SKEWED_TASKS;
                }
                if (Long.valueOf(l3.longValue() - valueOf2.longValue()).longValue() > Long.valueOf(valueOf2.longValue() - Long.valueOf(l3.longValue() - arrayList.get(i6 - DEFAULT_NUM_SKEWED_TASKS).longValue()).longValue()).longValue()) {
                    i6--;
                    l3 = Long.valueOf(l3.longValue() - arrayList.get(i6).longValue());
                }
                if (containsSkewedSize(arrayList, topNLargeKeySizes, i5, i6)) {
                    hashSet.add(Integer.valueOf(i7));
                }
                arrayList2.add(i7, HashRange.of(i5, i6));
                LOG.debug("KeyRange {}~{}, Size {}", new Object[]{Integer.valueOf(i5), Integer.valueOf(i6 - DEFAULT_NUM_SKEWED_TASKS), Long.valueOf(l3.longValue() - l4.longValue())});
                l4 = l3;
                i5 = i6;
            } else {
                if (containsSkewedSize(arrayList, topNLargeKeySizes, i5, i6)) {
                    hashSet.add(Integer.valueOf(i7));
                }
                arrayList2.add(i7, HashRange.of(i5, i3 + DEFAULT_NUM_SKEWED_TASKS));
                while (i6 <= i3) {
                    l3 = Long.valueOf(l3.longValue() + arrayList.get(i6).longValue());
                    i6 += DEFAULT_NUM_SKEWED_TASKS;
                }
                LOG.debug("KeyRange {}~{}, Size {}", new Object[]{Integer.valueOf(i5), Integer.valueOf(i3 + DEFAULT_NUM_SKEWED_TASKS), Long.valueOf(l3.longValue() - l4.longValue())});
            }
        }
        return Pair.of(PartitionSetProperty.of(arrayList2), ResourceAntiAffinityProperty.of(hashSet));
    }

    private List<Long> getTopNLargeKeySizes(List<Long> list) {
        return (List) list.stream().sorted(Comparator.reverseOrder()).limit(this.numSkewedKeys).collect(Collectors.toList());
    }

    private boolean containsSkewedSize(List<Long> list, List<Long> list2, int i, int i2) {
        for (int i3 = i; i3 < i2; i3 += DEFAULT_NUM_SKEWED_TASKS) {
            if (list2.contains(list.get(i3))) {
                return true;
            }
        }
        return false;
    }
}
