001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.remote.transport.netty;
020
021import io.netty.bootstrap.Bootstrap;
022import io.netty.bootstrap.ServerBootstrap;
023import io.netty.channel.Channel;
024import io.netty.channel.ChannelFuture;
025import io.netty.channel.ChannelOption;
026import io.netty.channel.EventLoopGroup;
027import io.netty.channel.group.ChannelGroup;
028import io.netty.channel.group.DefaultChannelGroup;
029import io.netty.channel.nio.NioEventLoopGroup;
030import io.netty.channel.socket.nio.NioServerSocketChannel;
031import io.netty.channel.socket.nio.NioSocketChannel;
032import io.netty.util.concurrent.GlobalEventExecutor;
033import org.apache.reef.tang.annotations.Parameter;
034import org.apache.reef.wake.EStage;
035import org.apache.reef.wake.EventHandler;
036import org.apache.reef.wake.impl.DefaultThreadFactory;
037import org.apache.reef.wake.remote.Encoder;
038import org.apache.reef.wake.remote.RemoteConfiguration;
039import org.apache.reef.wake.remote.address.LocalAddressProvider;
040import org.apache.reef.wake.remote.exception.RemoteRuntimeException;
041import org.apache.reef.wake.remote.impl.TransportEvent;
042import org.apache.reef.wake.remote.ports.TcpPortProvider;
043import org.apache.reef.wake.remote.transport.Link;
044import org.apache.reef.wake.remote.transport.LinkListener;
045import org.apache.reef.wake.remote.transport.Transport;
046import org.apache.reef.wake.remote.transport.exception.TransportRuntimeException;
047
048import javax.inject.Inject;
049import java.io.IOException;
050import java.net.BindException;
051import java.net.ConnectException;
052import java.net.InetSocketAddress;
053import java.net.SocketAddress;
054import java.util.Iterator;
055import java.util.concurrent.ConcurrentHashMap;
056import java.util.concurrent.ConcurrentMap;
057import java.util.concurrent.atomic.AtomicInteger;
058import java.util.logging.Level;
059import java.util.logging.Logger;
060
061/**
062 * Messaging transport implementation with Netty.
063 */
064public class NettyMessagingTransport implements Transport {
065
066  private static final String CLASS_NAME = NettyMessagingTransport.class.getName();
067  private static final Logger LOG = Logger.getLogger(CLASS_NAME);
068
069  private static final int SERVER_BOSS_NUM_THREADS = 3;
070  private static final int SERVER_WORKER_NUM_THREADS = 20;
071  private static final int CLIENT_WORKER_NUM_THREADS = 10;
072
073  private final ConcurrentMap<SocketAddress, LinkReference> addrToLinkRefMap = new ConcurrentHashMap<>();
074
075  private final EventLoopGroup clientWorkerGroup;
076  private final EventLoopGroup serverBossGroup;
077  private final EventLoopGroup serverWorkerGroup;
078
079  private final Bootstrap clientBootstrap;
080  private final ServerBootstrap serverBootstrap;
081  private final Channel acceptor;
082
083  private final ChannelGroup clientChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
084  private final ChannelGroup serverChannelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
085
086  private final int serverPort;
087  private final SocketAddress localAddress;
088
089  private final NettyClientEventListener clientEventListener;
090  private final NettyServerEventListener serverEventListener;
091
092  private final int numberOfTries;
093  private final int retryTimeout;
094  /**
095   * Indicates a hostname that isn't set or known.
096   */
097  public static final String UNKNOWN_HOST_NAME = "##UNKNOWN##";
098
099  /**
100   * Constructs a messaging transport.
101   *
102   * @param hostAddress   the server host address
103   * @param port          the server listening port; when it is 0, randomly assign a port number
104   * @param clientStage   the client-side stage that handles transport events
105   * @param serverStage   the server-side stage that handles transport events
106   * @param numberOfTries the number of tries of connection
107   * @param retryTimeout  the timeout of reconnection
108   * @param tcpPortProvider  gives an iterator that produces random tcp ports in a range
109   */
110  @Inject
111  NettyMessagingTransport(
112      @Parameter(RemoteConfiguration.HostAddress.class) final String hostAddress,
113      @Parameter(RemoteConfiguration.Port.class) final int port,
114      @Parameter(RemoteConfiguration.RemoteClientStage.class) final EStage<TransportEvent> clientStage,
115      @Parameter(RemoteConfiguration.RemoteServerStage.class) final EStage<TransportEvent> serverStage,
116      @Parameter(RemoteConfiguration.NumberOfTries.class) final int numberOfTries,
117      @Parameter(RemoteConfiguration.RetryTimeout.class) final int retryTimeout,
118      final TcpPortProvider tcpPortProvider,
119      final LocalAddressProvider localAddressProvider) {
120
121    int p = port;
122    if (p < 0) {
123      throw new RemoteRuntimeException("Invalid server port: " + p);
124    }
125
126    final String host = UNKNOWN_HOST_NAME.equals(hostAddress) ? localAddressProvider.getLocalAddress() : hostAddress;
127
128    this.numberOfTries = numberOfTries;
129    this.retryTimeout = retryTimeout;
130    this.clientEventListener = new NettyClientEventListener(this.addrToLinkRefMap, clientStage);
131    this.serverEventListener = new NettyServerEventListener(this.addrToLinkRefMap, serverStage);
132
133    this.serverBossGroup = new NioEventLoopGroup(SERVER_BOSS_NUM_THREADS,
134        new DefaultThreadFactory(CLASS_NAME + "ServerBoss"));
135    this.serverWorkerGroup = new NioEventLoopGroup(SERVER_WORKER_NUM_THREADS,
136        new DefaultThreadFactory(CLASS_NAME + "ServerWorker"));
137    this.clientWorkerGroup = new NioEventLoopGroup(CLIENT_WORKER_NUM_THREADS,
138        new DefaultThreadFactory(CLASS_NAME + "ClientWorker"));
139
140    this.clientBootstrap = new Bootstrap();
141    this.clientBootstrap.group(this.clientWorkerGroup)
142        .channel(NioSocketChannel.class)
143        .handler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("client",
144            this.clientChannelGroup, this.clientEventListener)))
145        .option(ChannelOption.SO_REUSEADDR, true)
146        .option(ChannelOption.SO_KEEPALIVE, true);
147
148    this.serverBootstrap = new ServerBootstrap();
149    this.serverBootstrap.group(this.serverBossGroup, this.serverWorkerGroup)
150        .channel(NioServerSocketChannel.class)
151        .childHandler(new NettyChannelInitializer(new NettyDefaultChannelHandlerFactory("server",
152            this.serverChannelGroup, this.serverEventListener)))
153        .option(ChannelOption.SO_BACKLOG, 128)
154        .option(ChannelOption.SO_REUSEADDR, true)
155        .childOption(ChannelOption.SO_KEEPALIVE, true);
156
157    LOG.log(Level.FINE, "Binding to {0}", p);
158
159    Channel acceptorFound = null;
160    try {
161      if (p > 0) {
162        acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel();
163      } else {
164        final Iterator<Integer> ports = tcpPortProvider.iterator();
165        while (acceptorFound == null) {
166          if (!ports.hasNext()) {
167            break;
168          }
169          p = ports.next();
170          LOG.log(Level.FINEST, "Try port {0}", p);
171          try {
172            acceptorFound = this.serverBootstrap.bind(new InetSocketAddress(host, p)).sync().channel();
173          } catch (final Exception ex) {
174            if (ex instanceof BindException) {
175              LOG.log(Level.FINEST, "The port {0} is already bound. Try again", p);
176            } else {
177              throw ex;
178            }
179          }
180        }
181      }
182    } catch (final Exception ex) {
183      final RuntimeException transportException =
184          new TransportRuntimeException("Cannot bind to port " + p);
185      LOG.log(Level.SEVERE, "Cannot bind to port " + p, ex);
186
187      this.clientWorkerGroup.shutdownGracefully();
188      this.serverBossGroup.shutdownGracefully();
189      this.serverWorkerGroup.shutdownGracefully();
190      throw transportException;
191    }
192
193    this.acceptor = acceptorFound;
194    this.serverPort = p;
195    this.localAddress = new InetSocketAddress(host, this.serverPort);
196
197    LOG.log(Level.FINE, "Starting netty transport socket address: {0}", this.localAddress);
198  }
199
200  /**
201   * Closes all channels and releases all resources.
202   */
203  @Override
204  public void close() throws Exception {
205
206    LOG.log(Level.FINE, "Closing netty transport socket address: {0}", this.localAddress);
207
208    this.clientChannelGroup.close().awaitUninterruptibly();
209    this.serverChannelGroup.close().awaitUninterruptibly();
210    this.acceptor.close().sync();
211    this.clientWorkerGroup.shutdownGracefully();
212    this.serverBossGroup.shutdownGracefully();
213    this.serverWorkerGroup.shutdownGracefully();
214
215    LOG.log(Level.FINE, "Closing netty transport socket address: {0} done", this.localAddress);
216  }
217
218  /**
219   * Returns a link for the remote address if cached; otherwise opens, caches and returns.
220   * When it opens a link for the remote address, only one attempt for the address is made at a given time
221   *
222   * @param remoteAddr the remote socket address
223   * @param encoder    the encoder
224   * @param listener   the link listener
225   * @return a link associated with the address
226   */
227  @Override
228  public <T> Link<T> open(final SocketAddress remoteAddr, final Encoder<? super T> encoder,
229                          final LinkListener<? super T> listener) throws IOException {
230
231    Link<T> link = null;
232
233    for (int i = 0; i <= this.numberOfTries; ++i) {
234      LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
235
236      if (linkRef != null) {
237        link = (Link<T>) linkRef.getLink();
238        if (LOG.isLoggable(Level.FINE)) {
239          LOG.log(Level.FINE, "Link {0} for {1} found", new Object[]{link, remoteAddr});
240        }
241        if (link != null) {
242          return link;
243        }
244      }
245      
246      if (i == this.numberOfTries) {
247        // Connection failure 
248        throw new ConnectException("Connection to " + remoteAddr + " refused");
249      }
250
251      LOG.log(Level.FINE, "No cached link for {0} thread {1}",
252          new Object[]{remoteAddr, Thread.currentThread()});
253
254      // no linkRef
255      final LinkReference newLinkRef = new LinkReference();
256      final LinkReference prior = this.addrToLinkRefMap.putIfAbsent(remoteAddr, newLinkRef);
257      final AtomicInteger flag = prior != null ?
258          prior.getConnectInProgress() : newLinkRef.getConnectInProgress();
259
260      synchronized (flag) {
261        if (!flag.compareAndSet(0, 1)) {
262          while (flag.get() == 1) {
263            try {
264              flag.wait();
265            } catch (final InterruptedException ex) {
266              LOG.log(Level.WARNING, "Wait interrupted", ex);
267            }
268          }
269        }
270      }
271
272      linkRef = this.addrToLinkRefMap.get(remoteAddr);
273      link = (Link<T>) linkRef.getLink();
274
275      if (link != null) {
276        return link;
277      }
278
279      ChannelFuture connectFuture = null;
280      try {
281        connectFuture = this.clientBootstrap.connect(remoteAddr);
282        connectFuture.syncUninterruptibly();
283
284        link = new NettyLink<>(connectFuture.channel(), encoder, listener);
285        linkRef.setLink(link);
286
287        synchronized (flag) {
288          flag.compareAndSet(1, 2);
289          flag.notifyAll();
290        }
291        break;
292      } catch (final Exception e) {
293        if (e.getClass().getSimpleName().compareTo("ConnectException") == 0) {
294          LOG.log(Level.WARNING, "Connection refused. Retry {0} of {1}",
295              new Object[]{i + 1, this.numberOfTries});
296          synchronized (flag) {
297            flag.compareAndSet(1, 0);
298            flag.notifyAll();
299          }
300
301          if (i < this.numberOfTries) {
302            try {
303              Thread.sleep(retryTimeout);
304            } catch (final InterruptedException interrupt) {
305              LOG.log(Level.WARNING, "Thread {0} interrupted while sleeping", Thread.currentThread());
306            }
307          }
308        } else {
309          throw e;
310        }
311      }
312    }
313    
314    return link;
315  }
316
317  /**
318   * Returns a link for the remote address if already cached; otherwise, returns null.
319   *
320   * @param remoteAddr the remote address
321   * @return a link if already cached; otherwise, null
322   */
323  public <T> Link<T> get(final SocketAddress remoteAddr) {
324    final LinkReference linkRef = this.addrToLinkRefMap.get(remoteAddr);
325    return linkRef != null ? (Link<T>) linkRef.getLink() : null;
326  }
327
328  /**
329   * Gets a server local socket address of this transport.
330   *
331   * @return a server local socket address
332   */
333  @Override
334  public SocketAddress getLocalAddress() {
335    return this.localAddress;
336  }
337
338  /**
339   * Gets a server listening port of this transport.
340   *
341   * @return a listening port number
342   */
343  @Override
344  public int getListeningPort() {
345    return this.serverPort;
346  }
347
348  /**
349   * Registers the exception event handler.
350   *
351   * @param handler the exception event handler
352   */
353  @Override
354  public void registerErrorHandler(final EventHandler<Exception> handler) {
355    this.clientEventListener.registerErrorHandler(handler);
356    this.serverEventListener.registerErrorHandler(handler);
357  }
358}