package org.apache.nemo.runtime.common.message.ncs;

import java.net.SocketAddress;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
import javax.inject.Inject;
import org.apache.nemo.runtime.common.ReplyFutureMap;
import org.apache.nemo.runtime.common.comm.ControlMessage;
import org.apache.nemo.runtime.common.message.MessageEnvironment;
import org.apache.nemo.runtime.common.message.MessageListener;
import org.apache.nemo.runtime.common.message.MessageParameters;
import org.apache.nemo.runtime.common.message.MessageSender;
import org.apache.reef.exception.evaluator.NetworkException;
import org.apache.reef.io.network.Connection;
import org.apache.reef.io.network.ConnectionFactory;
import org.apache.reef.io.network.Message;
import org.apache.reef.io.network.NetworkConnectionService;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.transport.LinkListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/nemo/runtime/common/message/ncs/NcsMessageEnvironment.class */
public final class NcsMessageEnvironment implements MessageEnvironment {
    private static final Logger LOG = LoggerFactory.getLogger(NcsMessageEnvironment.class.getName());
    private static final String NCS_CONN_FACTORY_ID = "NCS_CONN_FACTORY_ID";
    private final NetworkConnectionService networkConnectionService;
    private final IdentifierFactory idFactory;
    private final String senderId;
    private final ReplyFutureMap<ControlMessage.Message> replyFutureMap = new ReplyFutureMap<>();
    private final ConcurrentMap<String, MessageListener> listenerConcurrentMap = new ConcurrentHashMap();
    private final Map<String, Connection> receiverToConnectionMap = new ConcurrentHashMap();
    private final ConnectionFactory<ControlMessage.Message> connectionFactory;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/nemo/runtime/common/message/ncs/NcsMessageEnvironment$MessageType.class */
    public enum MessageType {
        Send,
        Request,
        Reply
    }

    /* loaded from: input_file:org/apache/nemo/runtime/common/message/ncs/NcsMessageEnvironment$NcsLinkListener.class */
    private final class NcsLinkListener implements LinkListener<Message<ControlMessage.Message>> {
        private NcsLinkListener() {
        }

        public void onSuccess(Message<ControlMessage.Message> message) {
        }

        public void onException(Throwable th, SocketAddress socketAddress, Message<ControlMessage.Message> message) {
            NcsMessageEnvironment.LOG.error("NCS Exception");
        }
    }

    /* loaded from: input_file:org/apache/nemo/runtime/common/message/ncs/NcsMessageEnvironment$NcsMessageHandler.class */
    private final class NcsMessageHandler implements EventHandler<Message<ControlMessage.Message>> {
        private NcsMessageHandler() {
        }

        public void onNext(Message<ControlMessage.Message> message) {
            ControlMessage.Message extractSingleMessage = NcsMessageEnvironment.this.extractSingleMessage(message);
            switch (NcsMessageEnvironment.this.getMsgType(extractSingleMessage)) {
                case Send:
                    processSendMessage(extractSingleMessage);
                    return;
                case Request:
                    processRequestMessage(extractSingleMessage);
                    return;
                case Reply:
                    processReplyMessage(extractSingleMessage);
                    return;
                default:
                    throw new IllegalArgumentException(extractSingleMessage.toString());
            }
        }

        private void processSendMessage(ControlMessage.Message message) {
            ((MessageListener) NcsMessageEnvironment.this.listenerConcurrentMap.get(message.getListenerId())).onMessage(message);
        }

        private void processRequestMessage(ControlMessage.Message message) {
            String listenerId = message.getListenerId();
            ((MessageListener) NcsMessageEnvironment.this.listenerConcurrentMap.get(listenerId)).onMessageWithContext(message, new NcsMessageContext(NcsMessageEnvironment.this.getExecutorId(message), NcsMessageEnvironment.this.connectionFactory, NcsMessageEnvironment.this.idFactory));
        }

        private void processReplyMessage(ControlMessage.Message message) {
            NcsMessageEnvironment.this.replyFutureMap.onSuccessMessage(NcsMessageEnvironment.this.getRequestId(message), message);
        }
    }

    @Inject
    private NcsMessageEnvironment(NetworkConnectionService networkConnectionService, IdentifierFactory identifierFactory, @Parameter(MessageParameters.SenderId.class) String str) {
        this.networkConnectionService = networkConnectionService;
        this.idFactory = identifierFactory;
        this.senderId = str;
        this.connectionFactory = networkConnectionService.registerConnectionFactory(identifierFactory.getNewInstance(NCS_CONN_FACTORY_ID), new ControlMessageCodec(), new NcsMessageHandler(), new NcsLinkListener(), identifierFactory.getNewInstance(str));
    }

    @Override // org.apache.nemo.runtime.common.message.MessageEnvironment
    public <T> void setupListener(String str, MessageListener<T> messageListener) {
        if (this.listenerConcurrentMap.putIfAbsent(str, messageListener) != null) {
            throw new RuntimeException("A listener for " + str + " was already setup");
        }
    }

    @Override // org.apache.nemo.runtime.common.message.MessageEnvironment
    public void removeListener(String str) {
        this.listenerConcurrentMap.remove(str);
    }

    @Override // org.apache.nemo.runtime.common.message.MessageEnvironment
    public <T> Future<MessageSender<T>> asyncConnect(String str, String str2) {
        Connection newConnection;
        if (this.receiverToConnectionMap.containsKey(str)) {
            newConnection = this.receiverToConnectionMap.get(str);
        } else {
            newConnection = this.connectionFactory.newConnection(this.idFactory.getNewInstance(str));
            try {
                newConnection.open();
            } catch (NetworkException e) {
                try {
                    newConnection.close();
                } catch (NetworkException e2) {
                    LOG.info("Can't close the broken connection.", e2);
                }
                CompletableFuture completableFuture = new CompletableFuture();
                completableFuture.completeExceptionally(e);
                return completableFuture;
            }
        }
        return CompletableFuture.completedFuture(new NcsMessageSender(newConnection, this.replyFutureMap));
    }

    @Override // org.apache.nemo.runtime.common.message.MessageEnvironment
    public String getId() {
        return this.senderId;
    }

    @Override // org.apache.nemo.runtime.common.message.MessageEnvironment
    public void close() throws Exception {
        this.networkConnectionService.close();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public ControlMessage.Message extractSingleMessage(Message<ControlMessage.Message> message) {
        return (ControlMessage.Message) message.getData().iterator().next();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public MessageType getMsgType(ControlMessage.Message message) {
        switch (AnonymousClass1.$SwitchMap$org$apache$nemo$runtime$common$comm$ControlMessage$MessageType[message.getType().ordinal()]) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
            case 6:
            case 7:
            case 8:
            case 9:
            case 10:
                return MessageType.Send;
            case 11:
            case 12:
            case 13:
                return MessageType.Request;
            case 14:
            case 15:
            case ControlMessage.Message.PIPEINITMSG_FIELD_NUMBER /* 16 */:
                return MessageType.Reply;
            default:
                throw new IllegalArgumentException(message.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String getExecutorId(ControlMessage.Message message) {
        switch (message.getType()) {
            case RequestBlockLocation:
                return message.getRequestBlockLocationMsg().getExecutorId();
            case RequestBroadcastVariable:
                return message.getRequestbroadcastVariableMsg().getExecutorId();
            case RequestPipeLoc:
                return message.getRequestPipeLocMsg().getExecutorId();
            default:
                throw new IllegalArgumentException(message.toString());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public long getRequestId(ControlMessage.Message message) {
        switch (AnonymousClass1.$SwitchMap$org$apache$nemo$runtime$common$comm$ControlMessage$MessageType[message.getType().ordinal()]) {
            case 14:
                return message.getBlockLocationInfoMsg().getRequestId();
            case 15:
                return message.getBroadcastVariableMsg().getRequestId();
            case ControlMessage.Message.PIPEINITMSG_FIELD_NUMBER /* 16 */:
                return message.getPipeLocInfoMsg().getRequestId();
            default:
                throw new IllegalArgumentException(message.toString());
        }
    }
}
