001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.ws.jetty9; 019 020import java.io.IOException; 021import java.util.*; 022import java.util.concurrent.ConcurrentHashMap; 023 024import javax.servlet.ServletException; 025import javax.servlet.http.HttpServletRequest; 026import javax.servlet.http.HttpServletResponse; 027 028import org.apache.activemq.jms.pool.IntrospectionSupport; 029import org.apache.activemq.transport.Transport; 030import org.apache.activemq.transport.TransportAcceptListener; 031import org.apache.activemq.transport.util.HttpTransportUtils; 032import org.eclipse.jetty.websocket.api.WebSocketListener; 033import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; 034import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; 035import org.eclipse.jetty.websocket.servlet.WebSocketCreator; 036import org.eclipse.jetty.websocket.servlet.WebSocketServlet; 037import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; 038 039/** 040 * Handle connection upgrade requests and creates web sockets 041 */ 042public class WSServlet extends WebSocketServlet { 043 044 private static final long serialVersionUID = -4716657876092884139L; 045 046 private TransportAcceptListener listener; 047 048 private final static Map<String, Integer> stompProtocols = new ConcurrentHashMap<> (); 049 private final static Map<String, Integer> mqttProtocols = new ConcurrentHashMap<> (); 050 051 private Map<String, Object> transportOptions; 052 053 static { 054 stompProtocols.put("v12.stomp", 3); 055 stompProtocols.put("v11.stomp", 2); 056 stompProtocols.put("v10.stomp", 1); 057 stompProtocols.put("stomp", 0); 058 059 mqttProtocols.put("mqttv3.1", 1); 060 mqttProtocols.put("mqtt", 0); 061 } 062 063 @Override 064 public void init() throws ServletException { 065 super.init(); 066 listener = (TransportAcceptListener) getServletContext().getAttribute("acceptListener"); 067 if (listener == null) { 068 throw new ServletException("No such attribute 'acceptListener' available in the ServletContext"); 069 } 070 } 071 072 @Override 073 protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { 074 getServletContext().getNamedDispatcher("default").forward(request, response); 075 } 076 077 @Override 078 public void configure(WebSocketServletFactory factory) { 079 factory.setCreator(new WebSocketCreator() { 080 @Override 081 public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) { 082 WebSocketListener socket; 083 boolean isMqtt = false; 084 for (String subProtocol : req.getSubProtocols()) { 085 if (subProtocol.startsWith("mqtt")) { 086 isMqtt = true; 087 } 088 } 089 if (isMqtt) { 090 socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); 091 resp.setAcceptedSubProtocol(getAcceptedSubProtocol(mqttProtocols,req.getSubProtocols(), "mqtt")); 092 ((MQTTSocket)socket).setTransportOptions(new HashMap(transportOptions)); 093 ((MQTTSocket)socket).setPeerCertificates(req.getCertificates()); 094 } else { 095 socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest())); 096 ((StompSocket)socket).setCertificates(req.getCertificates()); 097 resp.setAcceptedSubProtocol(getAcceptedSubProtocol(stompProtocols,req.getSubProtocols(), "stomp")); 098 } 099 listener.onAccept((Transport) socket); 100 return socket; 101 } 102 }); 103 } 104 105 private String getAcceptedSubProtocol(final Map<String, Integer> protocols, 106 List<String> subProtocols, String defaultProtocol) { 107 List<SubProtocol> matchedProtocols = new ArrayList<>(); 108 if (subProtocols != null && subProtocols.size() > 0) { 109 //detect which subprotocols match accepted protocols and add to the list 110 for (String subProtocol : subProtocols) { 111 Integer priority = protocols.get(subProtocol); 112 if(subProtocol != null && priority != null) { 113 //only insert if both subProtocol and priority are not null 114 matchedProtocols.add(new SubProtocol(subProtocol, priority)); 115 } 116 } 117 //sort the list by priority 118 if (matchedProtocols.size() > 0) { 119 Collections.sort(matchedProtocols, new Comparator<SubProtocol>() { 120 @Override 121 public int compare(SubProtocol s1, SubProtocol s2) { 122 return s2.priority.compareTo(s1.priority); 123 } 124 }); 125 return matchedProtocols.get(0).protocol; 126 } 127 } 128 return defaultProtocol; 129 } 130 131 private class SubProtocol { 132 private String protocol; 133 private Integer priority; 134 public SubProtocol(String protocol, Integer priority) { 135 this.protocol = protocol; 136 this.priority = priority; 137 } 138 } 139 140 public void setTransportOptions(Map<String, Object> transportOptions) { 141 this.transportOptions = transportOptions; 142 } 143}