001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.remote.impl;
020
021import org.apache.reef.wake.EStage;
022import org.apache.reef.wake.EventHandler;
023import org.apache.reef.wake.WakeParameters;
024import org.apache.reef.wake.impl.DefaultThreadFactory;
025import org.apache.reef.wake.impl.ThreadPoolStage;
026import org.apache.reef.wake.remote.exception.RemoteRuntimeException;
027
028import java.net.SocketAddress;
029import java.util.List;
030import java.util.concurrent.*;
031import java.util.logging.Level;
032import java.util.logging.Logger;
033
034/**
035 * Receive incoming events and dispatch to correct handlers in order
036 */
037public class OrderedRemoteReceiverStage implements EStage<TransportEvent> {
038
039  private static final Logger LOG = Logger.getLogger(OrderedRemoteReceiverStage.class.getName());
040  private final long shutdownTimeout = WakeParameters.REMOTE_EXECUTOR_SHUTDOWN_TIMEOUT;
041
042  private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap;
043  private final ExecutorService pushExecutor;
044  private final ExecutorService pullExecutor;
045
046  private final ThreadPoolStage<TransportEvent> pushStage;
047  private final ThreadPoolStage<OrderedEventStream> pullStage;
048
049  /**
050   * Constructs a ordered remote receiver stage
051   *
052   * @param handler      the handler of remote events
053   * @param errorHandler the exception handler
054   */
055  public OrderedRemoteReceiverStage(EventHandler<RemoteEvent<byte[]>> handler, EventHandler<Throwable> errorHandler) {
056    this.streamMap = new ConcurrentHashMap<SocketAddress, OrderedEventStream>();
057
058    this.pushExecutor = Executors.newCachedThreadPool(new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Push"));
059    this.pullExecutor = Executors.newCachedThreadPool(new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Pull"));
060
061    this.pullStage = new ThreadPoolStage<OrderedEventStream>(new OrderedPullEventHandler(handler), this.pullExecutor, errorHandler);
062    this.pushStage = new ThreadPoolStage<TransportEvent>(new OrderedPushEventHandler(streamMap, pullStage), this.pushExecutor, errorHandler); // for decoupling
063  }
064
065  @Override
066  public void onNext(TransportEvent value) {
067    LOG.log(Level.FINEST, "{0}", value);
068    pushStage.onNext(value);
069  }
070
071  @Override
072  public void close() throws Exception {
073    LOG.log(Level.FINE, "close");
074
075    if (pushExecutor != null) {
076      pushExecutor.shutdown();
077      try {
078        // wait for threads to finish for timeout
079        if (!pushExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) {
080          LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms.");
081          List<Runnable> droppedRunnables = pushExecutor.shutdownNow();
082          LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks.");
083        }
084      } catch (InterruptedException e) {
085        LOG.log(Level.WARNING, "Close interrupted");
086        throw new RemoteRuntimeException(e);
087      }
088    }
089
090    if (pullExecutor != null) {
091      pullExecutor.shutdown();
092      try {
093        // wait for threads to finish for timeout
094        if (!pullExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) {
095          LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms.");
096          List<Runnable> droppedRunnables = pullExecutor.shutdownNow();
097          LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks.");
098        }
099      } catch (InterruptedException e) {
100        LOG.log(Level.WARNING, "Close interrupted");
101        throw new RemoteRuntimeException(e);
102      }
103    }
104  }
105}
106
107class OrderedPushEventHandler implements EventHandler<TransportEvent> {
108
109  private static final Logger LOG = Logger.getLogger(OrderedPushEventHandler.class.getName());
110
111  private final RemoteEventCodec<byte[]> codec;
112  private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; // per remote address
113  private final ThreadPoolStage<OrderedEventStream> pullStage;
114
115  OrderedPushEventHandler(ConcurrentMap<SocketAddress, OrderedEventStream> streamMap,
116                          ThreadPoolStage<OrderedEventStream> pullStage) {
117    this.codec = new RemoteEventCodec<byte[]>(new ByteCodec());
118    this.streamMap = streamMap;
119    this.pullStage = pullStage;
120  }
121
122  @Override
123  public void onNext(TransportEvent value) {
124    RemoteEvent<byte[]> re = codec.decode(value.getData());
125    re.setLocalAddress(value.getLocalAddress());
126    re.setRemoteAddress(value.getRemoteAddress());
127
128    if (LOG.isLoggable(Level.FINER))
129      LOG.log(Level.FINER, "{0} {1}", new Object[]{value, re});
130
131    LOG.log(Level.FINER, "Value length is {0}", value.getData().length);
132
133    SocketAddress addr = re.remoteAddress();
134    OrderedEventStream stream = streamMap.get(re.remoteAddress());
135    if (stream == null) {
136      stream = new OrderedEventStream();
137      if (streamMap.putIfAbsent(addr, stream) != null) {
138        stream = streamMap.get(addr);
139      }
140    }
141    stream.add(re);
142    pullStage.onNext(stream);
143  }
144}
145
146class OrderedPullEventHandler implements EventHandler<OrderedEventStream> {
147
148  private static final Logger LOG = Logger.getLogger(OrderedPullEventHandler.class.getName());
149
150  private final EventHandler<RemoteEvent<byte[]>> handler;
151
152  OrderedPullEventHandler(EventHandler<RemoteEvent<byte[]>> handler) {
153    this.handler = handler;
154  }
155
156  @Override
157  public void onNext(OrderedEventStream stream) {
158    if (LOG.isLoggable(Level.FINER))
159      LOG.log(Level.FINER, "{0}", stream);
160
161    synchronized (stream) {
162      RemoteEvent<byte[]> event;
163      while ((event = stream.consume()) != null) {
164        handler.onNext(event);
165      }
166    }
167  }
168}
169
170class OrderedEventStream {
171  private static final Logger LOG = Logger.getLogger(OrderedEventStream.class.getName());
172  private final BlockingQueue<RemoteEvent<byte[]>> queue; // a queue of remote events
173  private long nextSeq; // the number of the next event to consume
174
175  OrderedEventStream() {
176    queue = new PriorityBlockingQueue<RemoteEvent<byte[]>>(11, new RemoteEventComparator<byte[]>());
177    nextSeq = 0;
178  }
179
180  synchronized void add(RemoteEvent<byte[]> event) {
181    queue.add(event);
182  }
183
184  synchronized RemoteEvent<byte[]> consume() {
185    RemoteEvent<byte[]> event = queue.peek();
186    if (event != null) {
187
188      if (event.getSeq() == nextSeq) {
189        event = queue.poll();
190        ++nextSeq;
191        return event;
192      } else {
193        LOG.log(Level.FINER, "Event sequence is {0} does not match expected {1}", new Object[]{event.getSeq(), nextSeq});
194      }
195    } else {
196      LOG.log(Level.FINER, "Event is null");
197    }
198
199    return null;
200  }
201}