001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.auto; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.net.Socket; 022import java.net.URI; 023import java.net.URISyntaxException; 024import java.nio.ByteBuffer; 025import java.util.HashMap; 026import java.util.Map; 027import java.util.Set; 028import java.util.concurrent.ConcurrentHashMap; 029import java.util.concurrent.ConcurrentMap; 030import java.util.concurrent.ExecutorService; 031import java.util.concurrent.Executors; 032import java.util.concurrent.Future; 033import java.util.concurrent.LinkedBlockingQueue; 034import java.util.concurrent.ThreadPoolExecutor; 035import java.util.concurrent.TimeUnit; 036import java.util.concurrent.TimeoutException; 037import java.util.concurrent.atomic.AtomicInteger; 038 039import javax.net.ServerSocketFactory; 040 041import org.apache.activemq.broker.BrokerService; 042import org.apache.activemq.broker.BrokerServiceAware; 043import org.apache.activemq.openwire.OpenWireFormatFactory; 044import org.apache.activemq.transport.InactivityIOException; 045import org.apache.activemq.transport.Transport; 046import org.apache.activemq.transport.TransportFactory; 047import org.apache.activemq.transport.TransportServer; 048import org.apache.activemq.transport.protocol.AmqpProtocolVerifier; 049import org.apache.activemq.transport.protocol.MqttProtocolVerifier; 050import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier; 051import org.apache.activemq.transport.protocol.ProtocolVerifier; 052import org.apache.activemq.transport.protocol.StompProtocolVerifier; 053import org.apache.activemq.transport.tcp.TcpTransport; 054import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer; 055import org.apache.activemq.transport.tcp.TcpTransportFactory; 056import org.apache.activemq.transport.tcp.TcpTransportServer; 057import org.apache.activemq.util.FactoryFinder; 058import org.apache.activemq.util.IOExceptionSupport; 059import org.apache.activemq.util.IntrospectionSupport; 060import org.apache.activemq.util.ServiceStopper; 061import org.apache.activemq.wireformat.WireFormat; 062import org.apache.activemq.wireformat.WireFormatFactory; 063import org.slf4j.Logger; 064import org.slf4j.LoggerFactory; 065 066/** 067 * A TCP based implementation of {@link TransportServer} 068 */ 069public class AutoTcpTransportServer extends TcpTransportServer { 070 071 private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class); 072 073 protected Map<String, Map<String, Object>> wireFormatOptions; 074 protected Map<String, Object> autoTransportOptions; 075 protected Set<String> enabledProtocols; 076 protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>(); 077 078 protected BrokerService brokerService; 079 080 protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE; 081 protected int protocolDetectionTimeOut = 30000; 082 083 private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/"); 084 private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>(); 085 086 private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/"); 087 088 public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException { 089 WireFormatFactory wff = null; 090 try { 091 wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme); 092 if (options != null) { 093 final Map<String, Object> wfOptions = new HashMap<>(); 094 if (options.get(AutoTransportUtils.ALL) != null) { 095 wfOptions.putAll(options.get(AutoTransportUtils.ALL)); 096 } 097 if (options.get(scheme) != null) { 098 wfOptions.putAll(options.get(scheme)); 099 } 100 IntrospectionSupport.setProperties(wff, wfOptions); 101 } 102 if (wff instanceof OpenWireFormatFactory) { 103 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff)); 104 } 105 return wff; 106 } catch (Throwable e) { 107 throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e); 108 } 109 } 110 111 public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException { 112 scheme = append(scheme, "nio"); 113 scheme = append(scheme, "ssl"); 114 115 if (scheme.isEmpty()) { 116 scheme = "tcp"; 117 } 118 119 TransportFactory tf = transportFactories.get(scheme); 120 if (tf == null) { 121 // Try to load if from a META-INF property. 122 try { 123 tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme); 124 if (options != null) { 125 IntrospectionSupport.setProperties(tf, options); 126 } 127 transportFactories.put(scheme, tf); 128 } catch (Throwable e) { 129 throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e); 130 } 131 } 132 return tf; 133 } 134 135 protected String append(String currentScheme, String scheme) { 136 if (this.getBindLocation().getScheme().contains(scheme)) { 137 if (!currentScheme.isEmpty()) { 138 currentScheme += "+"; 139 } 140 currentScheme += scheme; 141 } 142 return currentScheme; 143 } 144 145 /** 146 * @param transportFactory 147 * @param location 148 * @param serverSocketFactory 149 * @throws IOException 150 * @throws URISyntaxException 151 */ 152 public AutoTcpTransportServer(TcpTransportFactory transportFactory, 153 URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService, 154 Set<String> enabledProtocols) 155 throws IOException, URISyntaxException { 156 super(transportFactory, location, serverSocketFactory); 157 158 //Use an executor service here to handle new connections. Setting the max number 159 //of threads to the maximum number of connections the thread count isn't unbounded 160 service = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 161 maxConnectionThreadPoolSize, 162 30L, TimeUnit.SECONDS, 163 new LinkedBlockingQueue<Runnable>()); 164 //allow the thread pool to shrink if the max number of threads isn't needed 165 service.allowCoreThreadTimeOut(true); 166 167 this.brokerService = brokerService; 168 this.enabledProtocols = enabledProtocols; 169 initProtocolVerifiers(); 170 } 171 172 public int getMaxConnectionThreadPoolSize() { 173 return maxConnectionThreadPoolSize; 174 } 175 176 public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) { 177 this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize; 178 service.setCorePoolSize(maxConnectionThreadPoolSize); 179 service.setMaximumPoolSize(maxConnectionThreadPoolSize); 180 } 181 182 public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) { 183 this.protocolDetectionTimeOut = protocolDetectionTimeOut; 184 } 185 186 @Override 187 public void setWireFormatFactory(WireFormatFactory factory) { 188 super.setWireFormatFactory(factory); 189 initOpenWireProtocolVerifier(); 190 } 191 192 protected void initProtocolVerifiers() { 193 initOpenWireProtocolVerifier(); 194 195 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) { 196 protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier()); 197 } 198 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) { 199 protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier()); 200 } 201 if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) { 202 protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier()); 203 } 204 } 205 206 protected void initOpenWireProtocolVerifier() { 207 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) { 208 OpenWireProtocolVerifier owpv; 209 if (wireFormatFactory instanceof OpenWireFormatFactory) { 210 owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory); 211 } else { 212 owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory()); 213 } 214 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv); 215 } 216 } 217 218 protected boolean isAllProtocols() { 219 return enabledProtocols == null || enabledProtocols.isEmpty(); 220 } 221 222 223 protected final ThreadPoolExecutor service; 224 225 @Override 226 protected void handleSocket(final Socket socket) { 227 final AutoTcpTransportServer server = this; 228 //This needs to be done in a new thread because 229 //the socket might be waiting on the client to send bytes 230 //doHandleSocket can't complete until the protocol can be detected 231 service.submit(new Runnable() { 232 @Override 233 public void run() { 234 server.doHandleSocket(socket); 235 } 236 }); 237 } 238 239 @Override 240 protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception { 241 final InputStream is = socket.getInputStream(); 242 ExecutorService executor = Executors.newSingleThreadExecutor(); 243 244 final AtomicInteger readBytes = new AtomicInteger(0); 245 final ByteBuffer data = ByteBuffer.allocate(8); 246 // We need to peak at the first 8 bytes of the buffer to detect the protocol 247 Future<?> future = executor.submit(new Runnable() { 248 @Override 249 public void run() { 250 try { 251 do { 252 int read = is.read(); 253 if (read == -1) { 254 throw new IOException("Connection failed, stream is closed."); 255 } 256 data.put((byte) read); 257 readBytes.incrementAndGet(); 258 } while (readBytes.get() < 8); 259 } catch (Exception e) { 260 throw new IllegalStateException(e); 261 } 262 } 263 }); 264 265 waitForProtocolDetectionFinish(future, readBytes); 266 data.flip(); 267 ProtocolInfo protocolInfo = detectProtocol(data.array()); 268 269 InitBuffer initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get())); 270 initBuffer.buffer.put(data.array()); 271 272 if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) { 273 ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService); 274 } 275 276 WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat(); 277 Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory, initBuffer); 278 279 return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory); 280 } 281 282 protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception { 283 try { 284 //Wait for protocolDetectionTimeOut if defined 285 if (protocolDetectionTimeOut > 0) { 286 future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS); 287 } else { 288 future.get(); 289 } 290 } catch (TimeoutException e) { 291 throw new InactivityIOException("Client timed out before wire format could be detected. " + 292 " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent."); 293 } 294 } 295 296 /** 297 * @param socket 298 * @param format 299 * @param detectedTransportFactory 300 * @return 301 */ 302 protected TcpTransport createTransport(Socket socket, WireFormat format, 303 TcpTransportFactory detectedTransportFactory, InitBuffer initBuffer) throws IOException { 304 return new TcpTransport(format, socket, initBuffer); 305 } 306 307 public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) { 308 this.wireFormatOptions = wireFormatOptions; 309 } 310 311 public void setEnabledProtocols(Set<String> enabledProtocols) { 312 this.enabledProtocols = enabledProtocols; 313 } 314 315 public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) { 316 this.autoTransportOptions = autoTransportOptions; 317 if (autoTransportOptions.get("protocols") != null) { 318 this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols")); 319 } 320 } 321 @Override 322 protected void doStop(ServiceStopper stopper) throws Exception { 323 if (service != null) { 324 service.shutdown(); 325 } 326 super.doStop(stopper); 327 } 328 329 protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException { 330 TcpTransportFactory detectedTransportFactory = transportFactory; 331 WireFormatFactory detectedWireFormatFactory = wireFormatFactory; 332 333 boolean found = false; 334 for (String scheme : protocolVerifiers.keySet()) { 335 if (protocolVerifiers.get(scheme).isProtocol(buffer)) { 336 LOG.debug("Detected protocol " + scheme); 337 detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions); 338 339 if (scheme.equals("default")) { 340 scheme = ""; 341 } 342 343 detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions); 344 found = true; 345 break; 346 } 347 } 348 349 if (!found) { 350 throw new IllegalStateException("Could not detect the wire format"); 351 } 352 353 return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory); 354 355 } 356 357 protected class ProtocolInfo { 358 public final TcpTransportFactory detectedTransportFactory; 359 public final WireFormatFactory detectedWireFormatFactory; 360 361 public ProtocolInfo(TcpTransportFactory detectedTransportFactory, 362 WireFormatFactory detectedWireFormatFactory) { 363 super(); 364 this.detectedTransportFactory = detectedTransportFactory; 365 this.detectedWireFormatFactory = detectedWireFormatFactory; 366 } 367 } 368 369}