001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataInputStream; 021import java.io.DataOutputStream; 022import java.io.EOFException; 023import java.io.IOException; 024import java.net.Socket; 025import java.net.SocketTimeoutException; 026import java.net.URI; 027import java.net.UnknownHostException; 028import java.nio.ByteBuffer; 029import java.nio.channels.SelectionKey; 030import java.nio.channels.Selector; 031import java.security.cert.X509Certificate; 032import java.util.concurrent.CountDownLatch; 033 034import javax.net.SocketFactory; 035import javax.net.ssl.SSLContext; 036import javax.net.ssl.SSLEngine; 037import javax.net.ssl.SSLEngineResult; 038import javax.net.ssl.SSLEngineResult.HandshakeStatus; 039import javax.net.ssl.SSLParameters; 040import javax.net.ssl.SSLPeerUnverifiedException; 041import javax.net.ssl.SSLSession; 042 043import org.apache.activemq.command.ConnectionInfo; 044import org.apache.activemq.openwire.OpenWireFormat; 045import org.apache.activemq.thread.TaskRunnerFactory; 046import org.apache.activemq.util.IOExceptionSupport; 047import org.apache.activemq.util.ServiceStopper; 048import org.apache.activemq.wireformat.WireFormat; 049import org.slf4j.Logger; 050import org.slf4j.LoggerFactory; 051 052public class NIOSSLTransport extends NIOTransport { 053 054 private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); 055 056 protected boolean needClientAuth; 057 protected boolean wantClientAuth; 058 protected String[] enabledCipherSuites; 059 protected String[] enabledProtocols; 060 protected boolean verifyHostName = false; 061 062 protected SSLContext sslContext; 063 protected SSLEngine sslEngine; 064 protected SSLSession sslSession; 065 066 protected volatile boolean handshakeInProgress = false; 067 protected SSLEngineResult.Status status = null; 068 protected SSLEngineResult.HandshakeStatus handshakeStatus = null; 069 protected TaskRunnerFactory taskRunnerFactory; 070 071 public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 072 super(wireFormat, socketFactory, remoteLocation, localLocation); 073 } 074 075 public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 076 super(wireFormat, socket); 077 } 078 079 public void setSslContext(SSLContext sslContext) { 080 this.sslContext = sslContext; 081 } 082 083 @Override 084 protected void initializeStreams() throws IOException { 085 NIOOutputStream outputStream = null; 086 try { 087 channel = socket.getChannel(); 088 channel.configureBlocking(false); 089 090 if (sslContext == null) { 091 sslContext = SSLContext.getDefault(); 092 } 093 094 String remoteHost = null; 095 int remotePort = -1; 096 097 try { 098 URI remoteAddress = new URI(this.getRemoteAddress()); 099 remoteHost = remoteAddress.getHost(); 100 remotePort = remoteAddress.getPort(); 101 } catch (Exception e) { 102 } 103 104 // initialize engine, the initial sslSession we get will need to be 105 // updated once the ssl handshake process is completed. 106 if (remoteHost != null && remotePort != -1) { 107 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 108 } else { 109 sslEngine = sslContext.createSSLEngine(); 110 } 111 112 if (verifyHostName) { 113 SSLParameters sslParams = new SSLParameters(); 114 sslParams.setEndpointIdentificationAlgorithm("HTTPS"); 115 sslEngine.setSSLParameters(sslParams); 116 } 117 118 sslEngine.setUseClientMode(false); 119 if (enabledCipherSuites != null) { 120 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 121 } 122 123 if (enabledProtocols != null) { 124 sslEngine.setEnabledProtocols(enabledProtocols); 125 } 126 127 if (wantClientAuth) { 128 sslEngine.setWantClientAuth(wantClientAuth); 129 } 130 131 if (needClientAuth) { 132 sslEngine.setNeedClientAuth(needClientAuth); 133 } 134 135 sslSession = sslEngine.getSession(); 136 137 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 138 inputBuffer.clear(); 139 140 outputStream = new NIOOutputStream(channel); 141 outputStream.setEngine(sslEngine); 142 this.dataOut = new DataOutputStream(outputStream); 143 this.buffOut = outputStream; 144 sslEngine.beginHandshake(); 145 handshakeStatus = sslEngine.getHandshakeStatus(); 146 doHandshake(); 147 148 selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { 149 @Override 150 public void onSelect(SelectorSelection selection) { 151 try { 152 initialized.await(); 153 } catch (InterruptedException error) { 154 onException(IOExceptionSupport.create(error)); 155 } 156 serviceRead(); 157 } 158 159 @Override 160 public void onError(SelectorSelection selection, Throwable error) { 161 if (error instanceof IOException) { 162 onException((IOException) error); 163 } else { 164 onException(IOExceptionSupport.create(error)); 165 } 166 }}); 167 168 doInit(); 169 170 } catch (Exception e) { 171 try { 172 if(outputStream != null) { 173 outputStream.close(); 174 } 175 super.closeStreams(); 176 } catch (Exception ex) {} 177 throw new IOException(e); 178 } 179 } 180 181 182 final protected CountDownLatch initialized = new CountDownLatch(1); 183 184 protected void doInit() throws Exception { 185 taskRunnerFactory.execute(new Runnable() { 186 187 @Override 188 public void run() { 189 //Need to start in new thread to let startup finish first 190 //We can trigger a read because we know the channel is ready since the SSL handshake 191 //already happened 192 serviceRead(); 193 initialized.countDown(); 194 } 195 }); 196 } 197 198 protected void finishHandshake() throws Exception { 199 if (handshakeInProgress) { 200 handshakeInProgress = false; 201 nextFrameSize = -1; 202 203 // Once handshake completes we need to ask for the now real sslSession 204 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 205 // cipher suite. 206 sslSession = sslEngine.getSession(); 207 } 208 } 209 210 @Override 211 protected void serviceRead() { 212 try { 213 if (handshakeInProgress) { 214 doHandshake(); 215 } 216 217 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 218 plain.position(plain.limit()); 219 220 while (true) { 221 if (!plain.hasRemaining()) { 222 223 int readCount = secureRead(plain); 224 225 if (readCount == 0) { 226 break; 227 } 228 229 // channel is closed, cleanup 230 if (readCount == -1) { 231 onException(new EOFException()); 232 selection.close(); 233 break; 234 } 235 236 receiveCounter += readCount; 237 } 238 239 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 240 processCommand(plain); 241 } 242 } 243 } catch (IOException e) { 244 onException(e); 245 } catch (Throwable e) { 246 onException(IOExceptionSupport.create(e)); 247 } 248 } 249 250 protected void processCommand(ByteBuffer plain) throws Exception { 251 252 // Are we waiting for the next Command or are we building on the current one 253 if (nextFrameSize == -1) { 254 255 // We can get small packets that don't give us enough for the frame size 256 // so allocate enough for the initial size value and 257 if (plain.remaining() < Integer.SIZE) { 258 if (currentBuffer == null) { 259 currentBuffer = ByteBuffer.allocate(4); 260 } 261 262 // Go until we fill the integer sized current buffer. 263 while (currentBuffer.hasRemaining() && plain.hasRemaining()) { 264 currentBuffer.put(plain.get()); 265 } 266 267 // Didn't we get enough yet to figure out next frame size. 268 if (currentBuffer.hasRemaining()) { 269 return; 270 } else { 271 currentBuffer.flip(); 272 nextFrameSize = currentBuffer.getInt(); 273 } 274 275 } else { 276 277 // Either we are completing a previous read of the next frame size or its 278 // fully contained in plain already. 279 if (currentBuffer != null) { 280 281 // Finish the frame size integer read and get from the current buffer. 282 while (currentBuffer.hasRemaining()) { 283 currentBuffer.put(plain.get()); 284 } 285 286 currentBuffer.flip(); 287 nextFrameSize = currentBuffer.getInt(); 288 289 } else { 290 nextFrameSize = plain.getInt(); 291 } 292 } 293 294 if (wireFormat instanceof OpenWireFormat) { 295 long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); 296 if (nextFrameSize > maxFrameSize) { 297 throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + 298 " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); 299 } 300 } 301 302 // now we got the data, lets reallocate and store the size for the marshaler. 303 // if there's more data in plain, then the next call will start processing it. 304 currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); 305 currentBuffer.putInt(nextFrameSize); 306 307 } else { 308 309 // If its all in one read then we can just take it all, otherwise take only 310 // the current frame size and the next iteration starts a new command. 311 if (currentBuffer.remaining() >= plain.remaining()) { 312 currentBuffer.put(plain); 313 } else { 314 byte[] fill = new byte[currentBuffer.remaining()]; 315 plain.get(fill); 316 currentBuffer.put(fill); 317 } 318 319 // Either we have enough data for a new command or we have to wait for some more. 320 if (currentBuffer.hasRemaining()) { 321 return; 322 } else { 323 currentBuffer.flip(); 324 Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); 325 doConsume(command); 326 nextFrameSize = -1; 327 currentBuffer = null; 328 } 329 } 330 } 331 332 protected int secureRead(ByteBuffer plain) throws Exception { 333 334 if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 335 int bytesRead = channel.read(inputBuffer); 336 337 if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { 338 return 0; 339 } 340 341 if (bytesRead == -1) { 342 sslEngine.closeInbound(); 343 if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 344 return -1; 345 } 346 } 347 } 348 349 plain.clear(); 350 351 inputBuffer.flip(); 352 SSLEngineResult res; 353 do { 354 res = sslEngine.unwrap(inputBuffer, plain); 355 } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP 356 && res.bytesProduced() == 0); 357 358 if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { 359 finishHandshake(); 360 } 361 362 status = res.getStatus(); 363 handshakeStatus = res.getHandshakeStatus(); 364 365 // TODO deal with BUFFER_OVERFLOW 366 367 if (status == SSLEngineResult.Status.CLOSED) { 368 sslEngine.closeInbound(); 369 return -1; 370 } 371 372 inputBuffer.compact(); 373 plain.flip(); 374 375 return plain.remaining(); 376 } 377 378 protected void doHandshake() throws Exception { 379 handshakeInProgress = true; 380 Selector selector = null; 381 SelectionKey key = null; 382 boolean readable = true; 383 try { 384 while (true) { 385 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); 386 switch (handshakeStatus) { 387 case NEED_UNWRAP: 388 if (readable) { 389 secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); 390 } 391 if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { 392 long now = System.currentTimeMillis(); 393 if (selector == null) { 394 selector = Selector.open(); 395 key = channel.register(selector, SelectionKey.OP_READ); 396 } else { 397 key.interestOps(SelectionKey.OP_READ); 398 } 399 int keyCount = selector.select(this.getSoTimeout()); 400 if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { 401 throw new SocketTimeoutException("Timeout during handshake"); 402 } 403 readable = key.isReadable(); 404 } 405 break; 406 case NEED_TASK: 407 Runnable task; 408 while ((task = sslEngine.getDelegatedTask()) != null) { 409 task.run(); 410 } 411 break; 412 case NEED_WRAP: 413 ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); 414 break; 415 case FINISHED: 416 case NOT_HANDSHAKING: 417 finishHandshake(); 418 return; 419 } 420 } 421 } finally { 422 if (key!=null) try {key.cancel();} catch (Exception ignore) {} 423 if (selector!=null) try {selector.close();} catch (Exception ignore) {} 424 } 425 } 426 427 @Override 428 protected void doStart() throws Exception { 429 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 430 // no need to init as we can delay that until demand (eg in doHandshake) 431 super.doStart(); 432 } 433 434 @Override 435 protected void doStop(ServiceStopper stopper) throws Exception { 436 initialized.countDown(); 437 438 if (taskRunnerFactory != null) { 439 taskRunnerFactory.shutdownNow(); 440 taskRunnerFactory = null; 441 } 442 if (channel != null) { 443 channel.close(); 444 channel = null; 445 } 446 super.doStop(stopper); 447 } 448 449 /** 450 * Overriding in order to add the client's certificates to ConnectionInfo Commands. 451 * 452 * @param command 453 * The Command coming in. 454 */ 455 @Override 456 public void doConsume(Object command) { 457 if (command instanceof ConnectionInfo) { 458 ConnectionInfo connectionInfo = (ConnectionInfo) command; 459 connectionInfo.setTransportContext(getPeerCertificates()); 460 } 461 super.doConsume(command); 462 } 463 464 /** 465 * @return peer certificate chain associated with the ssl socket 466 */ 467 public X509Certificate[] getPeerCertificates() { 468 469 X509Certificate[] clientCertChain = null; 470 try { 471 if (sslEngine.getSession() != null) { 472 clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); 473 } 474 } catch (SSLPeerUnverifiedException e) { 475 if (LOG.isTraceEnabled()) { 476 LOG.trace("Failed to get peer certificates.", e); 477 } 478 } 479 480 return clientCertChain; 481 } 482 483 public boolean isNeedClientAuth() { 484 return needClientAuth; 485 } 486 487 public void setNeedClientAuth(boolean needClientAuth) { 488 this.needClientAuth = needClientAuth; 489 } 490 491 public boolean isWantClientAuth() { 492 return wantClientAuth; 493 } 494 495 public void setWantClientAuth(boolean wantClientAuth) { 496 this.wantClientAuth = wantClientAuth; 497 } 498 499 public String[] getEnabledCipherSuites() { 500 return enabledCipherSuites; 501 } 502 503 public void setEnabledCipherSuites(String[] enabledCipherSuites) { 504 this.enabledCipherSuites = enabledCipherSuites; 505 } 506 507 public String[] getEnabledProtocols() { 508 return enabledProtocols; 509 } 510 511 public void setEnabledProtocols(String[] enabledProtocols) { 512 this.enabledProtocols = enabledProtocols; 513 } 514 515 public boolean isVerifyHostName() { 516 return verifyHostName; 517 } 518 519 public void setVerifyHostName(boolean verifyHostName) { 520 this.verifyHostName = verifyHostName; 521 } 522}