summaryrefslogtreecommitdiffstats
path: root/pyload
diff options
context:
space:
mode:
authorGravatar RaNaN <Mast3rRaNaN@hotmail.de> 2013-12-15 14:36:38 +0100
committerGravatar RaNaN <Mast3rRaNaN@hotmail.de> 2013-12-15 14:36:38 +0100
commit367172406d1382f2fdd39936bde31ebcbcd1c56f (patch)
tree6abd7b2adffdd7fbb4eb35a9171c25b0e8eee4a0 /pyload
parentmore options to get webUI through proxy working (diff)
downloadpyload-367172406d1382f2fdd39936bde31ebcbcd1c56f.tar.xz
updated pywebsocket
Diffstat (limited to 'pyload')
-rw-r--r--pyload/lib/mod_pywebsocket/_stream_base.py46
-rw-r--r--pyload/lib/mod_pywebsocket/_stream_hybi.py24
-rw-r--r--pyload/lib/mod_pywebsocket/deflate_stream_extension.py69
-rw-r--r--pyload/lib/mod_pywebsocket/dispatch.py10
-rw-r--r--pyload/lib/mod_pywebsocket/extensions.py171
-rw-r--r--pyload/lib/mod_pywebsocket/handshake/_base.py72
-rw-r--r--pyload/lib/mod_pywebsocket/handshake/hybi.py64
-rw-r--r--pyload/lib/mod_pywebsocket/handshake/hybi00.py63
-rw-r--r--pyload/lib/mod_pywebsocket/headerparserhandler.py28
-rw-r--r--pyload/lib/mod_pywebsocket/mux.py470
-rwxr-xr-xpyload/lib/mod_pywebsocket/standalone.py249
-rw-r--r--pyload/lib/mod_pywebsocket/util.py43
-rw-r--r--pyload/remote/WebSocketBackend.py5
-rw-r--r--pyload/remote/wsbackend/Server.py263
14 files changed, 1153 insertions, 424 deletions
diff --git a/pyload/lib/mod_pywebsocket/_stream_base.py b/pyload/lib/mod_pywebsocket/_stream_base.py
index 60fb33d2c..8235666bb 100644
--- a/pyload/lib/mod_pywebsocket/_stream_base.py
+++ b/pyload/lib/mod_pywebsocket/_stream_base.py
@@ -39,6 +39,8 @@
# writing/reading.
+import socket
+
from mod_pywebsocket import util
@@ -109,20 +111,34 @@ class StreamBase(object):
ConnectionTerminatedException: when read returns empty string.
"""
- bytes = self._request.connection.read(length)
- if not bytes:
+ try:
+ read_bytes = self._request.connection.read(length)
+ if not read_bytes:
+ raise ConnectionTerminatedException(
+ 'Receiving %d byte failed. Peer (%r) closed connection' %
+ (length, (self._request.connection.remote_addr,)))
+ return read_bytes
+ except socket.error, e:
+ # Catch a socket.error. Because it's not a child class of the
+ # IOError prior to Python 2.6, we cannot omit this except clause.
+ # Use %s rather than %r for the exception to use human friendly
+ # format.
+ raise ConnectionTerminatedException(
+ 'Receiving %d byte failed. socket.error (%s) occurred' %
+ (length, e))
+ except IOError, e:
+ # Also catch an IOError because mod_python throws it.
raise ConnectionTerminatedException(
- 'Receiving %d byte failed. Peer (%r) closed connection' %
- (length, (self._request.connection.remote_addr,)))
- return bytes
+ 'Receiving %d byte failed. IOError (%s) occurred' %
+ (length, e))
- def _write(self, bytes):
+ def _write(self, bytes_to_write):
"""Writes given bytes to connection. In case we catch any exception,
prepends remote address to the exception message and raise again.
"""
try:
- self._request.connection.write(bytes)
+ self._request.connection.write(bytes_to_write)
except Exception, e:
util.prepend_message_to_exception(
'Failed to send message to %r: ' %
@@ -138,12 +154,12 @@ class StreamBase(object):
ConnectionTerminatedException: when read returns empty string.
"""
- bytes = []
+ read_bytes = []
while length > 0:
- new_bytes = self._read(length)
- bytes.append(new_bytes)
- length -= len(new_bytes)
- return ''.join(bytes)
+ new_read_bytes = self._read(length)
+ read_bytes.append(new_read_bytes)
+ length -= len(new_read_bytes)
+ return ''.join(read_bytes)
def _read_until(self, delim_char):
"""Reads bytes until we encounter delim_char. The result will not
@@ -153,13 +169,13 @@ class StreamBase(object):
ConnectionTerminatedException: when read returns empty string.
"""
- bytes = []
+ read_bytes = []
while True:
ch = self._read(1)
if ch == delim_char:
break
- bytes.append(ch)
- return ''.join(bytes)
+ read_bytes.append(ch)
+ return ''.join(read_bytes)
# vi:sts=4 sw=4 et
diff --git a/pyload/lib/mod_pywebsocket/_stream_hybi.py b/pyload/lib/mod_pywebsocket/_stream_hybi.py
index bd158fa6b..1c43249a4 100644
--- a/pyload/lib/mod_pywebsocket/_stream_hybi.py
+++ b/pyload/lib/mod_pywebsocket/_stream_hybi.py
@@ -280,7 +280,7 @@ def parse_frame(receive_bytes, logger=None,
if logger.isEnabledFor(common.LOGLEVEL_FINE):
unmask_start = time.time()
- bytes = masker.mask(raw_payload_bytes)
+ unmasked_bytes = masker.mask(raw_payload_bytes)
if logger.isEnabledFor(common.LOGLEVEL_FINE):
logger.log(
@@ -288,7 +288,7 @@ def parse_frame(receive_bytes, logger=None,
'Done unmasking payload data at %s MB/s',
payload_length / (time.time() - unmask_start) / 1000 / 1000)
- return opcode, bytes, fin, rsv1, rsv2, rsv3
+ return opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3
class FragmentedFrameBuilder(object):
@@ -403,9 +403,6 @@ class StreamOptions(object):
self.encode_text_message_to_utf8 = True
self.mask_send = False
self.unmask_receive = True
- # RFC6455 disallows fragmented control frames, but mux extension
- # relaxes the restriction.
- self.allow_fragmented_control_frame = False
class Stream(StreamBase):
@@ -463,10 +460,10 @@ class Stream(StreamBase):
unmask_receive=self._options.unmask_receive)
def _receive_frame_as_frame_object(self):
- opcode, bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame()
+ opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame()
return Frame(fin=fin, rsv1=rsv1, rsv2=rsv2, rsv3=rsv3,
- opcode=opcode, payload=bytes)
+ opcode=opcode, payload=unmasked_bytes)
def receive_filtered_frame(self):
"""Receives a frame and applies frame filters and message filters.
@@ -602,8 +599,7 @@ class Stream(StreamBase):
else:
# Start of fragmentation frame
- if (not self._options.allow_fragmented_control_frame and
- common.is_control_opcode(frame.opcode)):
+ if common.is_control_opcode(frame.opcode):
raise InvalidFrameException(
'Control frames must not be fragmented')
@@ -672,7 +668,7 @@ class Stream(StreamBase):
reason = ''
self._send_closing_handshake(code, reason)
self._logger.debug(
- 'Sent ack for client-initiated closing handshake '
+ 'Acknowledged closing handshake initiated by the peer '
'(code=%r, reason=%r)', code, reason)
def _process_ping_message(self, message):
@@ -815,13 +811,15 @@ class Stream(StreamBase):
self._write(frame)
- def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''):
+ def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason='',
+ wait_response=True):
"""Closes a WebSocket connection.
Args:
code: Status code for close frame. If code is None, a close
frame with empty body will be sent.
reason: string representing close reason.
+ wait_response: True when caller want to wait the response.
Raises:
BadOperationException: when reason is specified with code None
or reason is not an instance of both str and unicode.
@@ -844,11 +842,11 @@ class Stream(StreamBase):
self._send_closing_handshake(code, reason)
self._logger.debug(
- 'Sent server-initiated closing handshake (code=%r, reason=%r)',
+ 'Initiated closing handshake (code=%r, reason=%r)',
code, reason)
if (code == common.STATUS_GOING_AWAY or
- code == common.STATUS_PROTOCOL_ERROR):
+ code == common.STATUS_PROTOCOL_ERROR) or not wait_response:
# It doesn't make sense to wait for a close frame if the reason is
# protocol error or that the server is going away. For some of
# other reasons, it might not make sense to wait for a close frame,
diff --git a/pyload/lib/mod_pywebsocket/deflate_stream_extension.py b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py
new file mode 100644
index 000000000..d2ba477c4
--- /dev/null
+++ b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py
@@ -0,0 +1,69 @@
+# Copyright 2013, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+from mod_pywebsocket import common
+from mod_pywebsocket.extensions import _available_processors
+from mod_pywebsocket.extensions import ExtensionProcessorInterface
+from mod_pywebsocket import util
+
+
+class DeflateStreamExtensionProcessor(ExtensionProcessorInterface):
+ """WebSocket DEFLATE stream extension processor.
+
+ Specification:
+ Section 9.2.1 in
+ http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
+ """
+
+ def __init__(self, request):
+ ExtensionProcessorInterface.__init__(self, request)
+ self._logger = util.get_class_logger(self)
+
+ def name(self):
+ return common.DEFLATE_STREAM_EXTENSION
+
+ def _get_extension_response_internal(self):
+ if len(self._request.get_parameter_names()) != 0:
+ return None
+
+ self._logger.debug(
+ 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION)
+
+ return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION)
+
+ def _setup_stream_options_internal(self, stream_options):
+ stream_options.deflate_stream = True
+
+
+_available_processors[common.DEFLATE_STREAM_EXTENSION] = (
+ DeflateStreamExtensionProcessor)
+
+
+# vi:sts=4 sw=4 et
diff --git a/pyload/lib/mod_pywebsocket/dispatch.py b/pyload/lib/mod_pywebsocket/dispatch.py
index 25905f180..96c91e0c9 100644
--- a/pyload/lib/mod_pywebsocket/dispatch.py
+++ b/pyload/lib/mod_pywebsocket/dispatch.py
@@ -255,6 +255,9 @@ class Dispatcher(object):
try:
do_extra_handshake_(request)
except handshake.AbortedByUserException, e:
+ # Re-raise to tell the caller of this function to finish this
+ # connection without sending any error.
+ self._logger.debug('%s', util.get_stack_trace())
raise
except Exception, e:
util.prepend_message_to_exception(
@@ -294,11 +297,12 @@ class Dispatcher(object):
request.ws_stream.close_connection()
# Catch non-critical exceptions the handler didn't handle.
except handshake.AbortedByUserException, e:
- self._logger.debug('%s', e)
+ self._logger.debug('%s', util.get_stack_trace())
raise
except msgutil.BadOperationException, e:
self._logger.debug('%s', e)
- request.ws_stream.close_connection(common.STATUS_ABNORMAL_CLOSURE)
+ request.ws_stream.close_connection(
+ common.STATUS_INTERNAL_ENDPOINT_ERROR)
except msgutil.InvalidFrameException, e:
# InvalidFrameException must be caught before
# ConnectionTerminatedException that catches InvalidFrameException.
@@ -314,6 +318,8 @@ class Dispatcher(object):
except msgutil.ConnectionTerminatedException, e:
self._logger.debug('%s', e)
except Exception, e:
+ # Any other exceptions are forwarded to the caller of this
+ # function.
util.prepend_message_to_exception(
'%s raised exception for %s: ' % (
_TRANSFER_DATA_HANDLER_NAME, request.ws_resource),
diff --git a/pyload/lib/mod_pywebsocket/extensions.py b/pyload/lib/mod_pywebsocket/extensions.py
index 03dbf9ee1..18841ed92 100644
--- a/pyload/lib/mod_pywebsocket/extensions.py
+++ b/pyload/lib/mod_pywebsocket/extensions.py
@@ -38,47 +38,42 @@ _available_processors = {}
class ExtensionProcessorInterface(object):
- def name(self):
- return None
+ def __init__(self, request):
+ self._request = request
+ self._active = True
- def get_extension_response(self):
+ def request(self):
+ return self._request
+
+ def name(self):
return None
- def setup_stream_options(self, stream_options):
+ def check_consistency_with_other_processors(self, processors):
pass
+ def set_active(self, active):
+ self._active = active
-class DeflateStreamExtensionProcessor(ExtensionProcessorInterface):
- """WebSocket DEFLATE stream extension processor.
-
- Specification:
- Section 9.2.1 in
- http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10
- """
-
- def __init__(self, request):
- self._logger = util.get_class_logger(self)
-
- self._request = request
+ def is_active(self):
+ return self._active
- def name(self):
- return common.DEFLATE_STREAM_EXTENSION
+ def _get_extension_response_internal(self):
+ return None
def get_extension_response(self):
- if len(self._request.get_parameter_names()) != 0:
- return None
-
- self._logger.debug(
- 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION)
+ if self._active:
+ response = self._get_extension_response_internal()
+ if response is None:
+ self._active = False
+ return response
+ return None
- return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION)
+ def _setup_stream_options_internal(self, stream_options):
+ pass
def setup_stream_options(self, stream_options):
- stream_options.deflate_stream = True
-
-
-_available_processors[common.DEFLATE_STREAM_EXTENSION] = (
- DeflateStreamExtensionProcessor)
+ if self._active:
+ self._setup_stream_options_internal(stream_options)
def _log_compression_ratio(logger, original_bytes, total_original_bytes,
@@ -109,6 +104,17 @@ def _log_decompression_ratio(logger, received_bytes, total_received_bytes,
(ratio, average_ratio))
+def _validate_window_bits(bits):
+ if bits is not None:
+ try:
+ bits = int(bits)
+ except ValueError, e:
+ return False
+ if bits < 8 or bits > 15:
+ return False
+ return True
+
+
class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
"""WebSocket Per-frame DEFLATE extension processor.
@@ -120,10 +126,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
_NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover'
def __init__(self, request):
+ ExtensionProcessorInterface.__init__(self, request)
self._logger = util.get_class_logger(self)
- self._request = request
-
self._response_window_bits = None
self._response_no_context_takeover = False
self._bfinal = False
@@ -143,7 +148,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
def name(self):
return common.DEFLATE_FRAME_EXTENSION
- def get_extension_response(self):
+ def _get_extension_response_internal(self):
# Any unknown parameter will be just ignored.
window_bits = self._request.get_parameter_value(
@@ -155,13 +160,8 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
self._NO_CONTEXT_TAKEOVER_PARAM) is not None):
return None
- if window_bits is not None:
- try:
- window_bits = int(window_bits)
- except ValueError, e:
- return None
- if window_bits < 8 or window_bits > 15:
- return None
+ if not _validate_window_bits(window_bits):
+ return None
self._deflater = util._RFC1979Deflater(
window_bits, no_context_takeover)
@@ -191,7 +191,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface):
return response
- def setup_stream_options(self, stream_options):
+ def _setup_stream_options_internal(self, stream_options):
class _OutgoingFilter(object):
@@ -311,8 +311,8 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):
_METHOD_PARAM = 'method'
def __init__(self, request):
+ ExtensionProcessorInterface.__init__(self, request)
self._logger = util.get_class_logger(self)
- self._request = request
self._compression_method_name = None
self._compression_processor = None
self._compression_processor_hook = None
@@ -357,7 +357,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):
self._compression_processor = compression_processor
return processor_response
- def get_extension_response(self):
+ def _get_extension_response_internal(self):
processor_response = self._get_compression_processor_response()
if processor_response is None:
return None
@@ -372,7 +372,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface):
(self._request.name(), self._compression_method_name))
return response
- def setup_stream_options(self, stream_options):
+ def _setup_stream_options_internal(self, stream_options):
if self._compression_processor is None:
return
self._compression_processor.setup_stream_options(stream_options)
@@ -418,7 +418,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):
_C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover'
def __init__(self, request):
- self._request = request
+ ExtensionProcessorInterface.__init__(self, request)
self._logger = util.get_class_logger(self)
self._c2s_max_window_bits = None
@@ -445,18 +445,13 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):
def name(self):
return 'deflate'
- def get_extension_response(self):
+ def _get_extension_response_internal(self):
# Any unknown parameter will be just ignored.
s2c_max_window_bits = self._request.get_parameter_value(
self._S2C_MAX_WINDOW_BITS_PARAM)
- if s2c_max_window_bits is not None:
- try:
- s2c_max_window_bits = int(s2c_max_window_bits)
- except ValueError, e:
- return None
- if s2c_max_window_bits < 8 or s2c_max_window_bits > 15:
- return None
+ if not _validate_window_bits(s2c_max_window_bits):
+ return None
s2c_no_context_takeover = self._request.has_parameter(
self._S2C_NO_CONTEXT_TAKEOVER_PARAM)
@@ -502,7 +497,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface):
return response
- def setup_stream_options(self, stream_options):
+ def _setup_stream_options_internal(self, stream_options):
class _OutgoingMessageFilter(object):
def __init__(self, parent):
@@ -676,42 +671,72 @@ class MuxExtensionProcessor(ExtensionProcessorInterface):
_QUOTA_PARAM = 'quota'
def __init__(self, request):
- self._request = request
+ ExtensionProcessorInterface.__init__(self, request)
+ self._quota = 0
+ self._extensions = []
def name(self):
return common.MUX_EXTENSION
- def get_extension_response(self, ws_request,
- logical_channel_extensions):
- # Mux extension cannot be used after extensions that depend on
- # frame boundary, extension data field, or any reserved bits
- # which are attributed to each frame.
- for extension in logical_channel_extensions:
- name = extension.name()
- if (name == common.PERFRAME_COMPRESSION_EXTENSION or
- name == common.DEFLATE_FRAME_EXTENSION or
- name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION):
- return None
-
+ def check_consistency_with_other_processors(self, processors):
+ before_mux = True
+ for processor in processors:
+ name = processor.name()
+ if name == self.name():
+ before_mux = False
+ continue
+ if not processor.is_active():
+ continue
+ if before_mux:
+ # Mux extension cannot be used after extensions
+ # that depend on frame boundary, extension data field, or any
+ # reserved bits which are attributed to each frame.
+ if (name == common.PERFRAME_COMPRESSION_EXTENSION or
+ name == common.DEFLATE_FRAME_EXTENSION or
+ name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION):
+ self.set_active(False)
+ return
+ else:
+ # Mux extension should not be applied before any history-based
+ # compression extension.
+ if (name == common.PERFRAME_COMPRESSION_EXTENSION or
+ name == common.DEFLATE_FRAME_EXTENSION or
+ name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION or
+ name == common.PERMESSAGE_COMPRESSION_EXTENSION or
+ name == common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION):
+ self.set_active(False)
+ return
+
+ def _get_extension_response_internal(self):
+ self._active = False
quota = self._request.get_parameter_value(self._QUOTA_PARAM)
- if quota is None:
- ws_request.mux_quota = 0
- else:
+ if quota is not None:
try:
quota = int(quota)
except ValueError, e:
return None
if quota < 0 or quota >= 2 ** 32:
return None
- ws_request.mux_quota = quota
+ self._quota = quota
- ws_request.mux = True
- ws_request.mux_extensions = logical_channel_extensions
+ self._active = True
return common.ExtensionParameter(common.MUX_EXTENSION)
- def setup_stream_options(self, stream_options):
+ def _setup_stream_options_internal(self, stream_options):
pass
+ def set_quota(self, quota):
+ self._quota = quota
+
+ def quota(self):
+ return self._quota
+
+ def set_extensions(self, extensions):
+ self._extensions = extensions
+
+ def extensions(self):
+ return self._extensions
+
_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor
diff --git a/pyload/lib/mod_pywebsocket/handshake/_base.py b/pyload/lib/mod_pywebsocket/handshake/_base.py
index e5c94ca90..c993a584b 100644
--- a/pyload/lib/mod_pywebsocket/handshake/_base.py
+++ b/pyload/lib/mod_pywebsocket/handshake/_base.py
@@ -84,42 +84,29 @@ def get_default_port(is_secure):
return common.DEFAULT_WEB_SOCKET_PORT
-def validate_subprotocol(subprotocol, hixie):
+def validate_subprotocol(subprotocol):
"""Validate a value in the Sec-WebSocket-Protocol field.
- See
- - RFC 6455: Section 4.1., 4.2.2., and 4.3.
- - HyBi 00: Section 4.1. Opening handshake
-
- Args:
- hixie: if True, checks if characters in subprotocol are in range
- between U+0020 and U+007E. It's required by HyBi 00 but not by
- RFC 6455.
+ See the Section 4.1., 4.2.2., and 4.3. of RFC 6455.
"""
if not subprotocol:
raise HandshakeException('Invalid subprotocol name: empty')
- if hixie:
- # Parameter should be in the range U+0020 to U+007E.
- for c in subprotocol:
- if not 0x20 <= ord(c) <= 0x7e:
- raise HandshakeException(
- 'Illegal character in subprotocol name: %r' % c)
- else:
- # Parameter should be encoded HTTP token.
- state = http_header_util.ParsingState(subprotocol)
- token = http_header_util.consume_token(state)
- rest = http_header_util.peek(state)
- # If |rest| is not None, |subprotocol| is not one token or invalid. If
- # |rest| is None, |token| must not be None because |subprotocol| is
- # concatenation of |token| and |rest| and is not None.
- if rest is not None:
- raise HandshakeException('Invalid non-token string in subprotocol '
- 'name: %r' % rest)
+
+ # Parameter should be encoded HTTP token.
+ state = http_header_util.ParsingState(subprotocol)
+ token = http_header_util.consume_token(state)
+ rest = http_header_util.peek(state)
+ # If |rest| is not None, |subprotocol| is not one token or invalid. If
+ # |rest| is None, |token| must not be None because |subprotocol| is
+ # concatenation of |token| and |rest| and is not None.
+ if rest is not None:
+ raise HandshakeException('Invalid non-token string in subprotocol '
+ 'name: %r' % rest)
def parse_host_header(request):
- fields = request.headers_in['Host'].split(':', 1)
+ fields = request.headers_in[common.HOST_HEADER].split(':', 1)
if len(fields) == 1:
return fields[0], get_default_port(request.is_https())
try:
@@ -132,27 +119,6 @@ def format_header(name, value):
return '%s: %s\r\n' % (name, value)
-def build_location(request):
- """Build WebSocket location for request."""
- location_parts = []
- if request.is_https():
- location_parts.append(common.WEB_SOCKET_SECURE_SCHEME)
- else:
- location_parts.append(common.WEB_SOCKET_SCHEME)
- location_parts.append('://')
- host, port = parse_host_header(request)
- connection_port = request.connection.local_addr[1]
- if port != connection_port:
- raise HandshakeException('Header/connection port mismatch: %d/%d' %
- (port, connection_port))
- location_parts.append(host)
- if (port != get_default_port(request.is_https())):
- location_parts.append(':')
- location_parts.append(str(port))
- location_parts.append(request.uri)
- return ''.join(location_parts)
-
-
def get_mandatory_header(request, key):
value = request.headers_in.get(key)
if value is None:
@@ -180,16 +146,6 @@ def check_request_line(request):
request.protocol)
-def check_header_lines(request, mandatory_headers):
- check_request_line(request)
-
- # The expected field names, and the meaning of their corresponding
- # values, are as follows.
- # |Upgrade| and |Connection|
- for key, expected_value in mandatory_headers:
- validate_mandatory_header(request, key, expected_value)
-
-
def parse_token_list(data):
"""Parses a header value which follows 1#token and returns parsed elements
as a list of strings.
diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi.py b/pyload/lib/mod_pywebsocket/handshake/hybi.py
index fc0e2a096..669097d77 100644
--- a/pyload/lib/mod_pywebsocket/handshake/hybi.py
+++ b/pyload/lib/mod_pywebsocket/handshake/hybi.py
@@ -48,6 +48,7 @@ import os
import re
from mod_pywebsocket import common
+from mod_pywebsocket import deflate_stream_extension
from mod_pywebsocket.extensions import get_extension_processor
from mod_pywebsocket.handshake._base import check_request_line
from mod_pywebsocket.handshake._base import format_header
@@ -180,44 +181,57 @@ class Handshaker(object):
processors.append(processor)
self._request.ws_extension_processors = processors
+ # List of extra headers. The extra handshake handler may add header
+ # data as name/value pairs to this list and pywebsocket appends
+ # them to the WebSocket handshake.
+ self._request.extra_headers = []
+
# Extra handshake handler may modify/remove processors.
self._dispatcher.do_extra_handshake(self._request)
processors = filter(lambda processor: processor is not None,
self._request.ws_extension_processors)
+ # Ask each processor if there are extensions on the request which
+ # cannot co-exist. When processor decided other processors cannot
+ # co-exist with it, the processor marks them (or itself) as
+ # "inactive". The first extension processor has the right to
+ # make the final call.
+ for processor in reversed(processors):
+ if processor.is_active():
+ processor.check_consistency_with_other_processors(
+ processors)
+ processors = filter(lambda processor: processor.is_active(),
+ processors)
+
accepted_extensions = []
- # We need to take care of mux extension here. Extensions that
- # are placed before mux should be applied to logical channels.
+ # We need to take into account of mux extension here.
+ # If mux extension exists:
+ # - Remove processors of extensions for logical channel,
+ # which are processors located before the mux processor
+ # - Pass extension requests for logical channel to mux processor
+ # - Attach the mux processor to the request. It will be referred
+ # by dispatcher to see whether the dispatcher should use mux
+ # handler or not.
mux_index = -1
for i, processor in enumerate(processors):
if processor.name() == common.MUX_EXTENSION:
mux_index = i
break
if mux_index >= 0:
- mux_processor = processors[mux_index]
- logical_channel_processors = processors[:mux_index]
- processors = processors[mux_index+1:]
-
- for processor in logical_channel_processors:
- extension_response = processor.get_extension_response()
- if extension_response is None:
- # Rejected.
- continue
- accepted_extensions.append(extension_response)
- # Pass a shallow copy of accepted_extensions as extensions for
- # logical channels.
- mux_response = mux_processor.get_extension_response(
- self._request, accepted_extensions[:])
- if mux_response is not None:
- accepted_extensions.append(mux_response)
+ logical_channel_extensions = []
+ for processor in processors[:mux_index]:
+ logical_channel_extensions.append(processor.request())
+ processor.set_active(False)
+ self._request.mux_processor = processors[mux_index]
+ self._request.mux_processor.set_extensions(
+ logical_channel_extensions)
+ processors = filter(lambda processor: processor.is_active(),
+ processors)
stream_options = StreamOptions()
- # When there is mux extension, here, |processors| contain only
- # prosessors for extensions placed after mux.
for processor in processors:
-
extension_response = processor.get_extension_response()
if extension_response is None:
# Rejected.
@@ -242,7 +256,7 @@ class Handshaker(object):
raise HandshakeException(
'do_extra_handshake must choose one subprotocol from '
'ws_requested_protocols and set it to ws_protocol')
- validate_subprotocol(self._request.ws_protocol, hixie=False)
+ validate_subprotocol(self._request.ws_protocol)
self._logger.debug(
'Subprotocol accepted: %r',
@@ -375,6 +389,7 @@ class Handshaker(object):
response.append('HTTP/1.1 101 Switching Protocols\r\n')
+ # WebSocket headers
response.append(format_header(
common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))
response.append(format_header(
@@ -390,6 +405,11 @@ class Handshaker(object):
response.append(format_header(
common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
common.format_extensions(self._request.ws_extensions)))
+
+ # Headers not specific for WebSocket
+ for name, value in self._request.extra_headers:
+ response.append(format_header(name, value))
+
response.append('\r\n')
return ''.join(response)
diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi00.py b/pyload/lib/mod_pywebsocket/handshake/hybi00.py
index cc6f8dc43..8757717a6 100644
--- a/pyload/lib/mod_pywebsocket/handshake/hybi00.py
+++ b/pyload/lib/mod_pywebsocket/handshake/hybi00.py
@@ -51,11 +51,12 @@ from mod_pywebsocket import common
from mod_pywebsocket.stream import StreamHixie75
from mod_pywebsocket import util
from mod_pywebsocket.handshake._base import HandshakeException
-from mod_pywebsocket.handshake._base import build_location
-from mod_pywebsocket.handshake._base import check_header_lines
+from mod_pywebsocket.handshake._base import check_request_line
from mod_pywebsocket.handshake._base import format_header
+from mod_pywebsocket.handshake._base import get_default_port
from mod_pywebsocket.handshake._base import get_mandatory_header
-from mod_pywebsocket.handshake._base import validate_subprotocol
+from mod_pywebsocket.handshake._base import parse_host_header
+from mod_pywebsocket.handshake._base import validate_mandatory_header
_MANDATORY_HEADERS = [
@@ -65,6 +66,56 @@ _MANDATORY_HEADERS = [
]
+def _validate_subprotocol(subprotocol):
+ """Checks if characters in subprotocol are in range between U+0020 and
+ U+007E. A value in the Sec-WebSocket-Protocol field need to satisfy this
+ requirement.
+
+ See the Section 4.1. Opening handshake of the spec.
+ """
+
+ if not subprotocol:
+ raise HandshakeException('Invalid subprotocol name: empty')
+
+ # Parameter should be in the range U+0020 to U+007E.
+ for c in subprotocol:
+ if not 0x20 <= ord(c) <= 0x7e:
+ raise HandshakeException(
+ 'Illegal character in subprotocol name: %r' % c)
+
+
+def _check_header_lines(request, mandatory_headers):
+ check_request_line(request)
+
+ # The expected field names, and the meaning of their corresponding
+ # values, are as follows.
+ # |Upgrade| and |Connection|
+ for key, expected_value in mandatory_headers:
+ validate_mandatory_header(request, key, expected_value)
+
+
+def _build_location(request):
+ """Build WebSocket location for request."""
+
+ location_parts = []
+ if request.is_https():
+ location_parts.append(common.WEB_SOCKET_SECURE_SCHEME)
+ else:
+ location_parts.append(common.WEB_SOCKET_SCHEME)
+ location_parts.append('://')
+ host, port = parse_host_header(request)
+ connection_port = request.connection.local_addr[1]
+ if port != connection_port:
+ raise HandshakeException('Header/connection port mismatch: %d/%d' %
+ (port, connection_port))
+ location_parts.append(host)
+ if (port != get_default_port(request.is_https())):
+ location_parts.append(':')
+ location_parts.append(str(port))
+ location_parts.append(request.unparsed_uri)
+ return ''.join(location_parts)
+
+
class Handshaker(object):
"""Opening handshake processor for the WebSocket protocol version HyBi 00.
"""
@@ -101,7 +152,7 @@ class Handshaker(object):
# 5.1 Reading the client's opening handshake.
# dispatcher sets it in self._request.
- check_header_lines(self._request, _MANDATORY_HEADERS)
+ _check_header_lines(self._request, _MANDATORY_HEADERS)
self._set_resource()
self._set_subprotocol()
self._set_location()
@@ -121,14 +172,14 @@ class Handshaker(object):
subprotocol = self._request.headers_in.get(
common.SEC_WEBSOCKET_PROTOCOL_HEADER)
if subprotocol is not None:
- validate_subprotocol(subprotocol, hixie=True)
+ _validate_subprotocol(subprotocol)
self._request.ws_protocol = subprotocol
def _set_location(self):
# |Host|
host = self._request.headers_in.get(common.HOST_HEADER)
if host is not None:
- self._request.ws_location = build_location(self._request)
+ self._request.ws_location = _build_location(self._request)
# TODO(ukai): check host is this host.
def _set_origin(self):
diff --git a/pyload/lib/mod_pywebsocket/headerparserhandler.py b/pyload/lib/mod_pywebsocket/headerparserhandler.py
index 2cc62de04..c244421cf 100644
--- a/pyload/lib/mod_pywebsocket/headerparserhandler.py
+++ b/pyload/lib/mod_pywebsocket/headerparserhandler.py
@@ -167,7 +167,9 @@ def _create_dispatcher():
handler_root, handler_scan, allow_handlers_outside_root)
for warning in dispatcher.source_warnings():
- apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING)
+ apache.log_error(
+ 'mod_pywebsocket: Warning in source loading: %s' % warning,
+ apache.APLOG_WARNING)
return dispatcher
@@ -191,12 +193,16 @@ def headerparserhandler(request):
# Fallback to default http handler for request paths for which
# we don't have request handlers.
if not _dispatcher.get_handler_suite(request.uri):
- request.log_error('No handler for resource: %r' % request.uri,
- apache.APLOG_INFO)
- request.log_error('Fallback to Apache', apache.APLOG_INFO)
+ request.log_error(
+ 'mod_pywebsocket: No handler for resource: %r' % request.uri,
+ apache.APLOG_INFO)
+ request.log_error(
+ 'mod_pywebsocket: Fallback to Apache', apache.APLOG_INFO)
return apache.DECLINED
except dispatch.DispatchException, e:
- request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
+ request.log_error(
+ 'mod_pywebsocket: Dispatch failed for error: %s' % e,
+ apache.APLOG_INFO)
if not handshake_is_done:
return e.status
@@ -210,26 +216,30 @@ def headerparserhandler(request):
handshake.do_handshake(
request, _dispatcher, allowDraft75=allow_draft75)
except handshake.VersionException, e:
- request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
+ request.log_error(
+ 'mod_pywebsocket: Handshake failed for version error: %s' % e,
+ apache.APLOG_INFO)
request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
return apache.HTTP_BAD_REQUEST
except handshake.HandshakeException, e:
# Handshake for ws/wss failed.
# Send http response with error status.
- request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
+ request.log_error(
+ 'mod_pywebsocket: Handshake failed for error: %s' % e,
+ apache.APLOG_INFO)
return e.status
handshake_is_done = True
request._dispatcher = _dispatcher
_dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e:
- request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO)
+ request.log_error('mod_pywebsocket: Aborted: %s' % e, apache.APLOG_INFO)
except Exception, e:
# DispatchException can also be thrown if something is wrong in
# pywebsocket code. It's caught here, then.
- request.log_error('mod_pywebsocket: %s\n%s' %
+ request.log_error('mod_pywebsocket: Exception occurred: %s\n%s' %
(e, util.get_stack_trace()),
apache.APLOG_ERR)
# Unknown exceptions before handshake mean Apache must handle its
diff --git a/pyload/lib/mod_pywebsocket/mux.py b/pyload/lib/mod_pywebsocket/mux.py
index f0bdd2461..7923fb211 100644
--- a/pyload/lib/mod_pywebsocket/mux.py
+++ b/pyload/lib/mod_pywebsocket/mux.py
@@ -50,6 +50,7 @@ from mod_pywebsocket import handshake
from mod_pywebsocket import util
from mod_pywebsocket._stream_base import BadOperationException
from mod_pywebsocket._stream_base import ConnectionTerminatedException
+from mod_pywebsocket._stream_base import InvalidFrameException
from mod_pywebsocket._stream_hybi import Frame
from mod_pywebsocket._stream_hybi import Stream
from mod_pywebsocket._stream_hybi import StreamOptions
@@ -94,10 +95,12 @@ _DROP_CODE_UNKNOWN_MUX_OPCODE = 2004
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005
_DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006
_DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007
+_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 2010
-_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002
_DROP_CODE_SEND_QUOTA_VIOLATION = 3005
+_DROP_CODE_SEND_QUOTA_OVERFLOW = 3006
_DROP_CODE_ACKNOWLEDGED = 3008
+_DROP_CODE_BAD_FRAGMENTATION = 3009
class MuxUnexpectedException(Exception):
@@ -158,8 +161,7 @@ def _encode_number(number):
def _create_add_channel_response(channel_id, encoded_handshake,
- encoding=0, rejected=False,
- outer_frame_mask=False):
+ encoding=0, rejected=False):
if encoding != 0 and encoding != 1:
raise ValueError('Invalid encoding %d' % encoding)
@@ -169,12 +171,10 @@ def _create_add_channel_response(channel_id, encoded_handshake,
_encode_channel_id(channel_id) +
_encode_number(len(encoded_handshake)) +
encoded_handshake)
- payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block
- return create_binary_frame(payload, mask=outer_frame_mask)
+ return block
-def _create_drop_channel(channel_id, code=None, message='',
- outer_frame_mask=False):
+def _create_drop_channel(channel_id, code=None, message=''):
if len(message) > 0 and code is None:
raise ValueError('Code must be specified if message is specified')
@@ -187,36 +187,31 @@ def _create_drop_channel(channel_id, code=None, message='',
reason_size = _encode_number(len(reason))
block += reason_size + reason
- payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block
- return create_binary_frame(payload, mask=outer_frame_mask)
+ return block
-def _create_flow_control(channel_id, replenished_quota,
- outer_frame_mask=False):
+def _create_flow_control(channel_id, replenished_quota):
first_byte = _MUX_OPCODE_FLOW_CONTROL << 5
block = (chr(first_byte) +
_encode_channel_id(channel_id) +
_encode_number(replenished_quota))
- payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block
- return create_binary_frame(payload, mask=outer_frame_mask)
+ return block
-def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False):
+def _create_new_channel_slot(slots, send_quota):
if slots < 0 or send_quota < 0:
raise ValueError('slots and send_quota must be non-negative.')
first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5
block = (chr(first_byte) +
_encode_number(slots) +
_encode_number(send_quota))
- payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block
- return create_binary_frame(payload, mask=outer_frame_mask)
+ return block
-def _create_fallback_new_channel_slot(outer_frame_mask=False):
+def _create_fallback_new_channel_slot():
first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag
block = (chr(first_byte) + _encode_number(0) + _encode_number(0))
- payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block
- return create_binary_frame(payload, mask=outer_frame_mask)
+ return block
def _parse_request_text(request_text):
@@ -318,44 +313,34 @@ class _MuxFramePayloadParser(object):
def _read_number(self):
if self._read_position + 1 > len(self._data):
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ raise ValueError(
'Cannot read the first byte of number field')
number = ord(self._data[self._read_position])
if number & 0x80 == 0x80:
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ raise ValueError(
'The most significant bit of the first byte of number should '
'be unset')
self._read_position += 1
pos = self._read_position
if number == 127:
if pos + 8 > len(self._data):
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
- 'Invalid number field')
+ raise ValueError('Invalid number field')
self._read_position += 8
number = struct.unpack('!Q', self._data[pos:pos+8])[0]
if number > 0x7FFFFFFFFFFFFFFF:
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
- 'Encoded number >= 2^63')
+ raise ValueError('Encoded number(%d) >= 2^63' % number)
if number <= 0xFFFF:
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ raise ValueError(
'%d should not be encoded by 9 bytes encoding' % number)
return number
if number == 126:
if pos + 2 > len(self._data):
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
- 'Invalid number field')
+ raise ValueError('Invalid number field')
self._read_position += 2
number = struct.unpack('!H', self._data[pos:pos+2])[0]
if number <= 125:
- raise PhysicalConnectionError(
- _DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ raise ValueError(
'%d should not be encoded by 3 bytes encoding' % number)
return number
@@ -366,7 +351,11 @@ class _MuxFramePayloadParser(object):
- the contents.
"""
- size = self._read_number()
+ try:
+ size = self._read_number()
+ except ValueError, e:
+ raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ str(e))
pos = self._read_position
if pos + size > len(self._data):
raise PhysicalConnectionError(
@@ -419,9 +408,11 @@ class _MuxFramePayloadParser(object):
try:
control_block.channel_id = self.read_channel_id()
+ control_block.send_quota = self._read_number()
except ValueError, e:
- raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK)
- control_block.send_quota = self._read_number()
+ raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ str(e))
+
return control_block
def _read_drop_channel(self, first_byte, control_block):
@@ -455,8 +446,12 @@ class _MuxFramePayloadParser(object):
_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
'Reserved bits must be unset')
control_block.fallback = first_byte & 1
- control_block.slots = self._read_number()
- control_block.send_quota = self._read_number()
+ try:
+ control_block.slots = self._read_number()
+ control_block.send_quota = self._read_number()
+ except ValueError, e:
+ raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ str(e))
return control_block
def read_control_blocks(self):
@@ -549,8 +544,12 @@ class _LogicalConnection(object):
self._mux_handler = mux_handler
self._channel_id = channel_id
self._incoming_data = ''
+
+ # - Protects _waiting_write_completion
+ # - Signals the thread waiting for completion of write by mux handler
self._write_condition = threading.Condition()
self._waiting_write_completion = False
+
self._read_condition = threading.Condition()
self._read_state = self.STATE_ACTIVE
@@ -594,6 +593,7 @@ class _LogicalConnection(object):
self._waiting_write_completion = True
self._mux_handler.send_data(self._channel_id, data)
self._write_condition.wait()
+ # TODO(tyoshino): Raise an exception if woke up by on_writer_done.
finally:
self._write_condition.release()
@@ -607,20 +607,31 @@ class _LogicalConnection(object):
self._mux_handler.send_control_data(data)
- def notify_write_done(self):
+ def on_write_data_done(self):
"""Called when sending data is completed."""
try:
self._write_condition.acquire()
if not self._waiting_write_completion:
raise MuxUnexpectedException(
- 'Invalid call of notify_write_done for logical connection'
- ' %d' % self._channel_id)
+ 'Invalid call of on_write_data_done for logical '
+ 'connection %d' % self._channel_id)
+ self._waiting_write_completion = False
+ self._write_condition.notify()
+ finally:
+ self._write_condition.release()
+
+ def on_writer_done(self):
+ """Called by the mux handler when the writer thread has finished."""
+
+ try:
+ self._write_condition.acquire()
self._waiting_write_completion = False
self._write_condition.notify()
finally:
self._write_condition.release()
+
def append_frame_data(self, frame_data):
"""Appends incoming frame data. Called when mux_handler dispatches
frame data to the corresponding application.
@@ -686,37 +697,162 @@ class _LogicalConnection(object):
self._read_condition.release()
+class _InnerMessage(object):
+ """Holds the result of _InnerMessageBuilder.build().
+ """
+
+ def __init__(self, opcode, payload):
+ self.opcode = opcode
+ self.payload = payload
+
+
+class _InnerMessageBuilder(object):
+ """A class that holds the context of inner message fragmentation and
+ builds a message from fragmented inner frame(s).
+ """
+
+ def __init__(self):
+ self._control_opcode = None
+ self._pending_control_fragments = []
+ self._message_opcode = None
+ self._pending_message_fragments = []
+ self._frame_handler = self._handle_first
+
+ def _handle_first(self, frame):
+ if frame.opcode == common.OPCODE_CONTINUATION:
+ raise InvalidFrameException('Sending invalid continuation opcode')
+
+ if common.is_control_opcode(frame.opcode):
+ return self._process_first_fragmented_control(frame)
+ else:
+ return self._process_first_fragmented_message(frame)
+
+ def _process_first_fragmented_control(self, frame):
+ self._control_opcode = frame.opcode
+ self._pending_control_fragments.append(frame.payload)
+ if not frame.fin:
+ self._frame_handler = self._handle_fragmented_control
+ return None
+ return self._reassemble_fragmented_control()
+
+ def _process_first_fragmented_message(self, frame):
+ self._message_opcode = frame.opcode
+ self._pending_message_fragments.append(frame.payload)
+ if not frame.fin:
+ self._frame_handler = self._handle_fragmented_message
+ return None
+ return self._reassemble_fragmented_message()
+
+ def _handle_fragmented_control(self, frame):
+ if frame.opcode != common.OPCODE_CONTINUATION:
+ raise InvalidFrameException(
+ 'Sending invalid opcode %d while sending fragmented control '
+ 'message' % frame.opcode)
+ self._pending_control_fragments.append(frame.payload)
+ if not frame.fin:
+ return None
+ return self._reassemble_fragmented_control()
+
+ def _reassemble_fragmented_control(self):
+ opcode = self._control_opcode
+ payload = ''.join(self._pending_control_fragments)
+ self._control_opcode = None
+ self._pending_control_fragments = []
+ if self._message_opcode is not None:
+ self._frame_handler = self._handle_fragmented_message
+ else:
+ self._frame_handler = self._handle_first
+ return _InnerMessage(opcode, payload)
+
+ def _handle_fragmented_message(self, frame):
+ # Sender can interleave a control message while sending fragmented
+ # messages.
+ if common.is_control_opcode(frame.opcode):
+ if self._control_opcode is not None:
+ raise MuxUnexpectedException(
+ 'Should not reach here(Bug in builder)')
+ return self._process_first_fragmented_control(frame)
+
+ if frame.opcode != common.OPCODE_CONTINUATION:
+ raise InvalidFrameException(
+ 'Sending invalid opcode %d while sending fragmented message' %
+ frame.opcode)
+ self._pending_message_fragments.append(frame.payload)
+ if not frame.fin:
+ return None
+ return self._reassemble_fragmented_message()
+
+ def _reassemble_fragmented_message(self):
+ opcode = self._message_opcode
+ payload = ''.join(self._pending_message_fragments)
+ self._message_opcode = None
+ self._pending_message_fragments = []
+ self._frame_handler = self._handle_first
+ return _InnerMessage(opcode, payload)
+
+ def build(self, frame):
+ """Build an inner message. Returns an _InnerMessage instance when
+ the given frame is the last fragmented frame. Returns None otherwise.
+
+ Args:
+ frame: an inner frame.
+ Raises:
+ InvalidFrameException: when received invalid opcode. (e.g.
+ receiving non continuation data opcode but the fin flag of
+ the previous inner frame was not set.)
+ """
+
+ return self._frame_handler(frame)
+
+
class _LogicalStream(Stream):
"""Mimics the Stream class. This class interprets multiplexed WebSocket
frames.
"""
- def __init__(self, request, send_quota, receive_quota):
+ def __init__(self, request, stream_options, send_quota, receive_quota):
"""Constructs an instance.
Args:
request: _LogicalRequest instance.
+ stream_options: StreamOptions instance.
send_quota: Initial send quota.
receive_quota: Initial receive quota.
"""
- # TODO(bashi): Support frame filters.
- stream_options = StreamOptions()
# Physical stream is responsible for masking.
stream_options.unmask_receive = False
- # Control frames can be fragmented on logical channel.
- stream_options.allow_fragmented_control_frame = True
Stream.__init__(self, request, stream_options)
+
+ self._send_closed = False
self._send_quota = send_quota
- self._send_quota_condition = threading.Condition()
+ # - Protects _send_closed and _send_quota
+ # - Signals the thread waiting for send quota replenished
+ self._send_condition = threading.Condition()
+
+ # The opcode of the first frame in messages.
+ self._message_opcode = common.OPCODE_TEXT
+ # True when the last message was fragmented.
+ self._last_message_was_fragmented = False
+
self._receive_quota = receive_quota
self._write_inner_frame_semaphore = threading.Semaphore()
+ self._inner_message_builder = _InnerMessageBuilder()
+
def _create_inner_frame(self, opcode, payload, end=True):
- # TODO(bashi): Support extensions that use reserved bits.
- first_byte = (end << 7) | opcode
- return (_encode_channel_id(self._request.channel_id) +
- chr(first_byte) + payload)
+ frame = Frame(fin=end, opcode=opcode, payload=payload)
+ for frame_filter in self._options.outgoing_frame_filters:
+ frame_filter.filter(frame)
+
+ if len(payload) != len(frame.payload):
+ raise MuxUnexpectedException(
+ 'Mux extension must not be used after extensions which change '
+ ' frame boundary')
+
+ first_byte = ((frame.fin << 7) | (frame.rsv1 << 6) |
+ (frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode)
+ return chr(first_byte) + frame.payload
def _write_inner_frame(self, opcode, payload, end=True):
payload_length = len(payload)
@@ -730,14 +866,36 @@ class _LogicalStream(Stream):
# multiplexing control blocks can be inserted between fragmented
# inner frames on the physical channel.
self._write_inner_frame_semaphore.acquire()
+
+ # Consume an octet quota when this is the first fragmented frame.
+ if opcode != common.OPCODE_CONTINUATION:
+ try:
+ self._send_condition.acquire()
+ while (not self._send_closed) and self._send_quota == 0:
+ self._send_condition.wait()
+
+ if self._send_closed:
+ raise BadOperationException(
+ 'Logical connection %d is closed' %
+ self._request.channel_id)
+
+ self._send_quota -= 1
+ finally:
+ self._send_condition.release()
+
while write_position < payload_length:
try:
- self._send_quota_condition.acquire()
- while self._send_quota == 0:
+ self._send_condition.acquire()
+ while (not self._send_closed) and self._send_quota == 0:
self._logger.debug(
'No quota. Waiting FlowControl message for %d.' %
self._request.channel_id)
- self._send_quota_condition.wait()
+ self._send_condition.wait()
+
+ if self._send_closed:
+ raise BadOperationException(
+ 'Logical connection %d is closed' %
+ self.request._channel_id)
remaining = payload_length - write_position
write_length = min(self._send_quota, remaining)
@@ -749,18 +907,16 @@ class _LogicalStream(Stream):
opcode,
payload[write_position:write_position+write_length],
inner_frame_end)
- frame_data = self._writer.build(
- inner_frame, end=True, binary=True)
self._send_quota -= write_length
self._logger.debug('Consumed quota=%d, remaining=%d' %
(write_length, self._send_quota))
finally:
- self._send_quota_condition.release()
+ self._send_condition.release()
# Writing data will block the worker so we need to release
- # _send_quota_condition before writing.
- self._logger.debug('Sending inner frame: %r' % frame_data)
- self._request.connection.write(frame_data)
+ # _send_condition before writing.
+ self._logger.debug('Sending inner frame: %r' % inner_frame)
+ self._request.connection.write(inner_frame)
write_position += write_length
opcode = common.OPCODE_CONTINUATION
@@ -773,12 +929,18 @@ class _LogicalStream(Stream):
def replenish_send_quota(self, send_quota):
"""Replenish send quota."""
- self._send_quota_condition.acquire()
- self._send_quota += send_quota
- self._logger.debug('Replenished send quota for channel id %d: %d' %
- (self._request.channel_id, self._send_quota))
- self._send_quota_condition.notify()
- self._send_quota_condition.release()
+ try:
+ self._send_condition.acquire()
+ if self._send_quota + send_quota > 0x7FFFFFFFFFFFFFFF:
+ self._send_quota = 0
+ raise LogicalChannelError(
+ self._request.channel_id, _DROP_CODE_SEND_QUOTA_OVERFLOW)
+ self._send_quota += send_quota
+ self._logger.debug('Replenished send quota for channel id %d: %d' %
+ (self._request.channel_id, self._send_quota))
+ finally:
+ self._send_condition.notify()
+ self._send_condition.release()
def consume_receive_quota(self, amount):
"""Consumes receive quota. Returns False on failure."""
@@ -808,7 +970,19 @@ class _LogicalStream(Stream):
opcode = common.OPCODE_TEXT
message = message.encode('utf-8')
+ for message_filter in self._options.outgoing_message_filters:
+ message = message_filter.filter(message, end, binary)
+
+ if self._last_message_was_fragmented:
+ if opcode != self._message_opcode:
+ raise BadOperationException('Message types are different in '
+ 'frames for the same message')
+ opcode = common.OPCODE_CONTINUATION
+ else:
+ self._message_opcode = opcode
+
self._write_inner_frame(opcode, message, end)
+ self._last_message_was_fragmented = not end
def _receive_frame(self):
"""Overrides Stream._receive_frame.
@@ -821,6 +995,9 @@ class _LogicalStream(Stream):
opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self)
amount = len(payload)
+ # Replenish extra one octet when receiving the first fragmented frame.
+ if opcode != common.OPCODE_CONTINUATION:
+ amount += 1
self._receive_quota += amount
frame_data = _create_flow_control(self._request.channel_id,
amount)
@@ -829,6 +1006,21 @@ class _LogicalStream(Stream):
self._request.connection.write_control_data(frame_data)
return opcode, payload, fin, rsv1, rsv2, rsv3
+ def _get_message_from_frame(self, frame):
+ """Overrides Stream._get_message_from_frame.
+ """
+
+ try:
+ inner_message = self._inner_message_builder.build(frame)
+ except InvalidFrameException:
+ raise LogicalChannelError(
+ self._request.channel_id, _DROP_CODE_BAD_FRAGMENTATION)
+
+ if inner_message is None:
+ return None
+ self._original_opcode = inner_message.opcode
+ return inner_message.payload
+
def receive_message(self):
"""Overrides Stream.receive_message."""
@@ -882,6 +1074,14 @@ class _LogicalStream(Stream):
pass
+ def stop_sending(self):
+ """Stops accepting new send operation (_write_inner_frame)."""
+
+ self._send_condition.acquire()
+ self._send_closed = True
+ self._send_condition.notify()
+ self._send_condition.release()
+
class _OutgoingData(object):
"""A structure that holds data to be sent via physical connection and
@@ -911,8 +1111,17 @@ class _PhysicalConnectionWriter(threading.Thread):
self._logger = util.get_class_logger(self)
self._mux_handler = mux_handler
self.setDaemon(True)
+
+ # When set, make this thread stop accepting new data, flush pending
+ # data and exit.
self._stop_requested = False
+ # The close code of the physical connection.
+ self._close_code = common.STATUS_NORMAL_CLOSURE
+ # Deque for passing write data. It's protected by _deque_condition
+ # until _stop_requested is set.
self._deque = collections.deque()
+ # - Protects _deque, _stop_requested and _close_code
+ # - Signals threads waiting for them to be available
self._deque_condition = threading.Condition()
def put_outgoing_data(self, data):
@@ -937,8 +1146,11 @@ class _PhysicalConnectionWriter(threading.Thread):
self._deque_condition.release()
def _write_data(self, outgoing_data):
+ message = (_encode_channel_id(outgoing_data.channel_id) +
+ outgoing_data.data)
try:
- self._mux_handler.physical_connection.write(outgoing_data.data)
+ self._mux_handler.physical_stream.send_message(
+ message=message, end=True, binary=True)
except Exception, e:
util.prepend_message_to_exception(
'Failed to send message to %r: ' %
@@ -948,33 +1160,51 @@ class _PhysicalConnectionWriter(threading.Thread):
# TODO(bashi): It would be better to block the thread that sends
# control data as well.
if outgoing_data.channel_id != _CONTROL_CHANNEL_ID:
- self._mux_handler.notify_write_done(outgoing_data.channel_id)
+ self._mux_handler.notify_write_data_done(outgoing_data.channel_id)
def run(self):
- self._deque_condition.acquire()
- while not self._stop_requested:
- if len(self._deque) == 0:
- self._deque_condition.wait()
- continue
-
- outgoing_data = self._deque.popleft()
- self._deque_condition.release()
- self._write_data(outgoing_data)
+ try:
self._deque_condition.acquire()
+ while not self._stop_requested:
+ if len(self._deque) == 0:
+ self._deque_condition.wait()
+ continue
- # Flush deque
- try:
- while len(self._deque) > 0:
outgoing_data = self._deque.popleft()
+
+ self._deque_condition.release()
self._write_data(outgoing_data)
+ self._deque_condition.acquire()
+
+ # Flush deque.
+ #
+ # At this point, self._deque_condition is always acquired.
+ try:
+ while len(self._deque) > 0:
+ outgoing_data = self._deque.popleft()
+ self._write_data(outgoing_data)
+ finally:
+ self._deque_condition.release()
+
+ # Close physical connection.
+ try:
+ # Don't wait the response here. The response will be read
+ # by the reader thread.
+ self._mux_handler.physical_stream.close_connection(
+ self._close_code, wait_response=False)
+ except Exception, e:
+ util.prepend_message_to_exception(
+ 'Failed to close the physical connection: %r' % e)
+ raise
finally:
- self._deque_condition.release()
+ self._mux_handler.notify_writer_done()
- def stop(self):
+ def stop(self, close_code=common.STATUS_NORMAL_CLOSURE):
"""Stops the writer thread."""
self._deque_condition.acquire()
self._stop_requested = True
+ self._close_code = close_code
self._deque_condition.notify()
self._deque_condition.release()
@@ -1055,6 +1285,9 @@ class _Worker(threading.Thread):
try:
# Non-critical exceptions will be handled by dispatcher.
self._mux_handler.dispatcher.transfer_data(self._request)
+ except LogicalChannelError, e:
+ self._mux_handler.fail_logical_channel(
+ e.channel_id, e.drop_code, e.message)
finally:
self._mux_handler.notify_worker_done(self._request.channel_id)
@@ -1083,8 +1316,6 @@ class _MuxHandshaker(hybi.Handshaker):
# these headers are included already.
request.headers_in[common.UPGRADE_HEADER] = (
common.WEBSOCKET_UPGRADE_TYPE)
- request.headers_in[common.CONNECTION_HEADER] = (
- common.UPGRADE_CONNECTION_TYPE)
request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = (
str(common.VERSION_HYBI_LATEST))
request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = (
@@ -1095,8 +1326,9 @@ class _MuxHandshaker(hybi.Handshaker):
self._logger.debug('Creating logical stream for %d' %
self._request.channel_id)
- return _LogicalStream(self._request, self._send_quota,
- self._receive_quota)
+ return _LogicalStream(
+ self._request, stream_options, self._send_quota,
+ self._receive_quota)
def _create_handshake_response(self, accept):
"""Override hybi._create_handshake_response."""
@@ -1105,7 +1337,9 @@ class _MuxHandshaker(hybi.Handshaker):
response.append('HTTP/1.1 101 Switching Protocols\r\n')
- # Upgrade, Connection and Sec-WebSocket-Accept should be excluded.
+ # Upgrade and Sec-WebSocket-Accept should be excluded.
+ response.append('%s: %s\r\n' % (
+ common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
if self._request.ws_protocol is not None:
response.append('%s: %s\r\n' % (
common.SEC_WEBSOCKET_PROTOCOL_HEADER,
@@ -1169,8 +1403,6 @@ class _HandshakeDeltaBase(object):
del headers[key]
else:
headers[key] = value
- # TODO(bashi): Support extensions
- headers['Sec-WebSocket-Extensions'] = ''
return headers
@@ -1232,8 +1464,12 @@ class _MuxHandler(object):
# Create "Implicitly Opened Connection".
logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID)
- self._handshake_base = _HandshakeDeltaBase(
- self.original_request.headers_in)
+ headers = copy.copy(self.original_request.headers_in)
+ # Add extensions for logical channel.
+ headers[common.SEC_WEBSOCKET_EXTENSIONS_HEADER] = (
+ common.format_extensions(
+ self.original_request.mux_processor.extensions()))
+ self._handshake_base = _HandshakeDeltaBase(headers)
logical_request = _LogicalRequest(
_DEFAULT_CHANNEL_ID,
self.original_request.method,
@@ -1245,8 +1481,9 @@ class _MuxHandler(object):
# but we will send FlowControl later so set the initial quota to
# _INITIAL_QUOTA_FOR_CLIENT.
self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT)
+ send_quota = self.original_request.mux_processor.quota()
if not self._do_handshake_for_logical_request(
- logical_request, send_quota=self.original_request.mux_quota):
+ logical_request, send_quota=send_quota):
raise MuxUnexpectedException(
'Failed handshake on the default channel id')
self._add_logical_channel(logical_request)
@@ -1287,7 +1524,6 @@ class _MuxHandler(object):
if not self._worker_done_notify_received:
self._logger.debug('Waiting worker(s) timed out')
return False
-
finally:
self._logical_channels_condition.release()
@@ -1297,7 +1533,7 @@ class _MuxHandler(object):
return True
- def notify_write_done(self, channel_id):
+ def notify_write_data_done(self, channel_id):
"""Called by the writer thread when a write operation has done.
Args:
@@ -1308,7 +1544,7 @@ class _MuxHandler(object):
self._logical_channels_condition.acquire()
if channel_id in self._logical_channels:
channel_data = self._logical_channels[channel_id]
- channel_data.request.connection.notify_write_done()
+ channel_data.request.connection.on_write_data_done()
else:
self._logger.debug('Seems that logical channel for %d has gone'
% channel_id)
@@ -1469,9 +1705,11 @@ class _MuxHandler(object):
return
channel_data = self._logical_channels[block.channel_id]
channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED
+
# Close the logical channel
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
+ channel_data.request.ws_stream.stop_sending()
finally:
self._logical_channels_condition.release()
@@ -1506,8 +1744,11 @@ class _MuxHandler(object):
return
channel_data = self._logical_channels[channel_id]
fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame()
+ consuming_byte = len(payload)
+ if opcode != common.OPCODE_CONTINUATION:
+ consuming_byte += 1
if not channel_data.request.ws_stream.consume_receive_quota(
- len(payload)):
+ consuming_byte):
# The client violates quota. Close logical channel.
raise LogicalChannelError(
channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION)
@@ -1569,15 +1810,32 @@ class _MuxHandler(object):
finished.
"""
- # Terminate all logical connections
- self._logger.debug('termiating all logical connections...')
+ self._logger.debug(
+ 'Termiating all logical connections waiting for incoming data '
+ '...')
self._logical_channels_condition.acquire()
for channel_data in self._logical_channels.values():
try:
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
except Exception:
- pass
+ self._logger.debug(traceback.format_exc())
+ self._logical_channels_condition.release()
+
+ def notify_writer_done(self):
+ """This method is called by the writer thread when the writer has
+ finished.
+ """
+
+ self._logger.debug(
+ 'Termiating all logical connections waiting for write '
+ 'completion ...')
+ self._logical_channels_condition.acquire()
+ for channel_data in self._logical_channels.values():
+ try:
+ channel_data.request.connection.on_writer_done()
+ except Exception:
+ self._logger.debug(traceback.format_exc())
self._logical_channels_condition.release()
def fail_physical_connection(self, code, message):
@@ -1590,8 +1848,7 @@ class _MuxHandler(object):
self._logger.debug('Failing the physical connection...')
self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message)
- self.physical_stream.close_connection(
- common.STATUS_INTERNAL_ENDPOINT_ERROR)
+ self._writer.stop(common.STATUS_INTERNAL_ENDPOINT_ERROR)
def fail_logical_channel(self, channel_id, code, message):
"""Fail a logical channel.
@@ -1611,8 +1868,10 @@ class _MuxHandler(object):
# called later and it will send DropChannel.
channel_data.drop_code = code
channel_data.drop_message = message
+
channel_data.request.connection.set_read_state(
_LogicalConnection.STATE_TERMINATED)
+ channel_data.request.ws_stream.stop_sending()
else:
self._send_drop_channel(channel_id, code, message)
finally:
@@ -1620,7 +1879,8 @@ class _MuxHandler(object):
def use_mux(request):
- return hasattr(request, 'mux') and request.mux
+ return hasattr(request, 'mux_processor') and (
+ request.mux_processor.is_active())
def start(request, dispatcher):
diff --git a/pyload/lib/mod_pywebsocket/standalone.py b/pyload/lib/mod_pywebsocket/standalone.py
index 07a33d9c9..e9f083753 100755
--- a/pyload/lib/mod_pywebsocket/standalone.py
+++ b/pyload/lib/mod_pywebsocket/standalone.py
@@ -76,6 +76,9 @@ SUPPORTING TLS
To support TLS, run standalone.py with -t, -k, and -c options.
+Note that when ssl module is used and the key/cert location is incorrect,
+TLS connection silently fails while pyOpenSSL fails on startup.
+
SUPPORTING CLIENT AUTHENTICATION
@@ -140,18 +143,6 @@ import sys
import threading
import time
-_HAS_SSL = False
-_HAS_OPEN_SSL = False
-try:
- import ssl
- _HAS_SSL = True
-except ImportError:
- try:
- import OpenSSL.SSL
- _HAS_OPEN_SSL = True
- except ImportError:
- pass
-
from mod_pywebsocket import common
from mod_pywebsocket import dispatch
from mod_pywebsocket import handshake
@@ -168,6 +159,10 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128
# 1024 is practically large enough to contain WebSocket handshake lines.
_MAX_MEMORIZED_LINES = 1024
+# Constants for the --tls_module flag.
+_TLS_BY_STANDARD_MODULE = 'ssl'
+_TLS_BY_PYOPENSSL = 'pyopenssl'
+
class _StandaloneConnection(object):
"""Mimic mod_python mp_conn."""
@@ -231,11 +226,23 @@ class _StandaloneRequest(object):
self.headers_in = request_handler.headers
def get_uri(self):
- """Getter to mimic request.uri."""
+ """Getter to mimic request.uri.
+
+ This method returns the raw data at the Request-URI part of the
+ Request-Line, while the uri method on the request object of mod_python
+ returns the path portion after parsing the raw data. This behavior is
+ kept for compatibility.
+ """
return self._request_handler.path
uri = property(get_uri)
+ def get_unparsed_uri(self):
+ """Getter to mimic request.unparsed_uri."""
+
+ return self._request_handler.path
+ unparsed_uri = property(get_unparsed_uri)
+
def get_method(self):
"""Getter to mimic request.method."""
@@ -266,26 +273,67 @@ class _StandaloneRequest(object):
'Drained data following close frame: %r', drained_data)
+def _import_ssl():
+ global ssl
+ try:
+ import ssl
+ return True
+ except ImportError:
+ return False
+
+
+def _import_pyopenssl():
+ global OpenSSL
+ try:
+ import OpenSSL.SSL
+ return True
+ except ImportError:
+ return False
+
+
class _StandaloneSSLConnection(object):
- """A wrapper class for OpenSSL.SSL.Connection to provide makefile method
- which is not supported by the class.
+ """A wrapper class for OpenSSL.SSL.Connection to
+ - provide makefile method which is not supported by the class
+ - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't
+ accept the "how" argument.
+ - convert SysCallError exceptions that its recv method may raise into a
+ return value of '', meaning EOF. We cannot overwrite the recv method on
+ self._connection since it's immutable.
"""
+ _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv']
+
def __init__(self, connection):
self._connection = connection
def __getattribute__(self, name):
- if name in ('_connection', 'makefile'):
+ if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:
return object.__getattribute__(self, name)
return self._connection.__getattribute__(name)
def __setattr__(self, name, value):
- if name in ('_connection', 'makefile'):
+ if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:
return object.__setattr__(self, name, value)
return self._connection.__setattr__(name, value)
def makefile(self, mode='r', bufsize=-1):
- return socket._fileobject(self._connection, mode, bufsize)
+ return socket._fileobject(self, mode, bufsize)
+
+ def shutdown(self, unused_how):
+ self._connection.shutdown()
+
+ def recv(self, bufsize, flags=0):
+ if flags != 0:
+ raise ValueError('Non-zero flags not allowed')
+
+ try:
+ return self._connection.recv(bufsize)
+ except OpenSSL.SSL.SysCallError, (err, message):
+ if err == -1:
+ # Suppress "unexpected EOF" exception. See the OpenSSL document
+ # for SSL_get_error.
+ return ''
+ raise
def _alias_handlers(dispatcher, websock_handlers_map_file):
@@ -340,7 +388,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
warnings = options.dispatcher.source_warnings()
if warnings:
for warning in warnings:
- logging.warning('mod_pywebsocket: %s' % warning)
+ logging.warning('Warning in source loading: %s' % warning)
self._logger = util.get_class_logger(self)
@@ -387,25 +435,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
except Exception, e:
self._logger.info('Skip by failure: %r', e)
continue
- if self.websocket_server_options.use_tls:
- if _HAS_SSL:
- if self.websocket_server_options.tls_client_auth:
- client_cert_ = ssl.CERT_REQUIRED
+ server_options = self.websocket_server_options
+ if server_options.use_tls:
+ # For the case of _HAS_OPEN_SSL, we do wrapper setup after
+ # accept.
+ if server_options.tls_module == _TLS_BY_STANDARD_MODULE:
+ if server_options.tls_client_auth:
+ if server_options.tls_client_cert_optional:
+ client_cert_ = ssl.CERT_OPTIONAL
+ else:
+ client_cert_ = ssl.CERT_REQUIRED
else:
client_cert_ = ssl.CERT_NONE
socket_ = ssl.wrap_socket(socket_,
- keyfile=self.websocket_server_options.private_key,
- certfile=self.websocket_server_options.certificate,
+ keyfile=server_options.private_key,
+ certfile=server_options.certificate,
ssl_version=ssl.PROTOCOL_SSLv23,
- ca_certs=self.websocket_server_options.tls_client_ca,
- cert_reqs=client_cert_)
- if _HAS_OPEN_SSL:
- ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
- ctx.use_privatekey_file(
- self.websocket_server_options.private_key)
- ctx.use_certificate_file(
- self.websocket_server_options.certificate)
- socket_ = OpenSSL.SSL.Connection(ctx, socket_)
+ ca_certs=server_options.tls_client_ca,
+ cert_reqs=client_cert_,
+ do_handshake_on_connect=False)
self._sockets.append((socket_, addrinfo))
def server_bind(self):
@@ -479,7 +527,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
self._logger.critical('Not supported: fileno')
return self._sockets[0][0].fileno()
- def handle_error(self, rquest, client_address):
+ def handle_error(self, request, client_address):
"""Override SocketServer.handle_error."""
self._logger.error(
@@ -496,8 +544,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
"""
accepted_socket, client_address = self.socket.accept()
- if self.websocket_server_options.use_tls and _HAS_OPEN_SSL:
- accepted_socket = _StandaloneSSLConnection(accepted_socket)
+
+ server_options = self.websocket_server_options
+ if server_options.use_tls:
+ if server_options.tls_module == _TLS_BY_STANDARD_MODULE:
+ try:
+ accepted_socket.do_handshake()
+ except ssl.SSLError, e:
+ self._logger.debug('%r', e)
+ raise
+
+ # Print cipher in use. Handshake is done on accept.
+ self._logger.debug('Cipher: %s', accepted_socket.cipher())
+ self._logger.debug('Client cert: %r',
+ accepted_socket.getpeercert())
+ elif server_options.tls_module == _TLS_BY_PYOPENSSL:
+ # We cannot print the cipher in use. pyOpenSSL doesn't provide
+ # any method to fetch that.
+
+ ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+ ctx.use_privatekey_file(server_options.private_key)
+ ctx.use_certificate_file(server_options.certificate)
+
+ def default_callback(conn, cert, errnum, errdepth, ok):
+ return ok == 1
+
+ # See the OpenSSL document for SSL_CTX_set_verify.
+ if server_options.tls_client_auth:
+ verify_mode = OpenSSL.SSL.VERIFY_PEER
+ if not server_options.tls_client_cert_optional:
+ verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
+ ctx.set_verify(verify_mode, default_callback)
+ ctx.load_verify_locations(server_options.tls_client_ca,
+ None)
+ else:
+ ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback)
+
+ accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket)
+ accepted_socket.set_accept_state()
+
+ # Convert SSL related error into socket.error so that
+ # SocketServer ignores them and keeps running.
+ #
+ # TODO(tyoshino): Convert all kinds of errors.
+ try:
+ accepted_socket.do_handshake()
+ except OpenSSL.SSL.Error, e:
+ # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module
+ # does.
+ self._logger.debug('%r', e)
+ raise socket.error(1, '%r' % e)
+ cert = accepted_socket.get_peer_certificate()
+ self._logger.debug('Client cert subject: %r',
+ cert.get_subject().get_components())
+ accepted_socket = _StandaloneSSLConnection(accepted_socket)
+ else:
+ raise ValueError('No TLS support module is available')
+
return accepted_socket, client_address
def serve_forever(self, poll_interval=0.5):
@@ -636,7 +739,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
self._logger.info('Fallback to CGIHTTPRequestHandler')
return True
except dispatch.DispatchException, e:
- self._logger.info('%s', e)
+ self._logger.info('Dispatch failed for error: %s', e)
self.send_error(e.status)
return False
@@ -652,7 +755,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
allowDraft75=self._options.allow_draft75,
strict=self._options.strict)
except handshake.VersionException, e:
- self._logger.info('%s', e)
+ self._logger.info('Handshake failed for version error: %s', e)
self.send_response(common.HTTP_STATUS_BAD_REQUEST)
self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
@@ -660,14 +763,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
return False
except handshake.HandshakeException, e:
# Handshake for ws(s) failed.
- self._logger.info('%s', e)
+ self._logger.info('Handshake failed for error: %s', e)
self.send_error(e.status)
return False
request._dispatcher = self._options.dispatcher
self._options.dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e:
- self._logger.info('%s', e)
+ self._logger.info('Aborted: %s', e)
return False
def log_request(self, code='-', size='-'):
@@ -799,6 +902,12 @@ def _build_option_parser():
'as CGI programs. Must be executable.'))
parser.add_option('-t', '--tls', dest='use_tls', action='store_true',
default=False, help='use TLS (wss://)')
+ parser.add_option('--tls-module', '--tls_module', dest='tls_module',
+ type='choice',
+ choices = [_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL],
+ help='Use ssl module if "%s" is specified. '
+ 'Use pyOpenSSL module if "%s" is specified' %
+ (_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL))
parser.add_option('-k', '--private-key', '--private_key',
dest='private_key',
default='', help='TLS private key file.')
@@ -806,7 +915,12 @@ def _build_option_parser():
default='', help='TLS certificate file.')
parser.add_option('--tls-client-auth', dest='tls_client_auth',
action='store_true', default=False,
- help='Requires TLS client auth on every connection.')
+ help='Requests TLS client auth on every connection.')
+ parser.add_option('--tls-client-cert-optional',
+ dest='tls_client_cert_optional',
+ action='store_true', default=False,
+ help=('Makes client certificate optional even though '
+ 'TLS client auth is enabled.'))
parser.add_option('--tls-client-ca', dest='tls_client_ca', default='',
help=('Specifies a pem file which contains a set of '
'concatenated CA certificates which are used to '
@@ -933,6 +1047,12 @@ def _main(args=None):
_configure_logging(options)
+ if options.allow_draft75:
+ logging.warning('--allow_draft75 option is obsolete.')
+
+ if options.strict:
+ logging.warning('--strict option is obsolete.')
+
# TODO(tyoshino): Clean up initialization of CGI related values. Move some
# of code here to WebSocketRequestHandler class if it's better.
options.cgi_directories = []
@@ -955,20 +1075,53 @@ def _main(args=None):
options.is_executable_method = __check_script
if options.use_tls:
- if not (_HAS_SSL or _HAS_OPEN_SSL):
- logging.critical('TLS support requires ssl or pyOpenSSL module.')
+ if options.tls_module is None:
+ if _import_ssl():
+ options.tls_module = _TLS_BY_STANDARD_MODULE
+ logging.debug('Using ssl module')
+ elif _import_pyopenssl():
+ options.tls_module = _TLS_BY_PYOPENSSL
+ logging.debug('Using pyOpenSSL module')
+ else:
+ logging.critical(
+ 'TLS support requires ssl or pyOpenSSL module.')
+ sys.exit(1)
+ elif options.tls_module == _TLS_BY_STANDARD_MODULE:
+ if not _import_ssl():
+ logging.critical('ssl module is not available')
+ sys.exit(1)
+ elif options.tls_module == _TLS_BY_PYOPENSSL:
+ if not _import_pyopenssl():
+ logging.critical('pyOpenSSL module is not available')
+ sys.exit(1)
+ else:
+ logging.critical('Invalid --tls-module option: %r',
+ options.tls_module)
sys.exit(1)
+
if not options.private_key or not options.certificate:
logging.critical(
'To use TLS, specify private_key and certificate.')
sys.exit(1)
- if options.tls_client_auth:
- if not options.use_tls:
+ if (options.tls_client_cert_optional and
+ not options.tls_client_auth):
+ logging.critical('Client authentication must be enabled to '
+ 'specify tls_client_cert_optional')
+ sys.exit(1)
+ else:
+ if options.tls_module is not None:
+ logging.critical('Use --tls-module option only together with '
+ '--use-tls option.')
+ sys.exit(1)
+
+ if options.tls_client_auth:
+ logging.critical('TLS must be enabled for client authentication.')
+ sys.exit(1)
+
+ if options.tls_client_cert_optional:
logging.critical('TLS must be enabled for client authentication.')
sys.exit(1)
- if not _HAS_SSL:
- logging.critical('Client authentication requires ssl module.')
if not options.scan_dir:
options.scan_dir = options.websock_handlers
diff --git a/pyload/lib/mod_pywebsocket/util.py b/pyload/lib/mod_pywebsocket/util.py
index 7bb0b5d9e..fc8451be7 100644
--- a/pyload/lib/mod_pywebsocket/util.py
+++ b/pyload/lib/mod_pywebsocket/util.py
@@ -56,6 +56,11 @@ import socket
import traceback
import zlib
+try:
+ from mod_pywebsocket import fast_masking
+except ImportError:
+ pass
+
def get_stack_trace():
"""Get the current stack trace as string.
@@ -169,26 +174,40 @@ class RepeatedXorMasker(object):
ended and resumes from that point on the next mask method call.
"""
- def __init__(self, mask):
- self._mask = map(ord, mask)
- self._mask_size = len(self._mask)
- self._count = 0
+ def __init__(self, masking_key):
+ self._masking_key = masking_key
+ self._masking_key_index = 0
- def mask(self, s):
+ def _mask_using_swig(self, s):
+ masked_data = fast_masking.mask(
+ s, self._masking_key, self._masking_key_index)
+ self._masking_key_index = (
+ (self._masking_key_index + len(s)) % len(self._masking_key))
+ return masked_data
+
+ def _mask_using_array(self, s):
result = array.array('B')
result.fromstring(s)
+
# Use temporary local variables to eliminate the cost to access
# attributes
- count = self._count
- mask = self._mask
- mask_size = self._mask_size
+ masking_key = map(ord, self._masking_key)
+ masking_key_size = len(masking_key)
+ masking_key_index = self._masking_key_index
+
for i in xrange(len(result)):
- result[i] ^= mask[count]
- count = (count + 1) % mask_size
- self._count = count
+ result[i] ^= masking_key[masking_key_index]
+ masking_key_index = (masking_key_index + 1) % masking_key_size
+
+ self._masking_key_index = masking_key_index
return result.tostring()
+ if 'fast_masking' in globals():
+ mask = _mask_using_swig
+ else:
+ mask = _mask_using_array
+
class DeflateRequest(object):
"""A wrapper class for request object to intercept send and recv to perform
@@ -252,6 +271,7 @@ class _Deflater(object):
self._logger.debug('Compress result %r', compressed_bytes)
return compressed_bytes
+
class _Inflater(object):
def __init__(self):
@@ -346,6 +366,7 @@ class _RFC1979Deflater(object):
return self._deflater.compress_and_flush(bytes)[:-4]
return self._deflater.compress(bytes)
+
class _RFC1979Inflater(object):
"""A decompressor class for byte sequence compressed and flushed following
the algorithm described in the RFC1979 section 2.1.
diff --git a/pyload/remote/WebSocketBackend.py b/pyload/remote/WebSocketBackend.py
index 7238af679..55edee50e 100644
--- a/pyload/remote/WebSocketBackend.py
+++ b/pyload/remote/WebSocketBackend.py
@@ -48,10 +48,11 @@ class WebSocketBackend(BackendBase):
# tls is needed when requested or webUI is also on tls
if self.core.api.isWSSecure():
from wsbackend.Server import import_ssl
- if import_ssl():
+ tls_module = import_ssl()
+ if tls_module:
options.use_tls = True
+ options.tls_module = tls_module
options.certificate = self.core.config['ssl']['cert']
- options.ca_certificate = options.certificate
options.private_key = self.core.config['ssl']['key']
self.core.log.info(_('Using secure WebSocket'))
else:
diff --git a/pyload/remote/wsbackend/Server.py b/pyload/remote/wsbackend/Server.py
index 9a6649ca9..3ffe198eb 100644
--- a/pyload/remote/wsbackend/Server.py
+++ b/pyload/remote/wsbackend/Server.py
@@ -37,6 +37,7 @@
import BaseHTTPServer
import CGIHTTPServer
import SocketServer
+import base64
import httplib
import logging
import os
@@ -46,9 +47,6 @@ import socket
import sys
import threading
-_HAS_SSL = False
-_HAS_OPEN_SSL = False
-
from mod_pywebsocket import common
from mod_pywebsocket import dispatch
from mod_pywebsocket import handshake
@@ -65,20 +63,9 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128
# 1024 is practically large enough to contain WebSocket handshake lines.
_MAX_MEMORIZED_LINES = 1024
-def import_ssl():
- global _HAS_SSL, _HAS_OPEN_SSL
- global ssl, OpenSSL
- try:
- import ssl
- _HAS_SSL = True
- except ImportError:
- try:
- import OpenSSL.SSL
- _HAS_OPEN_SSL = True
- except ImportError:
- pass
-
- return _HAS_OPEN_SSL or _HAS_SSL
+# Constants for the --tls_module flag.
+_TLS_BY_STANDARD_MODULE = 'ssl'
+_TLS_BY_PYOPENSSL = 'pyopenssl'
class _StandaloneConnection(object):
@@ -143,11 +130,23 @@ class _StandaloneRequest(object):
self.headers_in = request_handler.headers
def get_uri(self):
- """Getter to mimic request.uri."""
+ """Getter to mimic request.uri.
+
+ This method returns the raw data at the Request-URI part of the
+ Request-Line, while the uri method on the request object of mod_python
+ returns the path portion after parsing the raw data. This behavior is
+ kept for compatibility.
+ """
return self._request_handler.path
uri = property(get_uri)
+ def get_unparsed_uri(self):
+ """Getter to mimic request.unparsed_uri."""
+
+ return self._request_handler.path
+ unparsed_uri = property(get_unparsed_uri)
+
def get_method(self):
"""Getter to mimic request.method."""
@@ -178,26 +177,67 @@ class _StandaloneRequest(object):
'Drained data following close frame: %r', drained_data)
+def _import_ssl():
+ global ssl
+ try:
+ import ssl
+ return True
+ except ImportError:
+ return False
+
+
+def _import_pyopenssl():
+ global OpenSSL
+ try:
+ import OpenSSL.SSL
+ return True
+ except ImportError:
+ return False
+
+
class _StandaloneSSLConnection(object):
- """A wrapper class for OpenSSL.SSL.Connection to provide makefile method
- which is not supported by the class.
+ """A wrapper class for OpenSSL.SSL.Connection to
+ - provide makefile method which is not supported by the class
+ - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't
+ accept the "how" argument.
+ - convert SysCallError exceptions that its recv method may raise into a
+ return value of '', meaning EOF. We cannot overwrite the recv method on
+ self._connection since it's immutable.
"""
+ _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv']
+
def __init__(self, connection):
self._connection = connection
def __getattribute__(self, name):
- if name in ('_connection', 'makefile'):
+ if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:
return object.__getattribute__(self, name)
return self._connection.__getattribute__(name)
def __setattr__(self, name, value):
- if name in ('_connection', 'makefile'):
+ if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES:
return object.__setattr__(self, name, value)
return self._connection.__setattr__(name, value)
def makefile(self, mode='r', bufsize=-1):
- return socket._fileobject(self._connection, mode, bufsize)
+ return socket._fileobject(self, mode, bufsize)
+
+ def shutdown(self, unused_how):
+ self._connection.shutdown()
+
+ def recv(self, bufsize, flags=0):
+ if flags != 0:
+ raise ValueError('Non-zero flags not allowed')
+
+ try:
+ return self._connection.recv(bufsize)
+ except OpenSSL.SSL.SysCallError, (err, message):
+ if err == -1:
+ # Suppress "unexpected EOF" exception. See the OpenSSL document
+ # for SSL_get_error.
+ return ''
+ raise
def _alias_handlers(dispatcher, websock_handlers_map_file):
@@ -284,25 +324,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
except Exception, e:
self._logger.info('Skip by failure: %r', e)
continue
- if self.websocket_server_options.use_tls:
- if _HAS_SSL:
- if self.websocket_server_options.tls_client_auth:
- client_cert_ = ssl.CERT_REQUIRED
+ server_options = self.websocket_server_options
+ if server_options.use_tls:
+ # For the case of _HAS_OPEN_SSL, we do wrapper setup after
+ # accept.
+ if server_options.tls_module == _TLS_BY_STANDARD_MODULE:
+ if server_options.tls_client_auth:
+ if server_options.tls_client_cert_optional:
+ client_cert_ = ssl.CERT_OPTIONAL
+ else:
+ client_cert_ = ssl.CERT_REQUIRED
else:
client_cert_ = ssl.CERT_NONE
socket_ = ssl.wrap_socket(socket_,
- keyfile=self.websocket_server_options.private_key,
- certfile=self.websocket_server_options.certificate,
- ssl_version=ssl.PROTOCOL_SSLv23,
- ca_certs=self.websocket_server_options.tls_client_ca,
- cert_reqs=client_cert_)
- if _HAS_OPEN_SSL:
- ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
- ctx.use_privatekey_file(
- self.websocket_server_options.private_key)
- ctx.use_certificate_file(
- self.websocket_server_options.certificate)
- socket_ = OpenSSL.SSL.Connection(ctx, socket_)
+ keyfile=server_options.private_key,
+ certfile=server_options.certificate,
+ ssl_version=ssl.PROTOCOL_SSLv23,
+ ca_certs=server_options.tls_client_ca,
+ cert_reqs=client_cert_,
+ do_handshake_on_connect=False)
self._sockets.append((socket_, addrinfo))
def server_bind(self):
@@ -375,7 +415,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
self._logger.critical('Not supported: fileno')
return self._sockets[0][0].fileno()
- def handle_error(self, rquest, client_address):
+ def handle_error(self, request, client_address):
"""Override SocketServer.handle_error."""
self._logger.error(
@@ -392,8 +432,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
"""
accepted_socket, client_address = self.socket.accept()
- if self.websocket_server_options.use_tls and _HAS_OPEN_SSL:
- accepted_socket = _StandaloneSSLConnection(accepted_socket)
+
+ server_options = self.websocket_server_options
+ if server_options.use_tls:
+ if server_options.tls_module == _TLS_BY_STANDARD_MODULE:
+ try:
+ accepted_socket.do_handshake()
+ except ssl.SSLError, e:
+ self._logger.debug('%r', e)
+ raise
+
+ # Print cipher in use. Handshake is done on accept.
+ self._logger.debug('Cipher: %s', accepted_socket.cipher())
+ self._logger.debug('Client cert: %r',
+ accepted_socket.getpeercert())
+ elif server_options.tls_module == _TLS_BY_PYOPENSSL:
+ # We cannot print the cipher in use. pyOpenSSL doesn't provide
+ # any method to fetch that.
+
+ ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
+ ctx.use_privatekey_file(server_options.private_key)
+ ctx.use_certificate_file(server_options.certificate)
+
+ def default_callback(conn, cert, errnum, errdepth, ok):
+ return ok == 1
+
+ # See the OpenSSL document for SSL_CTX_set_verify.
+ if server_options.tls_client_auth:
+ verify_mode = OpenSSL.SSL.VERIFY_PEER
+ if not server_options.tls_client_cert_optional:
+ verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT
+ ctx.set_verify(verify_mode, default_callback)
+ ctx.load_verify_locations(server_options.tls_client_ca,
+ None)
+ else:
+ ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback)
+
+ accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket)
+ accepted_socket.set_accept_state()
+
+ # Convert SSL related error into socket.error so that
+ # SocketServer ignores them and keeps running.
+ #
+ # TODO(tyoshino): Convert all kinds of errors.
+ try:
+ accepted_socket.do_handshake()
+ except OpenSSL.SSL.Error, e:
+ # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module
+ # does.
+ self._logger.debug('%r', e)
+ raise socket.error(1, '%r' % e)
+ cert = accepted_socket.get_peer_certificate()
+ self._logger.debug('Client cert subject: %r',
+ cert.get_subject().get_components())
+ accepted_socket = _StandaloneSSLConnection(accepted_socket)
+ else:
+ raise ValueError('No TLS support module is available')
+
return accepted_socket, client_address
def serve_forever(self, poll_interval=0.5):
@@ -474,8 +569,6 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
raise
self._logger.debug("WS: Broken pipe")
-
-
def parse_request(self):
"""Override BaseHTTPServer.BaseHTTPRequestHandler.parse_request.
@@ -545,7 +638,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
self._logger.info('Fallback to CGIHTTPRequestHandler')
return False
except dispatch.DispatchException, e:
- self._logger.info('%s', e)
+ self._logger.info('Dispatch failed for error: %s', e)
self.send_error(e.status)
return False
@@ -561,7 +654,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
allowDraft75=self._options.allow_draft75,
strict=self._options.strict)
except handshake.VersionException, e:
- self._logger.info('%s', e)
+ self._logger.info('Handshake failed for version error: %s', e)
self.send_response(common.HTTP_STATUS_BAD_REQUEST)
self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER,
e.supported_versions)
@@ -569,14 +662,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
return False
except handshake.HandshakeException, e:
# Handshake for ws(s) failed.
- self._logger.info('%s', e)
+ self._logger.info('Handshake failed for error: %s', e)
self.send_error(e.status)
return False
request._dispatcher = self._options.dispatcher
self._options.dispatcher.transfer_data(request)
except handshake.AbortedByUserException, e:
- self._logger.info('%s', e)
+ self._logger.info('Aborted: %s', e)
return False
def log_request(self, code='-', size='-'):
@@ -606,7 +699,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler):
if CGIHTTPServer.CGIHTTPRequestHandler.is_cgi(self):
if '..' in self.path:
return False
- # strip query parameter from request path
+ # strip query parameter from request path
resource_name = self.path.split('?', 2)[0]
# convert resource_name into real path name in filesystem.
scriptfile = self.translate_path(resource_name)
@@ -629,11 +722,11 @@ def _configure_logging(options):
logger.setLevel(logging.getLevelName(options.log_level.upper()))
if options.log_file:
handler = logging.handlers.RotatingFileHandler(
- options.log_file, 'a', options.log_max, options.log_count)
+ options.log_file, 'a', options.log_max, options.log_count)
else:
handler = logging.StreamHandler()
formatter = logging.Formatter(
- '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
+ '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
@@ -650,9 +743,10 @@ class DefaultOptions:
use_tls = False
private_key = ''
certificate = ''
- ca_certificate = ''
- tls_client_ca = ''
+ tls_client_ca = None
tls_client_auth = False
+ tls_client_cert_optional = False
+ tls_module = _TLS_BY_STANDARD_MODULE
dispatcher = None
request_queue_size = _DEFAULT_REQUEST_QUEUE_SIZE
use_basic_auth = False
@@ -664,6 +758,16 @@ class DefaultOptions:
cgi_directories = ''
is_executable_method = False
+
+def import_ssl():
+ if _import_ssl():
+ return _TLS_BY_STANDARD_MODULE
+
+ elif _import_pyopenssl():
+ return _TLS_BY_PYOPENSSL
+
+ return False
+
def _main(args=None):
"""You can call this function from your own program, but please note that
this function has some side-effects that might affect your program. For
@@ -677,6 +781,12 @@ def _main(args=None):
_configure_logging(options)
+ if options.allow_draft75:
+ logging.warning('--allow_draft75 option is obsolete.')
+
+ if options.strict:
+ logging.warning('--strict option is obsolete.')
+
# TODO(tyoshino): Clean up initialization of CGI related values. Move some
# of code here to WebSocketRequestHandler class if it's better.
options.cgi_directories = []
@@ -699,20 +809,53 @@ def _main(args=None):
options.is_executable_method = __check_script
if options.use_tls:
- if not (_HAS_SSL or _HAS_OPEN_SSL):
- logging.critical('TLS support requires ssl or pyOpenSSL module.')
+ if options.tls_module is None:
+ if _import_ssl():
+ options.tls_module = _TLS_BY_STANDARD_MODULE
+ logging.debug('Using ssl module')
+ elif _import_pyopenssl():
+ options.tls_module = _TLS_BY_PYOPENSSL
+ logging.debug('Using pyOpenSSL module')
+ else:
+ logging.critical(
+ 'TLS support requires ssl or pyOpenSSL module.')
+ sys.exit(1)
+ elif options.tls_module == _TLS_BY_STANDARD_MODULE:
+ if not _import_ssl():
+ logging.critical('ssl module is not available')
+ sys.exit(1)
+ elif options.tls_module == _TLS_BY_PYOPENSSL:
+ if not _import_pyopenssl():
+ logging.critical('pyOpenSSL module is not available')
+ sys.exit(1)
+ else:
+ logging.critical('Invalid --tls-module option: %r',
+ options.tls_module)
sys.exit(1)
+
if not options.private_key or not options.certificate:
logging.critical(
- 'To use TLS, specify private_key and certificate.')
+ 'To use TLS, specify private_key and certificate.')
+ sys.exit(1)
+
+ if (options.tls_client_cert_optional and
+ not options.tls_client_auth):
+ logging.critical('Client authentication must be enabled to '
+ 'specify tls_client_cert_optional')
+ sys.exit(1)
+ else:
+ if options.tls_module is not None:
+ logging.critical('Use --tls-module option only together with '
+ '--use-tls option.')
+ sys.exit(1)
+
+ if options.tls_client_auth:
+ logging.critical('TLS must be enabled for client authentication.')
sys.exit(1)
- if options.tls_client_auth:
- if not options.use_tls:
+ if options.tls_client_cert_optional:
logging.critical('TLS must be enabled for client authentication.')
sys.exit(1)
- if not _HAS_SSL:
- logging.critical('Client authentication requires ssl module.')
if not options.scan_dir:
options.scan_dir = options.websock_handlers