diff options
author | RaNaN <Mast3rRaNaN@hotmail.de> | 2013-12-15 14:36:38 +0100 |
---|---|---|
committer | RaNaN <Mast3rRaNaN@hotmail.de> | 2013-12-15 14:36:38 +0100 |
commit | 367172406d1382f2fdd39936bde31ebcbcd1c56f (patch) | |
tree | 6abd7b2adffdd7fbb4eb35a9171c25b0e8eee4a0 /pyload | |
parent | more options to get webUI through proxy working (diff) | |
download | pyload-367172406d1382f2fdd39936bde31ebcbcd1c56f.tar.xz |
updated pywebsocket
Diffstat (limited to 'pyload')
-rw-r--r-- | pyload/lib/mod_pywebsocket/_stream_base.py | 46 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/_stream_hybi.py | 24 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/deflate_stream_extension.py | 69 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/dispatch.py | 10 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/extensions.py | 171 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/_base.py | 72 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/hybi.py | 64 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/handshake/hybi00.py | 63 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/headerparserhandler.py | 28 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/mux.py | 470 | ||||
-rwxr-xr-x | pyload/lib/mod_pywebsocket/standalone.py | 249 | ||||
-rw-r--r-- | pyload/lib/mod_pywebsocket/util.py | 43 | ||||
-rw-r--r-- | pyload/remote/WebSocketBackend.py | 5 | ||||
-rw-r--r-- | pyload/remote/wsbackend/Server.py | 263 |
14 files changed, 1153 insertions, 424 deletions
diff --git a/pyload/lib/mod_pywebsocket/_stream_base.py b/pyload/lib/mod_pywebsocket/_stream_base.py index 60fb33d2c..8235666bb 100644 --- a/pyload/lib/mod_pywebsocket/_stream_base.py +++ b/pyload/lib/mod_pywebsocket/_stream_base.py @@ -39,6 +39,8 @@ # writing/reading. +import socket + from mod_pywebsocket import util @@ -109,20 +111,34 @@ class StreamBase(object): ConnectionTerminatedException: when read returns empty string. """ - bytes = self._request.connection.read(length) - if not bytes: + try: + read_bytes = self._request.connection.read(length) + if not read_bytes: + raise ConnectionTerminatedException( + 'Receiving %d byte failed. Peer (%r) closed connection' % + (length, (self._request.connection.remote_addr,))) + return read_bytes + except socket.error, e: + # Catch a socket.error. Because it's not a child class of the + # IOError prior to Python 2.6, we cannot omit this except clause. + # Use %s rather than %r for the exception to use human friendly + # format. + raise ConnectionTerminatedException( + 'Receiving %d byte failed. socket.error (%s) occurred' % + (length, e)) + except IOError, e: + # Also catch an IOError because mod_python throws it. raise ConnectionTerminatedException( - 'Receiving %d byte failed. Peer (%r) closed connection' % - (length, (self._request.connection.remote_addr,))) - return bytes + 'Receiving %d byte failed. IOError (%s) occurred' % + (length, e)) - def _write(self, bytes): + def _write(self, bytes_to_write): """Writes given bytes to connection. In case we catch any exception, prepends remote address to the exception message and raise again. """ try: - self._request.connection.write(bytes) + self._request.connection.write(bytes_to_write) except Exception, e: util.prepend_message_to_exception( 'Failed to send message to %r: ' % @@ -138,12 +154,12 @@ class StreamBase(object): ConnectionTerminatedException: when read returns empty string. """ - bytes = [] + read_bytes = [] while length > 0: - new_bytes = self._read(length) - bytes.append(new_bytes) - length -= len(new_bytes) - return ''.join(bytes) + new_read_bytes = self._read(length) + read_bytes.append(new_read_bytes) + length -= len(new_read_bytes) + return ''.join(read_bytes) def _read_until(self, delim_char): """Reads bytes until we encounter delim_char. The result will not @@ -153,13 +169,13 @@ class StreamBase(object): ConnectionTerminatedException: when read returns empty string. """ - bytes = [] + read_bytes = [] while True: ch = self._read(1) if ch == delim_char: break - bytes.append(ch) - return ''.join(bytes) + read_bytes.append(ch) + return ''.join(read_bytes) # vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/_stream_hybi.py b/pyload/lib/mod_pywebsocket/_stream_hybi.py index bd158fa6b..1c43249a4 100644 --- a/pyload/lib/mod_pywebsocket/_stream_hybi.py +++ b/pyload/lib/mod_pywebsocket/_stream_hybi.py @@ -280,7 +280,7 @@ def parse_frame(receive_bytes, logger=None, if logger.isEnabledFor(common.LOGLEVEL_FINE): unmask_start = time.time() - bytes = masker.mask(raw_payload_bytes) + unmasked_bytes = masker.mask(raw_payload_bytes) if logger.isEnabledFor(common.LOGLEVEL_FINE): logger.log( @@ -288,7 +288,7 @@ def parse_frame(receive_bytes, logger=None, 'Done unmasking payload data at %s MB/s', payload_length / (time.time() - unmask_start) / 1000 / 1000) - return opcode, bytes, fin, rsv1, rsv2, rsv3 + return opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3 class FragmentedFrameBuilder(object): @@ -403,9 +403,6 @@ class StreamOptions(object): self.encode_text_message_to_utf8 = True self.mask_send = False self.unmask_receive = True - # RFC6455 disallows fragmented control frames, but mux extension - # relaxes the restriction. - self.allow_fragmented_control_frame = False class Stream(StreamBase): @@ -463,10 +460,10 @@ class Stream(StreamBase): unmask_receive=self._options.unmask_receive) def _receive_frame_as_frame_object(self): - opcode, bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame() + opcode, unmasked_bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame() return Frame(fin=fin, rsv1=rsv1, rsv2=rsv2, rsv3=rsv3, - opcode=opcode, payload=bytes) + opcode=opcode, payload=unmasked_bytes) def receive_filtered_frame(self): """Receives a frame and applies frame filters and message filters. @@ -602,8 +599,7 @@ class Stream(StreamBase): else: # Start of fragmentation frame - if (not self._options.allow_fragmented_control_frame and - common.is_control_opcode(frame.opcode)): + if common.is_control_opcode(frame.opcode): raise InvalidFrameException( 'Control frames must not be fragmented') @@ -672,7 +668,7 @@ class Stream(StreamBase): reason = '' self._send_closing_handshake(code, reason) self._logger.debug( - 'Sent ack for client-initiated closing handshake ' + 'Acknowledged closing handshake initiated by the peer ' '(code=%r, reason=%r)', code, reason) def _process_ping_message(self, message): @@ -815,13 +811,15 @@ class Stream(StreamBase): self._write(frame) - def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''): + def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason='', + wait_response=True): """Closes a WebSocket connection. Args: code: Status code for close frame. If code is None, a close frame with empty body will be sent. reason: string representing close reason. + wait_response: True when caller want to wait the response. Raises: BadOperationException: when reason is specified with code None or reason is not an instance of both str and unicode. @@ -844,11 +842,11 @@ class Stream(StreamBase): self._send_closing_handshake(code, reason) self._logger.debug( - 'Sent server-initiated closing handshake (code=%r, reason=%r)', + 'Initiated closing handshake (code=%r, reason=%r)', code, reason) if (code == common.STATUS_GOING_AWAY or - code == common.STATUS_PROTOCOL_ERROR): + code == common.STATUS_PROTOCOL_ERROR) or not wait_response: # It doesn't make sense to wait for a close frame if the reason is # protocol error or that the server is going away. For some of # other reasons, it might not make sense to wait for a close frame, diff --git a/pyload/lib/mod_pywebsocket/deflate_stream_extension.py b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py new file mode 100644 index 000000000..d2ba477c4 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/deflate_stream_extension.py @@ -0,0 +1,69 @@ +# Copyright 2013, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from mod_pywebsocket import common +from mod_pywebsocket.extensions import _available_processors +from mod_pywebsocket.extensions import ExtensionProcessorInterface +from mod_pywebsocket import util + + +class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): + """WebSocket DEFLATE stream extension processor. + + Specification: + Section 9.2.1 in + http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 + """ + + def __init__(self, request): + ExtensionProcessorInterface.__init__(self, request) + self._logger = util.get_class_logger(self) + + def name(self): + return common.DEFLATE_STREAM_EXTENSION + + def _get_extension_response_internal(self): + if len(self._request.get_parameter_names()) != 0: + return None + + self._logger.debug( + 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) + + return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) + + def _setup_stream_options_internal(self, stream_options): + stream_options.deflate_stream = True + + +_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( + DeflateStreamExtensionProcessor) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/dispatch.py b/pyload/lib/mod_pywebsocket/dispatch.py index 25905f180..96c91e0c9 100644 --- a/pyload/lib/mod_pywebsocket/dispatch.py +++ b/pyload/lib/mod_pywebsocket/dispatch.py @@ -255,6 +255,9 @@ class Dispatcher(object): try: do_extra_handshake_(request) except handshake.AbortedByUserException, e: + # Re-raise to tell the caller of this function to finish this + # connection without sending any error. + self._logger.debug('%s', util.get_stack_trace()) raise except Exception, e: util.prepend_message_to_exception( @@ -294,11 +297,12 @@ class Dispatcher(object): request.ws_stream.close_connection() # Catch non-critical exceptions the handler didn't handle. except handshake.AbortedByUserException, e: - self._logger.debug('%s', e) + self._logger.debug('%s', util.get_stack_trace()) raise except msgutil.BadOperationException, e: self._logger.debug('%s', e) - request.ws_stream.close_connection(common.STATUS_ABNORMAL_CLOSURE) + request.ws_stream.close_connection( + common.STATUS_INTERNAL_ENDPOINT_ERROR) except msgutil.InvalidFrameException, e: # InvalidFrameException must be caught before # ConnectionTerminatedException that catches InvalidFrameException. @@ -314,6 +318,8 @@ class Dispatcher(object): except msgutil.ConnectionTerminatedException, e: self._logger.debug('%s', e) except Exception, e: + # Any other exceptions are forwarded to the caller of this + # function. util.prepend_message_to_exception( '%s raised exception for %s: ' % ( _TRANSFER_DATA_HANDLER_NAME, request.ws_resource), diff --git a/pyload/lib/mod_pywebsocket/extensions.py b/pyload/lib/mod_pywebsocket/extensions.py index 03dbf9ee1..18841ed92 100644 --- a/pyload/lib/mod_pywebsocket/extensions.py +++ b/pyload/lib/mod_pywebsocket/extensions.py @@ -38,47 +38,42 @@ _available_processors = {} class ExtensionProcessorInterface(object): - def name(self): - return None + def __init__(self, request): + self._request = request + self._active = True - def get_extension_response(self): + def request(self): + return self._request + + def name(self): return None - def setup_stream_options(self, stream_options): + def check_consistency_with_other_processors(self, processors): pass + def set_active(self, active): + self._active = active -class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): - """WebSocket DEFLATE stream extension processor. - - Specification: - Section 9.2.1 in - http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 - """ - - def __init__(self, request): - self._logger = util.get_class_logger(self) - - self._request = request + def is_active(self): + return self._active - def name(self): - return common.DEFLATE_STREAM_EXTENSION + def _get_extension_response_internal(self): + return None def get_extension_response(self): - if len(self._request.get_parameter_names()) != 0: - return None - - self._logger.debug( - 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) + if self._active: + response = self._get_extension_response_internal() + if response is None: + self._active = False + return response + return None - return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) + def _setup_stream_options_internal(self, stream_options): + pass def setup_stream_options(self, stream_options): - stream_options.deflate_stream = True - - -_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( - DeflateStreamExtensionProcessor) + if self._active: + self._setup_stream_options_internal(stream_options) def _log_compression_ratio(logger, original_bytes, total_original_bytes, @@ -109,6 +104,17 @@ def _log_decompression_ratio(logger, received_bytes, total_received_bytes, (ratio, average_ratio)) +def _validate_window_bits(bits): + if bits is not None: + try: + bits = int(bits) + except ValueError, e: + return False + if bits < 8 or bits > 15: + return False + return True + + class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): """WebSocket Per-frame DEFLATE extension processor. @@ -120,10 +126,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover' def __init__(self, request): + ExtensionProcessorInterface.__init__(self, request) self._logger = util.get_class_logger(self) - self._request = request - self._response_window_bits = None self._response_no_context_takeover = False self._bfinal = False @@ -143,7 +148,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): def name(self): return common.DEFLATE_FRAME_EXTENSION - def get_extension_response(self): + def _get_extension_response_internal(self): # Any unknown parameter will be just ignored. window_bits = self._request.get_parameter_value( @@ -155,13 +160,8 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): self._NO_CONTEXT_TAKEOVER_PARAM) is not None): return None - if window_bits is not None: - try: - window_bits = int(window_bits) - except ValueError, e: - return None - if window_bits < 8 or window_bits > 15: - return None + if not _validate_window_bits(window_bits): + return None self._deflater = util._RFC1979Deflater( window_bits, no_context_takeover) @@ -191,7 +191,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): return response - def setup_stream_options(self, stream_options): + def _setup_stream_options_internal(self, stream_options): class _OutgoingFilter(object): @@ -311,8 +311,8 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface): _METHOD_PARAM = 'method' def __init__(self, request): + ExtensionProcessorInterface.__init__(self, request) self._logger = util.get_class_logger(self) - self._request = request self._compression_method_name = None self._compression_processor = None self._compression_processor_hook = None @@ -357,7 +357,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface): self._compression_processor = compression_processor return processor_response - def get_extension_response(self): + def _get_extension_response_internal(self): processor_response = self._get_compression_processor_response() if processor_response is None: return None @@ -372,7 +372,7 @@ class CompressionExtensionProcessorBase(ExtensionProcessorInterface): (self._request.name(), self._compression_method_name)) return response - def setup_stream_options(self, stream_options): + def _setup_stream_options_internal(self, stream_options): if self._compression_processor is None: return self._compression_processor.setup_stream_options(stream_options) @@ -418,7 +418,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface): _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover' def __init__(self, request): - self._request = request + ExtensionProcessorInterface.__init__(self, request) self._logger = util.get_class_logger(self) self._c2s_max_window_bits = None @@ -445,18 +445,13 @@ class DeflateMessageProcessor(ExtensionProcessorInterface): def name(self): return 'deflate' - def get_extension_response(self): + def _get_extension_response_internal(self): # Any unknown parameter will be just ignored. s2c_max_window_bits = self._request.get_parameter_value( self._S2C_MAX_WINDOW_BITS_PARAM) - if s2c_max_window_bits is not None: - try: - s2c_max_window_bits = int(s2c_max_window_bits) - except ValueError, e: - return None - if s2c_max_window_bits < 8 or s2c_max_window_bits > 15: - return None + if not _validate_window_bits(s2c_max_window_bits): + return None s2c_no_context_takeover = self._request.has_parameter( self._S2C_NO_CONTEXT_TAKEOVER_PARAM) @@ -502,7 +497,7 @@ class DeflateMessageProcessor(ExtensionProcessorInterface): return response - def setup_stream_options(self, stream_options): + def _setup_stream_options_internal(self, stream_options): class _OutgoingMessageFilter(object): def __init__(self, parent): @@ -676,42 +671,72 @@ class MuxExtensionProcessor(ExtensionProcessorInterface): _QUOTA_PARAM = 'quota' def __init__(self, request): - self._request = request + ExtensionProcessorInterface.__init__(self, request) + self._quota = 0 + self._extensions = [] def name(self): return common.MUX_EXTENSION - def get_extension_response(self, ws_request, - logical_channel_extensions): - # Mux extension cannot be used after extensions that depend on - # frame boundary, extension data field, or any reserved bits - # which are attributed to each frame. - for extension in logical_channel_extensions: - name = extension.name() - if (name == common.PERFRAME_COMPRESSION_EXTENSION or - name == common.DEFLATE_FRAME_EXTENSION or - name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): - return None - + def check_consistency_with_other_processors(self, processors): + before_mux = True + for processor in processors: + name = processor.name() + if name == self.name(): + before_mux = False + continue + if not processor.is_active(): + continue + if before_mux: + # Mux extension cannot be used after extensions + # that depend on frame boundary, extension data field, or any + # reserved bits which are attributed to each frame. + if (name == common.PERFRAME_COMPRESSION_EXTENSION or + name == common.DEFLATE_FRAME_EXTENSION or + name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): + self.set_active(False) + return + else: + # Mux extension should not be applied before any history-based + # compression extension. + if (name == common.PERFRAME_COMPRESSION_EXTENSION or + name == common.DEFLATE_FRAME_EXTENSION or + name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION or + name == common.PERMESSAGE_COMPRESSION_EXTENSION or + name == common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION): + self.set_active(False) + return + + def _get_extension_response_internal(self): + self._active = False quota = self._request.get_parameter_value(self._QUOTA_PARAM) - if quota is None: - ws_request.mux_quota = 0 - else: + if quota is not None: try: quota = int(quota) except ValueError, e: return None if quota < 0 or quota >= 2 ** 32: return None - ws_request.mux_quota = quota + self._quota = quota - ws_request.mux = True - ws_request.mux_extensions = logical_channel_extensions + self._active = True return common.ExtensionParameter(common.MUX_EXTENSION) - def setup_stream_options(self, stream_options): + def _setup_stream_options_internal(self, stream_options): pass + def set_quota(self, quota): + self._quota = quota + + def quota(self): + return self._quota + + def set_extensions(self, extensions): + self._extensions = extensions + + def extensions(self): + return self._extensions + _available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor diff --git a/pyload/lib/mod_pywebsocket/handshake/_base.py b/pyload/lib/mod_pywebsocket/handshake/_base.py index e5c94ca90..c993a584b 100644 --- a/pyload/lib/mod_pywebsocket/handshake/_base.py +++ b/pyload/lib/mod_pywebsocket/handshake/_base.py @@ -84,42 +84,29 @@ def get_default_port(is_secure): return common.DEFAULT_WEB_SOCKET_PORT -def validate_subprotocol(subprotocol, hixie): +def validate_subprotocol(subprotocol): """Validate a value in the Sec-WebSocket-Protocol field. - See - - RFC 6455: Section 4.1., 4.2.2., and 4.3. - - HyBi 00: Section 4.1. Opening handshake - - Args: - hixie: if True, checks if characters in subprotocol are in range - between U+0020 and U+007E. It's required by HyBi 00 but not by - RFC 6455. + See the Section 4.1., 4.2.2., and 4.3. of RFC 6455. """ if not subprotocol: raise HandshakeException('Invalid subprotocol name: empty') - if hixie: - # Parameter should be in the range U+0020 to U+007E. - for c in subprotocol: - if not 0x20 <= ord(c) <= 0x7e: - raise HandshakeException( - 'Illegal character in subprotocol name: %r' % c) - else: - # Parameter should be encoded HTTP token. - state = http_header_util.ParsingState(subprotocol) - token = http_header_util.consume_token(state) - rest = http_header_util.peek(state) - # If |rest| is not None, |subprotocol| is not one token or invalid. If - # |rest| is None, |token| must not be None because |subprotocol| is - # concatenation of |token| and |rest| and is not None. - if rest is not None: - raise HandshakeException('Invalid non-token string in subprotocol ' - 'name: %r' % rest) + + # Parameter should be encoded HTTP token. + state = http_header_util.ParsingState(subprotocol) + token = http_header_util.consume_token(state) + rest = http_header_util.peek(state) + # If |rest| is not None, |subprotocol| is not one token or invalid. If + # |rest| is None, |token| must not be None because |subprotocol| is + # concatenation of |token| and |rest| and is not None. + if rest is not None: + raise HandshakeException('Invalid non-token string in subprotocol ' + 'name: %r' % rest) def parse_host_header(request): - fields = request.headers_in['Host'].split(':', 1) + fields = request.headers_in[common.HOST_HEADER].split(':', 1) if len(fields) == 1: return fields[0], get_default_port(request.is_https()) try: @@ -132,27 +119,6 @@ def format_header(name, value): return '%s: %s\r\n' % (name, value) -def build_location(request): - """Build WebSocket location for request.""" - location_parts = [] - if request.is_https(): - location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) - else: - location_parts.append(common.WEB_SOCKET_SCHEME) - location_parts.append('://') - host, port = parse_host_header(request) - connection_port = request.connection.local_addr[1] - if port != connection_port: - raise HandshakeException('Header/connection port mismatch: %d/%d' % - (port, connection_port)) - location_parts.append(host) - if (port != get_default_port(request.is_https())): - location_parts.append(':') - location_parts.append(str(port)) - location_parts.append(request.uri) - return ''.join(location_parts) - - def get_mandatory_header(request, key): value = request.headers_in.get(key) if value is None: @@ -180,16 +146,6 @@ def check_request_line(request): request.protocol) -def check_header_lines(request, mandatory_headers): - check_request_line(request) - - # The expected field names, and the meaning of their corresponding - # values, are as follows. - # |Upgrade| and |Connection| - for key, expected_value in mandatory_headers: - validate_mandatory_header(request, key, expected_value) - - def parse_token_list(data): """Parses a header value which follows 1#token and returns parsed elements as a list of strings. diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi.py b/pyload/lib/mod_pywebsocket/handshake/hybi.py index fc0e2a096..669097d77 100644 --- a/pyload/lib/mod_pywebsocket/handshake/hybi.py +++ b/pyload/lib/mod_pywebsocket/handshake/hybi.py @@ -48,6 +48,7 @@ import os import re from mod_pywebsocket import common +from mod_pywebsocket import deflate_stream_extension from mod_pywebsocket.extensions import get_extension_processor from mod_pywebsocket.handshake._base import check_request_line from mod_pywebsocket.handshake._base import format_header @@ -180,44 +181,57 @@ class Handshaker(object): processors.append(processor) self._request.ws_extension_processors = processors + # List of extra headers. The extra handshake handler may add header + # data as name/value pairs to this list and pywebsocket appends + # them to the WebSocket handshake. + self._request.extra_headers = [] + # Extra handshake handler may modify/remove processors. self._dispatcher.do_extra_handshake(self._request) processors = filter(lambda processor: processor is not None, self._request.ws_extension_processors) + # Ask each processor if there are extensions on the request which + # cannot co-exist. When processor decided other processors cannot + # co-exist with it, the processor marks them (or itself) as + # "inactive". The first extension processor has the right to + # make the final call. + for processor in reversed(processors): + if processor.is_active(): + processor.check_consistency_with_other_processors( + processors) + processors = filter(lambda processor: processor.is_active(), + processors) + accepted_extensions = [] - # We need to take care of mux extension here. Extensions that - # are placed before mux should be applied to logical channels. + # We need to take into account of mux extension here. + # If mux extension exists: + # - Remove processors of extensions for logical channel, + # which are processors located before the mux processor + # - Pass extension requests for logical channel to mux processor + # - Attach the mux processor to the request. It will be referred + # by dispatcher to see whether the dispatcher should use mux + # handler or not. mux_index = -1 for i, processor in enumerate(processors): if processor.name() == common.MUX_EXTENSION: mux_index = i break if mux_index >= 0: - mux_processor = processors[mux_index] - logical_channel_processors = processors[:mux_index] - processors = processors[mux_index+1:] - - for processor in logical_channel_processors: - extension_response = processor.get_extension_response() - if extension_response is None: - # Rejected. - continue - accepted_extensions.append(extension_response) - # Pass a shallow copy of accepted_extensions as extensions for - # logical channels. - mux_response = mux_processor.get_extension_response( - self._request, accepted_extensions[:]) - if mux_response is not None: - accepted_extensions.append(mux_response) + logical_channel_extensions = [] + for processor in processors[:mux_index]: + logical_channel_extensions.append(processor.request()) + processor.set_active(False) + self._request.mux_processor = processors[mux_index] + self._request.mux_processor.set_extensions( + logical_channel_extensions) + processors = filter(lambda processor: processor.is_active(), + processors) stream_options = StreamOptions() - # When there is mux extension, here, |processors| contain only - # prosessors for extensions placed after mux. for processor in processors: - extension_response = processor.get_extension_response() if extension_response is None: # Rejected. @@ -242,7 +256,7 @@ class Handshaker(object): raise HandshakeException( 'do_extra_handshake must choose one subprotocol from ' 'ws_requested_protocols and set it to ws_protocol') - validate_subprotocol(self._request.ws_protocol, hixie=False) + validate_subprotocol(self._request.ws_protocol) self._logger.debug( 'Subprotocol accepted: %r', @@ -375,6 +389,7 @@ class Handshaker(object): response.append('HTTP/1.1 101 Switching Protocols\r\n') + # WebSocket headers response.append(format_header( common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE)) response.append(format_header( @@ -390,6 +405,11 @@ class Handshaker(object): response.append(format_header( common.SEC_WEBSOCKET_EXTENSIONS_HEADER, common.format_extensions(self._request.ws_extensions))) + + # Headers not specific for WebSocket + for name, value in self._request.extra_headers: + response.append(format_header(name, value)) + response.append('\r\n') return ''.join(response) diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi00.py b/pyload/lib/mod_pywebsocket/handshake/hybi00.py index cc6f8dc43..8757717a6 100644 --- a/pyload/lib/mod_pywebsocket/handshake/hybi00.py +++ b/pyload/lib/mod_pywebsocket/handshake/hybi00.py @@ -51,11 +51,12 @@ from mod_pywebsocket import common from mod_pywebsocket.stream import StreamHixie75 from mod_pywebsocket import util from mod_pywebsocket.handshake._base import HandshakeException -from mod_pywebsocket.handshake._base import build_location -from mod_pywebsocket.handshake._base import check_header_lines +from mod_pywebsocket.handshake._base import check_request_line from mod_pywebsocket.handshake._base import format_header +from mod_pywebsocket.handshake._base import get_default_port from mod_pywebsocket.handshake._base import get_mandatory_header -from mod_pywebsocket.handshake._base import validate_subprotocol +from mod_pywebsocket.handshake._base import parse_host_header +from mod_pywebsocket.handshake._base import validate_mandatory_header _MANDATORY_HEADERS = [ @@ -65,6 +66,56 @@ _MANDATORY_HEADERS = [ ] +def _validate_subprotocol(subprotocol): + """Checks if characters in subprotocol are in range between U+0020 and + U+007E. A value in the Sec-WebSocket-Protocol field need to satisfy this + requirement. + + See the Section 4.1. Opening handshake of the spec. + """ + + if not subprotocol: + raise HandshakeException('Invalid subprotocol name: empty') + + # Parameter should be in the range U+0020 to U+007E. + for c in subprotocol: + if not 0x20 <= ord(c) <= 0x7e: + raise HandshakeException( + 'Illegal character in subprotocol name: %r' % c) + + +def _check_header_lines(request, mandatory_headers): + check_request_line(request) + + # The expected field names, and the meaning of their corresponding + # values, are as follows. + # |Upgrade| and |Connection| + for key, expected_value in mandatory_headers: + validate_mandatory_header(request, key, expected_value) + + +def _build_location(request): + """Build WebSocket location for request.""" + + location_parts = [] + if request.is_https(): + location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) + else: + location_parts.append(common.WEB_SOCKET_SCHEME) + location_parts.append('://') + host, port = parse_host_header(request) + connection_port = request.connection.local_addr[1] + if port != connection_port: + raise HandshakeException('Header/connection port mismatch: %d/%d' % + (port, connection_port)) + location_parts.append(host) + if (port != get_default_port(request.is_https())): + location_parts.append(':') + location_parts.append(str(port)) + location_parts.append(request.unparsed_uri) + return ''.join(location_parts) + + class Handshaker(object): """Opening handshake processor for the WebSocket protocol version HyBi 00. """ @@ -101,7 +152,7 @@ class Handshaker(object): # 5.1 Reading the client's opening handshake. # dispatcher sets it in self._request. - check_header_lines(self._request, _MANDATORY_HEADERS) + _check_header_lines(self._request, _MANDATORY_HEADERS) self._set_resource() self._set_subprotocol() self._set_location() @@ -121,14 +172,14 @@ class Handshaker(object): subprotocol = self._request.headers_in.get( common.SEC_WEBSOCKET_PROTOCOL_HEADER) if subprotocol is not None: - validate_subprotocol(subprotocol, hixie=True) + _validate_subprotocol(subprotocol) self._request.ws_protocol = subprotocol def _set_location(self): # |Host| host = self._request.headers_in.get(common.HOST_HEADER) if host is not None: - self._request.ws_location = build_location(self._request) + self._request.ws_location = _build_location(self._request) # TODO(ukai): check host is this host. def _set_origin(self): diff --git a/pyload/lib/mod_pywebsocket/headerparserhandler.py b/pyload/lib/mod_pywebsocket/headerparserhandler.py index 2cc62de04..c244421cf 100644 --- a/pyload/lib/mod_pywebsocket/headerparserhandler.py +++ b/pyload/lib/mod_pywebsocket/headerparserhandler.py @@ -167,7 +167,9 @@ def _create_dispatcher(): handler_root, handler_scan, allow_handlers_outside_root) for warning in dispatcher.source_warnings(): - apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING) + apache.log_error( + 'mod_pywebsocket: Warning in source loading: %s' % warning, + apache.APLOG_WARNING) return dispatcher @@ -191,12 +193,16 @@ def headerparserhandler(request): # Fallback to default http handler for request paths for which # we don't have request handlers. if not _dispatcher.get_handler_suite(request.uri): - request.log_error('No handler for resource: %r' % request.uri, - apache.APLOG_INFO) - request.log_error('Fallback to Apache', apache.APLOG_INFO) + request.log_error( + 'mod_pywebsocket: No handler for resource: %r' % request.uri, + apache.APLOG_INFO) + request.log_error( + 'mod_pywebsocket: Fallback to Apache', apache.APLOG_INFO) return apache.DECLINED except dispatch.DispatchException, e: - request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + request.log_error( + 'mod_pywebsocket: Dispatch failed for error: %s' % e, + apache.APLOG_INFO) if not handshake_is_done: return e.status @@ -210,26 +216,30 @@ def headerparserhandler(request): handshake.do_handshake( request, _dispatcher, allowDraft75=allow_draft75) except handshake.VersionException, e: - request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + request.log_error( + 'mod_pywebsocket: Handshake failed for version error: %s' % e, + apache.APLOG_INFO) request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER, e.supported_versions) return apache.HTTP_BAD_REQUEST except handshake.HandshakeException, e: # Handshake for ws/wss failed. # Send http response with error status. - request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + request.log_error( + 'mod_pywebsocket: Handshake failed for error: %s' % e, + apache.APLOG_INFO) return e.status handshake_is_done = True request._dispatcher = _dispatcher _dispatcher.transfer_data(request) except handshake.AbortedByUserException, e: - request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + request.log_error('mod_pywebsocket: Aborted: %s' % e, apache.APLOG_INFO) except Exception, e: # DispatchException can also be thrown if something is wrong in # pywebsocket code. It's caught here, then. - request.log_error('mod_pywebsocket: %s\n%s' % + request.log_error('mod_pywebsocket: Exception occurred: %s\n%s' % (e, util.get_stack_trace()), apache.APLOG_ERR) # Unknown exceptions before handshake mean Apache must handle its diff --git a/pyload/lib/mod_pywebsocket/mux.py b/pyload/lib/mod_pywebsocket/mux.py index f0bdd2461..7923fb211 100644 --- a/pyload/lib/mod_pywebsocket/mux.py +++ b/pyload/lib/mod_pywebsocket/mux.py @@ -50,6 +50,7 @@ from mod_pywebsocket import handshake from mod_pywebsocket import util from mod_pywebsocket._stream_base import BadOperationException from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException from mod_pywebsocket._stream_hybi import Frame from mod_pywebsocket._stream_hybi import Stream from mod_pywebsocket._stream_hybi import StreamOptions @@ -94,10 +95,12 @@ _DROP_CODE_UNKNOWN_MUX_OPCODE = 2004 _DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005 _DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006 _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007 +_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 2010 -_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002 _DROP_CODE_SEND_QUOTA_VIOLATION = 3005 +_DROP_CODE_SEND_QUOTA_OVERFLOW = 3006 _DROP_CODE_ACKNOWLEDGED = 3008 +_DROP_CODE_BAD_FRAGMENTATION = 3009 class MuxUnexpectedException(Exception): @@ -158,8 +161,7 @@ def _encode_number(number): def _create_add_channel_response(channel_id, encoded_handshake, - encoding=0, rejected=False, - outer_frame_mask=False): + encoding=0, rejected=False): if encoding != 0 and encoding != 1: raise ValueError('Invalid encoding %d' % encoding) @@ -169,12 +171,10 @@ def _create_add_channel_response(channel_id, encoded_handshake, _encode_channel_id(channel_id) + _encode_number(len(encoded_handshake)) + encoded_handshake) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_drop_channel(channel_id, code=None, message='', - outer_frame_mask=False): +def _create_drop_channel(channel_id, code=None, message=''): if len(message) > 0 and code is None: raise ValueError('Code must be specified if message is specified') @@ -187,36 +187,31 @@ def _create_drop_channel(channel_id, code=None, message='', reason_size = _encode_number(len(reason)) block += reason_size + reason - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_flow_control(channel_id, replenished_quota, - outer_frame_mask=False): +def _create_flow_control(channel_id, replenished_quota): first_byte = _MUX_OPCODE_FLOW_CONTROL << 5 block = (chr(first_byte) + _encode_channel_id(channel_id) + _encode_number(replenished_quota)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False): +def _create_new_channel_slot(slots, send_quota): if slots < 0 or send_quota < 0: raise ValueError('slots and send_quota must be non-negative.') first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5 block = (chr(first_byte) + _encode_number(slots) + _encode_number(send_quota)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_fallback_new_channel_slot(outer_frame_mask=False): +def _create_fallback_new_channel_slot(): first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag block = (chr(first_byte) + _encode_number(0) + _encode_number(0)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block def _parse_request_text(request_text): @@ -318,44 +313,34 @@ class _MuxFramePayloadParser(object): def _read_number(self): if self._read_position + 1 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( 'Cannot read the first byte of number field') number = ord(self._data[self._read_position]) if number & 0x80 == 0x80: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( 'The most significant bit of the first byte of number should ' 'be unset') self._read_position += 1 pos = self._read_position if number == 127: if pos + 8 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Invalid number field') + raise ValueError('Invalid number field') self._read_position += 8 number = struct.unpack('!Q', self._data[pos:pos+8])[0] if number > 0x7FFFFFFFFFFFFFFF: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Encoded number >= 2^63') + raise ValueError('Encoded number(%d) >= 2^63' % number) if number <= 0xFFFF: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( '%d should not be encoded by 9 bytes encoding' % number) return number if number == 126: if pos + 2 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Invalid number field') + raise ValueError('Invalid number field') self._read_position += 2 number = struct.unpack('!H', self._data[pos:pos+2])[0] if number <= 125: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( '%d should not be encoded by 3 bytes encoding' % number) return number @@ -366,7 +351,11 @@ class _MuxFramePayloadParser(object): - the contents. """ - size = self._read_number() + try: + size = self._read_number() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) pos = self._read_position if pos + size > len(self._data): raise PhysicalConnectionError( @@ -419,9 +408,11 @@ class _MuxFramePayloadParser(object): try: control_block.channel_id = self.read_channel_id() + control_block.send_quota = self._read_number() except ValueError, e: - raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) - control_block.send_quota = self._read_number() + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) + return control_block def _read_drop_channel(self, first_byte, control_block): @@ -455,8 +446,12 @@ class _MuxFramePayloadParser(object): _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') control_block.fallback = first_byte & 1 - control_block.slots = self._read_number() - control_block.send_quota = self._read_number() + try: + control_block.slots = self._read_number() + control_block.send_quota = self._read_number() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) return control_block def read_control_blocks(self): @@ -549,8 +544,12 @@ class _LogicalConnection(object): self._mux_handler = mux_handler self._channel_id = channel_id self._incoming_data = '' + + # - Protects _waiting_write_completion + # - Signals the thread waiting for completion of write by mux handler self._write_condition = threading.Condition() self._waiting_write_completion = False + self._read_condition = threading.Condition() self._read_state = self.STATE_ACTIVE @@ -594,6 +593,7 @@ class _LogicalConnection(object): self._waiting_write_completion = True self._mux_handler.send_data(self._channel_id, data) self._write_condition.wait() + # TODO(tyoshino): Raise an exception if woke up by on_writer_done. finally: self._write_condition.release() @@ -607,20 +607,31 @@ class _LogicalConnection(object): self._mux_handler.send_control_data(data) - def notify_write_done(self): + def on_write_data_done(self): """Called when sending data is completed.""" try: self._write_condition.acquire() if not self._waiting_write_completion: raise MuxUnexpectedException( - 'Invalid call of notify_write_done for logical connection' - ' %d' % self._channel_id) + 'Invalid call of on_write_data_done for logical ' + 'connection %d' % self._channel_id) + self._waiting_write_completion = False + self._write_condition.notify() + finally: + self._write_condition.release() + + def on_writer_done(self): + """Called by the mux handler when the writer thread has finished.""" + + try: + self._write_condition.acquire() self._waiting_write_completion = False self._write_condition.notify() finally: self._write_condition.release() + def append_frame_data(self, frame_data): """Appends incoming frame data. Called when mux_handler dispatches frame data to the corresponding application. @@ -686,37 +697,162 @@ class _LogicalConnection(object): self._read_condition.release() +class _InnerMessage(object): + """Holds the result of _InnerMessageBuilder.build(). + """ + + def __init__(self, opcode, payload): + self.opcode = opcode + self.payload = payload + + +class _InnerMessageBuilder(object): + """A class that holds the context of inner message fragmentation and + builds a message from fragmented inner frame(s). + """ + + def __init__(self): + self._control_opcode = None + self._pending_control_fragments = [] + self._message_opcode = None + self._pending_message_fragments = [] + self._frame_handler = self._handle_first + + def _handle_first(self, frame): + if frame.opcode == common.OPCODE_CONTINUATION: + raise InvalidFrameException('Sending invalid continuation opcode') + + if common.is_control_opcode(frame.opcode): + return self._process_first_fragmented_control(frame) + else: + return self._process_first_fragmented_message(frame) + + def _process_first_fragmented_control(self, frame): + self._control_opcode = frame.opcode + self._pending_control_fragments.append(frame.payload) + if not frame.fin: + self._frame_handler = self._handle_fragmented_control + return None + return self._reassemble_fragmented_control() + + def _process_first_fragmented_message(self, frame): + self._message_opcode = frame.opcode + self._pending_message_fragments.append(frame.payload) + if not frame.fin: + self._frame_handler = self._handle_fragmented_message + return None + return self._reassemble_fragmented_message() + + def _handle_fragmented_control(self, frame): + if frame.opcode != common.OPCODE_CONTINUATION: + raise InvalidFrameException( + 'Sending invalid opcode %d while sending fragmented control ' + 'message' % frame.opcode) + self._pending_control_fragments.append(frame.payload) + if not frame.fin: + return None + return self._reassemble_fragmented_control() + + def _reassemble_fragmented_control(self): + opcode = self._control_opcode + payload = ''.join(self._pending_control_fragments) + self._control_opcode = None + self._pending_control_fragments = [] + if self._message_opcode is not None: + self._frame_handler = self._handle_fragmented_message + else: + self._frame_handler = self._handle_first + return _InnerMessage(opcode, payload) + + def _handle_fragmented_message(self, frame): + # Sender can interleave a control message while sending fragmented + # messages. + if common.is_control_opcode(frame.opcode): + if self._control_opcode is not None: + raise MuxUnexpectedException( + 'Should not reach here(Bug in builder)') + return self._process_first_fragmented_control(frame) + + if frame.opcode != common.OPCODE_CONTINUATION: + raise InvalidFrameException( + 'Sending invalid opcode %d while sending fragmented message' % + frame.opcode) + self._pending_message_fragments.append(frame.payload) + if not frame.fin: + return None + return self._reassemble_fragmented_message() + + def _reassemble_fragmented_message(self): + opcode = self._message_opcode + payload = ''.join(self._pending_message_fragments) + self._message_opcode = None + self._pending_message_fragments = [] + self._frame_handler = self._handle_first + return _InnerMessage(opcode, payload) + + def build(self, frame): + """Build an inner message. Returns an _InnerMessage instance when + the given frame is the last fragmented frame. Returns None otherwise. + + Args: + frame: an inner frame. + Raises: + InvalidFrameException: when received invalid opcode. (e.g. + receiving non continuation data opcode but the fin flag of + the previous inner frame was not set.) + """ + + return self._frame_handler(frame) + + class _LogicalStream(Stream): """Mimics the Stream class. This class interprets multiplexed WebSocket frames. """ - def __init__(self, request, send_quota, receive_quota): + def __init__(self, request, stream_options, send_quota, receive_quota): """Constructs an instance. Args: request: _LogicalRequest instance. + stream_options: StreamOptions instance. send_quota: Initial send quota. receive_quota: Initial receive quota. """ - # TODO(bashi): Support frame filters. - stream_options = StreamOptions() # Physical stream is responsible for masking. stream_options.unmask_receive = False - # Control frames can be fragmented on logical channel. - stream_options.allow_fragmented_control_frame = True Stream.__init__(self, request, stream_options) + + self._send_closed = False self._send_quota = send_quota - self._send_quota_condition = threading.Condition() + # - Protects _send_closed and _send_quota + # - Signals the thread waiting for send quota replenished + self._send_condition = threading.Condition() + + # The opcode of the first frame in messages. + self._message_opcode = common.OPCODE_TEXT + # True when the last message was fragmented. + self._last_message_was_fragmented = False + self._receive_quota = receive_quota self._write_inner_frame_semaphore = threading.Semaphore() + self._inner_message_builder = _InnerMessageBuilder() + def _create_inner_frame(self, opcode, payload, end=True): - # TODO(bashi): Support extensions that use reserved bits. - first_byte = (end << 7) | opcode - return (_encode_channel_id(self._request.channel_id) + - chr(first_byte) + payload) + frame = Frame(fin=end, opcode=opcode, payload=payload) + for frame_filter in self._options.outgoing_frame_filters: + frame_filter.filter(frame) + + if len(payload) != len(frame.payload): + raise MuxUnexpectedException( + 'Mux extension must not be used after extensions which change ' + ' frame boundary') + + first_byte = ((frame.fin << 7) | (frame.rsv1 << 6) | + (frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode) + return chr(first_byte) + frame.payload def _write_inner_frame(self, opcode, payload, end=True): payload_length = len(payload) @@ -730,14 +866,36 @@ class _LogicalStream(Stream): # multiplexing control blocks can be inserted between fragmented # inner frames on the physical channel. self._write_inner_frame_semaphore.acquire() + + # Consume an octet quota when this is the first fragmented frame. + if opcode != common.OPCODE_CONTINUATION: + try: + self._send_condition.acquire() + while (not self._send_closed) and self._send_quota == 0: + self._send_condition.wait() + + if self._send_closed: + raise BadOperationException( + 'Logical connection %d is closed' % + self._request.channel_id) + + self._send_quota -= 1 + finally: + self._send_condition.release() + while write_position < payload_length: try: - self._send_quota_condition.acquire() - while self._send_quota == 0: + self._send_condition.acquire() + while (not self._send_closed) and self._send_quota == 0: self._logger.debug( 'No quota. Waiting FlowControl message for %d.' % self._request.channel_id) - self._send_quota_condition.wait() + self._send_condition.wait() + + if self._send_closed: + raise BadOperationException( + 'Logical connection %d is closed' % + self.request._channel_id) remaining = payload_length - write_position write_length = min(self._send_quota, remaining) @@ -749,18 +907,16 @@ class _LogicalStream(Stream): opcode, payload[write_position:write_position+write_length], inner_frame_end) - frame_data = self._writer.build( - inner_frame, end=True, binary=True) self._send_quota -= write_length self._logger.debug('Consumed quota=%d, remaining=%d' % (write_length, self._send_quota)) finally: - self._send_quota_condition.release() + self._send_condition.release() # Writing data will block the worker so we need to release - # _send_quota_condition before writing. - self._logger.debug('Sending inner frame: %r' % frame_data) - self._request.connection.write(frame_data) + # _send_condition before writing. + self._logger.debug('Sending inner frame: %r' % inner_frame) + self._request.connection.write(inner_frame) write_position += write_length opcode = common.OPCODE_CONTINUATION @@ -773,12 +929,18 @@ class _LogicalStream(Stream): def replenish_send_quota(self, send_quota): """Replenish send quota.""" - self._send_quota_condition.acquire() - self._send_quota += send_quota - self._logger.debug('Replenished send quota for channel id %d: %d' % - (self._request.channel_id, self._send_quota)) - self._send_quota_condition.notify() - self._send_quota_condition.release() + try: + self._send_condition.acquire() + if self._send_quota + send_quota > 0x7FFFFFFFFFFFFFFF: + self._send_quota = 0 + raise LogicalChannelError( + self._request.channel_id, _DROP_CODE_SEND_QUOTA_OVERFLOW) + self._send_quota += send_quota + self._logger.debug('Replenished send quota for channel id %d: %d' % + (self._request.channel_id, self._send_quota)) + finally: + self._send_condition.notify() + self._send_condition.release() def consume_receive_quota(self, amount): """Consumes receive quota. Returns False on failure.""" @@ -808,7 +970,19 @@ class _LogicalStream(Stream): opcode = common.OPCODE_TEXT message = message.encode('utf-8') + for message_filter in self._options.outgoing_message_filters: + message = message_filter.filter(message, end, binary) + + if self._last_message_was_fragmented: + if opcode != self._message_opcode: + raise BadOperationException('Message types are different in ' + 'frames for the same message') + opcode = common.OPCODE_CONTINUATION + else: + self._message_opcode = opcode + self._write_inner_frame(opcode, message, end) + self._last_message_was_fragmented = not end def _receive_frame(self): """Overrides Stream._receive_frame. @@ -821,6 +995,9 @@ class _LogicalStream(Stream): opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self) amount = len(payload) + # Replenish extra one octet when receiving the first fragmented frame. + if opcode != common.OPCODE_CONTINUATION: + amount += 1 self._receive_quota += amount frame_data = _create_flow_control(self._request.channel_id, amount) @@ -829,6 +1006,21 @@ class _LogicalStream(Stream): self._request.connection.write_control_data(frame_data) return opcode, payload, fin, rsv1, rsv2, rsv3 + def _get_message_from_frame(self, frame): + """Overrides Stream._get_message_from_frame. + """ + + try: + inner_message = self._inner_message_builder.build(frame) + except InvalidFrameException: + raise LogicalChannelError( + self._request.channel_id, _DROP_CODE_BAD_FRAGMENTATION) + + if inner_message is None: + return None + self._original_opcode = inner_message.opcode + return inner_message.payload + def receive_message(self): """Overrides Stream.receive_message.""" @@ -882,6 +1074,14 @@ class _LogicalStream(Stream): pass + def stop_sending(self): + """Stops accepting new send operation (_write_inner_frame).""" + + self._send_condition.acquire() + self._send_closed = True + self._send_condition.notify() + self._send_condition.release() + class _OutgoingData(object): """A structure that holds data to be sent via physical connection and @@ -911,8 +1111,17 @@ class _PhysicalConnectionWriter(threading.Thread): self._logger = util.get_class_logger(self) self._mux_handler = mux_handler self.setDaemon(True) + + # When set, make this thread stop accepting new data, flush pending + # data and exit. self._stop_requested = False + # The close code of the physical connection. + self._close_code = common.STATUS_NORMAL_CLOSURE + # Deque for passing write data. It's protected by _deque_condition + # until _stop_requested is set. self._deque = collections.deque() + # - Protects _deque, _stop_requested and _close_code + # - Signals threads waiting for them to be available self._deque_condition = threading.Condition() def put_outgoing_data(self, data): @@ -937,8 +1146,11 @@ class _PhysicalConnectionWriter(threading.Thread): self._deque_condition.release() def _write_data(self, outgoing_data): + message = (_encode_channel_id(outgoing_data.channel_id) + + outgoing_data.data) try: - self._mux_handler.physical_connection.write(outgoing_data.data) + self._mux_handler.physical_stream.send_message( + message=message, end=True, binary=True) except Exception, e: util.prepend_message_to_exception( 'Failed to send message to %r: ' % @@ -948,33 +1160,51 @@ class _PhysicalConnectionWriter(threading.Thread): # TODO(bashi): It would be better to block the thread that sends # control data as well. if outgoing_data.channel_id != _CONTROL_CHANNEL_ID: - self._mux_handler.notify_write_done(outgoing_data.channel_id) + self._mux_handler.notify_write_data_done(outgoing_data.channel_id) def run(self): - self._deque_condition.acquire() - while not self._stop_requested: - if len(self._deque) == 0: - self._deque_condition.wait() - continue - - outgoing_data = self._deque.popleft() - self._deque_condition.release() - self._write_data(outgoing_data) + try: self._deque_condition.acquire() + while not self._stop_requested: + if len(self._deque) == 0: + self._deque_condition.wait() + continue - # Flush deque - try: - while len(self._deque) > 0: outgoing_data = self._deque.popleft() + + self._deque_condition.release() self._write_data(outgoing_data) + self._deque_condition.acquire() + + # Flush deque. + # + # At this point, self._deque_condition is always acquired. + try: + while len(self._deque) > 0: + outgoing_data = self._deque.popleft() + self._write_data(outgoing_data) + finally: + self._deque_condition.release() + + # Close physical connection. + try: + # Don't wait the response here. The response will be read + # by the reader thread. + self._mux_handler.physical_stream.close_connection( + self._close_code, wait_response=False) + except Exception, e: + util.prepend_message_to_exception( + 'Failed to close the physical connection: %r' % e) + raise finally: - self._deque_condition.release() + self._mux_handler.notify_writer_done() - def stop(self): + def stop(self, close_code=common.STATUS_NORMAL_CLOSURE): """Stops the writer thread.""" self._deque_condition.acquire() self._stop_requested = True + self._close_code = close_code self._deque_condition.notify() self._deque_condition.release() @@ -1055,6 +1285,9 @@ class _Worker(threading.Thread): try: # Non-critical exceptions will be handled by dispatcher. self._mux_handler.dispatcher.transfer_data(self._request) + except LogicalChannelError, e: + self._mux_handler.fail_logical_channel( + e.channel_id, e.drop_code, e.message) finally: self._mux_handler.notify_worker_done(self._request.channel_id) @@ -1083,8 +1316,6 @@ class _MuxHandshaker(hybi.Handshaker): # these headers are included already. request.headers_in[common.UPGRADE_HEADER] = ( common.WEBSOCKET_UPGRADE_TYPE) - request.headers_in[common.CONNECTION_HEADER] = ( - common.UPGRADE_CONNECTION_TYPE) request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = ( str(common.VERSION_HYBI_LATEST)) request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = ( @@ -1095,8 +1326,9 @@ class _MuxHandshaker(hybi.Handshaker): self._logger.debug('Creating logical stream for %d' % self._request.channel_id) - return _LogicalStream(self._request, self._send_quota, - self._receive_quota) + return _LogicalStream( + self._request, stream_options, self._send_quota, + self._receive_quota) def _create_handshake_response(self, accept): """Override hybi._create_handshake_response.""" @@ -1105,7 +1337,9 @@ class _MuxHandshaker(hybi.Handshaker): response.append('HTTP/1.1 101 Switching Protocols\r\n') - # Upgrade, Connection and Sec-WebSocket-Accept should be excluded. + # Upgrade and Sec-WebSocket-Accept should be excluded. + response.append('%s: %s\r\n' % ( + common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) if self._request.ws_protocol is not None: response.append('%s: %s\r\n' % ( common.SEC_WEBSOCKET_PROTOCOL_HEADER, @@ -1169,8 +1403,6 @@ class _HandshakeDeltaBase(object): del headers[key] else: headers[key] = value - # TODO(bashi): Support extensions - headers['Sec-WebSocket-Extensions'] = '' return headers @@ -1232,8 +1464,12 @@ class _MuxHandler(object): # Create "Implicitly Opened Connection". logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID) - self._handshake_base = _HandshakeDeltaBase( - self.original_request.headers_in) + headers = copy.copy(self.original_request.headers_in) + # Add extensions for logical channel. + headers[common.SEC_WEBSOCKET_EXTENSIONS_HEADER] = ( + common.format_extensions( + self.original_request.mux_processor.extensions())) + self._handshake_base = _HandshakeDeltaBase(headers) logical_request = _LogicalRequest( _DEFAULT_CHANNEL_ID, self.original_request.method, @@ -1245,8 +1481,9 @@ class _MuxHandler(object): # but we will send FlowControl later so set the initial quota to # _INITIAL_QUOTA_FOR_CLIENT. self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT) + send_quota = self.original_request.mux_processor.quota() if not self._do_handshake_for_logical_request( - logical_request, send_quota=self.original_request.mux_quota): + logical_request, send_quota=send_quota): raise MuxUnexpectedException( 'Failed handshake on the default channel id') self._add_logical_channel(logical_request) @@ -1287,7 +1524,6 @@ class _MuxHandler(object): if not self._worker_done_notify_received: self._logger.debug('Waiting worker(s) timed out') return False - finally: self._logical_channels_condition.release() @@ -1297,7 +1533,7 @@ class _MuxHandler(object): return True - def notify_write_done(self, channel_id): + def notify_write_data_done(self, channel_id): """Called by the writer thread when a write operation has done. Args: @@ -1308,7 +1544,7 @@ class _MuxHandler(object): self._logical_channels_condition.acquire() if channel_id in self._logical_channels: channel_data = self._logical_channels[channel_id] - channel_data.request.connection.notify_write_done() + channel_data.request.connection.on_write_data_done() else: self._logger.debug('Seems that logical channel for %d has gone' % channel_id) @@ -1469,9 +1705,11 @@ class _MuxHandler(object): return channel_data = self._logical_channels[block.channel_id] channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED + # Close the logical channel channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) + channel_data.request.ws_stream.stop_sending() finally: self._logical_channels_condition.release() @@ -1506,8 +1744,11 @@ class _MuxHandler(object): return channel_data = self._logical_channels[channel_id] fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame() + consuming_byte = len(payload) + if opcode != common.OPCODE_CONTINUATION: + consuming_byte += 1 if not channel_data.request.ws_stream.consume_receive_quota( - len(payload)): + consuming_byte): # The client violates quota. Close logical channel. raise LogicalChannelError( channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION) @@ -1569,15 +1810,32 @@ class _MuxHandler(object): finished. """ - # Terminate all logical connections - self._logger.debug('termiating all logical connections...') + self._logger.debug( + 'Termiating all logical connections waiting for incoming data ' + '...') self._logical_channels_condition.acquire() for channel_data in self._logical_channels.values(): try: channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) except Exception: - pass + self._logger.debug(traceback.format_exc()) + self._logical_channels_condition.release() + + def notify_writer_done(self): + """This method is called by the writer thread when the writer has + finished. + """ + + self._logger.debug( + 'Termiating all logical connections waiting for write ' + 'completion ...') + self._logical_channels_condition.acquire() + for channel_data in self._logical_channels.values(): + try: + channel_data.request.connection.on_writer_done() + except Exception: + self._logger.debug(traceback.format_exc()) self._logical_channels_condition.release() def fail_physical_connection(self, code, message): @@ -1590,8 +1848,7 @@ class _MuxHandler(object): self._logger.debug('Failing the physical connection...') self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message) - self.physical_stream.close_connection( - common.STATUS_INTERNAL_ENDPOINT_ERROR) + self._writer.stop(common.STATUS_INTERNAL_ENDPOINT_ERROR) def fail_logical_channel(self, channel_id, code, message): """Fail a logical channel. @@ -1611,8 +1868,10 @@ class _MuxHandler(object): # called later and it will send DropChannel. channel_data.drop_code = code channel_data.drop_message = message + channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) + channel_data.request.ws_stream.stop_sending() else: self._send_drop_channel(channel_id, code, message) finally: @@ -1620,7 +1879,8 @@ class _MuxHandler(object): def use_mux(request): - return hasattr(request, 'mux') and request.mux + return hasattr(request, 'mux_processor') and ( + request.mux_processor.is_active()) def start(request, dispatcher): diff --git a/pyload/lib/mod_pywebsocket/standalone.py b/pyload/lib/mod_pywebsocket/standalone.py index 07a33d9c9..e9f083753 100755 --- a/pyload/lib/mod_pywebsocket/standalone.py +++ b/pyload/lib/mod_pywebsocket/standalone.py @@ -76,6 +76,9 @@ SUPPORTING TLS To support TLS, run standalone.py with -t, -k, and -c options. +Note that when ssl module is used and the key/cert location is incorrect, +TLS connection silently fails while pyOpenSSL fails on startup. + SUPPORTING CLIENT AUTHENTICATION @@ -140,18 +143,6 @@ import sys import threading import time -_HAS_SSL = False -_HAS_OPEN_SSL = False -try: - import ssl - _HAS_SSL = True -except ImportError: - try: - import OpenSSL.SSL - _HAS_OPEN_SSL = True - except ImportError: - pass - from mod_pywebsocket import common from mod_pywebsocket import dispatch from mod_pywebsocket import handshake @@ -168,6 +159,10 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128 # 1024 is practically large enough to contain WebSocket handshake lines. _MAX_MEMORIZED_LINES = 1024 +# Constants for the --tls_module flag. +_TLS_BY_STANDARD_MODULE = 'ssl' +_TLS_BY_PYOPENSSL = 'pyopenssl' + class _StandaloneConnection(object): """Mimic mod_python mp_conn.""" @@ -231,11 +226,23 @@ class _StandaloneRequest(object): self.headers_in = request_handler.headers def get_uri(self): - """Getter to mimic request.uri.""" + """Getter to mimic request.uri. + + This method returns the raw data at the Request-URI part of the + Request-Line, while the uri method on the request object of mod_python + returns the path portion after parsing the raw data. This behavior is + kept for compatibility. + """ return self._request_handler.path uri = property(get_uri) + def get_unparsed_uri(self): + """Getter to mimic request.unparsed_uri.""" + + return self._request_handler.path + unparsed_uri = property(get_unparsed_uri) + def get_method(self): """Getter to mimic request.method.""" @@ -266,26 +273,67 @@ class _StandaloneRequest(object): 'Drained data following close frame: %r', drained_data) +def _import_ssl(): + global ssl + try: + import ssl + return True + except ImportError: + return False + + +def _import_pyopenssl(): + global OpenSSL + try: + import OpenSSL.SSL + return True + except ImportError: + return False + + class _StandaloneSSLConnection(object): - """A wrapper class for OpenSSL.SSL.Connection to provide makefile method - which is not supported by the class. + """A wrapper class for OpenSSL.SSL.Connection to + - provide makefile method which is not supported by the class + - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't + accept the "how" argument. + - convert SysCallError exceptions that its recv method may raise into a + return value of '', meaning EOF. We cannot overwrite the recv method on + self._connection since it's immutable. """ + _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv'] + def __init__(self, connection): self._connection = connection def __getattribute__(self, name): - if name in ('_connection', 'makefile'): + if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES: return object.__getattribute__(self, name) return self._connection.__getattribute__(name) def __setattr__(self, name, value): - if name in ('_connection', 'makefile'): + if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES: return object.__setattr__(self, name, value) return self._connection.__setattr__(name, value) def makefile(self, mode='r', bufsize=-1): - return socket._fileobject(self._connection, mode, bufsize) + return socket._fileobject(self, mode, bufsize) + + def shutdown(self, unused_how): + self._connection.shutdown() + + def recv(self, bufsize, flags=0): + if flags != 0: + raise ValueError('Non-zero flags not allowed') + + try: + return self._connection.recv(bufsize) + except OpenSSL.SSL.SysCallError, (err, message): + if err == -1: + # Suppress "unexpected EOF" exception. See the OpenSSL document + # for SSL_get_error. + return '' + raise def _alias_handlers(dispatcher, websock_handlers_map_file): @@ -340,7 +388,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): warnings = options.dispatcher.source_warnings() if warnings: for warning in warnings: - logging.warning('mod_pywebsocket: %s' % warning) + logging.warning('Warning in source loading: %s' % warning) self._logger = util.get_class_logger(self) @@ -387,25 +435,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): except Exception, e: self._logger.info('Skip by failure: %r', e) continue - if self.websocket_server_options.use_tls: - if _HAS_SSL: - if self.websocket_server_options.tls_client_auth: - client_cert_ = ssl.CERT_REQUIRED + server_options = self.websocket_server_options + if server_options.use_tls: + # For the case of _HAS_OPEN_SSL, we do wrapper setup after + # accept. + if server_options.tls_module == _TLS_BY_STANDARD_MODULE: + if server_options.tls_client_auth: + if server_options.tls_client_cert_optional: + client_cert_ = ssl.CERT_OPTIONAL + else: + client_cert_ = ssl.CERT_REQUIRED else: client_cert_ = ssl.CERT_NONE socket_ = ssl.wrap_socket(socket_, - keyfile=self.websocket_server_options.private_key, - certfile=self.websocket_server_options.certificate, + keyfile=server_options.private_key, + certfile=server_options.certificate, ssl_version=ssl.PROTOCOL_SSLv23, - ca_certs=self.websocket_server_options.tls_client_ca, - cert_reqs=client_cert_) - if _HAS_OPEN_SSL: - ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) - ctx.use_privatekey_file( - self.websocket_server_options.private_key) - ctx.use_certificate_file( - self.websocket_server_options.certificate) - socket_ = OpenSSL.SSL.Connection(ctx, socket_) + ca_certs=server_options.tls_client_ca, + cert_reqs=client_cert_, + do_handshake_on_connect=False) self._sockets.append((socket_, addrinfo)) def server_bind(self): @@ -479,7 +527,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): self._logger.critical('Not supported: fileno') return self._sockets[0][0].fileno() - def handle_error(self, rquest, client_address): + def handle_error(self, request, client_address): """Override SocketServer.handle_error.""" self._logger.error( @@ -496,8 +544,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): """ accepted_socket, client_address = self.socket.accept() - if self.websocket_server_options.use_tls and _HAS_OPEN_SSL: - accepted_socket = _StandaloneSSLConnection(accepted_socket) + + server_options = self.websocket_server_options + if server_options.use_tls: + if server_options.tls_module == _TLS_BY_STANDARD_MODULE: + try: + accepted_socket.do_handshake() + except ssl.SSLError, e: + self._logger.debug('%r', e) + raise + + # Print cipher in use. Handshake is done on accept. + self._logger.debug('Cipher: %s', accepted_socket.cipher()) + self._logger.debug('Client cert: %r', + accepted_socket.getpeercert()) + elif server_options.tls_module == _TLS_BY_PYOPENSSL: + # We cannot print the cipher in use. pyOpenSSL doesn't provide + # any method to fetch that. + + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + ctx.use_privatekey_file(server_options.private_key) + ctx.use_certificate_file(server_options.certificate) + + def default_callback(conn, cert, errnum, errdepth, ok): + return ok == 1 + + # See the OpenSSL document for SSL_CTX_set_verify. + if server_options.tls_client_auth: + verify_mode = OpenSSL.SSL.VERIFY_PEER + if not server_options.tls_client_cert_optional: + verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT + ctx.set_verify(verify_mode, default_callback) + ctx.load_verify_locations(server_options.tls_client_ca, + None) + else: + ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback) + + accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket) + accepted_socket.set_accept_state() + + # Convert SSL related error into socket.error so that + # SocketServer ignores them and keeps running. + # + # TODO(tyoshino): Convert all kinds of errors. + try: + accepted_socket.do_handshake() + except OpenSSL.SSL.Error, e: + # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module + # does. + self._logger.debug('%r', e) + raise socket.error(1, '%r' % e) + cert = accepted_socket.get_peer_certificate() + self._logger.debug('Client cert subject: %r', + cert.get_subject().get_components()) + accepted_socket = _StandaloneSSLConnection(accepted_socket) + else: + raise ValueError('No TLS support module is available') + return accepted_socket, client_address def serve_forever(self, poll_interval=0.5): @@ -636,7 +739,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): self._logger.info('Fallback to CGIHTTPRequestHandler') return True except dispatch.DispatchException, e: - self._logger.info('%s', e) + self._logger.info('Dispatch failed for error: %s', e) self.send_error(e.status) return False @@ -652,7 +755,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): allowDraft75=self._options.allow_draft75, strict=self._options.strict) except handshake.VersionException, e: - self._logger.info('%s', e) + self._logger.info('Handshake failed for version error: %s', e) self.send_response(common.HTTP_STATUS_BAD_REQUEST) self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER, e.supported_versions) @@ -660,14 +763,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): return False except handshake.HandshakeException, e: # Handshake for ws(s) failed. - self._logger.info('%s', e) + self._logger.info('Handshake failed for error: %s', e) self.send_error(e.status) return False request._dispatcher = self._options.dispatcher self._options.dispatcher.transfer_data(request) except handshake.AbortedByUserException, e: - self._logger.info('%s', e) + self._logger.info('Aborted: %s', e) return False def log_request(self, code='-', size='-'): @@ -799,6 +902,12 @@ def _build_option_parser(): 'as CGI programs. Must be executable.')) parser.add_option('-t', '--tls', dest='use_tls', action='store_true', default=False, help='use TLS (wss://)') + parser.add_option('--tls-module', '--tls_module', dest='tls_module', + type='choice', + choices = [_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL], + help='Use ssl module if "%s" is specified. ' + 'Use pyOpenSSL module if "%s" is specified' % + (_TLS_BY_STANDARD_MODULE, _TLS_BY_PYOPENSSL)) parser.add_option('-k', '--private-key', '--private_key', dest='private_key', default='', help='TLS private key file.') @@ -806,7 +915,12 @@ def _build_option_parser(): default='', help='TLS certificate file.') parser.add_option('--tls-client-auth', dest='tls_client_auth', action='store_true', default=False, - help='Requires TLS client auth on every connection.') + help='Requests TLS client auth on every connection.') + parser.add_option('--tls-client-cert-optional', + dest='tls_client_cert_optional', + action='store_true', default=False, + help=('Makes client certificate optional even though ' + 'TLS client auth is enabled.')) parser.add_option('--tls-client-ca', dest='tls_client_ca', default='', help=('Specifies a pem file which contains a set of ' 'concatenated CA certificates which are used to ' @@ -933,6 +1047,12 @@ def _main(args=None): _configure_logging(options) + if options.allow_draft75: + logging.warning('--allow_draft75 option is obsolete.') + + if options.strict: + logging.warning('--strict option is obsolete.') + # TODO(tyoshino): Clean up initialization of CGI related values. Move some # of code here to WebSocketRequestHandler class if it's better. options.cgi_directories = [] @@ -955,20 +1075,53 @@ def _main(args=None): options.is_executable_method = __check_script if options.use_tls: - if not (_HAS_SSL or _HAS_OPEN_SSL): - logging.critical('TLS support requires ssl or pyOpenSSL module.') + if options.tls_module is None: + if _import_ssl(): + options.tls_module = _TLS_BY_STANDARD_MODULE + logging.debug('Using ssl module') + elif _import_pyopenssl(): + options.tls_module = _TLS_BY_PYOPENSSL + logging.debug('Using pyOpenSSL module') + else: + logging.critical( + 'TLS support requires ssl or pyOpenSSL module.') + sys.exit(1) + elif options.tls_module == _TLS_BY_STANDARD_MODULE: + if not _import_ssl(): + logging.critical('ssl module is not available') + sys.exit(1) + elif options.tls_module == _TLS_BY_PYOPENSSL: + if not _import_pyopenssl(): + logging.critical('pyOpenSSL module is not available') + sys.exit(1) + else: + logging.critical('Invalid --tls-module option: %r', + options.tls_module) sys.exit(1) + if not options.private_key or not options.certificate: logging.critical( 'To use TLS, specify private_key and certificate.') sys.exit(1) - if options.tls_client_auth: - if not options.use_tls: + if (options.tls_client_cert_optional and + not options.tls_client_auth): + logging.critical('Client authentication must be enabled to ' + 'specify tls_client_cert_optional') + sys.exit(1) + else: + if options.tls_module is not None: + logging.critical('Use --tls-module option only together with ' + '--use-tls option.') + sys.exit(1) + + if options.tls_client_auth: + logging.critical('TLS must be enabled for client authentication.') + sys.exit(1) + + if options.tls_client_cert_optional: logging.critical('TLS must be enabled for client authentication.') sys.exit(1) - if not _HAS_SSL: - logging.critical('Client authentication requires ssl module.') if not options.scan_dir: options.scan_dir = options.websock_handlers diff --git a/pyload/lib/mod_pywebsocket/util.py b/pyload/lib/mod_pywebsocket/util.py index 7bb0b5d9e..fc8451be7 100644 --- a/pyload/lib/mod_pywebsocket/util.py +++ b/pyload/lib/mod_pywebsocket/util.py @@ -56,6 +56,11 @@ import socket import traceback import zlib +try: + from mod_pywebsocket import fast_masking +except ImportError: + pass + def get_stack_trace(): """Get the current stack trace as string. @@ -169,26 +174,40 @@ class RepeatedXorMasker(object): ended and resumes from that point on the next mask method call. """ - def __init__(self, mask): - self._mask = map(ord, mask) - self._mask_size = len(self._mask) - self._count = 0 + def __init__(self, masking_key): + self._masking_key = masking_key + self._masking_key_index = 0 - def mask(self, s): + def _mask_using_swig(self, s): + masked_data = fast_masking.mask( + s, self._masking_key, self._masking_key_index) + self._masking_key_index = ( + (self._masking_key_index + len(s)) % len(self._masking_key)) + return masked_data + + def _mask_using_array(self, s): result = array.array('B') result.fromstring(s) + # Use temporary local variables to eliminate the cost to access # attributes - count = self._count - mask = self._mask - mask_size = self._mask_size + masking_key = map(ord, self._masking_key) + masking_key_size = len(masking_key) + masking_key_index = self._masking_key_index + for i in xrange(len(result)): - result[i] ^= mask[count] - count = (count + 1) % mask_size - self._count = count + result[i] ^= masking_key[masking_key_index] + masking_key_index = (masking_key_index + 1) % masking_key_size + + self._masking_key_index = masking_key_index return result.tostring() + if 'fast_masking' in globals(): + mask = _mask_using_swig + else: + mask = _mask_using_array + class DeflateRequest(object): """A wrapper class for request object to intercept send and recv to perform @@ -252,6 +271,7 @@ class _Deflater(object): self._logger.debug('Compress result %r', compressed_bytes) return compressed_bytes + class _Inflater(object): def __init__(self): @@ -346,6 +366,7 @@ class _RFC1979Deflater(object): return self._deflater.compress_and_flush(bytes)[:-4] return self._deflater.compress(bytes) + class _RFC1979Inflater(object): """A decompressor class for byte sequence compressed and flushed following the algorithm described in the RFC1979 section 2.1. diff --git a/pyload/remote/WebSocketBackend.py b/pyload/remote/WebSocketBackend.py index 7238af679..55edee50e 100644 --- a/pyload/remote/WebSocketBackend.py +++ b/pyload/remote/WebSocketBackend.py @@ -48,10 +48,11 @@ class WebSocketBackend(BackendBase): # tls is needed when requested or webUI is also on tls if self.core.api.isWSSecure(): from wsbackend.Server import import_ssl - if import_ssl(): + tls_module = import_ssl() + if tls_module: options.use_tls = True + options.tls_module = tls_module options.certificate = self.core.config['ssl']['cert'] - options.ca_certificate = options.certificate options.private_key = self.core.config['ssl']['key'] self.core.log.info(_('Using secure WebSocket')) else: diff --git a/pyload/remote/wsbackend/Server.py b/pyload/remote/wsbackend/Server.py index 9a6649ca9..3ffe198eb 100644 --- a/pyload/remote/wsbackend/Server.py +++ b/pyload/remote/wsbackend/Server.py @@ -37,6 +37,7 @@ import BaseHTTPServer import CGIHTTPServer import SocketServer +import base64 import httplib import logging import os @@ -46,9 +47,6 @@ import socket import sys import threading -_HAS_SSL = False -_HAS_OPEN_SSL = False - from mod_pywebsocket import common from mod_pywebsocket import dispatch from mod_pywebsocket import handshake @@ -65,20 +63,9 @@ _DEFAULT_REQUEST_QUEUE_SIZE = 128 # 1024 is practically large enough to contain WebSocket handshake lines. _MAX_MEMORIZED_LINES = 1024 -def import_ssl(): - global _HAS_SSL, _HAS_OPEN_SSL - global ssl, OpenSSL - try: - import ssl - _HAS_SSL = True - except ImportError: - try: - import OpenSSL.SSL - _HAS_OPEN_SSL = True - except ImportError: - pass - - return _HAS_OPEN_SSL or _HAS_SSL +# Constants for the --tls_module flag. +_TLS_BY_STANDARD_MODULE = 'ssl' +_TLS_BY_PYOPENSSL = 'pyopenssl' class _StandaloneConnection(object): @@ -143,11 +130,23 @@ class _StandaloneRequest(object): self.headers_in = request_handler.headers def get_uri(self): - """Getter to mimic request.uri.""" + """Getter to mimic request.uri. + + This method returns the raw data at the Request-URI part of the + Request-Line, while the uri method on the request object of mod_python + returns the path portion after parsing the raw data. This behavior is + kept for compatibility. + """ return self._request_handler.path uri = property(get_uri) + def get_unparsed_uri(self): + """Getter to mimic request.unparsed_uri.""" + + return self._request_handler.path + unparsed_uri = property(get_unparsed_uri) + def get_method(self): """Getter to mimic request.method.""" @@ -178,26 +177,67 @@ class _StandaloneRequest(object): 'Drained data following close frame: %r', drained_data) +def _import_ssl(): + global ssl + try: + import ssl + return True + except ImportError: + return False + + +def _import_pyopenssl(): + global OpenSSL + try: + import OpenSSL.SSL + return True + except ImportError: + return False + + class _StandaloneSSLConnection(object): - """A wrapper class for OpenSSL.SSL.Connection to provide makefile method - which is not supported by the class. + """A wrapper class for OpenSSL.SSL.Connection to + - provide makefile method which is not supported by the class + - tweak shutdown method since OpenSSL.SSL.Connection.shutdown doesn't + accept the "how" argument. + - convert SysCallError exceptions that its recv method may raise into a + return value of '', meaning EOF. We cannot overwrite the recv method on + self._connection since it's immutable. """ + _OVERRIDDEN_ATTRIBUTES = ['_connection', 'makefile', 'shutdown', 'recv'] + def __init__(self, connection): self._connection = connection def __getattribute__(self, name): - if name in ('_connection', 'makefile'): + if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES: return object.__getattribute__(self, name) return self._connection.__getattribute__(name) def __setattr__(self, name, value): - if name in ('_connection', 'makefile'): + if name in _StandaloneSSLConnection._OVERRIDDEN_ATTRIBUTES: return object.__setattr__(self, name, value) return self._connection.__setattr__(name, value) def makefile(self, mode='r', bufsize=-1): - return socket._fileobject(self._connection, mode, bufsize) + return socket._fileobject(self, mode, bufsize) + + def shutdown(self, unused_how): + self._connection.shutdown() + + def recv(self, bufsize, flags=0): + if flags != 0: + raise ValueError('Non-zero flags not allowed') + + try: + return self._connection.recv(bufsize) + except OpenSSL.SSL.SysCallError, (err, message): + if err == -1: + # Suppress "unexpected EOF" exception. See the OpenSSL document + # for SSL_get_error. + return '' + raise def _alias_handlers(dispatcher, websock_handlers_map_file): @@ -284,25 +324,25 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): except Exception, e: self._logger.info('Skip by failure: %r', e) continue - if self.websocket_server_options.use_tls: - if _HAS_SSL: - if self.websocket_server_options.tls_client_auth: - client_cert_ = ssl.CERT_REQUIRED + server_options = self.websocket_server_options + if server_options.use_tls: + # For the case of _HAS_OPEN_SSL, we do wrapper setup after + # accept. + if server_options.tls_module == _TLS_BY_STANDARD_MODULE: + if server_options.tls_client_auth: + if server_options.tls_client_cert_optional: + client_cert_ = ssl.CERT_OPTIONAL + else: + client_cert_ = ssl.CERT_REQUIRED else: client_cert_ = ssl.CERT_NONE socket_ = ssl.wrap_socket(socket_, - keyfile=self.websocket_server_options.private_key, - certfile=self.websocket_server_options.certificate, - ssl_version=ssl.PROTOCOL_SSLv23, - ca_certs=self.websocket_server_options.tls_client_ca, - cert_reqs=client_cert_) - if _HAS_OPEN_SSL: - ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) - ctx.use_privatekey_file( - self.websocket_server_options.private_key) - ctx.use_certificate_file( - self.websocket_server_options.certificate) - socket_ = OpenSSL.SSL.Connection(ctx, socket_) + keyfile=server_options.private_key, + certfile=server_options.certificate, + ssl_version=ssl.PROTOCOL_SSLv23, + ca_certs=server_options.tls_client_ca, + cert_reqs=client_cert_, + do_handshake_on_connect=False) self._sockets.append((socket_, addrinfo)) def server_bind(self): @@ -375,7 +415,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): self._logger.critical('Not supported: fileno') return self._sockets[0][0].fileno() - def handle_error(self, rquest, client_address): + def handle_error(self, request, client_address): """Override SocketServer.handle_error.""" self._logger.error( @@ -392,8 +432,63 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): """ accepted_socket, client_address = self.socket.accept() - if self.websocket_server_options.use_tls and _HAS_OPEN_SSL: - accepted_socket = _StandaloneSSLConnection(accepted_socket) + + server_options = self.websocket_server_options + if server_options.use_tls: + if server_options.tls_module == _TLS_BY_STANDARD_MODULE: + try: + accepted_socket.do_handshake() + except ssl.SSLError, e: + self._logger.debug('%r', e) + raise + + # Print cipher in use. Handshake is done on accept. + self._logger.debug('Cipher: %s', accepted_socket.cipher()) + self._logger.debug('Client cert: %r', + accepted_socket.getpeercert()) + elif server_options.tls_module == _TLS_BY_PYOPENSSL: + # We cannot print the cipher in use. pyOpenSSL doesn't provide + # any method to fetch that. + + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + ctx.use_privatekey_file(server_options.private_key) + ctx.use_certificate_file(server_options.certificate) + + def default_callback(conn, cert, errnum, errdepth, ok): + return ok == 1 + + # See the OpenSSL document for SSL_CTX_set_verify. + if server_options.tls_client_auth: + verify_mode = OpenSSL.SSL.VERIFY_PEER + if not server_options.tls_client_cert_optional: + verify_mode |= OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT + ctx.set_verify(verify_mode, default_callback) + ctx.load_verify_locations(server_options.tls_client_ca, + None) + else: + ctx.set_verify(OpenSSL.SSL.VERIFY_NONE, default_callback) + + accepted_socket = OpenSSL.SSL.Connection(ctx, accepted_socket) + accepted_socket.set_accept_state() + + # Convert SSL related error into socket.error so that + # SocketServer ignores them and keeps running. + # + # TODO(tyoshino): Convert all kinds of errors. + try: + accepted_socket.do_handshake() + except OpenSSL.SSL.Error, e: + # Set errno part to 1 (SSL_ERROR_SSL) like the ssl module + # does. + self._logger.debug('%r', e) + raise socket.error(1, '%r' % e) + cert = accepted_socket.get_peer_certificate() + self._logger.debug('Client cert subject: %r', + cert.get_subject().get_components()) + accepted_socket = _StandaloneSSLConnection(accepted_socket) + else: + raise ValueError('No TLS support module is available') + return accepted_socket, client_address def serve_forever(self, poll_interval=0.5): @@ -474,8 +569,6 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): raise self._logger.debug("WS: Broken pipe") - - def parse_request(self): """Override BaseHTTPServer.BaseHTTPRequestHandler.parse_request. @@ -545,7 +638,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): self._logger.info('Fallback to CGIHTTPRequestHandler') return False except dispatch.DispatchException, e: - self._logger.info('%s', e) + self._logger.info('Dispatch failed for error: %s', e) self.send_error(e.status) return False @@ -561,7 +654,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): allowDraft75=self._options.allow_draft75, strict=self._options.strict) except handshake.VersionException, e: - self._logger.info('%s', e) + self._logger.info('Handshake failed for version error: %s', e) self.send_response(common.HTTP_STATUS_BAD_REQUEST) self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER, e.supported_versions) @@ -569,14 +662,14 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): return False except handshake.HandshakeException, e: # Handshake for ws(s) failed. - self._logger.info('%s', e) + self._logger.info('Handshake failed for error: %s', e) self.send_error(e.status) return False request._dispatcher = self._options.dispatcher self._options.dispatcher.transfer_data(request) except handshake.AbortedByUserException, e: - self._logger.info('%s', e) + self._logger.info('Aborted: %s', e) return False def log_request(self, code='-', size='-'): @@ -606,7 +699,7 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): if CGIHTTPServer.CGIHTTPRequestHandler.is_cgi(self): if '..' in self.path: return False - # strip query parameter from request path + # strip query parameter from request path resource_name = self.path.split('?', 2)[0] # convert resource_name into real path name in filesystem. scriptfile = self.translate_path(resource_name) @@ -629,11 +722,11 @@ def _configure_logging(options): logger.setLevel(logging.getLevelName(options.log_level.upper())) if options.log_file: handler = logging.handlers.RotatingFileHandler( - options.log_file, 'a', options.log_max, options.log_count) + options.log_file, 'a', options.log_max, options.log_count) else: handler = logging.StreamHandler() formatter = logging.Formatter( - '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s') + '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) @@ -650,9 +743,10 @@ class DefaultOptions: use_tls = False private_key = '' certificate = '' - ca_certificate = '' - tls_client_ca = '' + tls_client_ca = None tls_client_auth = False + tls_client_cert_optional = False + tls_module = _TLS_BY_STANDARD_MODULE dispatcher = None request_queue_size = _DEFAULT_REQUEST_QUEUE_SIZE use_basic_auth = False @@ -664,6 +758,16 @@ class DefaultOptions: cgi_directories = '' is_executable_method = False + +def import_ssl(): + if _import_ssl(): + return _TLS_BY_STANDARD_MODULE + + elif _import_pyopenssl(): + return _TLS_BY_PYOPENSSL + + return False + def _main(args=None): """You can call this function from your own program, but please note that this function has some side-effects that might affect your program. For @@ -677,6 +781,12 @@ def _main(args=None): _configure_logging(options) + if options.allow_draft75: + logging.warning('--allow_draft75 option is obsolete.') + + if options.strict: + logging.warning('--strict option is obsolete.') + # TODO(tyoshino): Clean up initialization of CGI related values. Move some # of code here to WebSocketRequestHandler class if it's better. options.cgi_directories = [] @@ -699,20 +809,53 @@ def _main(args=None): options.is_executable_method = __check_script if options.use_tls: - if not (_HAS_SSL or _HAS_OPEN_SSL): - logging.critical('TLS support requires ssl or pyOpenSSL module.') + if options.tls_module is None: + if _import_ssl(): + options.tls_module = _TLS_BY_STANDARD_MODULE + logging.debug('Using ssl module') + elif _import_pyopenssl(): + options.tls_module = _TLS_BY_PYOPENSSL + logging.debug('Using pyOpenSSL module') + else: + logging.critical( + 'TLS support requires ssl or pyOpenSSL module.') + sys.exit(1) + elif options.tls_module == _TLS_BY_STANDARD_MODULE: + if not _import_ssl(): + logging.critical('ssl module is not available') + sys.exit(1) + elif options.tls_module == _TLS_BY_PYOPENSSL: + if not _import_pyopenssl(): + logging.critical('pyOpenSSL module is not available') + sys.exit(1) + else: + logging.critical('Invalid --tls-module option: %r', + options.tls_module) sys.exit(1) + if not options.private_key or not options.certificate: logging.critical( - 'To use TLS, specify private_key and certificate.') + 'To use TLS, specify private_key and certificate.') + sys.exit(1) + + if (options.tls_client_cert_optional and + not options.tls_client_auth): + logging.critical('Client authentication must be enabled to ' + 'specify tls_client_cert_optional') + sys.exit(1) + else: + if options.tls_module is not None: + logging.critical('Use --tls-module option only together with ' + '--use-tls option.') + sys.exit(1) + + if options.tls_client_auth: + logging.critical('TLS must be enabled for client authentication.') sys.exit(1) - if options.tls_client_auth: - if not options.use_tls: + if options.tls_client_cert_optional: logging.critical('TLS must be enabled for client authentication.') sys.exit(1) - if not _HAS_SSL: - logging.critical('Client authentication requires ssl module.') if not options.scan_dir: options.scan_dir = options.websock_handlers |