diff options
Diffstat (limited to 'pyload/lib/mod_pywebsocket/handshake/hybi00.py')
| -rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/hybi00.py | 63 | 
1 files changed, 57 insertions, 6 deletions
| diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi00.py b/pyload/lib/mod_pywebsocket/handshake/hybi00.py index cc6f8dc43..8757717a6 100644 --- a/pyload/lib/mod_pywebsocket/handshake/hybi00.py +++ b/pyload/lib/mod_pywebsocket/handshake/hybi00.py @@ -51,11 +51,12 @@ from mod_pywebsocket import common  from mod_pywebsocket.stream import StreamHixie75  from mod_pywebsocket import util  from mod_pywebsocket.handshake._base import HandshakeException -from mod_pywebsocket.handshake._base import build_location -from mod_pywebsocket.handshake._base import check_header_lines +from mod_pywebsocket.handshake._base import check_request_line  from mod_pywebsocket.handshake._base import format_header +from mod_pywebsocket.handshake._base import get_default_port  from mod_pywebsocket.handshake._base import get_mandatory_header -from mod_pywebsocket.handshake._base import validate_subprotocol +from mod_pywebsocket.handshake._base import parse_host_header +from mod_pywebsocket.handshake._base import validate_mandatory_header  _MANDATORY_HEADERS = [ @@ -65,6 +66,56 @@ _MANDATORY_HEADERS = [  ] +def _validate_subprotocol(subprotocol): +    """Checks if characters in subprotocol are in range between U+0020 and +    U+007E. A value in the Sec-WebSocket-Protocol field need to satisfy this +    requirement. + +    See the Section 4.1. Opening handshake of the spec. +    """ + +    if not subprotocol: +        raise HandshakeException('Invalid subprotocol name: empty') + +    # Parameter should be in the range U+0020 to U+007E. +    for c in subprotocol: +        if not 0x20 <= ord(c) <= 0x7e: +            raise HandshakeException( +                'Illegal character in subprotocol name: %r' % c) + + +def _check_header_lines(request, mandatory_headers): +    check_request_line(request) + +    # The expected field names, and the meaning of their corresponding +    # values, are as follows. +    #  |Upgrade| and |Connection| +    for key, expected_value in mandatory_headers: +        validate_mandatory_header(request, key, expected_value) + + +def _build_location(request): +    """Build WebSocket location for request.""" + +    location_parts = [] +    if request.is_https(): +        location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) +    else: +        location_parts.append(common.WEB_SOCKET_SCHEME) +    location_parts.append('://') +    host, port = parse_host_header(request) +    connection_port = request.connection.local_addr[1] +    if port != connection_port: +        raise HandshakeException('Header/connection port mismatch: %d/%d' % +                                 (port, connection_port)) +    location_parts.append(host) +    if (port != get_default_port(request.is_https())): +        location_parts.append(':') +        location_parts.append(str(port)) +    location_parts.append(request.unparsed_uri) +    return ''.join(location_parts) + +  class Handshaker(object):      """Opening handshake processor for the WebSocket protocol version HyBi 00.      """ @@ -101,7 +152,7 @@ class Handshaker(object):          # 5.1 Reading the client's opening handshake.          # dispatcher sets it in self._request. -        check_header_lines(self._request, _MANDATORY_HEADERS) +        _check_header_lines(self._request, _MANDATORY_HEADERS)          self._set_resource()          self._set_subprotocol()          self._set_location() @@ -121,14 +172,14 @@ class Handshaker(object):          subprotocol = self._request.headers_in.get(              common.SEC_WEBSOCKET_PROTOCOL_HEADER)          if subprotocol is not None: -            validate_subprotocol(subprotocol, hixie=True) +            _validate_subprotocol(subprotocol)          self._request.ws_protocol = subprotocol      def _set_location(self):          # |Host|          host = self._request.headers_in.get(common.HOST_HEADER)          if host is not None: -            self._request.ws_location = build_location(self._request) +            self._request.ws_location = _build_location(self._request)          # TODO(ukai): check host is this host.      def _set_origin(self): | 
