001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.activemq.transport.stomp;
018
019import java.io.DataInput;
020import java.io.DataInputStream;
021import java.io.DataOutput;
022import java.io.DataOutputStream;
023import java.io.IOException;
024import java.io.InputStream;
025import java.io.PushbackInputStream;
026import java.util.HashMap;
027import java.util.Map;
028import java.util.concurrent.atomic.AtomicLong;
029
030import org.apache.activemq.util.ByteArrayInputStream;
031import org.apache.activemq.util.ByteArrayOutputStream;
032import org.apache.activemq.util.ByteSequence;
033import org.apache.activemq.wireformat.WireFormat;
034
035/**
036 * Implements marshalling and unmarsalling the <a
037 * href="http://stomp.codehaus.org/">Stomp</a> protocol.
038 */
039public class StompWireFormat implements WireFormat {
040
041    private static final byte[] NO_DATA = new byte[] {};
042    private static final byte[] END_OF_FRAME = new byte[] {0, '\n'};
043
044    private static final int MAX_COMMAND_LENGTH = 1024;
045    private static final int MAX_HEADER_LENGTH = 1024 * 10;
046    private static final int MAX_HEADERS = 1000;
047    private static final int MAX_DATA_LENGTH = 1024 * 1024 * 100;
048
049    public static final long DEFAULT_MAX_FRAME_SIZE = Long.MAX_VALUE;
050    public static final long DEFAULT_CONNECTION_TIMEOUT = 30000;
051
052    private int version = 1;
053    private int maxDataLength = MAX_DATA_LENGTH;
054    private long maxFrameSize = DEFAULT_MAX_FRAME_SIZE;
055    private String stompVersion = Stomp.DEFAULT_VERSION;
056    private long connectionAttemptTimeout = DEFAULT_CONNECTION_TIMEOUT;
057
058    //The current frame size as it is unmarshalled from the stream
059    private final AtomicLong frameSize = new AtomicLong();
060
061    @Override
062    public ByteSequence marshal(Object command) throws IOException {
063        ByteArrayOutputStream baos = new ByteArrayOutputStream();
064        DataOutputStream dos = new DataOutputStream(baos);
065        marshal(command, dos);
066        dos.close();
067        return baos.toByteSequence();
068    }
069
070    @Override
071    public Object unmarshal(ByteSequence packet) throws IOException {
072        ByteArrayInputStream stream = new ByteArrayInputStream(packet);
073        DataInputStream dis = new DataInputStream(stream);
074        return unmarshal(dis);
075    }
076
077    @Override
078    public void marshal(Object command, DataOutput os) throws IOException {
079        StompFrame stomp = (org.apache.activemq.transport.stomp.StompFrame)command;
080
081        if (stomp.getAction().equals(Stomp.Commands.KEEPALIVE)) {
082            os.write(Stomp.BREAK);
083            return;
084        }
085
086        StringBuilder buffer = new StringBuilder();
087        buffer.append(stomp.getAction());
088        buffer.append(Stomp.NEWLINE);
089
090        // Output the headers.
091        for (Map.Entry<String, String> entry : stomp.getHeaders().entrySet()) {
092            buffer.append(entry.getKey());
093            buffer.append(Stomp.Headers.SEPERATOR);
094            buffer.append(encodeHeader(entry.getValue()));
095            buffer.append(Stomp.NEWLINE);
096        }
097
098        // Add a newline to seperate the headers from the content.
099        buffer.append(Stomp.NEWLINE);
100
101        os.write(buffer.toString().getBytes("UTF-8"));
102        os.write(stomp.getContent());
103        os.write(END_OF_FRAME);
104    }
105
106    @Override
107    public Object unmarshal(DataInput in) throws IOException {
108
109        try {
110
111            // parse action
112            String action = parseAction(in, frameSize);
113
114            // Parse the headers
115            HashMap<String, String> headers = parseHeaders(in, frameSize);
116
117            // Read in the data part.
118            byte[] data = NO_DATA;
119            String contentLength = headers.get(Stomp.Headers.CONTENT_LENGTH);
120            if ((action.equals(Stomp.Commands.SEND) || action.equals(Stomp.Responses.MESSAGE)) && contentLength != null) {
121
122                // Bless the client, he's telling us how much data to read in.
123                int length = parseContentLength(contentLength, frameSize);
124
125                data = new byte[length];
126                in.readFully(data);
127
128                if (in.readByte() != 0) {
129                    throw new ProtocolException(Stomp.Headers.CONTENT_LENGTH + " bytes were read and " + "there was no trailing null byte", true);
130                }
131
132            } else {
133
134                // We don't know how much to read.. data ends when we hit a 0
135                byte b;
136                ByteArrayOutputStream baos = null;
137                while ((b = in.readByte()) != 0) {
138                    if (baos == null) {
139                        baos = new ByteArrayOutputStream();
140                    } else if (baos.size() > getMaxDataLength()) {
141                        throw new ProtocolException("The maximum data length was exceeded", true);
142                    } else {
143                        if (frameSize.incrementAndGet() > getMaxFrameSize()) {
144                            throw new ProtocolException("The maximum frame size was exceeded", true);
145                        }
146                    }
147
148                    baos.write(b);
149                }
150
151                if (baos != null) {
152                    baos.close();
153                    data = baos.toByteArray();
154                }
155            }
156
157            return new StompFrame(action, headers, data);
158
159        } catch (ProtocolException e) {
160            return new StompFrameError(e);
161        } finally {
162            frameSize.set(0);
163        }
164    }
165
166    private String readLine(DataInput in, int maxLength, String errorMessage) throws IOException {
167        ByteSequence sequence = readHeaderLine(in, maxLength, errorMessage);
168        return new String(sequence.getData(), sequence.getOffset(), sequence.getLength(), "UTF-8").trim();
169    }
170
171    private ByteSequence readHeaderLine(DataInput in, int maxLength, String errorMessage) throws IOException {
172        byte b;
173        ByteArrayOutputStream baos = new ByteArrayOutputStream(maxLength);
174        while ((b = in.readByte()) != '\n') {
175            if (baos.size() > maxLength) {
176                baos.close();
177                throw new ProtocolException(errorMessage, true);
178            }
179            baos.write(b);
180        }
181
182        baos.close();
183        ByteSequence line = baos.toByteSequence();
184
185        if (stompVersion.equals(Stomp.V1_0) || stompVersion.equals(Stomp.V1_2)) {
186            int lineLength = line.getLength();
187            if (lineLength > 0 && line.data[lineLength-1] == '\r') {
188                line.setLength(lineLength-1);
189            }
190        }
191
192        return line;
193    }
194
195    protected String parseAction(DataInput in, AtomicLong frameSize) throws IOException {
196        String action = null;
197
198        // skip white space to next real action line
199        while (true) {
200            action = readLine(in, MAX_COMMAND_LENGTH, "The maximum command length was exceeded");
201            if (action == null) {
202                throw new IOException("connection was closed");
203            } else {
204                action = action.trim();
205                if (action.length() > 0) {
206                    break;
207                }
208            }
209        }
210        frameSize.addAndGet(action.length());
211        return action;
212    }
213
214    protected HashMap<String, String> parseHeaders(DataInput in, AtomicLong frameSize) throws IOException {
215        HashMap<String, String> headers = new HashMap<String, String>(25);
216        while (true) {
217            ByteSequence line = readHeaderLine(in, MAX_HEADER_LENGTH, "The maximum header length was exceeded");
218            if (line != null && line.length > 1) {
219
220                if (headers.size() > MAX_HEADERS) {
221                    throw new ProtocolException("The maximum number of headers was exceeded", true);
222                }
223                frameSize.addAndGet(line.length);
224
225                try {
226
227                    ByteArrayInputStream headerLine = new ByteArrayInputStream(line);
228                    ByteArrayOutputStream stream = new ByteArrayOutputStream(line.length);
229
230                    // First complete the name
231                    int result = -1;
232                    while ((result = headerLine.read()) != -1) {
233                        if (result != ':') {
234                            stream.write(result);
235                        } else {
236                            break;
237                        }
238                    }
239
240                    ByteSequence nameSeq = stream.toByteSequence();
241
242                    String name = new String(nameSeq.getData(), nameSeq.getOffset(), nameSeq.getLength(), "UTF-8");
243                    String value = decodeHeader(headerLine);
244                    if (stompVersion.equals(Stomp.V1_0)) {
245                        value = value.trim();
246                    }
247
248                    if (!headers.containsKey(name)) {
249                        headers.put(name, value);
250                    }
251
252                    stream.close();
253
254                } catch (Exception e) {
255                    throw new ProtocolException("Unable to parser header line [" + line + "]", true);
256                }
257            } else {
258                break;
259            }
260        }
261        return headers;
262    }
263
264    protected int parseContentLength(String contentLength, AtomicLong frameSize) throws ProtocolException {
265        int length;
266        try {
267            length = Integer.parseInt(contentLength.trim());
268        } catch (NumberFormatException e) {
269            throw new ProtocolException("Specified content-length is not a valid integer", true);
270        }
271
272        if (length > getMaxDataLength()) {
273            throw new ProtocolException("The maximum data length was exceeded", true);
274        }
275
276        if (frameSize.addAndGet(length) > getMaxFrameSize()) {
277            throw new ProtocolException("The maximum frame size was exceeded", true);
278        }
279
280        return length;
281    }
282
283    private String encodeHeader(String header) throws IOException {
284        String result = header;
285        if (!stompVersion.equals(Stomp.V1_0)) {
286            byte[] utf8buf = header.getBytes("UTF-8");
287            ByteArrayOutputStream stream = new ByteArrayOutputStream(utf8buf.length);
288            for(byte val : utf8buf) {
289                switch(val) {
290                case Stomp.ESCAPE:
291                    stream.write(Stomp.ESCAPE_ESCAPE_SEQ);
292                    break;
293                case Stomp.BREAK:
294                    stream.write(Stomp.NEWLINE_ESCAPE_SEQ);
295                    break;
296                case Stomp.COLON:
297                    stream.write(Stomp.COLON_ESCAPE_SEQ);
298                    break;
299                default:
300                    stream.write(val);
301                }
302            }
303            result =  new String(stream.toByteArray(), "UTF-8");
304            stream.close();
305        }
306
307        return result;
308    }
309
310    private String decodeHeader(InputStream header) throws IOException {
311        ByteArrayOutputStream decoded = new ByteArrayOutputStream();
312        PushbackInputStream stream = new PushbackInputStream(header);
313
314        int value = -1;
315        while( (value = stream.read()) != -1) {
316            if (value == 92) {
317
318                int next = stream.read();
319                if (next != -1) {
320                    switch(next) {
321                    case 110:
322                        decoded.write(Stomp.BREAK);
323                        break;
324                    case 99:
325                        decoded.write(Stomp.COLON);
326                        break;
327                    case 92:
328                        decoded.write(Stomp.ESCAPE);
329                        break;
330                    default:
331                        stream.unread(next);
332                        decoded.write(value);
333                    }
334                } else {
335                    decoded.write(value);
336                }
337
338            } else {
339                decoded.write(value);
340            }
341        }
342
343        decoded.close();
344
345        return new String(decoded.toByteArray(), "UTF-8");
346    }
347
348    @Override
349    public int getVersion() {
350        return version;
351    }
352
353    @Override
354    public void setVersion(int version) {
355        this.version = version;
356    }
357
358    public String getStompVersion() {
359        return stompVersion;
360    }
361
362    public void setStompVersion(String stompVersion) {
363        this.stompVersion = stompVersion;
364    }
365
366    public void setMaxDataLength(int maxDataLength) {
367        this.maxDataLength = maxDataLength;
368    }
369
370    public int getMaxDataLength() {
371        return maxDataLength;
372    }
373
374    public long getMaxFrameSize() {
375        return maxFrameSize;
376    }
377
378    public void setMaxFrameSize(long maxFrameSize) {
379        this.maxFrameSize = maxFrameSize;
380    }
381
382    public long getConnectionAttemptTimeout() {
383        return connectionAttemptTimeout;
384    }
385
386    public void setConnectionAttemptTimeout(long connectionAttemptTimeout) {
387        this.connectionAttemptTimeout = connectionAttemptTimeout;
388    }
389}