001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.wake.remote.impl;
020
021import org.apache.reef.wake.EStage;
022import org.apache.reef.wake.EventHandler;
023import org.apache.reef.wake.WakeParameters;
024import org.apache.reef.wake.impl.DefaultThreadFactory;
025import org.apache.reef.wake.impl.ThreadPoolStage;
026import org.apache.reef.wake.remote.exception.RemoteRuntimeException;
027
028import java.net.SocketAddress;
029import java.util.List;
030import java.util.concurrent.*;
031import java.util.logging.Level;
032import java.util.logging.Logger;
033
034/**
035 * Receive incoming events and dispatch to correct handlers in order.
036 */
037public class OrderedRemoteReceiverStage implements EStage<TransportEvent> {
038
039  private static final Logger LOG = Logger.getLogger(OrderedRemoteReceiverStage.class.getName());
040  private final long shutdownTimeout = WakeParameters.REMOTE_EXECUTOR_SHUTDOWN_TIMEOUT;
041
042  private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap;
043  private final ExecutorService pushExecutor;
044  private final ExecutorService pullExecutor;
045
046  private final ThreadPoolStage<TransportEvent> pushStage;
047  private final ThreadPoolStage<OrderedEventStream> pullStage;
048
049  /**
050   * Constructs a ordered remote receiver stage.
051   *
052   * @param handler      the handler of remote events
053   * @param errorHandler the exception handler
054   */
055  public OrderedRemoteReceiverStage(
056      final EventHandler<RemoteEvent<byte[]>> handler, final EventHandler<Throwable> errorHandler) {
057    this.streamMap = new ConcurrentHashMap<SocketAddress, OrderedEventStream>();
058
059    this.pushExecutor = Executors.newCachedThreadPool(
060        new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Push"));
061    this.pullExecutor = Executors.newCachedThreadPool(
062        new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Pull"));
063
064    this.pullStage = new ThreadPoolStage<OrderedEventStream>(
065        new OrderedPullEventHandler(handler), this.pullExecutor, errorHandler);
066    this.pushStage = new ThreadPoolStage<TransportEvent>(
067        new OrderedPushEventHandler(streamMap, pullStage), this.pushExecutor, errorHandler); // for decoupling
068  }
069
070  @Override
071  public void onNext(final TransportEvent value) {
072    LOG.log(Level.FINEST, "{0}", value);
073    pushStage.onNext(value);
074  }
075
076  @Override
077  public void close() throws Exception {
078    LOG.log(Level.FINE, "close");
079
080    if (pushExecutor != null) {
081      pushExecutor.shutdown();
082      try {
083        // wait for threads to finish for timeout
084        if (!pushExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) {
085          LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms.");
086          final List<Runnable> droppedRunnables = pushExecutor.shutdownNow();
087          LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks.");
088        }
089      } catch (final InterruptedException e) {
090        LOG.log(Level.WARNING, "Close interrupted");
091        throw new RemoteRuntimeException(e);
092      }
093    }
094
095    if (pullExecutor != null) {
096      pullExecutor.shutdown();
097      try {
098        // wait for threads to finish for timeout
099        if (!pullExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) {
100          LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms.");
101          final List<Runnable> droppedRunnables = pullExecutor.shutdownNow();
102          LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks.");
103        }
104      } catch (final InterruptedException e) {
105        LOG.log(Level.WARNING, "Close interrupted");
106        throw new RemoteRuntimeException(e);
107      }
108    }
109  }
110}
111
112class OrderedPushEventHandler implements EventHandler<TransportEvent> {
113
114  private static final Logger LOG = Logger.getLogger(OrderedPushEventHandler.class.getName());
115
116  private final RemoteEventCodec<byte[]> codec;
117  private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; // per remote address
118  private final ThreadPoolStage<OrderedEventStream> pullStage;
119
120  OrderedPushEventHandler(final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap,
121                          final ThreadPoolStage<OrderedEventStream> pullStage) {
122    this.codec = new RemoteEventCodec<byte[]>(new ByteCodec());
123    this.streamMap = streamMap;
124    this.pullStage = pullStage;
125  }
126
127  @Override
128  public void onNext(final TransportEvent value) {
129    final RemoteEvent<byte[]> re = codec.decode(value.getData());
130    re.setLocalAddress(value.getLocalAddress());
131    re.setRemoteAddress(value.getRemoteAddress());
132
133    if (LOG.isLoggable(Level.FINER)) {
134      LOG.log(Level.FINER, "{0} {1}", new Object[]{value, re});
135    }
136
137    LOG.log(Level.FINER, "Value length is {0}", value.getData().length);
138
139    final SocketAddress addr = re.remoteAddress();
140    OrderedEventStream stream = streamMap.get(re.remoteAddress());
141    if (stream == null) {
142      stream = new OrderedEventStream();
143      if (streamMap.putIfAbsent(addr, stream) != null) {
144        stream = streamMap.get(addr);
145      }
146    }
147    stream.add(re);
148    pullStage.onNext(stream);
149  }
150}
151
152class OrderedPullEventHandler implements EventHandler<OrderedEventStream> {
153
154  private static final Logger LOG = Logger.getLogger(OrderedPullEventHandler.class.getName());
155
156  private final EventHandler<RemoteEvent<byte[]>> handler;
157
158  OrderedPullEventHandler(final EventHandler<RemoteEvent<byte[]>> handler) {
159    this.handler = handler;
160  }
161
162  @Override
163  public void onNext(final OrderedEventStream stream) {
164    if (LOG.isLoggable(Level.FINER)) {
165      LOG.log(Level.FINER, "{0}", stream);
166    }
167
168    synchronized (stream) {
169      RemoteEvent<byte[]> event;
170      while ((event = stream.consume()) != null) {
171        handler.onNext(event);
172      }
173    }
174  }
175}
176
177class OrderedEventStream {
178  private static final Logger LOG = Logger.getLogger(OrderedEventStream.class.getName());
179  private final BlockingQueue<RemoteEvent<byte[]>> queue; // a queue of remote events
180  private long nextSeq; // the number of the next event to consume
181
182  OrderedEventStream() {
183    queue = new PriorityBlockingQueue<RemoteEvent<byte[]>>(11, new RemoteEventComparator<byte[]>());
184    nextSeq = 0;
185  }
186
187  synchronized void add(final RemoteEvent<byte[]> event) {
188    queue.add(event);
189  }
190
191  synchronized RemoteEvent<byte[]> consume() {
192    RemoteEvent<byte[]> event = queue.peek();
193    if (event != null) {
194
195      if (event.getSeq() == nextSeq) {
196        event = queue.poll();
197        ++nextSeq;
198        return event;
199      } else {
200        LOG.log(Level.FINER, "Event sequence is {0} does not match expected {1}",
201            new Object[]{event.getSeq(), nextSeq});
202      }
203    } else {
204      LOG.log(Level.FINER, "Event is null");
205    }
206
207    return null;
208  }
209}