package org.apache.nemo.runtime.executor.bytetransfer;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.group.ChannelGroup;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.executor.bytetransfer.ByteTransferContext;
import org.apache.nemo.runtime.executor.data.BlockManagerWorker;
import org.apache.nemo.runtime.executor.data.PipeManagerWorker;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/nemo/runtime/executor/bytetransfer/ContextManager.class */
public final class ContextManager extends SimpleChannelInboundHandler<ControlMessage.ByteTransferContextSetupMessage> {
    private final PipeManagerWorker pipeManagerWorker;
    private final BlockManagerWorker blockManagerWorker;
    private final ByteTransfer byteTransfer;
    private final ChannelGroup channelGroup;
    private final String localExecutorId;
    private final Channel channel;
    private volatile String remoteExecutorId = null;
    private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByLocal = new ConcurrentHashMap();
    private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByLocal = new ConcurrentHashMap();
    private final ConcurrentMap<Integer, ByteInputContext> inputContextsInitiatedByRemote = new ConcurrentHashMap();
    private final ConcurrentMap<Integer, ByteOutputContext> outputContextsInitiatedByRemote = new ConcurrentHashMap();
    private final AtomicInteger nextInputTransferIndex = new AtomicInteger(0);
    private final AtomicInteger nextOutputTransferIndex = new AtomicInteger(0);

    /* JADX INFO: Access modifiers changed from: package-private */
    public ContextManager(PipeManagerWorker pipeManagerWorker, BlockManagerWorker blockManagerWorker, ByteTransfer byteTransfer, ChannelGroup channelGroup, String str, Channel channel) {
        this.pipeManagerWorker = pipeManagerWorker;
        this.blockManagerWorker = blockManagerWorker;
        this.byteTransfer = byteTransfer;
        this.channelGroup = channelGroup;
        this.localExecutorId = str;
        this.channel = channel;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Channel getChannel() {
        return this.channel;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ByteInputContext getInputContext(ControlMessage.ByteTransferDataDirection byteTransferDataDirection, int i) {
        return (byteTransferDataDirection == ControlMessage.ByteTransferDataDirection.INITIATOR_SENDS_DATA ? this.inputContextsInitiatedByRemote : this.inputContextsInitiatedByLocal).get(Integer.valueOf(i));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void channelRead0(ChannelHandlerContext channelHandlerContext, ControlMessage.ByteTransferContextSetupMessage byteTransferContextSetupMessage) throws Exception {
        setRemoteExecutorId(byteTransferContextSetupMessage.getInitiatorExecutorId());
        this.byteTransfer.onNewContextByRemoteExecutor(byteTransferContextSetupMessage.getInitiatorExecutorId(), this.channel);
        ControlMessage.ByteTransferDataDirection dataDirection = byteTransferContextSetupMessage.getDataDirection();
        int transferIndex = byteTransferContextSetupMessage.getTransferIndex();
        boolean isPipe = byteTransferContextSetupMessage.getIsPipe();
        ByteTransferContext.ContextId contextId = new ByteTransferContext.ContextId(this.remoteExecutorId, this.localExecutorId, dataDirection, transferIndex, isPipe);
        byte[] byteArray = byteTransferContextSetupMessage.getContextDescriptor().toByteArray();
        if (dataDirection == ControlMessage.ByteTransferDataDirection.INITIATOR_SENDS_DATA) {
            ByteInputContext compute = this.inputContextsInitiatedByRemote.compute(Integer.valueOf(transferIndex), (num, byteInputContext) -> {
                if (byteInputContext != null) {
                    throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
                }
                return new ByteInputContext(this.remoteExecutorId, contextId, byteArray, this);
            });
            if (isPipe) {
                this.pipeManagerWorker.onInputContext(compute);
                return;
            } else {
                this.blockManagerWorker.onInputContext(compute);
                return;
            }
        }
        ByteOutputContext compute2 = this.outputContextsInitiatedByRemote.compute(Integer.valueOf(transferIndex), (num2, byteOutputContext) -> {
            if (byteOutputContext != null) {
                throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
            }
            return new ByteOutputContext(this.remoteExecutorId, contextId, byteArray, this);
        });
        if (isPipe) {
            this.pipeManagerWorker.onOutputContext(compute2);
        } else {
            this.blockManagerWorker.onOutputContext(compute2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void onContextExpired(ByteTransferContext byteTransferContext) {
        ByteTransferContext.ContextId contextId = byteTransferContext.getContextId();
        (byteTransferContext instanceof ByteInputContext ? contextId.getDataDirection() == ControlMessage.ByteTransferDataDirection.INITIATOR_SENDS_DATA ? this.inputContextsInitiatedByRemote : this.inputContextsInitiatedByLocal : contextId.getDataDirection() == ControlMessage.ByteTransferDataDirection.INITIATOR_SENDS_DATA ? this.outputContextsInitiatedByLocal : this.outputContextsInitiatedByRemote).remove(Integer.valueOf(contextId.getTransferIndex()), byteTransferContext);
    }

    <T extends ByteTransferContext> T newContext(ConcurrentMap<Integer, T> concurrentMap, AtomicInteger atomicInteger, ControlMessage.ByteTransferDataDirection byteTransferDataDirection, Function<ByteTransferContext.ContextId, T> function, String str, boolean z) {
        setRemoteExecutorId(str);
        int andIncrement = atomicInteger.getAndIncrement();
        ByteTransferContext.ContextId contextId = new ByteTransferContext.ContextId(this.localExecutorId, str, byteTransferDataDirection, andIncrement, z);
        T compute = concurrentMap.compute(Integer.valueOf(andIncrement), (num, byteTransferContext) -> {
            if (byteTransferContext != null) {
                throw new RuntimeException(String.format("Duplicate ContextId: %s", contextId));
            }
            return (ByteTransferContext) function.apply(contextId);
        });
        this.channel.writeAndFlush(compute).addListener(compute.getChannelWriteListener());
        return compute;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ByteInputContext newInputContext(String str, byte[] bArr, boolean z) {
        return (ByteInputContext) newContext(this.inputContextsInitiatedByLocal, this.nextInputTransferIndex, ControlMessage.ByteTransferDataDirection.INITIATOR_RECEIVES_DATA, contextId -> {
            return new ByteInputContext(str, contextId, bArr, this);
        }, str, z);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ByteOutputContext newOutputContext(String str, byte[] bArr, boolean z) {
        return (ByteOutputContext) newContext(this.outputContextsInitiatedByLocal, this.nextOutputTransferIndex, ControlMessage.ByteTransferDataDirection.INITIATOR_SENDS_DATA, contextId -> {
            return new ByteOutputContext(str, contextId, bArr, this);
        }, str, z);
    }

    private void setRemoteExecutorId(String str) {
        if (this.remoteExecutorId == null) {
            this.remoteExecutorId = str;
        } else if (!str.equals(this.remoteExecutorId)) {
            throw new RuntimeException(String.format("Wrong ContextManager: (%s != %s)", str, this.remoteExecutorId));
        }
    }

    public void channelActive(ChannelHandlerContext channelHandlerContext) {
        this.channelGroup.add(channelHandlerContext.channel());
    }

    public void channelInactive(ChannelHandlerContext channelHandlerContext) {
        this.channelGroup.remove(channelHandlerContext.channel());
        Exception exc = new Exception("Channel closed");
        throwChannelErrorOnContexts(this.inputContextsInitiatedByLocal, exc);
        throwChannelErrorOnContexts(this.outputContextsInitiatedByLocal, exc);
        throwChannelErrorOnContexts(this.inputContextsInitiatedByRemote, exc);
        throwChannelErrorOnContexts(this.outputContextsInitiatedByRemote, exc);
    }

    private <T extends ByteTransferContext> void throwChannelErrorOnContexts(ConcurrentMap<Integer, T> concurrentMap, Throwable th) {
        Iterator<T> it = concurrentMap.values().iterator();
        while (it.hasNext()) {
            it.next().onChannelError(th);
        }
    }
}
