001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.remote.transport.netty;
020
021import io.netty.bootstrap.Bootstrap;
022import io.netty.bootstrap.ServerBootstrap;
023import io.netty.channel.Channel;
024import io.netty.channel.ChannelFuture;
025import io.netty.channel.ChannelOption;
026import io.netty.channel.EventLoopGroup;
027import io.netty.channel.group.ChannelGroup;
028import io.netty.channel.group.DefaultChannelGroup;
029import io.netty.channel.nio.NioEventLoopGroup;
030import io.netty.channel.socket.nio.NioServerSocketChannel;
031import io.netty.channel.socket.nio.NioSocketChannel;
032import io.netty.util.concurrent.GlobalEventExecutor;
033import org.apache.reef.wake.EStage;
034import org.apache.reef.wake.EventHandler;
035import org.apache.reef.wake.impl.DefaultThreadFactory;
036import org.apache.reef.wake.remote.Encoder;
037import org.apache.reef.wake.remote.exception.RemoteRuntimeException;
038import org.apache.reef.wake.remote.impl.TransportEvent;
039import org.apache.reef.wake.remote.ports.RangeTcpPortProvider;
040import org.apache.reef.wake.remote.ports.TcpPortProvider;
041import org.apache.reef.wake.remote.transport.Link;
042import org.apache.reef.wake.remote.transport.LinkListener;
043import org.apache.reef.wake.remote.transport.Transport;
044import org.apache.reef.wake.remote.transport.exception.TransportRuntimeException;
045
046import java.io.IOException;
047import java.net.BindException;
048import java.net.ConnectException;
049import java.net.InetSocketAddress;
050import java.net.SocketAddress;
051import java.util.Iterator;
052import java.util.concurrent.ConcurrentHashMap;
053import java.util.concurrent.ConcurrentMap;
054import java.util.concurrent.atomic.AtomicInteger;
055import java.util.logging.Level;
056import java.util.logging.Logger;
057
058/**
059 * Messaging transport implementation with Netty
060 */
061public class NettyMessagingTransport implements Transport {
062
063  private static final String CLASS_NAME = NettyMessagingTransport.class.getName();
064  private static final Logger LOG = Logger.getLogger(CLASS_NAME);
065
066  private static final int SERVER_BOSS_NUM_THREADS = 3;
067  private static final int SERVER_WORKER_NUM_THREADS = 20;
068  private static final int CLIENT_WORKER_NUM_THREADS = 10;
069
070  private final ConcurrentMap<SocketAddress, LinkReference> addrToLinkRefMap = new ConcurrentHashMap<>();
071
072  private final EventLoopGroup clientWorkerGroup;
073  private final EventLoopGroup serverBossGroup;
074  private final EventLoopGroup serverWorkerGroup;
075
076  private final Bootstrap clientBootstrap;
077  private final ServerBootstrap serverBootstrap;
078  private final Channel acceptor;
079
080  private final ChannelGroup clientChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
081  private final ChannelGroup serverChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
082
083  private final int serverPort;
084  private final SocketAddress localAddress;
085
086  private final NettyClientEventListener clientEventListener;
087  private final NettyServerEventListener serverEventListener;
088
089  private final int numberOfTries;
090  private final int retryTimeout;
091
092  /**
093   * Constructs a messaging transport
094   *
095   * @param hostAddress   the server host address
096   * @param port          the server listening port; when it is 0, randomly assign a port number
097   * @param clientStage   the client-side stage that handles transport events
098   * @param serverStage   the server-side stage that handles transport events
099   * @param numberOfTries the number of tries of connection
100   * @param retryTimeout  the timeout of reconnection
101   * @param tcpPortProvider  gives an iterator that produces random tcp ports in a range
102   *
103   */
104  public NettyMessagingTransport(final String hostAddress, int port,
105                                 final EStage<TransportEvent> clientStage,
106                                 final EStage<TransportEvent> serverStage,
107                                 final int numberOfTries,
108                                 final int retryTimeout,
109                                 final TcpPortProvider tcpPortProvider) {
110
111    if (port < 0) {
112      throw new RemoteRuntimeException("Invalid server port: " + port);
113    }
114
115    this.numberOfTries = numberOfTries;
116    this.retryTimeout = retryTimeout;
117    this.clientEventListener = new NettyClientEventListener(this.addrToLinkRefMap, clientStage);
118    this.serverEventListener = new NettyServerEventListener(this.addrToLinkRefMap, serverStage);
119
120    this.serverBossGroup = new NioEventLoopGroup(SERVER_BOSS_NUM_THREADS, new DefaultThreadFactory(CLASS_NAME + "ServerBoss"));
121    this.serverWorkerGroup = new NioEventLoopGroup(SERVER_WORKER_NUM_THREADS, new DefaultThreadFactory(CLASS_NAME + "ServerWorker"));
122    this.clientWorkerGroup = new NioEventLoopGroup(CLIENT_WORKER_NUM_THREADS, new DefaultThreadFactory(CLASS_NAME + "ClientWorker"));
123
124    this.clientBootstrap = new Bootstrap();
125    this.clientBootstrap.group(this.clientWorkerGroup)
126        .channel(NioSocketChannel.class)
127        .handler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("client",
128            this.clientChannelGroup, this.clientEventListener)))
129        .option(ChannelOption.SO_REUSEADDR, true)
130        .option(ChannelOption.SO_KEEPALIVE, true);
131
132    this.serverBootstrap = new ServerBootstrap();
133    this.serverBootstrap.group(this.serverBossGroup, this.serverWorkerGroup)
134        .channel(NioServerSocketChannel.class)
135        .childHandler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("server",
136            this.serverChannelGroup, this.serverEventListener)))
137        .option(ChannelOption.SO_BACKLOG, 128)
138        .option(ChannelOption.SO_REUSEADDR, true)
139        .childOption(ChannelOption.SO_KEEPALIVE, true);
140
141    LOG.log(Level.FINE, "Binding to {0}", port);
142
143  Channel acceptor = null;
144  try {
145    if (port > 0) {
146      acceptor = this.serverBootstrap.bind(new InetSocketAddress(hostAddress, port)).sync().channel();
147    } else {
148      Iterator<Integer> ports = tcpPortProvider.iterator();
149      while (acceptor == null) {
150        if (!ports.hasNext()) break;
151        port = ports.next();
152        LOG.log(Level.FINEST, "Try port {0}", port);
153        try {
154          acceptor = this.serverBootstrap.bind(new InetSocketAddress(hostAddress, port)).sync().channel();
155        } catch (final Exception ex) {
156          if (ex instanceof BindException) {
157            LOG.log(Level.FINEST, "The port {0} is already bound. Try again", port);
158          } else {
159            throw ex;
160          }
161        }
162      }
163    }
164  } catch (final Exception ex) {
165    final RuntimeException transportException =
166       new TransportRuntimeException("Cannot bind to port " + port);
167    LOG.log(Level.SEVERE, "Cannot bind to port " + port, ex);
168
169      this.clientWorkerGroup.shutdownGracefully();
170      this.serverBossGroup.shutdownGracefully();
171      this.serverWorkerGroup.shutdownGracefully();
172      throw transportException;
173    }
174
175    this.acceptor = acceptor;
176    this.serverPort = port;
177    this.localAddress = new InetSocketAddress(hostAddress, this.serverPort);
178
179    LOG.log(Level.FINE, "Starting netty transport socket address: {0}", this.localAddress);
180  }
181
182  /**
183   * Constructs a messaging transport
184   *
185   * @param hostAddress   the server host address
186   * @param port          the server listening port; when it is 0, randomly assign a port number
187   * @param clientStage   the client-side stage that handles transport events
188   * @param serverStage   the server-side stage that handles transport events
189   * @param numberOfTries the number of tries of connection
190   * @param retryTimeout  the timeout of reconnection
191   * @deprecated use the constructor that takes a TcpProvider instead
192   */
193  @Deprecated
194  public NettyMessagingTransport(final String hostAddress, int port,
195                                 final EStage<TransportEvent> clientStage,
196                                 final EStage<TransportEvent> serverStage,
197                                 final int numberOfTries,
198                                 final int retryTimeout) {
199    this(hostAddress, port, clientStage, serverStage, numberOfTries, retryTimeout,
200            RangeTcpPortProvider.Default);
201  }
202
203  /**
204   * Closes all channels and releases all resources
205   */
206  @Override
207  public void close() throws Exception {
208
209    LOG.log(Level.FINE, "Closing netty transport socket address: {0}", this.localAddress);
210
211    this.clientChannelGroup.close().awaitUninterruptibly();
212    this.serverChannelGroup.close().awaitUninterruptibly();
213    this.acceptor.close().sync();
214    this.clientWorkerGroup.shutdownGracefully();
215    this.serverBossGroup.shutdownGracefully();
216    this.serverWorkerGroup.shutdownGracefully();
217
218    LOG.log(Level.FINE, "Closing netty transport socket address: {0} done", this.localAddress);
219  }
220
221  /**
222   * Returns a link for the remote address if cached; otherwise opens, caches and returns
223   * When it opens a link for the remote address, only one attempt for the address is made at a given time
224   *
225   * @param remoteAddr the remote socket address
226   * @param encoder    the encoder
227   * @param listener   the link listener
228   * @return a link associated with the address
229   */
230  @Override
231  public <T> Link<T> open(final SocketAddress remoteAddr, final Encoder<? super T> encoder,
232                          final LinkListener<? super T> listener) throws IOException {
233
234    Link<T> link = null;
235
236    for (int i = 0; i <= this.numberOfTries; ++i) {
237      LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
238
239      if (linkRef != null) {
240        link = (Link<T>) linkRef.getLink();
241        if (LOG.isLoggable(Level.FINE)) {
242          LOG.log(Level.FINE, "Link {0} for {1} found", new Object[]{link, remoteAddr});
243        }
244        if (link != null) {
245          return link;
246        }
247      }
248      
249      if (i == this.numberOfTries) {
250        // Connection failure 
251        throw new ConnectException("Connection to " + remoteAddr + " refused");
252      }
253
254      LOG.log(Level.FINE, "No cached link for {0} thread {1}",
255          new Object[]{remoteAddr, Thread.currentThread()});
256
257      // no linkRef
258      final LinkReference newLinkRef = new LinkReference();
259      final LinkReference prior = this.addrToLinkRefMap.putIfAbsent(remoteAddr, newLinkRef);
260      final AtomicInteger flag = prior != null ?
261          prior.getConnectInProgress() : newLinkRef.getConnectInProgress();
262
263      synchronized (flag) {
264        if (!flag.compareAndSet(0, 1)) {
265          while (flag.get() == 1) {
266            try {
267              flag.wait();
268            } catch (final InterruptedException ex) {
269              LOG.log(Level.WARNING, "Wait interrupted", ex);
270            }
271          }
272        }
273      }
274
275      linkRef = this.addrToLinkRefMap.get(remoteAddr);
276      link = (Link<T>) linkRef.getLink();
277
278      if (link != null) {
279        return link;
280      }
281
282      ChannelFuture connectFuture = null;
283      try {
284        connectFuture = this.clientBootstrap.connect(remoteAddr);
285        connectFuture.syncUninterruptibly();
286
287        link = new NettyLink<>(connectFuture.channel(), encoder, listener);
288        linkRef.setLink(link);
289
290        synchronized (flag) {
291          flag.compareAndSet(1, 2);
292          flag.notifyAll();
293        }
294        break;
295      } catch (final Exception e) {
296        if (e.getClass().getSimpleName().compareTo("ConnectException") == 0) {
297          LOG.log(Level.WARNING, "Connection refused. Retry {0} of {1}",
298              new Object[]{i + 1, this.numberOfTries});
299          synchronized (flag) {
300            flag.compareAndSet(1, 0);
301            flag.notifyAll();
302          }
303
304          if (i < this.numberOfTries) {
305            try {
306              Thread.sleep(retryTimeout);
307            } catch (final InterruptedException interrupt) {
308              LOG.log(Level.WARNING, "Thread {0} interrupted while sleeping", Thread.currentThread());
309            }
310          }
311        } else {
312          throw e;
313        }
314      }
315    }
316    
317    return link;
318  }
319
320  /**
321   * Returns a link for the remote address if already cached; otherwise, returns null
322   *
323   * @param remoteAddr the remote address
324   * @return a link if already cached; otherwise, null
325   */
326  public <T> Link<T> get(final SocketAddress remoteAddr) {
327    final LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
328    return linkRef != null ? (Link<T>) linkRef.getLink() : null;
329  }
330
331  /**
332   * Gets a server local socket address of this transport
333   *
334   * @return a server local socket address
335   */
336  @Override
337  public SocketAddress getLocalAddress() {
338    return this.localAddress;
339  }
340
341  /**
342   * Gets a server listening port of this transport
343   *
344   * @return a listening port number
345   */
346  @Override
347  public int getListeningPort() {
348    return this.serverPort;
349  }
350
351  /**
352   * Registers the exception event handler
353   *
354   * @param handler the exception event handler
355   */
356  @Override
357  public void registerErrorHandler(final EventHandler<Exception> handler) {
358    this.clientEventListener.registerErrorHandler(handler);
359    this.serverEventListener.registerErrorHandler(handler);
360  }
361}