package org.apache.nemo.runtime.master.scheduler;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.Streams;
import java.io.Serializable;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.apache.commons.lang3.SerializationUtils;
import org.apache.nemo.common.dag.DAG;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.runtime.common.RuntimeIdManager;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.common.message.MessageUtils;
import org.apache.nemo.runtime.common.metric.JobMetric;
import org.apache.nemo.runtime.common.metric.StateTransitionEvent;
import org.apache.nemo.runtime.common.metric.TaskMetric;
import org.apache.nemo.runtime.common.plan.RuntimeEdge;
import org.apache.nemo.runtime.common.plan.Task;
import org.apache.nemo.runtime.common.state.TaskState;
import org.apache.nemo.runtime.master.metric.MetricStore;
import org.apache.nemo.runtime.master.resource.ExecutorRepresenter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/nemo/runtime/master/scheduler/SimulatedTaskExecutor.class */
public final class SimulatedTaskExecutor {
    private static final Logger LOG = LoggerFactory.getLogger(SimulatedTaskExecutor.class.getName());
    private static final String TASK_METRIC_ID = "TaskMetric";
    private static final int SAMPLING_PARALLELISM = 32;
    private static final int FALLBACK_PARALLELISM = 1;
    private final SimulationScheduler scheduler;
    private final ExecutorRepresenter executorRepresenter;
    private Long timeCheckpoint;
    private final MetricStore actualMetricStore;
    private final ConcurrentMap<String, DAG<IRVertex, RuntimeEdge<IRVertex>>> stageIDToStageIRDAG = new ConcurrentHashMap();
    private final AtomicLong currentTime = new AtomicLong(-1);
    private Long executorInitializationTime = -1L;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SimulatedTaskExecutor(SimulationScheduler simulationScheduler, ExecutorRepresenter executorRepresenter, MetricStore metricStore) {
        this.scheduler = simulationScheduler;
        this.executorRepresenter = executorRepresenter;
        this.actualMetricStore = metricStore;
    }

    private long calculateExpectedTaskDuration(Task task) {
        DAG<IRVertex, RuntimeEdge<IRVertex>> computeIfAbsent = this.stageIDToStageIRDAG.computeIfAbsent(task.getStageId(), str -> {
            return (DAG) SerializationUtils.deserialize(task.getSerializedIRDag());
        });
        Map<String, Object> metricMap = this.actualMetricStore.getMetricMap(JobMetric.class);
        if (metricMap.size() > FALLBACK_PARALLELISM) {
            LOG.warn("MetricStore has more than one JobMetric. The results could be misleading.");
        }
        JsonNode stageDAG = ((JobMetric) metricMap.entrySet().iterator().next().getValue()).getStageDAG();
        Set set = (Set) Streams.stream(() -> {
            return stageDAG.get("vertices").iterator();
        }).filter(jsonNode -> {
            return jsonNode.get("properties").get("irDag").get("vertices").size() == computeIfAbsent.getVertices().size();
        }).filter(jsonNode2 -> {
            return jsonNode2.get("properties").get("irDag").get("edges").size() == computeIfAbsent.getEdges().size();
        }).filter(jsonNode3 -> {
            return jsonNode3.get("properties").get("executionProperties").get("org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty").asInt(FALLBACK_PARALLELISM) == SAMPLING_PARALLELISM;
        }).filter(jsonNode4 -> {
            return jsonNode4.get("properties").get("executionProperties").get("org.apache.nemo.common.ir.vertex.executionproperty.EnableDynamicTaskSizingProperty").asBoolean();
        }).map(jsonNode5 -> {
            return jsonNode5.get("id").asText();
        }).collect(Collectors.toSet());
        int intValue = ((Integer) task.getPropertyValue(ParallelismProperty.class).orElse(Integer.valueOf(FALLBACK_PARALLELISM))).intValue();
        return (long) (this.actualMetricStore.getMetricMap(TaskMetric.class).entrySet().stream().filter(entry -> {
            return set.contains(RuntimeIdManager.getStageIdFromTaskId((String) entry.getKey()));
        }).map((v0) -> {
            return v0.getValue();
        }).filter(obj -> {
            return ((TaskMetric) obj).getTaskSizeRatio() == intValue;
        }).mapToLong(obj2 -> {
            return ((TaskMetric) obj2).getTaskDuration();
        }).filter(j -> {
            return j > 0;
        }).average().orElse(0.0d) + 0.5d);
    }

    public void onTaskReceived(Task task) {
        if (this.executorInitializationTime.longValue() < 0) {
            this.executorInitializationTime = Long.valueOf(System.currentTimeMillis());
            this.currentTime.set(this.executorInitializationTime.longValue());
            this.timeCheckpoint = this.executorInitializationTime;
        }
        String taskId = task.getTaskId();
        int attemptIdx = task.getAttemptIdx();
        long currentTimeMillis = System.currentTimeMillis() - this.timeCheckpoint.longValue();
        this.timeCheckpoint = Long.valueOf(System.currentTimeMillis());
        sendMetric(TASK_METRIC_ID, taskId, "schedulingOverhead", SerializationUtils.serialize(Long.valueOf(currentTimeMillis)));
        long andAdd = this.currentTime.getAndAdd(currentTimeMillis);
        LOG.debug("{} started", taskId);
        this.currentTime.getAndAdd(calculateExpectedTaskDuration(task));
        sendMetric(TASK_METRIC_ID, taskId, "taskDuration", SerializationUtils.serialize(Long.valueOf(this.currentTime.get() - andAdd)));
        this.timeCheckpoint = Long.valueOf(System.currentTimeMillis());
        if (0 == 0) {
            onTaskStateChanged(taskId, attemptIdx, TaskState.State.COMPLETE, Optional.empty(), Optional.empty());
            LOG.debug("{} completed", taskId);
        } else {
            onTaskStateChanged(taskId, attemptIdx, TaskState.State.ON_HOLD, Optional.of(null), Optional.empty());
            LOG.debug("{} on hold", taskId);
        }
    }

    public Long getElapsedTime() {
        return Long.valueOf(this.currentTime.get() - this.executorInitializationTime.longValue());
    }

    private void onTaskStateChanged(String str, int i, TaskState.State state, Optional<String> optional, Optional<TaskState.RecoverableTaskFailureCause> optional2) {
        sendMetric(TASK_METRIC_ID, str, "stateTransitionEvent", SerializationUtils.serialize(new StateTransitionEvent(this.currentTime.get(), (Serializable) null, state)));
        ControlMessage.TaskStateChangedMsg.Builder state2 = ControlMessage.TaskStateChangedMsg.newBuilder().setExecutorId(this.executorRepresenter.getExecutorId()).setTaskId(str).setAttemptIdx(i).setState(MessageUtils.convertState(state));
        if (state == TaskState.State.ON_HOLD && optional.isPresent()) {
            state2.setVertexPutOnHoldId(optional.get());
        }
        optional2.ifPresent(recoverableTaskFailureCause -> {
            state2.setFailureCause(MessageUtils.convertFailureCause(recoverableTaskFailureCause));
        });
        sendControlMessage(ControlMessage.Message.newBuilder().setId(RuntimeIdManager.generateMessageId()).setListenerId("RUNTIME_MASTER_MESSAGE_LISTENER_ID").setType(ControlMessage.MessageType.TaskStateChanged).setTaskStateChangedMsg(state2.build()).build());
    }

    private void sendControlMessage(ControlMessage.Message message) {
        this.executorRepresenter.sendControlMessage(message);
    }

    public void sendMetric(String str, String str2, String str3, byte[] bArr) {
        this.scheduler.handleMetricMessage(str, str2, str3, bArr);
    }
}
