001/** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.wake.remote.impl; 020 021import org.apache.reef.wake.EStage; 022import org.apache.reef.wake.EventHandler; 023import org.apache.reef.wake.WakeParameters; 024import org.apache.reef.wake.impl.DefaultThreadFactory; 025import org.apache.reef.wake.impl.ThreadPoolStage; 026import org.apache.reef.wake.remote.exception.RemoteRuntimeException; 027 028import java.net.SocketAddress; 029import java.util.List; 030import java.util.concurrent.*; 031import java.util.logging.Level; 032import java.util.logging.Logger; 033 034/** 035 * Receive incoming events and dispatch to correct handlers in order 036 */ 037public class OrderedRemoteReceiverStage implements EStage<TransportEvent> { 038 039 private static final Logger LOG = Logger.getLogger(OrderedRemoteReceiverStage.class.getName()); 040 private final long shutdownTimeout = WakeParameters.REMOTE_EXECUTOR_SHUTDOWN_TIMEOUT; 041 042 private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; 043 private final ExecutorService pushExecutor; 044 private final ExecutorService pullExecutor; 045 046 private final ThreadPoolStage<TransportEvent> pushStage; 047 private final ThreadPoolStage<OrderedEventStream> pullStage; 048 049 /** 050 * Constructs a ordered remote receiver stage 051 * 052 * @param handler the handler of remote events 053 * @param errorHandler the exception handler 054 */ 055 public OrderedRemoteReceiverStage(EventHandler<RemoteEvent<byte[]>> handler, EventHandler<Throwable> errorHandler) { 056 this.streamMap = new ConcurrentHashMap<SocketAddress, OrderedEventStream>(); 057 058 this.pushExecutor = Executors.newCachedThreadPool(new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Push")); 059 this.pullExecutor = Executors.newCachedThreadPool(new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Pull")); 060 061 this.pullStage = new ThreadPoolStage<OrderedEventStream>(new OrderedPullEventHandler(handler), this.pullExecutor, errorHandler); 062 this.pushStage = new ThreadPoolStage<TransportEvent>(new OrderedPushEventHandler(streamMap, pullStage), this.pushExecutor, errorHandler); // for decoupling 063 } 064 065 @Override 066 public void onNext(TransportEvent value) { 067 LOG.log(Level.FINEST, "{0}", value); 068 pushStage.onNext(value); 069 } 070 071 @Override 072 public void close() throws Exception { 073 LOG.log(Level.FINE, "close"); 074 075 if (pushExecutor != null) { 076 pushExecutor.shutdown(); 077 try { 078 // wait for threads to finish for timeout 079 if (!pushExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 080 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 081 List<Runnable> droppedRunnables = pushExecutor.shutdownNow(); 082 LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks."); 083 } 084 } catch (InterruptedException e) { 085 LOG.log(Level.WARNING, "Close interrupted"); 086 throw new RemoteRuntimeException(e); 087 } 088 } 089 090 if (pullExecutor != null) { 091 pullExecutor.shutdown(); 092 try { 093 // wait for threads to finish for timeout 094 if (!pullExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 095 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 096 List<Runnable> droppedRunnables = pullExecutor.shutdownNow(); 097 LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks."); 098 } 099 } catch (InterruptedException e) { 100 LOG.log(Level.WARNING, "Close interrupted"); 101 throw new RemoteRuntimeException(e); 102 } 103 } 104 } 105} 106 107class OrderedPushEventHandler implements EventHandler<TransportEvent> { 108 109 private static final Logger LOG = Logger.getLogger(OrderedPushEventHandler.class.getName()); 110 111 private final RemoteEventCodec<byte[]> codec; 112 private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; // per remote address 113 private final ThreadPoolStage<OrderedEventStream> pullStage; 114 115 OrderedPushEventHandler(ConcurrentMap<SocketAddress, OrderedEventStream> streamMap, 116 ThreadPoolStage<OrderedEventStream> pullStage) { 117 this.codec = new RemoteEventCodec<byte[]>(new ByteCodec()); 118 this.streamMap = streamMap; 119 this.pullStage = pullStage; 120 } 121 122 @Override 123 public void onNext(TransportEvent value) { 124 RemoteEvent<byte[]> re = codec.decode(value.getData()); 125 re.setLocalAddress(value.getLocalAddress()); 126 re.setRemoteAddress(value.getRemoteAddress()); 127 128 if (LOG.isLoggable(Level.FINER)) 129 LOG.log(Level.FINER, "{0} {1}", new Object[]{value, re}); 130 131 LOG.log(Level.FINER, "Value length is {0}", value.getData().length); 132 133 SocketAddress addr = re.remoteAddress(); 134 OrderedEventStream stream = streamMap.get(re.remoteAddress()); 135 if (stream == null) { 136 stream = new OrderedEventStream(); 137 if (streamMap.putIfAbsent(addr, stream) != null) { 138 stream = streamMap.get(addr); 139 } 140 } 141 stream.add(re); 142 pullStage.onNext(stream); 143 } 144} 145 146class OrderedPullEventHandler implements EventHandler<OrderedEventStream> { 147 148 private static final Logger LOG = Logger.getLogger(OrderedPullEventHandler.class.getName()); 149 150 private final EventHandler<RemoteEvent<byte[]>> handler; 151 152 OrderedPullEventHandler(EventHandler<RemoteEvent<byte[]>> handler) { 153 this.handler = handler; 154 } 155 156 @Override 157 public void onNext(OrderedEventStream stream) { 158 if (LOG.isLoggable(Level.FINER)) 159 LOG.log(Level.FINER, "{0}", stream); 160 161 synchronized (stream) { 162 RemoteEvent<byte[]> event; 163 while ((event = stream.consume()) != null) { 164 handler.onNext(event); 165 } 166 } 167 } 168} 169 170class OrderedEventStream { 171 private static final Logger LOG = Logger.getLogger(OrderedEventStream.class.getName()); 172 private final BlockingQueue<RemoteEvent<byte[]>> queue; // a queue of remote events 173 private long nextSeq; // the number of the next event to consume 174 175 OrderedEventStream() { 176 queue = new PriorityBlockingQueue<RemoteEvent<byte[]>>(11, new RemoteEventComparator<byte[]>()); 177 nextSeq = 0; 178 } 179 180 synchronized void add(RemoteEvent<byte[]> event) { 181 queue.add(event); 182 } 183 184 synchronized RemoteEvent<byte[]> consume() { 185 RemoteEvent<byte[]> event = queue.peek(); 186 if (event != null) { 187 188 if (event.getSeq() == nextSeq) { 189 event = queue.poll(); 190 ++nextSeq; 191 return event; 192 } else { 193 LOG.log(Level.FINER, "Event sequence is {0} does not match expected {1}", new Object[]{event.getSeq(), nextSeq}); 194 } 195 } else { 196 LOG.log(Level.FINER, "Event is null"); 197 } 198 199 return null; 200 } 201}