001/** 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.wake.rx.impl; 020 021import org.apache.reef.tang.annotations.Parameter; 022import org.apache.reef.wake.StageConfiguration.NumberOfThreads; 023import org.apache.reef.wake.StageConfiguration.StageName; 024import org.apache.reef.wake.StageConfiguration.StageObserver; 025import org.apache.reef.wake.WakeParameters; 026import org.apache.reef.wake.exception.WakeRuntimeException; 027import org.apache.reef.wake.impl.DefaultThreadFactory; 028import org.apache.reef.wake.impl.StageManager; 029import org.apache.reef.wake.rx.AbstractRxStage; 030import org.apache.reef.wake.rx.Observer; 031 032import javax.inject.Inject; 033import java.util.List; 034import java.util.concurrent.*; 035import java.util.logging.Level; 036import java.util.logging.Logger; 037 038/** 039 * Stage that executes the observer with a thread pool. 040 * <p/> 041 * {@code onNext}'s will be arbitrarily subject to reordering, as with most stages. 042 * <p/> 043 * All {@code onNext}'s for which returning from the method call 044 * happens-before the call to {@code onComplete} will maintain 045 * this relationship when passed to the observer. 046 * <p/> 047 * Any {@code onNext} whose return is not ordered before 048 * {@code onComplete} may or may not get dropped. 049 * 050 * @param <T> type of event 051 */ 052public final class RxThreadPoolStage<T> extends AbstractRxStage<T> { 053 private static final Logger LOG = Logger.getLogger(RxThreadPoolStage.class.getName()); 054 055 private final Observer<T> observer; 056 private final ExecutorService executor; 057 private final long shutdownTimeout = WakeParameters.EXECUTOR_SHUTDOWN_TIMEOUT; 058 private ExecutorService completionExecutor; 059 private DefaultThreadFactory tf; 060 061 /** 062 * Constructs a Rx thread pool stage 063 * 064 * @param observer the observer to execute 065 * @param numThreads the number of threads 066 */ 067 @Inject 068 public RxThreadPoolStage(@Parameter(StageObserver.class) final Observer<T> observer, 069 @Parameter(NumberOfThreads.class) final int numThreads) { 070 this(observer.getClass().getName(), observer, numThreads); 071 } 072 073 /** 074 * Constructs a Rx thread pool stage 075 * 076 * @param name the stage name 077 * @param observer the observer to execute 078 * @param numThreads the number of threads 079 */ 080 @Inject 081 public RxThreadPoolStage(@Parameter(StageName.class) final String name, 082 @Parameter(StageObserver.class) final Observer<T> observer, 083 @Parameter(NumberOfThreads.class) final int numThreads) { 084 super(name); 085 this.observer = observer; 086 if (numThreads <= 0) 087 throw new WakeRuntimeException(name + " numThreads " + numThreads + " is less than or equal to 0"); 088 tf = new DefaultThreadFactory(name); 089 this.executor = Executors.newFixedThreadPool(numThreads, tf); 090 this.completionExecutor = Executors.newSingleThreadExecutor(tf); 091 StageManager.instance().register(this); 092 } 093 094 /** 095 * Provides the observer with the new value 096 * 097 * @param value the new value 098 */ 099 @Override 100 public void onNext(final T value) { 101 beforeOnNext(); 102 executor.submit(new Runnable() { 103 104 @Override 105 public void run() { 106 observer.onNext(value); 107 afterOnNext(); 108 } 109 }); 110 } 111 112 /** 113 * Notifies the observer that the provider has experienced an error 114 * condition. 115 * 116 * @param error the error 117 */ 118 @Override 119 public void onError(final Exception error) { 120 submitCompletion(new Runnable() { 121 122 @Override 123 public void run() { 124 observer.onError(error); 125 } 126 127 }); 128 } 129 130 /** 131 * Notifies the observer that the provider has finished sending push-based 132 * notifications. 133 */ 134 @Override 135 public void onCompleted() { 136 submitCompletion(new Runnable() { 137 138 @Override 139 public void run() { 140 observer.onCompleted(); 141 } 142 143 }); 144 } 145 146 private void submitCompletion(final Runnable r) { 147 executor.shutdown(); 148 completionExecutor.submit(new Runnable() { 149 150 @Override 151 public void run() { 152 try { 153 // no timeout for completion, only close() 154 if (!executor.awaitTermination(3153600000L, TimeUnit.SECONDS)) { 155 LOG.log(Level.SEVERE, "Executor terminated due to unrequired timeout"); 156 observer.onError(new TimeoutException()); 157 } 158 } catch (InterruptedException e) { 159 e.printStackTrace(); 160 observer.onError(e); 161 } 162 r.run(); 163 } 164 }); 165 } 166 167 /** 168 * Closes the stage 169 * 170 * @return Exception 171 */ 172 @Override 173 public void close() throws Exception { 174 if (closed.compareAndSet(false, true)) { 175 executor.shutdown(); 176 completionExecutor.shutdown(); 177 if (!executor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 178 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 179 List<Runnable> droppedRunnables = executor.shutdownNow(); 180 LOG.log(Level.WARNING, "Executor dropped " + droppedRunnables.size() + " tasks."); 181 } 182 if (!completionExecutor.awaitTermination(shutdownTimeout, TimeUnit.MILLISECONDS)) { 183 LOG.log(Level.WARNING, "Executor did not terminate in " + shutdownTimeout + "ms."); 184 List<Runnable> droppedRunnables = completionExecutor.shutdownNow(); 185 LOG.log(Level.WARNING, "Completion executor dropped " + droppedRunnables.size() + " tasks."); 186 } 187 } 188 } 189 190 /** 191 * Gets the queue length of this stage 192 * 193 * @return the queue length 194 */ 195 public int getQueueLength() { 196 return ((ThreadPoolExecutor) executor).getQueue().size(); 197 } 198}