001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataInputStream;
021import java.io.DataOutputStream;
022import java.io.EOFException;
023import java.io.IOException;
024import java.net.Socket;
025import java.net.SocketTimeoutException;
026import java.net.URI;
027import java.net.UnknownHostException;
028import java.nio.ByteBuffer;
029import java.nio.channels.SelectionKey;
030import java.nio.channels.Selector;
031import java.security.cert.X509Certificate;
032import java.util.concurrent.CountDownLatch;
033
034import javax.net.SocketFactory;
035import javax.net.ssl.SSLContext;
036import javax.net.ssl.SSLEngine;
037import javax.net.ssl.SSLEngineResult;
038import javax.net.ssl.SSLEngineResult.HandshakeStatus;
039import javax.net.ssl.SSLParameters;
040import javax.net.ssl.SSLPeerUnverifiedException;
041import javax.net.ssl.SSLSession;
042
043import org.apache.activemq.command.ConnectionInfo;
044import org.apache.activemq.openwire.OpenWireFormat;
045import org.apache.activemq.thread.TaskRunnerFactory;
046import org.apache.activemq.util.IOExceptionSupport;
047import org.apache.activemq.util.ServiceStopper;
048import org.apache.activemq.wireformat.WireFormat;
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052public class NIOSSLTransport extends NIOTransport {
053
054    private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
055
056    protected boolean needClientAuth;
057    protected boolean wantClientAuth;
058    protected String[] enabledCipherSuites;
059    protected String[] enabledProtocols;
060    protected boolean verifyHostName = false;
061
062    protected SSLContext sslContext;
063    protected SSLEngine sslEngine;
064    protected SSLSession sslSession;
065
066    protected volatile boolean handshakeInProgress = false;
067    protected SSLEngineResult.Status status = null;
068    protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
069    protected TaskRunnerFactory taskRunnerFactory;
070
071    public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
072        super(wireFormat, socketFactory, remoteLocation, localLocation);
073    }
074
075    public NIOSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
076        super(wireFormat, socket);
077    }
078
079    public void setSslContext(SSLContext sslContext) {
080        this.sslContext = sslContext;
081    }
082
083    @Override
084    protected void initializeStreams() throws IOException {
085        NIOOutputStream outputStream = null;
086        try {
087            channel = socket.getChannel();
088            channel.configureBlocking(false);
089
090            if (sslContext == null) {
091                sslContext = SSLContext.getDefault();
092            }
093
094            String remoteHost = null;
095            int remotePort = -1;
096
097            try {
098                URI remoteAddress = new URI(this.getRemoteAddress());
099                remoteHost = remoteAddress.getHost();
100                remotePort = remoteAddress.getPort();
101            } catch (Exception e) {
102            }
103
104            // initialize engine, the initial sslSession we get will need to be
105            // updated once the ssl handshake process is completed.
106            if (remoteHost != null && remotePort != -1) {
107                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
108            } else {
109                sslEngine = sslContext.createSSLEngine();
110            }
111
112            if (verifyHostName) {
113                SSLParameters sslParams = new SSLParameters();
114                sslParams.setEndpointIdentificationAlgorithm("HTTPS");
115                sslEngine.setSSLParameters(sslParams);
116            }
117
118            sslEngine.setUseClientMode(false);
119            if (enabledCipherSuites != null) {
120                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
121            }
122
123            if (enabledProtocols != null) {
124                sslEngine.setEnabledProtocols(enabledProtocols);
125            }
126
127            if (wantClientAuth) {
128                sslEngine.setWantClientAuth(wantClientAuth);
129            }
130
131            if (needClientAuth) {
132                sslEngine.setNeedClientAuth(needClientAuth);
133            }
134
135            sslSession = sslEngine.getSession();
136
137            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
138            inputBuffer.clear();
139
140            outputStream = new NIOOutputStream(channel);
141            outputStream.setEngine(sslEngine);
142            this.dataOut = new DataOutputStream(outputStream);
143            this.buffOut = outputStream;
144            sslEngine.beginHandshake();
145            handshakeStatus = sslEngine.getHandshakeStatus();
146            doHandshake();
147
148            selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
149                @Override
150                public void onSelect(SelectorSelection selection) {
151                    try {
152                        initialized.await();
153                    } catch (InterruptedException error) {
154                        onException(IOExceptionSupport.create(error));
155                    }
156                    serviceRead();
157                }
158
159                @Override
160                public void onError(SelectorSelection selection, Throwable error) {
161                    if (error instanceof IOException) {
162                        onException((IOException) error);
163                    } else {
164                        onException(IOExceptionSupport.create(error));
165                    }
166                }});
167
168            doInit();
169
170        } catch (Exception e) {
171            try {
172                if(outputStream != null) {
173                    outputStream.close();
174                }
175                super.closeStreams();
176            } catch (Exception ex) {}
177            throw new IOException(e);
178        }
179    }
180
181
182    final protected CountDownLatch initialized = new CountDownLatch(1);
183
184    protected void doInit() throws Exception {
185        taskRunnerFactory.execute(new Runnable() {
186
187            @Override
188            public void run() {
189                //Need to start in new thread to let startup finish first
190                //We can trigger a read because we know the channel is ready since the SSL handshake
191                //already happened
192                serviceRead();
193                initialized.countDown();
194            }
195        });
196    }
197
198    protected void finishHandshake() throws Exception {
199        if (handshakeInProgress) {
200            handshakeInProgress = false;
201            nextFrameSize = -1;
202
203            // Once handshake completes we need to ask for the now real sslSession
204            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
205            // cipher suite.
206            sslSession = sslEngine.getSession();
207        }
208    }
209
210    @Override
211    protected void serviceRead() {
212        try {
213            if (handshakeInProgress) {
214                doHandshake();
215            }
216
217            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
218            plain.position(plain.limit());
219
220            while (true) {
221                if (!plain.hasRemaining()) {
222
223                    int readCount = secureRead(plain);
224
225                    if (readCount == 0) {
226                        break;
227                    }
228
229                    // channel is closed, cleanup
230                    if (readCount == -1) {
231                        onException(new EOFException());
232                        selection.close();
233                        break;
234                    }
235
236                    receiveCounter += readCount;
237                }
238
239                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
240                    processCommand(plain);
241                }
242            }
243        } catch (IOException e) {
244            onException(e);
245        } catch (Throwable e) {
246            onException(IOExceptionSupport.create(e));
247        }
248    }
249
250    protected void processCommand(ByteBuffer plain) throws Exception {
251
252        // Are we waiting for the next Command or are we building on the current one
253        if (nextFrameSize == -1) {
254
255            // We can get small packets that don't give us enough for the frame size
256            // so allocate enough for the initial size value and
257            if (plain.remaining() < Integer.SIZE) {
258                if (currentBuffer == null) {
259                    currentBuffer = ByteBuffer.allocate(4);
260                }
261
262                // Go until we fill the integer sized current buffer.
263                while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
264                    currentBuffer.put(plain.get());
265                }
266
267                // Didn't we get enough yet to figure out next frame size.
268                if (currentBuffer.hasRemaining()) {
269                    return;
270                } else {
271                    currentBuffer.flip();
272                    nextFrameSize = currentBuffer.getInt();
273                }
274
275            } else {
276
277                // Either we are completing a previous read of the next frame size or its
278                // fully contained in plain already.
279                if (currentBuffer != null) {
280
281                    // Finish the frame size integer read and get from the current buffer.
282                    while (currentBuffer.hasRemaining()) {
283                        currentBuffer.put(plain.get());
284                    }
285
286                    currentBuffer.flip();
287                    nextFrameSize = currentBuffer.getInt();
288
289                } else {
290                    nextFrameSize = plain.getInt();
291                }
292            }
293
294            if (wireFormat instanceof OpenWireFormat) {
295                long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
296                if (nextFrameSize > maxFrameSize) {
297                    throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
298                                          " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
299                }
300            }
301
302            // now we got the data, lets reallocate and store the size for the marshaler.
303            // if there's more data in plain, then the next call will start processing it.
304            currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
305            currentBuffer.putInt(nextFrameSize);
306
307        } else {
308
309            // If its all in one read then we can just take it all, otherwise take only
310            // the current frame size and the next iteration starts a new command.
311            if (currentBuffer.remaining() >= plain.remaining()) {
312                currentBuffer.put(plain);
313            } else {
314                byte[] fill = new byte[currentBuffer.remaining()];
315                plain.get(fill);
316                currentBuffer.put(fill);
317            }
318
319            // Either we have enough data for a new command or we have to wait for some more.
320            if (currentBuffer.hasRemaining()) {
321                return;
322            } else {
323                currentBuffer.flip();
324                Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
325                doConsume(command);
326                nextFrameSize = -1;
327                currentBuffer = null;
328            }
329        }
330    }
331
332    protected int secureRead(ByteBuffer plain) throws Exception {
333
334        if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
335            int bytesRead = channel.read(inputBuffer);
336
337            if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
338                return 0;
339            }
340
341            if (bytesRead == -1) {
342                sslEngine.closeInbound();
343                if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
344                    return -1;
345                }
346            }
347        }
348
349        plain.clear();
350
351        inputBuffer.flip();
352        SSLEngineResult res;
353        do {
354            res = sslEngine.unwrap(inputBuffer, plain);
355        } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
356                && res.bytesProduced() == 0);
357
358        if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
359            finishHandshake();
360        }
361
362        status = res.getStatus();
363        handshakeStatus = res.getHandshakeStatus();
364
365        // TODO deal with BUFFER_OVERFLOW
366
367        if (status == SSLEngineResult.Status.CLOSED) {
368            sslEngine.closeInbound();
369            return -1;
370        }
371
372        inputBuffer.compact();
373        plain.flip();
374
375        return plain.remaining();
376    }
377
378    protected void doHandshake() throws Exception {
379        handshakeInProgress = true;
380        Selector selector = null;
381        SelectionKey key = null;
382        boolean readable = true;
383        try {
384            while (true) {
385                HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
386                switch (handshakeStatus) {
387                    case NEED_UNWRAP:
388                        if (readable) {
389                            secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
390                        }
391                        if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
392                            long now = System.currentTimeMillis();
393                            if (selector == null) {
394                                selector = Selector.open();
395                                key = channel.register(selector, SelectionKey.OP_READ);
396                            } else {
397                                key.interestOps(SelectionKey.OP_READ);
398                            }
399                            int keyCount = selector.select(this.getSoTimeout());
400                            if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
401                                throw new SocketTimeoutException("Timeout during handshake");
402                            }
403                            readable = key.isReadable();
404                        }
405                        break;
406                    case NEED_TASK:
407                        Runnable task;
408                        while ((task = sslEngine.getDelegatedTask()) != null) {
409                            task.run();
410                        }
411                        break;
412                    case NEED_WRAP:
413                        ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
414                        break;
415                    case FINISHED:
416                    case NOT_HANDSHAKING:
417                        finishHandshake();
418                        return;
419                }
420            }
421        } finally {
422            if (key!=null) try {key.cancel();} catch (Exception ignore) {}
423            if (selector!=null) try {selector.close();} catch (Exception ignore) {}
424        }
425    }
426
427    @Override
428    protected void doStart() throws Exception {
429        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
430        // no need to init as we can delay that until demand (eg in doHandshake)
431        super.doStart();
432    }
433
434    @Override
435    protected void doStop(ServiceStopper stopper) throws Exception {
436        initialized.countDown();
437
438        if (taskRunnerFactory != null) {
439            taskRunnerFactory.shutdownNow();
440            taskRunnerFactory = null;
441        }
442        if (channel != null) {
443            channel.close();
444            channel = null;
445        }
446        super.doStop(stopper);
447    }
448
449    /**
450     * Overriding in order to add the client's certificates to ConnectionInfo Commands.
451     *
452     * @param command
453     *            The Command coming in.
454     */
455    @Override
456    public void doConsume(Object command) {
457        if (command instanceof ConnectionInfo) {
458            ConnectionInfo connectionInfo = (ConnectionInfo) command;
459            connectionInfo.setTransportContext(getPeerCertificates());
460        }
461        super.doConsume(command);
462    }
463
464    /**
465     * @return peer certificate chain associated with the ssl socket
466     */
467    public X509Certificate[] getPeerCertificates() {
468
469        X509Certificate[] clientCertChain = null;
470        try {
471            if (sslEngine.getSession() != null) {
472                clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
473            }
474        } catch (SSLPeerUnverifiedException e) {
475            if (LOG.isTraceEnabled()) {
476                LOG.trace("Failed to get peer certificates.", e);
477            }
478        }
479
480        return clientCertChain;
481    }
482
483    public boolean isNeedClientAuth() {
484        return needClientAuth;
485    }
486
487    public void setNeedClientAuth(boolean needClientAuth) {
488        this.needClientAuth = needClientAuth;
489    }
490
491    public boolean isWantClientAuth() {
492        return wantClientAuth;
493    }
494
495    public void setWantClientAuth(boolean wantClientAuth) {
496        this.wantClientAuth = wantClientAuth;
497    }
498
499    public String[] getEnabledCipherSuites() {
500        return enabledCipherSuites;
501    }
502
503    public void setEnabledCipherSuites(String[] enabledCipherSuites) {
504        this.enabledCipherSuites = enabledCipherSuites;
505    }
506
507    public String[] getEnabledProtocols() {
508        return enabledProtocols;
509    }
510
511    public void setEnabledProtocols(String[] enabledProtocols) {
512        this.enabledProtocols = enabledProtocols;
513    }
514
515    public boolean isVerifyHostName() {
516        return verifyHostName;
517    }
518
519    public void setVerifyHostName(boolean verifyHostName) {
520        this.verifyHostName = verifyHostName;
521    }
522}