001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLParameters;
040import javax.net.ssl.SSLPeerUnverifiedException;
041import javax.net.ssl.SSLSession;
042
043import org.apache.activemq.command.ConnectionInfo;
044import org.apache.activemq.openwire.OpenWireFormat;
045import org.apache.activemq.thread.TaskRunnerFactory;
046import org.apache.activemq.util.IOExceptionSupport;
047import org.apache.activemq.util.ServiceStopper;
048import org.apache.activemq.wireformat.WireFormat;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052public class NIOSSLTransport extends NIOTransport {
053
054    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
055
056    protected boolean needClientAuth;
057    protected boolean wantClientAuth;
058    protected String[] enabledCipherSuites;
059    protected String[] enabledProtocols;
060    protected boolean verifyHostName = false;
061
062    protected SSLContext sslContext;
063    protected SSLEngine sslEngine;
064    protected SSLSession sslSession;
065
066    protected volatile boolean handshakeInProgress = false;
067    protected SSLEngineResult.Status status = null;
068    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
069    protected TaskRunnerFactory taskRunnerFactory;
070
071    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
072        super(wireFormat, socketFactory, remoteLocation, localLocation);
073    }
074
075    public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
076        super(wireFormat, socket);
077    }
078
079    public void setSslContext(SSLContext sslContext) {
080        this.sslContext = sslContext;
081    }
082
083    @Override
084    protected void initializeStreams() throws IOException {
085        NIOOutputStream outputStream = null;
086        try {
087            channel = socket.getChannel();
088            channel.configureBlocking(false);
089
090            if (sslContext == null) {
091                sslContext = SSLContext.getDefault();
092            }
093
094            String remoteHost = null;
095            int remotePort = -1;
096
097            try {
098                URI remoteAddress = new URI(this.getRemoteAddress());
099                remoteHost = remoteAddress.getHost();
100                remotePort = remoteAddress.getPort();
101            } catch (Exception e) {
102            }
103
104            // initialize engine, the initial sslSession we get will need to be
105            // updated once the ssl handshake process is completed.
106            if (remoteHost != null && remotePort != -1) {
107                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
108            } else {
109                sslEngine = sslContext.createSSLEngine();
110            }
111
112            if (verifyHostName) {
113                SSLParameters sslParams = new SSLParameters();
114                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
115                sslEngine.setSSLParameters(sslParams);
116            }
117
118            sslEngine.setUseClientMode(false);
119            if (enabledCipherSuites != null) {
120                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
121            }
122
123            if (enabledProtocols != null) {
124                sslEngine.setEnabledProtocols(enabledProtocols);
125            }
126
127            if (wantClientAuth) {
128                sslEngine.setWantClientAuth(wantClientAuth);
129            }
130
131            if (needClientAuth) {
132                sslEngine.setNeedClientAuth(needClientAuth);
133            }
134
135            sslSession = sslEngine.getSession();
136
137            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
138            inputBuffer.clear();
139
140            outputStream = new NIOOutputStream(channel);
141            outputStream.setEngine(sslEngine);
142            this.dataOut = new DataOutputStream(outputStream);
143            this.buffOut = outputStream;
144            sslEngine.beginHandshake();
145            handshakeStatus = sslEngine.getHandshakeStatus();
146            doHandshake();
147
148            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
149                @Override
150                public void onSelect(SelectorSelection selection) {
151                    try {
152                        initialized.await();
153                    } catch (InterruptedException error) {
154                        onException(IOExceptionSupport.create(error));
155                    }
156                    serviceRead();
157                }
158
159                @Override
160                public void onError(SelectorSelection selection, Throwable error) {
161                    if (error instanceof IOException) {
162                        onException((IOException) error);
163                    } else {
164                        onException(IOExceptionSupport.create(error));
165                    }
166                }});
167
168            doInit();
169
170        } catch (Exception e) {
171            try {
172                if(outputStream != null) {
173                    outputStream.close();
174                }
175                super.closeStreams();
176            } catch (Exception ex) {}
177            throw new IOException(e);
178        }
179    }
180
181
182    final protected CountDownLatch initialized = new CountDownLatch(1);
183
184    protected void doInit() throws Exception {
185        taskRunnerFactory.execute(new Runnable() {
186
187            @Override
188            public void run() {
189                //Need to start in new thread to let startup finish first
190                //We can trigger a read because we know the channel is ready since the SSL handshake
191                //already happened
192                serviceRead();
193                initialized.countDown();
194            }
195        });
196    }
197
198    protected void finishHandshake() throws Exception {
199        if (handshakeInProgress) {
200            handshakeInProgress = false;
201            nextFrameSize = -1;
202
203            // Once handshake completes we need to ask for the now real sslSession
204            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
205            // cipher suite.
206            sslSession = sslEngine.getSession();
207
208            // listen for events telling us when the socket is readable.
209            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
210                @Override
211                public void onSelect(SelectorSelection selection) {
212                    serviceRead();
213                }
214
215                @Override
216                public void onError(SelectorSelection selection, Throwable error) {
217                    if (error instanceof IOException) {
218                        onException((IOException) error);
219                    } else {
220                        onException(IOExceptionSupport.create(error));
221                    }
222                }
223            });
224        }
225    }
226
227    @Override
228    protected void serviceRead() {
229        try {
230            if (handshakeInProgress) {
231                doHandshake();
232            }
233
234            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
235            plain.position(plain.limit());
236
237            while (true) {
238                if (!plain.hasRemaining()) {
239
240                    int readCount = secureRead(plain);
241
242                    if (readCount == 0) {
243                        break;
244                    }
245
246                    // channel is closed, cleanup
247                    if (readCount == -1) {
248                        onException(new EOFException());
249                        selection.close();
250                        break;
251                    }
252
253                    receiveCounter += readCount;
254                }
255
256                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
257                    processCommand(plain);
258                }
259            }
260        } catch (IOException e) {
261            onException(e);
262        } catch (Throwable e) {
263            onException(IOExceptionSupport.create(e));
264        }
265    }
266
267    protected void processCommand(ByteBuffer plain) throws Exception {
268
269        // Are we waiting for the next Command or are we building on the current one
270        if (nextFrameSize == -1) {
271
272            // We can get small packets that don't give us enough for the frame size
273            // so allocate enough for the initial size value and
274            if (plain.remaining() < Integer.SIZE) {
275                if (currentBuffer == null) {
276                    currentBuffer = ByteBuffer.allocate(4);
277                }
278
279                // Go until we fill the integer sized current buffer.
280                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
281                    currentBuffer.put(plain.get());
282                }
283
284                // Didn't we get enough yet to figure out next frame size.
285                if (currentBuffer.hasRemaining()) {
286                    return;
287                } else {
288                    currentBuffer.flip();
289                    nextFrameSize = currentBuffer.getInt();
290                }
291
292            } else {
293
294                // Either we are completing a previous read of the next frame size or its
295                // fully contained in plain already.
296                if (currentBuffer != null) {
297
298                    // Finish the frame size integer read and get from the current buffer.
299                    while (currentBuffer.hasRemaining()) {
300                        currentBuffer.put(plain.get());
301                    }
302
303                    currentBuffer.flip();
304                    nextFrameSize = currentBuffer.getInt();
305
306                } else {
307                    nextFrameSize = plain.getInt();
308                }
309            }
310
311            if (wireFormat instanceof OpenWireFormat) {
312                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
313                if (nextFrameSize > maxFrameSize) {
314                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
315                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
316                }
317            }
318
319            // now we got the data, lets reallocate and store the size for the marshaler.
320            // if there's more data in plain, then the next call will start processing it.
321            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
322            currentBuffer.putInt(nextFrameSize);
323
324        } else {
325
326            // If its all in one read then we can just take it all, otherwise take only
327            // the current frame size and the next iteration starts a new command.
328            if (currentBuffer.remaining() >= plain.remaining()) {
329                currentBuffer.put(plain);
330            } else {
331                byte[] fill = new byte[currentBuffer.remaining()];
332                plain.get(fill);
333                currentBuffer.put(fill);
334            }
335
336            // Either we have enough data for a new command or we have to wait for some more.
337            if (currentBuffer.hasRemaining()) {
338                return;
339            } else {
340                currentBuffer.flip();
341                Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
342                doConsume(command);
343                nextFrameSize = -1;
344                currentBuffer = null;
345            }
346        }
347    }
348
349    protected int secureRead(ByteBuffer plain) throws Exception {
350
351        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
352            int bytesRead = channel.read(inputBuffer);
353
354            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
355                return 0;
356            }
357
358            if (bytesRead == -1) {
359                sslEngine.closeInbound();
360                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
361                    return -1;
362                }
363            }
364        }
365
366        plain.clear();
367
368        inputBuffer.flip();
369        SSLEngineResult res;
370        do {
371            res = sslEngine.unwrap(inputBuffer, plain);
372        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
373                && res.bytesProduced() == 0);
374
375        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
376            finishHandshake();
377        }
378
379        status = res.getStatus();
380        handshakeStatus = res.getHandshakeStatus();
381
382        // TODO deal with BUFFER_OVERFLOW
383
384        if (status == SSLEngineResult.Status.CLOSED) {
385            sslEngine.closeInbound();
386            return -1;
387        }
388
389        inputBuffer.compact();
390        plain.flip();
391
392        return plain.remaining();
393    }
394
395    protected void doHandshake() throws Exception {
396        handshakeInProgress = true;
397        Selector selector = null;
398        SelectionKey key = null;
399        boolean readable = true;
400        try {
401            while (true) {
402                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
403                switch (handshakeStatus) {
404                    case NEED_UNWRAP:
405                        if (readable) {
406                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
407                        }
408                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
409                            long now = System.currentTimeMillis();
410                            if (selector == null) {
411                                selector = Selector.open();
412                                key = channel.register(selector, SelectionKey.OP_READ);
413                            } else {
414                                key.interestOps(SelectionKey.OP_READ);
415                            }
416                            int keyCount = selector.select(this.getSoTimeout());
417                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
418                                throw new SocketTimeoutException("Timeout during handshake");
419                            }
420                            readable = key.isReadable();
421                        }
422                        break;
423                    case NEED_TASK:
424                        Runnable task;
425                        while ((task = sslEngine.getDelegatedTask()) != null) {
426                            task.run();
427                        }
428                        break;
429                    case NEED_WRAP:
430                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
431                        break;
432                    case FINISHED:
433                    case NOT_HANDSHAKING:
434                        finishHandshake();
435                        return;
436                }
437            }
438        } finally {
439            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
440            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
441        }
442    }
443
444    @Override
445    protected void doStart() throws Exception {
446        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
447        // no need to init as we can delay that until demand (eg in doHandshake)
448        super.doStart();
449    }
450
451    @Override
452    protected void doStop(ServiceStopper stopper) throws Exception {
453        initialized.countDown();
454
455        if (taskRunnerFactory != null) {
456            taskRunnerFactory.shutdownNow();
457            taskRunnerFactory = null;
458        }
459        if (channel != null) {
460            channel.close();
461            channel = null;
462        }
463        super.doStop(stopper);
464    }
465
466    /**
467     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
468     *
469     * @param command
470     *            The Command coming in.
471     */
472    @Override
473    public void doConsume(Object command) {
474        if (command instanceof ConnectionInfo) {
475            ConnectionInfo connectionInfo = (ConnectionInfo) command;
476            connectionInfo.setTransportContext(getPeerCertificates());
477        }
478        super.doConsume(command);
479    }
480
481    /**
482     * @return peer certificate chain associated with the ssl socket
483     */
484    public X509Certificate[] getPeerCertificates() {
485
486        X509Certificate[] clientCertChain = null;
487        try {
488            if (sslEngine.getSession() != null) {
489                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
490            }
491        } catch (SSLPeerUnverifiedException e) {
492            if (LOG.isTraceEnabled()) {
493                LOG.trace("Failed to get peer certificates.", e);
494            }
495        }
496
497        return clientCertChain;
498    }
499
500    public boolean isNeedClientAuth() {
501        return needClientAuth;
502    }
503
504    public void setNeedClientAuth(boolean needClientAuth) {
505        this.needClientAuth = needClientAuth;
506    }
507
508    public boolean isWantClientAuth() {
509        return wantClientAuth;
510    }
511
512    public void setWantClientAuth(boolean wantClientAuth) {
513        this.wantClientAuth = wantClientAuth;
514    }
515
516    public String[] getEnabledCipherSuites() {
517        return enabledCipherSuites;
518    }
519
520    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
521        this.enabledCipherSuites = enabledCipherSuites;
522    }
523
524    public String[] getEnabledProtocols() {
525        return enabledProtocols;
526    }
527
528    public void setEnabledProtocols(String[] enabledProtocols) {
529        this.enabledProtocols = enabledProtocols;
530    }
531
532    public boolean isVerifyHostName() {
533        return verifyHostName;
534    }
535
536    public void setVerifyHostName(boolean verifyHostName) {
537        this.verifyHostName = verifyHostName;
538    }
539}