001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032import java.util.concurrent.CountDownLatch; 033 034import javax.net.SocketFactory; 035import javax.net.ssl.SSLContext; 036import javax.net.ssl.SSLEngine; 037import javax.net.ssl.SSLEngineResult; 038import javax.net.ssl.SSLEngineResult.HandshakeStatus; 039import javax.net.ssl.SSLParameters; 040import javax.net.ssl.SSLPeerUnverifiedException; 041import javax.net.ssl.SSLSession; 042 043import org.apache.activemq.command.ConnectionInfo; 044import org.apache.activemq.openwire.OpenWireFormat; 045import org.apache.activemq.thread.TaskRunnerFactory; 046import org.apache.activemq.util.IOExceptionSupport; 047import org.apache.activemq.util.ServiceStopper; 048import org.apache.activemq.wireformat.WireFormat; 049import org.slf4j.Logger; 050import org.slf4j.LoggerFactory; 051 052public class NIOSSLTransport extends NIOTransport { 053 054 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 055 056 protected boolean needClientAuth; 057 protected boolean wantClientAuth; 058 protected String[] enabledCipherSuites; 059 protected String[] enabledProtocols; 060 protected boolean verifyHostName = false; 061 062 protected SSLContext sslContext; 063 protected SSLEngine sslEngine; 064 protected SSLSession sslSession; 065 066 protected volatile boolean handshakeInProgress = false; 067 protected SSLEngineResult.Status status = null; 068 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 069 protected TaskRunnerFactory taskRunnerFactory; 070 071 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 072 super(wireFormat, socketFactory, remoteLocation, localLocation); 073 } 074 075 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 076 super(wireFormat, socket); 077 } 078 079 public void setSslContext(SSLContext sslContext) { 080 this.sslContext = sslContext; 081 } 082 083 @Override 084 protected void initializeStreams() throws IOException { 085 NIOOutputStream outputStream = null; 086 try { 087 channel = socket.getChannel(); 088 channel.configureBlocking(false); 089 090 if (sslContext == null) { 091 sslContext = SSLContext.getDefault(); 092 } 093 094 String remoteHost = null; 095 int remotePort = -1; 096 097 try { 098 URI remoteAddress = new URI(this.getRemoteAddress()); 099 remoteHost = remoteAddress.getHost(); 100 remotePort = remoteAddress.getPort(); 101 } catch (Exception e) { 102 } 103 104 // initialize engine, the initial sslSession we get will need to be 105 // updated once the ssl handshake process is completed. 106 if (remoteHost != null && remotePort != -1) { 107 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 108 } else { 109 sslEngine = sslContext.createSSLEngine(); 110 } 111 112 if (verifyHostName) { 113 SSLParameters sslParams = new SSLParameters(); 114 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 115 sslEngine.setSSLParameters(sslParams); 116 } 117 118 sslEngine.setUseClientMode(false); 119 if (enabledCipherSuites != null) { 120 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 121 } 122 123 if (enabledProtocols != null) { 124 sslEngine.setEnabledProtocols(enabledProtocols); 125 } 126 127 if (wantClientAuth) { 128 sslEngine.setWantClientAuth(wantClientAuth); 129 } 130 131 if (needClientAuth) { 132 sslEngine.setNeedClientAuth(needClientAuth); 133 } 134 135 sslSession = sslEngine.getSession(); 136 137 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 138 inputBuffer.clear(); 139 140 outputStream = new NIOOutputStream(channel); 141 outputStream.setEngine(sslEngine); 142 this.dataOut = new DataOutputStream(outputStream); 143 this.buffOut = outputStream; 144 sslEngine.beginHandshake(); 145 handshakeStatus = sslEngine.getHandshakeStatus(); 146 doHandshake(); 147 148 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 149 @Override 150 public void onSelect(SelectorSelection selection) { 151 try { 152 initialized.await(); 153 } catch (InterruptedException error) { 154 onException(IOExceptionSupport.create(error)); 155 } 156 serviceRead(); 157 } 158 159 @Override 160 public void onError(SelectorSelection selection, Throwable error) { 161 if (error instanceof IOException) { 162 onException((IOException) error); 163 } else { 164 onException(IOExceptionSupport.create(error)); 165 } 166 }}); 167 168 doInit(); 169 170 } catch (Exception e) { 171 try { 172 if(outputStream != null) { 173 outputStream.close(); 174 } 175 super.closeStreams(); 176 } catch (Exception ex) {} 177 throw new IOException(e); 178 } 179 } 180 181 182 final protected CountDownLatch initialized = new CountDownLatch(1); 183 184 protected void doInit() throws Exception { 185 taskRunnerFactory.execute(new Runnable() { 186 187 @Override 188 public void run() { 189 //Need to start in new thread to let startup finish first 190 //We can trigger a read because we know the channel is ready since the SSL handshake 191 //already happened 192 serviceRead(); 193 initialized.countDown(); 194 } 195 }); 196 } 197 198 protected void finishHandshake() throws Exception { 199 if (handshakeInProgress) { 200 handshakeInProgress = false; 201 nextFrameSize = -1; 202 203 // Once handshake completes we need to ask for the now real sslSession 204 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 205 // cipher suite. 206 sslSession = sslEngine.getSession(); 207 208 // listen for events telling us when the socket is readable. 209 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 210 @Override 211 public void onSelect(SelectorSelection selection) { 212 serviceRead(); 213 } 214 215 @Override 216 public void onError(SelectorSelection selection, Throwable error) { 217 if (error instanceof IOException) { 218 onException((IOException) error); 219 } else { 220 onException(IOExceptionSupport.create(error)); 221 } 222 } 223 }); 224 } 225 } 226 227 @Override 228 protected void serviceRead() { 229 try { 230 if (handshakeInProgress) { 231 doHandshake(); 232 } 233 234 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 235 plain.position(plain.limit()); 236 237 while (true) { 238 if (!plain.hasRemaining()) { 239 240 int readCount = secureRead(plain); 241 242 if (readCount == 0) { 243 break; 244 } 245 246 // channel is closed, cleanup 247 if (readCount == -1) { 248 onException(new EOFException()); 249 selection.close(); 250 break; 251 } 252 253 receiveCounter += readCount; 254 } 255 256 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 257 processCommand(plain); 258 } 259 } 260 } catch (IOException e) { 261 onException(e); 262 } catch (Throwable e) { 263 onException(IOExceptionSupport.create(e)); 264 } 265 } 266 267 protected void processCommand(ByteBuffer plain) throws Exception { 268 269 // Are we waiting for the next Command or are we building on the current one 270 if (nextFrameSize == -1) { 271 272 // We can get small packets that don't give us enough for the frame size 273 // so allocate enough for the initial size value and 274 if (plain.remaining() < Integer.SIZE) { 275 if (currentBuffer == null) { 276 currentBuffer = ByteBuffer.allocate(4); 277 } 278 279 // Go until we fill the integer sized current buffer. 280 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 281 currentBuffer.put(plain.get()); 282 } 283 284 // Didn't we get enough yet to figure out next frame size. 285 if (currentBuffer.hasRemaining()) { 286 return; 287 } else { 288 currentBuffer.flip(); 289 nextFrameSize = currentBuffer.getInt(); 290 } 291 292 } else { 293 294 // Either we are completing a previous read of the next frame size or its 295 // fully contained in plain already. 296 if (currentBuffer != null) { 297 298 // Finish the frame size integer read and get from the current buffer. 299 while (currentBuffer.hasRemaining()) { 300 currentBuffer.put(plain.get()); 301 } 302 303 currentBuffer.flip(); 304 nextFrameSize = currentBuffer.getInt(); 305 306 } else { 307 nextFrameSize = plain.getInt(); 308 } 309 } 310 311 if (wireFormat instanceof OpenWireFormat) { 312 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 313 if (nextFrameSize > maxFrameSize) { 314 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 315 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 316 } 317 } 318 319 // now we got the data, lets reallocate and store the size for the marshaler. 320 // if there's more data in plain, then the next call will start processing it. 321 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 322 currentBuffer.putInt(nextFrameSize); 323 324 } else { 325 326 // If its all in one read then we can just take it all, otherwise take only 327 // the current frame size and the next iteration starts a new command. 328 if (currentBuffer.remaining() >= plain.remaining()) { 329 currentBuffer.put(plain); 330 } else { 331 byte[] fill = new byte[currentBuffer.remaining()]; 332 plain.get(fill); 333 currentBuffer.put(fill); 334 } 335 336 // Either we have enough data for a new command or we have to wait for some more. 337 if (currentBuffer.hasRemaining()) { 338 return; 339 } else { 340 currentBuffer.flip(); 341 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 342 doConsume(command); 343 nextFrameSize = -1; 344 currentBuffer = null; 345 } 346 } 347 } 348 349 protected int secureRead(ByteBuffer plain) throws Exception { 350 351 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 352 int bytesRead = channel.read(inputBuffer); 353 354 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 355 return 0; 356 } 357 358 if (bytesRead == -1) { 359 sslEngine.closeInbound(); 360 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 361 return -1; 362 } 363 } 364 } 365 366 plain.clear(); 367 368 inputBuffer.flip(); 369 SSLEngineResult res; 370 do { 371 res = sslEngine.unwrap(inputBuffer, plain); 372 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 373 && res.bytesProduced() == 0); 374 375 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 376 finishHandshake(); 377 } 378 379 status = res.getStatus(); 380 handshakeStatus = res.getHandshakeStatus(); 381 382 // TODO deal with BUFFER_OVERFLOW 383 384 if (status == SSLEngineResult.Status.CLOSED) { 385 sslEngine.closeInbound(); 386 return -1; 387 } 388 389 inputBuffer.compact(); 390 plain.flip(); 391 392 return plain.remaining(); 393 } 394 395 protected void doHandshake() throws Exception { 396 handshakeInProgress = true; 397 Selector selector = null; 398 SelectionKey key = null; 399 boolean readable = true; 400 try { 401 while (true) { 402 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 403 switch (handshakeStatus) { 404 case NEED_UNWRAP: 405 if (readable) { 406 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 407 } 408 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 409 long now = System.currentTimeMillis(); 410 if (selector == null) { 411 selector = Selector.open(); 412 key = channel.register(selector, SelectionKey.OP_READ); 413 } else { 414 key.interestOps(SelectionKey.OP_READ); 415 } 416 int keyCount = selector.select(this.getSoTimeout()); 417 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 418 throw new SocketTimeoutException("Timeout during handshake"); 419 } 420 readable = key.isReadable(); 421 } 422 break; 423 case NEED_TASK: 424 Runnable task; 425 while ((task = sslEngine.getDelegatedTask()) != null) { 426 task.run(); 427 } 428 break; 429 case NEED_WRAP: 430 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 431 break; 432 case FINISHED: 433 case NOT_HANDSHAKING: 434 finishHandshake(); 435 return; 436 } 437 } 438 } finally { 439 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 440 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 441 } 442 } 443 444 @Override 445 protected void doStart() throws Exception { 446 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 447 // no need to init as we can delay that until demand (eg in doHandshake) 448 super.doStart(); 449 } 450 451 @Override 452 protected void doStop(ServiceStopper stopper) throws Exception { 453 initialized.countDown(); 454 455 if (taskRunnerFactory != null) { 456 taskRunnerFactory.shutdownNow(); 457 taskRunnerFactory = null; 458 } 459 if (channel != null) { 460 channel.close(); 461 channel = null; 462 } 463 super.doStop(stopper); 464 } 465 466 /** 467 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 468 * 469 * @param command 470 * The Command coming in. 471 */ 472 @Override 473 public void doConsume(Object command) { 474 if (command instanceof ConnectionInfo) { 475 ConnectionInfo connectionInfo = (ConnectionInfo) command; 476 connectionInfo.setTransportContext(getPeerCertificates()); 477 } 478 super.doConsume(command); 479 } 480 481 /** 482 * @return peer certificate chain associated with the ssl socket 483 */ 484 public X509Certificate[] getPeerCertificates() { 485 486 X509Certificate[] clientCertChain = null; 487 try { 488 if (sslEngine.getSession() != null) { 489 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 490 } 491 } catch (SSLPeerUnverifiedException e) { 492 if (LOG.isTraceEnabled()) { 493 LOG.trace("Failed to get peer certificates.", e); 494 } 495 } 496 497 return clientCertChain; 498 } 499 500 public boolean isNeedClientAuth() { 501 return needClientAuth; 502 } 503 504 public void setNeedClientAuth(boolean needClientAuth) { 505 this.needClientAuth = needClientAuth; 506 } 507 508 public boolean isWantClientAuth() { 509 return wantClientAuth; 510 } 511 512 public void setWantClientAuth(boolean wantClientAuth) { 513 this.wantClientAuth = wantClientAuth; 514 } 515 516 public String[] getEnabledCipherSuites() { 517 return enabledCipherSuites; 518 } 519 520 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 521 this.enabledCipherSuites = enabledCipherSuites; 522 } 523 524 public String[] getEnabledProtocols() { 525 return enabledProtocols; 526 } 527 528 public void setEnabledProtocols(String[] enabledProtocols) { 529 this.enabledProtocols = enabledProtocols; 530 } 531 532 public boolean isVerifyHostName() { 533 return verifyHostName; 534 } 535 536 public void setVerifyHostName(boolean verifyHostName) { 537 this.verifyHostName = verifyHostName; 538 } 539}