001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.udp;
018
019import java.io.EOFException;
020import java.io.IOException;
021import java.net.BindException;
022import java.net.DatagramSocket;
023import java.net.InetAddress;
024import java.net.InetSocketAddress;
025import java.net.SocketAddress;
026import java.net.SocketException;
027import java.net.URI;
028import java.net.UnknownHostException;
029import java.nio.channels.AsynchronousCloseException;
030import java.nio.channels.DatagramChannel;
031
032import org.apache.activemq.Service;
033import org.apache.activemq.command.Command;
034import org.apache.activemq.command.Endpoint;
035import org.apache.activemq.openwire.OpenWireFormat;
036import org.apache.activemq.transport.Transport;
037import org.apache.activemq.transport.TransportThreadSupport;
038import org.apache.activemq.transport.reliable.ExceptionIfDroppedReplayStrategy;
039import org.apache.activemq.transport.reliable.ReplayBuffer;
040import org.apache.activemq.transport.reliable.ReplayStrategy;
041import org.apache.activemq.transport.reliable.Replayer;
042import org.apache.activemq.util.InetAddressUtil;
043import org.apache.activemq.util.IntSequenceGenerator;
044import org.apache.activemq.util.ServiceStopper;
045import org.slf4j.Logger;
046import org.slf4j.LoggerFactory;
047
048/**
049 * An implementation of the {@link Transport} interface using raw UDP
050 * 
051 * 
052 */
053public class UdpTransport extends TransportThreadSupport implements Transport, Service, Runnable {
054    private static final Logger LOG = LoggerFactory.getLogger(UdpTransport.class);
055
056    private static final int MAX_BIND_ATTEMPTS = 50;
057    private static final long BIND_ATTEMPT_DELAY = 100;
058
059    private CommandChannel commandChannel;
060    private OpenWireFormat wireFormat;
061    private ByteBufferPool bufferPool;
062    private ReplayStrategy replayStrategy = new ExceptionIfDroppedReplayStrategy();
063    private ReplayBuffer replayBuffer;
064    private int datagramSize = 4 * 1024;
065    private SocketAddress targetAddress;
066    private SocketAddress originalTargetAddress;
067    private DatagramChannel channel;
068    private boolean trace;
069    private boolean useLocalHost = false;
070    private int port;
071    private int minmumWireFormatVersion;
072    private String description;
073    private IntSequenceGenerator sequenceGenerator;
074    private boolean replayEnabled = true;
075
076    protected UdpTransport(OpenWireFormat wireFormat) throws IOException {
077        this.wireFormat = wireFormat;
078    }
079
080    public UdpTransport(OpenWireFormat wireFormat, URI remoteLocation) throws UnknownHostException, IOException {
081        this(wireFormat);
082        this.targetAddress = createAddress(remoteLocation);
083        description = remoteLocation.toString() + "@";
084    }
085
086    public UdpTransport(OpenWireFormat wireFormat, SocketAddress socketAddress) throws IOException {
087        this(wireFormat);
088        this.targetAddress = socketAddress;
089        this.description = getProtocolName() + "ServerConnection@";
090    }
091
092    /**
093     * Used by the server transport
094     */
095    public UdpTransport(OpenWireFormat wireFormat, int port) throws UnknownHostException, IOException {
096        this(wireFormat);
097        this.port = port;
098        this.targetAddress = null;
099        this.description = getProtocolName() + "Server@";
100    }
101
102    /**
103     * Creates a replayer for working with the reliable transport
104     */
105    public Replayer createReplayer() throws IOException {
106        if (replayEnabled) {
107            return getCommandChannel();
108        }
109        return null;
110    }
111
112    /**
113     * A one way asynchronous send
114     */
115    public void oneway(Object command) throws IOException {
116        oneway(command, targetAddress);
117    }
118
119    /**
120     * A one way asynchronous send to a given address
121     */
122    public void oneway(Object command, SocketAddress address) throws IOException {
123        if (LOG.isDebugEnabled()) {
124            LOG.debug("Sending oneway from: " + this + " to target: " + targetAddress + " command: " + command);
125        }
126        checkStarted();
127        commandChannel.write((Command)command, address);
128    }
129
130    /**
131     * @return pretty print of 'this'
132     */
133    public String toString() {
134        if (description != null) {
135            return description + port;
136        } else {
137            return getProtocolUriScheme() + targetAddress + "@" + port;
138        }
139    }
140
141    /**
142     * reads packets from a Socket
143     */
144    public void run() {
145        LOG.trace("Consumer thread starting for: " + toString());
146        while (!isStopped()) {
147            try {
148                Command command = commandChannel.read();
149                doConsume(command);
150            } catch (AsynchronousCloseException e) {
151                // DatagramChannel closed
152                try {
153                    stop();
154                } catch (Exception e2) {
155                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
156                }
157            } catch (SocketException e) {
158                // DatagramSocket closed
159                LOG.debug("Socket closed: " + e, e);
160                try {
161                    stop();
162                } catch (Exception e2) {
163                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
164                }
165            } catch (EOFException e) {
166                // DataInputStream closed
167                LOG.debug("Socket closed: " + e, e);
168                try {
169                    stop();
170                } catch (Exception e2) {
171                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
172                }
173            } catch (Exception e) {
174                try {
175                    stop();
176                } catch (Exception e2) {
177                    LOG.warn("Caught in: " + this + " while closing: " + e2 + ". Now Closed", e2);
178                }
179                if (e instanceof IOException) {
180                    onException((IOException)e);
181                } else {
182                    LOG.error("Caught: " + e, e);
183                    e.printStackTrace();
184                }
185            }
186        }
187    }
188
189    /**
190     * We have received the WireFormatInfo from the server on the actual channel
191     * we should use for all future communication with the server, so lets set
192     * the target to be the actual channel that the server has chosen for us to
193     * talk on.
194     */
195    public void setTargetEndpoint(Endpoint newTarget) {
196        if (newTarget instanceof DatagramEndpoint) {
197            DatagramEndpoint endpoint = (DatagramEndpoint)newTarget;
198            SocketAddress address = endpoint.getAddress();
199            if (address != null) {
200                if (originalTargetAddress == null) {
201                    originalTargetAddress = targetAddress;
202                }
203                targetAddress = address;
204                commandChannel.setTargetAddress(address);
205            }
206        }
207    }
208
209    // Properties
210    // -------------------------------------------------------------------------
211    public boolean isTrace() {
212        return trace;
213    }
214
215    public void setTrace(boolean trace) {
216        this.trace = trace;
217    }
218
219    public int getDatagramSize() {
220        return datagramSize;
221    }
222
223    public void setDatagramSize(int datagramSize) {
224        this.datagramSize = datagramSize;
225    }
226
227    public boolean isUseLocalHost() {
228        return useLocalHost;
229    }
230
231    /**
232     * Sets whether 'localhost' or the actual local host name should be used to
233     * make local connections. On some operating systems such as Macs its not
234     * possible to connect as the local host name so localhost is better.
235     */
236    public void setUseLocalHost(boolean useLocalHost) {
237        this.useLocalHost = useLocalHost;
238    }
239
240    public CommandChannel getCommandChannel() throws IOException {
241        if (commandChannel == null) {
242            commandChannel = createCommandChannel();
243        }
244        return commandChannel;
245    }
246
247    /**
248     * Sets the implementation of the command channel to use.
249     */
250    public void setCommandChannel(CommandDatagramChannel commandChannel) {
251        this.commandChannel = commandChannel;
252    }
253
254    public ReplayStrategy getReplayStrategy() {
255        return replayStrategy;
256    }
257
258    /**
259     * Sets the strategy used to replay missed datagrams
260     */
261    public void setReplayStrategy(ReplayStrategy replayStrategy) {
262        this.replayStrategy = replayStrategy;
263    }
264
265    public int getPort() {
266        return port;
267    }
268
269    /**
270     * Sets the port to connect on
271     */
272    public void setPort(int port) {
273        this.port = port;
274    }
275
276    public int getMinmumWireFormatVersion() {
277        return minmumWireFormatVersion;
278    }
279
280    public void setMinmumWireFormatVersion(int minmumWireFormatVersion) {
281        this.minmumWireFormatVersion = minmumWireFormatVersion;
282    }
283
284    public OpenWireFormat getWireFormat() {
285        return wireFormat;
286    }
287
288    public IntSequenceGenerator getSequenceGenerator() {
289        if (sequenceGenerator == null) {
290            sequenceGenerator = new IntSequenceGenerator();
291        }
292        return sequenceGenerator;
293    }
294
295    public void setSequenceGenerator(IntSequenceGenerator sequenceGenerator) {
296        this.sequenceGenerator = sequenceGenerator;
297    }
298
299    public boolean isReplayEnabled() {
300        return replayEnabled;
301    }
302
303    /**
304     * Sets whether or not replay should be enabled when using the reliable
305     * transport. i.e. should we maintain a buffer of messages that can be
306     * replayed?
307     */
308    public void setReplayEnabled(boolean replayEnabled) {
309        this.replayEnabled = replayEnabled;
310    }
311
312    public ByteBufferPool getBufferPool() {
313        if (bufferPool == null) {
314            bufferPool = new DefaultBufferPool();
315        }
316        return bufferPool;
317    }
318
319    public void setBufferPool(ByteBufferPool bufferPool) {
320        this.bufferPool = bufferPool;
321    }
322
323    public ReplayBuffer getReplayBuffer() {
324        return replayBuffer;
325    }
326
327    public void setReplayBuffer(ReplayBuffer replayBuffer) throws IOException {
328        this.replayBuffer = replayBuffer;
329        getCommandChannel().setReplayBuffer(replayBuffer);
330    }
331
332    // Implementation methods
333    // -------------------------------------------------------------------------
334
335    /**
336     * Creates an address from the given URI
337     */
338    protected InetSocketAddress createAddress(URI remoteLocation) throws UnknownHostException, IOException {
339        String host = resolveHostName(remoteLocation.getHost());
340        return new InetSocketAddress(host, remoteLocation.getPort());
341    }
342
343    protected String resolveHostName(String host) throws UnknownHostException {
344        String localName = InetAddressUtil.getLocalHostName();
345        if (localName != null && isUseLocalHost()) {
346            if (localName.equals(host)) {
347                return "localhost";
348            }
349        }
350        return host;
351    }
352
353    protected void doStart() throws Exception {
354        getCommandChannel().start();
355
356        super.doStart();
357    }
358
359    protected CommandChannel createCommandChannel() throws IOException {
360        SocketAddress localAddress = createLocalAddress();
361        channel = DatagramChannel.open();
362
363        channel = connect(channel, targetAddress);
364
365        DatagramSocket socket = channel.socket();
366        bind(socket, localAddress);
367        if (port == 0) {
368            port = socket.getLocalPort();
369        }
370
371        return createCommandDatagramChannel();
372    }
373
374    protected CommandChannel createCommandDatagramChannel() {
375        return new CommandDatagramChannel(this, getWireFormat(), getDatagramSize(), getTargetAddress(), createDatagramHeaderMarshaller(), getChannel(), getBufferPool());
376    }
377
378    protected void bind(DatagramSocket socket, SocketAddress localAddress) throws IOException {
379        channel.configureBlocking(true);
380
381        if (LOG.isDebugEnabled()) {
382            LOG.debug("Binding to address: " + localAddress);
383        }
384
385        //
386        // We have noticed that on some platfoms like linux, after you close
387        // down
388        // a previously bound socket, it can take a little while before we can
389        // bind it again.
390        // 
391        for (int i = 0; i < MAX_BIND_ATTEMPTS; i++) {
392            try {
393                socket.bind(localAddress);
394                return;
395            } catch (BindException e) {
396                if (i + 1 == MAX_BIND_ATTEMPTS) {
397                    throw e;
398                }
399                try {
400                    Thread.sleep(BIND_ATTEMPT_DELAY);
401                } catch (InterruptedException e1) {
402                    Thread.currentThread().interrupt();
403                    throw e;
404                }
405            }
406        }
407
408    }
409
410    protected DatagramChannel connect(DatagramChannel channel, SocketAddress targetAddress2) throws IOException {
411        // TODO
412        // connect to default target address to avoid security checks each time
413        // channel = channel.connect(targetAddress);
414
415        return channel;
416    }
417
418    protected SocketAddress createLocalAddress() {
419        return new InetSocketAddress(port);
420    }
421
422    protected void doStop(ServiceStopper stopper) throws Exception {
423        if (channel != null) {
424            channel.close();
425        }
426    }
427
428    protected DatagramHeaderMarshaller createDatagramHeaderMarshaller() {
429        return new DatagramHeaderMarshaller();
430    }
431
432    protected String getProtocolName() {
433        return "Udp";
434    }
435
436    protected String getProtocolUriScheme() {
437        return "udp://";
438    }
439
440    protected SocketAddress getTargetAddress() {
441        return targetAddress;
442    }
443
444    protected DatagramChannel getChannel() {
445        return channel;
446    }
447
448    protected void setChannel(DatagramChannel channel) {
449        this.channel = channel;
450    }
451
452    public InetSocketAddress getLocalSocketAddress() {
453        if (channel == null) {
454            return null;
455        } else {
456            return (InetSocketAddress)channel.socket().getLocalSocketAddress();
457        }
458    }
459
460    public String getRemoteAddress() {
461        if (targetAddress != null) {
462            return "" + targetAddress;
463        }
464        return null;
465    }
466
467    public int getReceiveCounter() {
468        if (commandChannel == null) {
469            return 0;
470        }
471        return commandChannel.getReceiveCounter();
472    }
473}