/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.execution;

import com.google.common.collect.Multimap;
import io.prestosql.Session;
import io.prestosql.execution.NodeTaskMap;
import io.prestosql.execution.QueryStateMachine;
import io.prestosql.execution.RemoteTask;
import io.prestosql.execution.RemoteTaskFactory;
import io.prestosql.execution.StateMachine;
import io.prestosql.execution.TaskId;
import io.prestosql.execution.TaskStatus;
import io.prestosql.execution.buffer.OutputBuffers;
import io.prestosql.metadata.InternalNode;
import io.prestosql.metadata.Split;
import io.prestosql.sql.planner.PlanFragment;
import io.prestosql.sql.planner.plan.PlanNodeId;
import java.util.Objects;
import java.util.OptionalInt;

public class MemoryTrackingRemoteTaskFactory
implements RemoteTaskFactory {
    private final RemoteTaskFactory remoteTaskFactory;
    private final QueryStateMachine stateMachine;

    public MemoryTrackingRemoteTaskFactory(RemoteTaskFactory remoteTaskFactory, QueryStateMachine stateMachine) {
        this.remoteTaskFactory = Objects.requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        this.stateMachine = Objects.requireNonNull(stateMachine, "stateMachine is null");
    }

    @Override
    public RemoteTask createRemoteTask(Session session, TaskId taskId, InternalNode node, PlanFragment fragment, Multimap<PlanNodeId, Split> initialSplits, OptionalInt totalPartitions, OutputBuffers outputBuffers, NodeTaskMap.PartitionedSplitCountTracker partitionedSplitCountTracker, boolean summarizeTaskInfo) {
        RemoteTask task = this.remoteTaskFactory.createRemoteTask(session, taskId, node, fragment, initialSplits, totalPartitions, outputBuffers, partitionedSplitCountTracker, summarizeTaskInfo);
        task.addStateChangeListener(new UpdatePeakMemory(this.stateMachine));
        return task;
    }

    private static final class UpdatePeakMemory
    implements StateMachine.StateChangeListener<TaskStatus> {
        private final QueryStateMachine stateMachine;
        private long previousUserMemory;
        private long previousSystemMemory;
        private long previousRevocableMemory;

        public UpdatePeakMemory(QueryStateMachine stateMachine) {
            this.stateMachine = stateMachine;
        }

        @Override
        public synchronized void stateChanged(TaskStatus newStatus) {
            long currentUserMemory = newStatus.getMemoryReservation().toBytes();
            long currentSystemMemory = newStatus.getSystemMemoryReservation().toBytes();
            long currentRevocableMemory = newStatus.getRevocableMemoryReservation().toBytes();
            long currentTotalMemory = currentUserMemory + currentSystemMemory + currentRevocableMemory;
            long deltaUserMemoryInBytes = currentUserMemory - this.previousUserMemory;
            long deltaRevocableMemoryInBytes = currentRevocableMemory - this.previousRevocableMemory;
            long deltaTotalMemoryInBytes = currentTotalMemory - (this.previousUserMemory + this.previousSystemMemory + this.previousRevocableMemory);
            this.previousUserMemory = currentUserMemory;
            this.previousSystemMemory = currentSystemMemory;
            this.previousRevocableMemory = currentRevocableMemory;
            this.stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes, currentUserMemory, currentRevocableMemory, currentTotalMemory);
        }
    }
}

