package org.apache.nemo.runtime.master.scheduler;

import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.nemo.common.Pair;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.exception.UnknownFailureCauseException;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.edge.executionproperty.MessageIdEdgeProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.IgnoreSchedulingTempDataReceiverProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.plan.PhysicalPlan;
import org.apache.nemo.runtime.common.plan.PlanRewriter;
import org.apache.nemo.runtime.common.plan.Stage;
import org.apache.nemo.runtime.common.plan.StageEdge;
import org.apache.nemo.runtime.common.plan.Task;
import org.apache.nemo.runtime.common.state.BlockState;
import org.apache.nemo.runtime.common.state.StageState;
import org.apache.nemo.runtime.common.state.TaskState;
import org.apache.nemo.runtime.master.BlockManagerMaster;
import org.apache.nemo.runtime.master.PlanStateManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/nemo/runtime/master/scheduler/BatchSchedulerUtils.class */
public final class BatchSchedulerUtils {
    private static final Logger LOG = LoggerFactory.getLogger(BatchSchedulerUtils.class.getName());

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.nemo.runtime.master.scheduler.BatchSchedulerUtils$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/nemo/runtime/master/scheduler/BatchSchedulerUtils$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$nemo$runtime$common$state$TaskState$RecoverableTaskFailureCause;
        static final /* synthetic */ int[] $SwitchMap$org$apache$nemo$common$ir$edge$executionproperty$CommunicationPatternProperty$Value = new int[CommunicationPatternProperty.Value.values().length];

        static {
            try {
                $SwitchMap$org$apache$nemo$common$ir$edge$executionproperty$CommunicationPatternProperty$Value[CommunicationPatternProperty.Value.SHUFFLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$nemo$common$ir$edge$executionproperty$CommunicationPatternProperty$Value[CommunicationPatternProperty.Value.BROADCAST.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$nemo$common$ir$edge$executionproperty$CommunicationPatternProperty$Value[CommunicationPatternProperty.Value.ONE_TO_ONE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$apache$nemo$runtime$common$state$TaskState$RecoverableTaskFailureCause = new int[TaskState.RecoverableTaskFailureCause.values().length];
            try {
                $SwitchMap$org$apache$nemo$runtime$common$state$TaskState$RecoverableTaskFailureCause[TaskState.RecoverableTaskFailureCause.INPUT_READ_FAILURE.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$apache$nemo$runtime$common$state$TaskState$RecoverableTaskFailureCause[TaskState.RecoverableTaskFailureCause.OUTPUT_WRITE_FAILURE.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    private BatchSchedulerUtils() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Optional<List<Stage>> selectEarliestSchedulableGroup(List<List<Stage>> list, PlanStateManager planStateManager) {
        return list == null ? Optional.empty() : list.stream().filter(list2 -> {
            Stream map = list2.stream().map((v0) -> {
                return v0.getId();
            });
            Objects.requireNonNull(planStateManager);
            return map.map(planStateManager::getStageState).anyMatch(state -> {
                return state.equals(StageState.State.INCOMPLETE);
            });
        }).findFirst();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static List<Task> selectSchedulableTasks(PlanStateManager planStateManager, BlockManagerMaster blockManagerMaster, Stage stage) {
        if (((Boolean) stage.getPropertyValue(IgnoreSchedulingTempDataReceiverProperty.class).orElse(false)).booleanValue()) {
            for (String str : planStateManager.getTaskAttemptsToSchedule(stage.getId())) {
                planStateManager.onTaskStateChanged(str, TaskState.State.EXECUTING);
                planStateManager.onTaskStateChanged(str, TaskState.State.COMPLETE);
            }
            return Collections.emptyList();
        }
        List incomingEdgesOf = planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(stage.getId());
        List outgoingEdgesOf = planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(stage.getId());
        List vertexIdToReadables = stage.getVertexIdToReadables();
        List<String> taskAttemptsToSchedule = planStateManager.getTaskAttemptsToSchedule(stage.getId());
        ArrayList arrayList = new ArrayList(taskAttemptsToSchedule.size());
        taskAttemptsToSchedule.forEach(str2 -> {
            blockManagerMaster.onProducerTaskScheduled(str2, getOutputBlockIds(planStateManager, str2));
            arrayList.add(new Task(planStateManager.getPhysicalPlan().getPlanId(), str2, stage.getExecutionProperties(), stage.getSerializedIRDAG(), incomingEdgesOf, outgoingEdgesOf, (Map) vertexIdToReadables.get(RuntimeIdManager.getIndexFromTaskId(str2))));
        });
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void onTaskExecutionComplete(ExecutorRegistry executorRegistry, String str, String str2) {
        LOG.debug("{} completed in {}", str2, str);
        executorRegistry.updateExecutor(str, (executorRepresenter, executorState) -> {
            executorRepresenter.onTaskExecutionComplete(str2);
            return Pair.of(executorRepresenter, executorState);
        });
    }

    static Set<StageEdge> getEdgesToOptimize(PlanStateManager planStateManager, String str) {
        DAG stageDAG = planStateManager.getPhysicalPlan().getStageDAG();
        List list = (List) ((Stage) stageDAG.getVertices().stream().filter(stage -> {
            return stage.getId().equals(RuntimeIdManager.getStageIdFromTaskId(str));
        }).findFirst().orElseThrow(RuntimeException::new)).getIRDAG().getVertices().stream().filter(iRVertex -> {
            return iRVertex.getPropertyValue(MessageIdVertexProperty.class).isPresent();
        }).map(iRVertex2 -> {
            return (Integer) iRVertex2.getPropertyValue(MessageIdVertexProperty.class).get();
        }).collect(Collectors.toList());
        if (list.size() != 1) {
            throw new IllegalStateException("Must be exactly one vertex with the message id: " + list.toString());
        }
        int intValue = ((Integer) list.get(0)).intValue();
        HashSet hashSet = new HashSet();
        Iterator it = stageDAG.getVertices().iterator();
        while (it.hasNext()) {
            hashSet.addAll((Set) stageDAG.getOutgoingEdgesOf((Stage) it.next()).stream().filter(stageEdge -> {
                Optional propertyValue = stageEdge.getPropertyValue(MessageIdEdgeProperty.class);
                return propertyValue.isPresent() && ((HashSet) propertyValue.get()).contains(Integer.valueOf(intValue));
            }).collect(Collectors.toSet()));
        }
        return hashSet;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void onTaskExecutionFailedRecoverable(PlanStateManager planStateManager, BlockManagerMaster blockManagerMaster, ExecutorRegistry executorRegistry, String str, String str2, TaskState.RecoverableTaskFailureCause recoverableTaskFailureCause) {
        LOG.info("{} failed in {} by {}", new Object[]{str2, str, recoverableTaskFailureCause});
        executorRegistry.updateExecutor(str, (executorRepresenter, executorState) -> {
            executorRepresenter.onTaskExecutionFailed(str2);
            return Pair.of(executorRepresenter, executorState);
        });
        switch (AnonymousClass1.$SwitchMap$org$apache$nemo$runtime$common$state$TaskState$RecoverableTaskFailureCause[recoverableTaskFailureCause.ordinal()]) {
            case 1:
            case 2:
                blockManagerMaster.onProducerTaskFailed(str2);
                retryTasksAndRequiredParents(planStateManager, blockManagerMaster, Collections.singleton(str2));
                return;
            default:
                throw new UnknownFailureCauseException(new Throwable("Unknown cause: " + recoverableTaskFailureCause));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Optional<PhysicalPlan> onTaskExecutionOnHold(PlanStateManager planStateManager, ExecutorRegistry executorRegistry, PlanRewriter planRewriter, String str, String str2) {
        LOG.info("{} put on hold in {}", new Object[]{str2, str});
        executorRegistry.updateExecutor(str, (executorRepresenter, executorState) -> {
            executorRepresenter.onTaskExecutionComplete(str2);
            return Pair.of(executorRepresenter, executorState);
        });
        boolean equals = planStateManager.getStageState(RuntimeIdManager.getStageIdFromTaskId(str2)).equals(StageState.State.COMPLETE);
        Set<StageEdge> edgesToOptimize = getEdgesToOptimize(planStateManager, str2);
        if (edgesToOptimize.isEmpty()) {
            throw new RuntimeException("No edges specified for data skew optimization");
        }
        return equals ? Optional.of(planRewriter.rewrite(getMessageId(edgesToOptimize))) : Optional.empty();
    }

    public static void onRunTimePassMessage(PlanStateManager planStateManager, PlanRewriter planRewriter, String str, Object obj) {
        Set<StageEdge> edgesToOptimize = getEdgesToOptimize(planStateManager, str);
        planRewriter.accumulate(getMessageId(edgesToOptimize), edgesToOptimize, obj);
    }

    static int getMessageId(Set<StageEdge> set) {
        return ((Integer) ((Set) set.stream().map(stageEdge -> {
            return (HashSet) stageEdge.getExecutionProperties().get(MessageIdEdgeProperty.class).orElseThrow(() -> {
                return new IllegalArgumentException(stageEdge.getId());
            });
        }).findFirst().orElseThrow(IllegalArgumentException::new)).iterator().next()).intValue();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void retryTasksAndRequiredParents(PlanStateManager planStateManager, BlockManagerMaster blockManagerMaster, Set<String> set) {
        Sets.SetView union = Sets.union(set, recursivelyGetParentTasksForLostBlocks(planStateManager, blockManagerMaster, set));
        LOG.info("Will be retried: {}", union);
        union.forEach(str -> {
            planStateManager.onTaskStateChanged(str, TaskState.State.SHOULD_RETRY);
        });
    }

    static Set<String> recursivelyGetParentTasksForLostBlocks(PlanStateManager planStateManager, BlockManagerMaster blockManagerMaster, Set<String> set) {
        if (set.isEmpty()) {
            return Collections.emptySet();
        }
        DAG stageDAG = planStateManager.getPhysicalPlan().getStageDAG();
        Map map = (Map) set.stream().map(RuntimeIdManager::getStageIdFromTaskId).flatMap(str -> {
            return stageDAG.getIncomingEdgesOf(str).stream();
        }).collect(Collectors.toMap((v0) -> {
            return v0.getId();
        }, Function.identity(), (stageEdge, stageEdge2) -> {
            return stageEdge;
        }));
        Set set2 = (Set) ((Set) set.stream().flatMap(str2 -> {
            return getInputBlockIds(planStateManager, str2).stream();
        }).map(RuntimeIdManager::getWildCardFromBlockId).collect(Collectors.toSet())).stream().filter(str3 -> {
            return blockManagerMaster.getBlockHandlers(str3, BlockState.State.AVAILABLE).isEmpty();
        }).flatMap(str4 -> {
            String id = ((StageEdge) map.get(RuntimeIdManager.getRuntimeEdgeIdFromBlockId(str4))).getSrc().getId();
            int taskIndexFromBlockId = RuntimeIdManager.getTaskIndexFromBlockId(str4);
            return planStateManager.getAllTaskAttemptsOfStage(id).stream().filter(str4 -> {
                return RuntimeIdManager.getStageIdFromTaskId(str4).equals(id) && RuntimeIdManager.getIndexFromTaskId(str4) == taskIndexFromBlockId;
            }).filter(str5 -> {
                return planStateManager.getTaskState(str5).equals(TaskState.State.COMPLETE);
            });
        }).collect(Collectors.toSet());
        return Sets.union(set2, recursivelyGetParentTasksForLostBlocks(planStateManager, blockManagerMaster, set2));
    }

    static Set<String> getOutputBlockIds(PlanStateManager planStateManager, String str) {
        return (Set) planStateManager.getPhysicalPlan().getStageDAG().getOutgoingEdgesOf(RuntimeIdManager.getStageIdFromTaskId(str)).stream().map(stageEdge -> {
            return RuntimeIdManager.generateBlockId(stageEdge.getId(), str);
        }).collect(Collectors.toSet());
    }

    static Set<String> getInputBlockIds(PlanStateManager planStateManager, String str) {
        return (Set) planStateManager.getPhysicalPlan().getStageDAG().getIncomingEdgesOf(RuntimeIdManager.getStageIdFromTaskId(str)).stream().flatMap(stageEdge -> {
            Set<String> allTaskAttemptsOfStage = planStateManager.getAllTaskAttemptsOfStage(stageEdge.getSrc().getId());
            switch (AnonymousClass1.$SwitchMap$org$apache$nemo$common$ir$edge$executionproperty$CommunicationPatternProperty$Value[stageEdge.getDataCommunicationPattern().ordinal()]) {
                case 1:
                case 2:
                    return allTaskAttemptsOfStage.stream().map(str2 -> {
                        return RuntimeIdManager.generateBlockId(stageEdge.getId(), str2);
                    });
                case 3:
                    return allTaskAttemptsOfStage.stream().filter(str3 -> {
                        return RuntimeIdManager.getIndexFromTaskId(str3) == RuntimeIdManager.getIndexFromTaskId(str);
                    }).map(str4 -> {
                        return RuntimeIdManager.generateBlockId(stageEdge.getId(), str4);
                    });
                default:
                    throw new IllegalStateException(stageEdge.toString());
            }
        }).collect(Collectors.toSet());
    }
}
