001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.wake.remote.impl; 020 021import org.apache.reef.wake.EStage; 022import org.apache.reef.wake.EventHandler; 023import org.apache.reef.wake.WakeParameters; 024import org.apache.reef.wake.impl.DefaultThreadFactory; 025import org.apache.reef.wake.impl.ThreadPoolStage; 026import org.apache.reef.wake.remote.exception.RemoteRuntimeException; 027 028import java.net.SocketAddress; 029import java.util.List; 030import java.util.concurrent.*; 031import java.util.logging.Level; 032import java.util.logging.Logger; 033 034/** 035 * Receive incoming events and dispatch to correct handlers in order. 036 */ 037public class OrderedRemoteReceiverStage implements EStage<TransportEvent> { 038 039 private static final Logger LOG = Logger.getLogger(OrderedRemoteReceiverStage.class.getName()); 040 private final long shutdownTimeout = WakeParameters.REMOTE_EXECUTOR_SHUTDOWN_TIMEOUT; 041 042 private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; 043 private final ExecutorService pushExecutor; 044 private final ExecutorService pullExecutor; 045 046 private final ThreadPoolStage<TransportEvent> pushStage; 047 private final ThreadPoolStage<OrderedEventStream> pullStage; 048 049 /** 050 * Constructs a ordered remote receiver stage. 051 * 052 * @param handler the handler of remote events 053 * @param errorHandler the exception handler 054 */ 055 public OrderedRemoteReceiverStage( 056 final EventHandler<RemoteEvent<byte[]>> handler, final EventHandler<Throwable> errorHandler) { 057 this.streamMap = new ConcurrentHashMap<SocketAddress, OrderedEventStream>(); 058 059 this.pushExecutor = Executors.newCachedThreadPool( 060 new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Push")); 061 this.pullExecutor = Executors.newCachedThreadPool( 062 new DefaultThreadFactory(OrderedRemoteReceiverStage.class.getName() + "_Pull")); 063 064 this.pullStage = new ThreadPoolStage<OrderedEventStream>( 065 new OrderedPullEventHandler(handler), this.pullExecutor, errorHandler); 066 this.pushStage = new ThreadPoolStage<TransportEvent>( 067 new OrderedPushEventHandler(streamMap, pullStage), this.pushExecutor, errorHandler); // for decoupling 068 } 069 070 @Override 071 public void onNext(final TransportEvent value) { 072 LOG.log(Level.FINEST, "{0}", value); 073 pushStage.onNext(value); 074 } 075 076 @Override 077 public void close() throws Exception { 078 LOG.log(Level.FINE, "close"); 079 080 if (pushExecutor != null) { 081 pushExecutor.shutdown(); 082 try { 083 // wait for threads to finish for timeout 084 if (!pushExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 085 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 086 final List<Runnable> droppedRunnables = pushExecutor.shutdownNow(); 087 LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks."); 088 } 089 } catch (final InterruptedException e) { 090 LOG.log(Level.WARNING, "Close interrupted"); 091 throw new RemoteRuntimeException(e); 092 } 093 } 094 095 if (pullExecutor != null) { 096 pullExecutor.shutdown(); 097 try { 098 // wait for threads to finish for timeout 099 if (!pullExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 100 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 101 final List<Runnable> droppedRunnables = pullExecutor.shutdownNow(); 102 LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks."); 103 } 104 } catch (final InterruptedException e) { 105 LOG.log(Level.WARNING, "Close interrupted"); 106 throw new RemoteRuntimeException(e); 107 } 108 } 109 } 110} 111 112class OrderedPushEventHandler implements EventHandler<TransportEvent> { 113 114 private static final Logger LOG = Logger.getLogger(OrderedPushEventHandler.class.getName()); 115 116 private final RemoteEventCodec<byte[]> codec; 117 private final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap; // per remote address 118 private final ThreadPoolStage<OrderedEventStream> pullStage; 119 120 OrderedPushEventHandler(final ConcurrentMap<SocketAddress, OrderedEventStream> streamMap, 121 final ThreadPoolStage<OrderedEventStream> pullStage) { 122 this.codec = new RemoteEventCodec<byte[]>(new ByteCodec()); 123 this.streamMap = streamMap; 124 this.pullStage = pullStage; 125 } 126 127 @Override 128 public void onNext(final TransportEvent value) { 129 final RemoteEvent<byte[]> re = codec.decode(value.getData()); 130 re.setLocalAddress(value.getLocalAddress()); 131 re.setRemoteAddress(value.getRemoteAddress()); 132 133 if (LOG.isLoggable(Level.FINER)) { 134 LOG.log(Level.FINER, "{0} {1}", new Object[]{value, re}); 135 } 136 137 LOG.log(Level.FINER, "Value length is {0}", value.getData().length); 138 139 final SocketAddress addr = re.remoteAddress(); 140 OrderedEventStream stream = streamMap.get(re.remoteAddress()); 141 if (stream == null) { 142 stream = new OrderedEventStream(); 143 if (streamMap.putIfAbsent(addr, stream) != null) { 144 stream = streamMap.get(addr); 145 } 146 } 147 stream.add(re); 148 pullStage.onNext(stream); 149 } 150} 151 152class OrderedPullEventHandler implements EventHandler<OrderedEventStream> { 153 154 private static final Logger LOG = Logger.getLogger(OrderedPullEventHandler.class.getName()); 155 156 private final EventHandler<RemoteEvent<byte[]>> handler; 157 158 OrderedPullEventHandler(final EventHandler<RemoteEvent<byte[]>> handler) { 159 this.handler = handler; 160 } 161 162 @Override 163 public void onNext(final OrderedEventStream stream) { 164 if (LOG.isLoggable(Level.FINER)) { 165 LOG.log(Level.FINER, "{0}", stream); 166 } 167 168 synchronized (stream) { 169 RemoteEvent<byte[]> event; 170 while ((event = stream.consume()) != null) { 171 handler.onNext(event); 172 } 173 } 174 } 175} 176 177class OrderedEventStream { 178 private static final Logger LOG = Logger.getLogger(OrderedEventStream.class.getName()); 179 private final BlockingQueue<RemoteEvent<byte[]>> queue; // a queue of remote events 180 private long nextSeq; // the number of the next event to consume 181 182 OrderedEventStream() { 183 queue = new PriorityBlockingQueue<RemoteEvent<byte[]>>(11, new RemoteEventComparator<byte[]>()); 184 nextSeq = 0; 185 } 186 187 synchronized void add(final RemoteEvent<byte[]> event) { 188 queue.add(event); 189 } 190 191 synchronized RemoteEvent<byte[]> consume() { 192 RemoteEvent<byte[]> event = queue.peek(); 193 if (event != null) { 194 195 if (event.getSeq() == nextSeq) { 196 event = queue.poll(); 197 ++nextSeq; 198 return event; 199 } else { 200 LOG.log(Level.FINER, "Event sequence is {0} does not match expected {1}", 201 new Object[]{event.getSeq(), nextSeq}); 202 } 203 } else { 204 LOG.log(Level.FINER, "Event is null"); 205 } 206 207 return null; 208 } 209}