001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032 033import javax.net.SocketFactory; 034import javax.net.ssl.SSLContext; 035import javax.net.ssl.SSLEngine; 036import javax.net.ssl.SSLEngineResult; 037import javax.net.ssl.SSLEngineResult.HandshakeStatus; 038import javax.net.ssl.SSLParameters; 039import javax.net.ssl.SSLPeerUnverifiedException; 040import javax.net.ssl.SSLSession; 041 042import org.apache.activemq.command.ConnectionInfo; 043import org.apache.activemq.openwire.OpenWireFormat; 044import org.apache.activemq.thread.TaskRunnerFactory; 045import org.apache.activemq.util.IOExceptionSupport; 046import org.apache.activemq.util.ServiceStopper; 047import org.apache.activemq.wireformat.WireFormat; 048import org.slf4j.Logger; 049import org.slf4j.LoggerFactory; 050 051public class NIOSSLTransport extends NIOTransport { 052 053 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 054 055 protected boolean needClientAuth; 056 protected boolean wantClientAuth; 057 protected String[] enabledCipherSuites; 058 protected String[] enabledProtocols; 059 protected boolean verifyHostName = false; 060 061 protected SSLContext sslContext; 062 protected SSLEngine sslEngine; 063 protected SSLSession sslSession; 064 065 protected volatile boolean handshakeInProgress = false; 066 protected SSLEngineResult.Status status = null; 067 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 068 protected TaskRunnerFactory taskRunnerFactory; 069 070 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 071 super(wireFormat, socketFactory, remoteLocation, localLocation); 072 } 073 074 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 075 super(wireFormat, socket); 076 } 077 078 public void setSslContext(SSLContext sslContext) { 079 this.sslContext = sslContext; 080 } 081 082 @Override 083 protected void initializeStreams() throws IOException { 084 NIOOutputStream outputStream = null; 085 try { 086 channel = socket.getChannel(); 087 channel.configureBlocking(false); 088 089 if (sslContext == null) { 090 sslContext = SSLContext.getDefault(); 091 } 092 093 String remoteHost = null; 094 int remotePort = -1; 095 096 try { 097 URI remoteAddress = new URI(this.getRemoteAddress()); 098 remoteHost = remoteAddress.getHost(); 099 remotePort = remoteAddress.getPort(); 100 } catch (Exception e) { 101 } 102 103 // initialize engine, the initial sslSession we get will need to be 104 // updated once the ssl handshake process is completed. 105 if (remoteHost != null && remotePort != -1) { 106 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 107 } else { 108 sslEngine = sslContext.createSSLEngine(); 109 } 110 111 if (verifyHostName) { 112 SSLParameters sslParams = new SSLParameters(); 113 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 114 sslEngine.setSSLParameters(sslParams); 115 } 116 117 sslEngine.setUseClientMode(false); 118 if (enabledCipherSuites != null) { 119 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 120 } 121 122 if (enabledProtocols != null) { 123 sslEngine.setEnabledProtocols(enabledProtocols); 124 } 125 126 if (wantClientAuth) { 127 sslEngine.setWantClientAuth(wantClientAuth); 128 } 129 130 if (needClientAuth) { 131 sslEngine.setNeedClientAuth(needClientAuth); 132 } 133 134 sslSession = sslEngine.getSession(); 135 136 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 137 inputBuffer.clear(); 138 139 outputStream = new NIOOutputStream(channel); 140 outputStream.setEngine(sslEngine); 141 this.dataOut = new DataOutputStream(outputStream); 142 this.buffOut = outputStream; 143 sslEngine.beginHandshake(); 144 handshakeStatus = sslEngine.getHandshakeStatus(); 145 doHandshake(); 146 147 if (inputBuffer.position() > 0 && inputBuffer.hasRemaining() && wireFormat instanceof OpenWireFormat) { 148 // can only deal with this pending openwire command after we have sent negotiation start command 149 // with default wireformat options from this thread. 150 // org.apache.activemq.transport.WireFormatNegotiator.sendWireFormat() 151 // read will block till wireformat is sent. 152 taskRunnerFactory.execute(new Runnable() { 153 @Override 154 public void run() { 155 serviceRead(); 156 } 157 }); 158 } 159 } catch (Exception e) { 160 try { 161 if(outputStream != null) { 162 outputStream.close(); 163 } 164 super.closeStreams(); 165 } catch (Exception ex) {} 166 throw new IOException(e); 167 } 168 } 169 170 protected void finishHandshake() throws Exception { 171 if (handshakeInProgress) { 172 handshakeInProgress = false; 173 nextFrameSize = -1; 174 175 // Once handshake completes we need to ask for the now real sslSession 176 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 177 // cipher suite. 178 sslSession = sslEngine.getSession(); 179 180 // listen for events telling us when the socket is readable. 181 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 182 @Override 183 public void onSelect(SelectorSelection selection) { 184 serviceRead(); 185 } 186 187 @Override 188 public void onError(SelectorSelection selection, Throwable error) { 189 if (error instanceof IOException) { 190 onException((IOException) error); 191 } else { 192 onException(IOExceptionSupport.create(error)); 193 } 194 } 195 }); 196 } 197 } 198 199 @Override 200 protected void serviceRead() { 201 try { 202 if (handshakeInProgress) { 203 doHandshake(); 204 } 205 206 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 207 plain.position(plain.limit()); 208 209 while (true) { 210 if (!plain.hasRemaining()) { 211 212 int readCount = secureRead(plain); 213 214 if (readCount == 0) { 215 break; 216 } 217 218 // channel is closed, cleanup 219 if (readCount == -1) { 220 onException(new EOFException()); 221 selection.close(); 222 break; 223 } 224 225 receiveCounter += readCount; 226 } 227 228 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 229 processCommand(plain); 230 } 231 } 232 } catch (IOException e) { 233 onException(e); 234 } catch (Throwable e) { 235 onException(IOExceptionSupport.create(e)); 236 } 237 } 238 239 protected void processCommand(ByteBuffer plain) throws Exception { 240 241 // Are we waiting for the next Command or are we building on the current one 242 if (nextFrameSize == -1) { 243 244 // We can get small packets that don't give us enough for the frame size 245 // so allocate enough for the initial size value and 246 if (plain.remaining() < Integer.SIZE) { 247 if (currentBuffer == null) { 248 currentBuffer = ByteBuffer.allocate(4); 249 } 250 251 // Go until we fill the integer sized current buffer. 252 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 253 currentBuffer.put(plain.get()); 254 } 255 256 // Didn't we get enough yet to figure out next frame size. 257 if (currentBuffer.hasRemaining()) { 258 return; 259 } else { 260 currentBuffer.flip(); 261 nextFrameSize = currentBuffer.getInt(); 262 } 263 264 } else { 265 266 // Either we are completing a previous read of the next frame size or its 267 // fully contained in plain already. 268 if (currentBuffer != null) { 269 270 // Finish the frame size integer read and get from the current buffer. 271 while (currentBuffer.hasRemaining()) { 272 currentBuffer.put(plain.get()); 273 } 274 275 currentBuffer.flip(); 276 nextFrameSize = currentBuffer.getInt(); 277 278 } else { 279 nextFrameSize = plain.getInt(); 280 } 281 } 282 283 if (wireFormat instanceof OpenWireFormat) { 284 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 285 if (nextFrameSize > maxFrameSize) { 286 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 287 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 288 } 289 } 290 291 // now we got the data, lets reallocate and store the size for the marshaler. 292 // if there's more data in plain, then the next call will start processing it. 293 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 294 currentBuffer.putInt(nextFrameSize); 295 296 } else { 297 298 // If its all in one read then we can just take it all, otherwise take only 299 // the current frame size and the next iteration starts a new command. 300 if (currentBuffer.remaining() >= plain.remaining()) { 301 currentBuffer.put(plain); 302 } else { 303 byte[] fill = new byte[currentBuffer.remaining()]; 304 plain.get(fill); 305 currentBuffer.put(fill); 306 } 307 308 // Either we have enough data for a new command or we have to wait for some more. 309 if (currentBuffer.hasRemaining()) { 310 return; 311 } else { 312 currentBuffer.flip(); 313 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 314 doConsume(command); 315 nextFrameSize = -1; 316 currentBuffer = null; 317 } 318 } 319 } 320 321 protected int secureRead(ByteBuffer plain) throws Exception { 322 323 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 324 int bytesRead = channel.read(inputBuffer); 325 326 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 327 return 0; 328 } 329 330 if (bytesRead == -1) { 331 sslEngine.closeInbound(); 332 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 333 return -1; 334 } 335 } 336 } 337 338 plain.clear(); 339 340 inputBuffer.flip(); 341 SSLEngineResult res; 342 do { 343 res = sslEngine.unwrap(inputBuffer, plain); 344 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 345 && res.bytesProduced() == 0); 346 347 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 348 finishHandshake(); 349 } 350 351 status = res.getStatus(); 352 handshakeStatus = res.getHandshakeStatus(); 353 354 // TODO deal with BUFFER_OVERFLOW 355 356 if (status == SSLEngineResult.Status.CLOSED) { 357 sslEngine.closeInbound(); 358 return -1; 359 } 360 361 inputBuffer.compact(); 362 plain.flip(); 363 364 return plain.remaining(); 365 } 366 367 protected void doHandshake() throws Exception { 368 handshakeInProgress = true; 369 Selector selector = null; 370 SelectionKey key = null; 371 boolean readable = true; 372 try { 373 while (true) { 374 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 375 switch (handshakeStatus) { 376 case NEED_UNWRAP: 377 if (readable) { 378 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 379 } 380 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 381 long now = System.currentTimeMillis(); 382 if (selector == null) { 383 selector = Selector.open(); 384 key = channel.register(selector, SelectionKey.OP_READ); 385 } else { 386 key.interestOps(SelectionKey.OP_READ); 387 } 388 int keyCount = selector.select(this.getSoTimeout()); 389 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 390 throw new SocketTimeoutException("Timeout during handshake"); 391 } 392 readable = key.isReadable(); 393 } 394 break; 395 case NEED_TASK: 396 Runnable task; 397 while ((task = sslEngine.getDelegatedTask()) != null) { 398 task.run(); 399 } 400 break; 401 case NEED_WRAP: 402 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 403 break; 404 case FINISHED: 405 case NOT_HANDSHAKING: 406 finishHandshake(); 407 return; 408 } 409 } 410 } finally { 411 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 412 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 413 } 414 } 415 416 @Override 417 protected void doStart() throws Exception { 418 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 419 // no need to init as we can delay that until demand (eg in doHandshake) 420 super.doStart(); 421 } 422 423 @Override 424 protected void doStop(ServiceStopper stopper) throws Exception { 425 if (taskRunnerFactory != null) { 426 taskRunnerFactory.shutdownNow(); 427 taskRunnerFactory = null; 428 } 429 if (channel != null) { 430 channel.close(); 431 channel = null; 432 } 433 super.doStop(stopper); 434 } 435 436 /** 437 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 438 * 439 * @param command 440 * The Command coming in. 441 */ 442 @Override 443 public void doConsume(Object command) { 444 if (command instanceof ConnectionInfo) { 445 ConnectionInfo connectionInfo = (ConnectionInfo) command; 446 connectionInfo.setTransportContext(getPeerCertificates()); 447 } 448 super.doConsume(command); 449 } 450 451 /** 452 * @return peer certificate chain associated with the ssl socket 453 */ 454 public X509Certificate[] getPeerCertificates() { 455 456 X509Certificate[] clientCertChain = null; 457 try { 458 if (sslEngine.getSession() != null) { 459 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 460 } 461 } catch (SSLPeerUnverifiedException e) { 462 if (LOG.isTraceEnabled()) { 463 LOG.trace("Failed to get peer certificates.", e); 464 } 465 } 466 467 return clientCertChain; 468 } 469 470 public boolean isNeedClientAuth() { 471 return needClientAuth; 472 } 473 474 public void setNeedClientAuth(boolean needClientAuth) { 475 this.needClientAuth = needClientAuth; 476 } 477 478 public boolean isWantClientAuth() { 479 return wantClientAuth; 480 } 481 482 public void setWantClientAuth(boolean wantClientAuth) { 483 this.wantClientAuth = wantClientAuth; 484 } 485 486 public String[] getEnabledCipherSuites() { 487 return enabledCipherSuites; 488 } 489 490 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 491 this.enabledCipherSuites = enabledCipherSuites; 492 } 493 494 public String[] getEnabledProtocols() { 495 return enabledProtocols; 496 } 497 498 public void setEnabledProtocols(String[] enabledProtocols) { 499 this.enabledProtocols = enabledProtocols; 500 } 501 502 public boolean isVerifyHostName() { 503 return verifyHostName; 504 } 505 506 public void setVerifyHostName(boolean verifyHostName) { 507 this.verifyHostName = verifyHostName; 508 } 509}