diff options
Diffstat (limited to 'module/lib/mod_pywebsocket')
-rw-r--r-- | module/lib/mod_pywebsocket/__init__.py | 41 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/_stream_hixie75.py | 3 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/_stream_hybi.py | 686 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/common.py | 5 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/dispatch.py | 16 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/extensions.py | 437 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/handshake/__init__.py | 10 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/handshake/_base.py | 15 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/handshake/draft75.py | 190 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/handshake/hybi.py | 64 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/headerparserhandler.py | 5 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/msgutil.py | 16 | ||||
-rwxr-xr-x | module/lib/mod_pywebsocket/standalone.py | 206 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/stream.py | 1 | ||||
-rw-r--r-- | module/lib/mod_pywebsocket/util.py | 61 |
15 files changed, 1140 insertions, 616 deletions
diff --git a/module/lib/mod_pywebsocket/__init__.py b/module/lib/mod_pywebsocket/__init__.py index c154da4a1..454ae0c45 100644 --- a/module/lib/mod_pywebsocket/__init__.py +++ b/module/lib/mod_pywebsocket/__init__.py @@ -34,7 +34,8 @@ mod_pywebsocket is a WebSocket extension for Apache HTTP Server intended for testing or experimental purposes. mod_python is required. -Installation: +Installation +============ 0. Prepare an Apache HTTP Server for which mod_python is enabled. @@ -60,11 +61,6 @@ Installation: <scan_dir> is useful in saving scan time when <websock_handlers> contains many non-WebSocket handler files. - If you want to support old handshake based on - draft-hixie-thewebsocketprotocol-75: - - PythonOption mod_pywebsocket.allow_draft75 On - If you want to allow handlers whose canonical path is not under the root directory (i.e. symbolic link is in root directory but its target is not), configure as follows: @@ -89,7 +85,8 @@ Installation: 3. Verify installation. You can use example/console.html to poke the server. -Writing WebSocket handlers: +Writing WebSocket handlers +========================== When a WebSocket request comes in, the resource name specified in the handshake is considered as if it is a file path under @@ -118,28 +115,36 @@ extra handshake (web_socket_do_extra_handshake): - ws_resource - ws_origin - ws_version -- ws_location (Hixie 75 and HyBi 00 only) -- ws_extensions (Hybi 06 and later) +- ws_location (HyBi 00 only) +- ws_extensions (HyBi 06 and later) - ws_deflate (HyBi 06 and later) - ws_protocol - ws_requested_protocols (HyBi 06 and later) -The last two are a bit tricky. +The last two are a bit tricky. See the next subsection. + + +Subprotocol Negotiation +----------------------- For HyBi 06 and later, ws_protocol is always set to None when web_socket_do_extra_handshake is called. If ws_requested_protocols is not None, you must choose one subprotocol from this list and set it to ws_protocol. -For Hixie 75 and HyBi 00, when web_socket_do_extra_handshake is called, +For HyBi 00, when web_socket_do_extra_handshake is called, ws_protocol is set to the value given by the client in -Sec-WebSocket-Protocol (WebSocket-Protocol for Hixie 75) header or None if +Sec-WebSocket-Protocol header or None if such header was not found in the opening handshake request. Finish extra handshake with ws_protocol untouched to accept the request subprotocol. -Then, Sec-WebSocket-Protocol (or WebSocket-Protocol) header will be sent to +Then, Sec-WebSocket-Protocol header will be sent to the client in response with the same value as requested. Raise an exception in web_socket_do_extra_handshake to reject the requested subprotocol. + +Data Transfer +------------- + web_socket_transfer_data is called after the handshake completed successfully. A handler can receive/send messages from/to the client using request. mod_pywebsocket.msgutil module provides utilities @@ -159,12 +164,16 @@ You can send a message by the following statement. request.ws_stream.send_message(message) + +Closing Connection +------------------ + Executing the following statement or just return-ing from web_socket_transfer_data cause connection close. request.ws_stream.close_connection() -When you're using IETF HyBi 00 or later protocol, close_connection will wait +close_connection will wait for closing handshake acknowledgement coming from the client. When it couldn't receive a valid acknowledgement, raises an exception. @@ -176,6 +185,10 @@ use in web_socket_passive_closing_handshake. - ws_close_code - ws_close_reason + +Threading +--------- + A WebSocket handler must be thread-safe if the server (Apache or standalone.py) is configured to use threads. """ diff --git a/module/lib/mod_pywebsocket/_stream_hixie75.py b/module/lib/mod_pywebsocket/_stream_hixie75.py index c84ca6e07..94cf5b31b 100644 --- a/module/lib/mod_pywebsocket/_stream_hixie75.py +++ b/module/lib/mod_pywebsocket/_stream_hixie75.py @@ -32,7 +32,8 @@ protocol version HyBi 00 and Hixie 75. Specification: -http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +- HyBi 00 http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +- Hixie 75 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-75 """ diff --git a/module/lib/mod_pywebsocket/_stream_hybi.py b/module/lib/mod_pywebsocket/_stream_hybi.py index 34fa7a60e..bd158fa6b 100644 --- a/module/lib/mod_pywebsocket/_stream_hybi.py +++ b/module/lib/mod_pywebsocket/_stream_hybi.py @@ -37,6 +37,7 @@ http://tools.ietf.org/html/rfc6455 from collections import deque +import logging import os import struct import time @@ -162,14 +163,145 @@ def create_text_frame( frame_filters) +def parse_frame(receive_bytes, logger=None, + ws_version=common.VERSION_HYBI_LATEST, + unmask_receive=True): + """Parses a frame. Returns a tuple containing each header field and + payload. + + Args: + receive_bytes: a function that reads frame data from a stream or + something similar. The function takes length of the bytes to be + read. The function must raise ConnectionTerminatedException if + there is not enough data to be read. + logger: a logging object. + ws_version: the version of WebSocket protocol. + unmask_receive: unmask received frames. When received unmasked + frame, raises InvalidFrameException. + + Raises: + ConnectionTerminatedException: when receive_bytes raises it. + InvalidFrameException: when the frame contains invalid data. + """ + + if not logger: + logger = logging.getLogger() + + logger.log(common.LOGLEVEL_FINE, 'Receive the first 2 octets of a frame') + + received = receive_bytes(2) + + first_byte = ord(received[0]) + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + second_byte = ord(received[1]) + mask = (second_byte >> 7) & 1 + payload_length = second_byte & 0x7f + + logger.log(common.LOGLEVEL_FINE, + 'FIN=%s, RSV1=%s, RSV2=%s, RSV3=%s, opcode=%s, ' + 'Mask=%s, Payload_length=%s', + fin, rsv1, rsv2, rsv3, opcode, mask, payload_length) + + if (mask == 1) != unmask_receive: + raise InvalidFrameException( + 'Mask bit on the received frame did\'nt match masking ' + 'configuration for received frames') + + # The HyBi and later specs disallow putting a value in 0x0-0xFFFF + # into the 8-octet extended payload length field (or 0x0-0xFD in + # 2-octet field). + valid_length_encoding = True + length_encoding_bytes = 1 + if payload_length == 127: + logger.log(common.LOGLEVEL_FINE, + 'Receive 8-octet extended payload length') + + extended_payload_length = receive_bytes(8) + payload_length = struct.unpack( + '!Q', extended_payload_length)[0] + if payload_length > 0x7FFFFFFFFFFFFFFF: + raise InvalidFrameException( + 'Extended payload length >= 2^63') + if ws_version >= 13 and payload_length < 0x10000: + valid_length_encoding = False + length_encoding_bytes = 8 + + logger.log(common.LOGLEVEL_FINE, + 'Decoded_payload_length=%s', payload_length) + elif payload_length == 126: + logger.log(common.LOGLEVEL_FINE, + 'Receive 2-octet extended payload length') + + extended_payload_length = receive_bytes(2) + payload_length = struct.unpack( + '!H', extended_payload_length)[0] + if ws_version >= 13 and payload_length < 126: + valid_length_encoding = False + length_encoding_bytes = 2 + + logger.log(common.LOGLEVEL_FINE, + 'Decoded_payload_length=%s', payload_length) + + if not valid_length_encoding: + logger.warning( + 'Payload length is not encoded using the minimal number of ' + 'bytes (%d is encoded using %d bytes)', + payload_length, + length_encoding_bytes) + + if mask == 1: + logger.log(common.LOGLEVEL_FINE, 'Receive mask') + + masking_nonce = receive_bytes(4) + masker = util.RepeatedXorMasker(masking_nonce) + + logger.log(common.LOGLEVEL_FINE, 'Mask=%r', masking_nonce) + else: + masker = _NOOP_MASKER + + logger.log(common.LOGLEVEL_FINE, 'Receive payload data') + if logger.isEnabledFor(common.LOGLEVEL_FINE): + receive_start = time.time() + + raw_payload_bytes = receive_bytes(payload_length) + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + logger.log( + common.LOGLEVEL_FINE, + 'Done receiving payload data at %s MB/s', + payload_length / (time.time() - receive_start) / 1000 / 1000) + logger.log(common.LOGLEVEL_FINE, 'Unmask payload data') + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + unmask_start = time.time() + + bytes = masker.mask(raw_payload_bytes) + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + logger.log( + common.LOGLEVEL_FINE, + 'Done unmasking payload data at %s MB/s', + payload_length / (time.time() - unmask_start) / 1000 / 1000) + + return opcode, bytes, fin, rsv1, rsv2, rsv3 + + class FragmentedFrameBuilder(object): """A stateful class to send a message as fragments.""" - def __init__(self, mask, frame_filters=[]): + def __init__(self, mask, frame_filters=[], encode_utf8=True): """Constructs an instance.""" self._mask = mask self._frame_filters = frame_filters + # This is for skipping UTF-8 encoding when building text type frames + # from compressed data. + self._encode_utf8 = encode_utf8 self._started = False @@ -177,7 +309,7 @@ class FragmentedFrameBuilder(object): # frames in the message are all the same. self._opcode = common.OPCODE_TEXT - def build(self, message, end, binary): + def build(self, payload_data, end, binary): if binary: frame_type = common.OPCODE_BINARY else: @@ -198,12 +330,12 @@ class FragmentedFrameBuilder(object): self._started = True fin = 0 - if binary: + if binary or not self._encode_utf8: return create_binary_frame( - message, opcode, fin, self._mask, self._frame_filters) + payload_data, opcode, fin, self._mask, self._frame_filters) else: return create_text_frame( - message, opcode, fin, self._mask, self._frame_filters) + payload_data, opcode, fin, self._mask, self._frame_filters) def _create_control_frame(opcode, body, mask, frame_filters): @@ -235,6 +367,22 @@ def create_close_frame(body, mask=False, frame_filters=[]): common.OPCODE_CLOSE, body, mask, frame_filters) +def create_closing_handshake_body(code, reason): + body = '' + if code is not None: + if (code > common.STATUS_USER_PRIVATE_MAX or + code < common.STATUS_NORMAL_CLOSURE): + raise BadOperationException('Status code is out of range') + if (code == common.STATUS_NO_STATUS_RECEIVED or + code == common.STATUS_ABNORMAL_CLOSURE or + code == common.STATUS_TLS_HANDSHAKE): + raise BadOperationException('Status code is reserved pseudo ' + 'code') + encoded_reason = reason.encode('utf-8') + body = struct.pack('!H', code) + encoded_reason + return body + + class StreamOptions(object): """Holds option values to configure Stream objects.""" @@ -248,8 +396,16 @@ class StreamOptions(object): self.outgoing_frame_filters = [] self.incoming_frame_filters = [] + # Filters applied to messages. Control frames are not affected by them. + self.outgoing_message_filters = [] + self.incoming_message_filters = [] + + self.encode_text_message_to_utf8 = True self.mask_send = False self.unmask_receive = True + # RFC6455 disallows fragmented control frames, but mux extension + # relaxes the restriction. + self.allow_fragmented_control_frame = False class Stream(StreamBase): @@ -283,7 +439,8 @@ class Stream(StreamBase): self._original_opcode = None self._writer = FragmentedFrameBuilder( - self._options.mask_send, self._options.outgoing_frame_filters) + self._options.mask_send, self._options.outgoing_frame_filters, + self._options.encode_text_message_to_utf8) self._ping_queue = deque() @@ -297,109 +454,13 @@ class Stream(StreamBase): InvalidFrameException: when the frame contains invalid data. """ - self._logger.log(common.LOGLEVEL_FINE, - 'Receive the first 2 octets of a frame') - - received = self.receive_bytes(2) - - first_byte = ord(received[0]) - fin = (first_byte >> 7) & 1 - rsv1 = (first_byte >> 6) & 1 - rsv2 = (first_byte >> 5) & 1 - rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xf - - second_byte = ord(received[1]) - mask = (second_byte >> 7) & 1 - payload_length = second_byte & 0x7f - - self._logger.log(common.LOGLEVEL_FINE, - 'FIN=%s, RSV1=%s, RSV2=%s, RSV3=%s, opcode=%s, ' - 'Mask=%s, Payload_length=%s', - fin, rsv1, rsv2, rsv3, opcode, mask, payload_length) - - if (mask == 1) != self._options.unmask_receive: - raise InvalidFrameException( - 'Mask bit on the received frame did\'nt match masking ' - 'configuration for received frames') - - # The Hybi-13 and later specs disallow putting a value in 0x0-0xFFFF - # into the 8-octet extended payload length field (or 0x0-0xFD in - # 2-octet field). - valid_length_encoding = True - length_encoding_bytes = 1 - if payload_length == 127: - self._logger.log(common.LOGLEVEL_FINE, - 'Receive 8-octet extended payload length') - - extended_payload_length = self.receive_bytes(8) - payload_length = struct.unpack( - '!Q', extended_payload_length)[0] - if payload_length > 0x7FFFFFFFFFFFFFFF: - raise InvalidFrameException( - 'Extended payload length >= 2^63') - if self._request.ws_version >= 13 and payload_length < 0x10000: - valid_length_encoding = False - length_encoding_bytes = 8 - - self._logger.log(common.LOGLEVEL_FINE, - 'Decoded_payload_length=%s', payload_length) - elif payload_length == 126: - self._logger.log(common.LOGLEVEL_FINE, - 'Receive 2-octet extended payload length') - - extended_payload_length = self.receive_bytes(2) - payload_length = struct.unpack( - '!H', extended_payload_length)[0] - if self._request.ws_version >= 13 and payload_length < 126: - valid_length_encoding = False - length_encoding_bytes = 2 - - self._logger.log(common.LOGLEVEL_FINE, - 'Decoded_payload_length=%s', payload_length) - - if not valid_length_encoding: - self._logger.warning( - 'Payload length is not encoded using the minimal number of ' - 'bytes (%d is encoded using %d bytes)', - payload_length, - length_encoding_bytes) - - if mask == 1: - self._logger.log(common.LOGLEVEL_FINE, 'Receive mask') - - masking_nonce = self.receive_bytes(4) - masker = util.RepeatedXorMasker(masking_nonce) - - self._logger.log(common.LOGLEVEL_FINE, 'Mask=%r', masking_nonce) - else: - masker = _NOOP_MASKER - - self._logger.log(common.LOGLEVEL_FINE, 'Receive payload data') - if self._logger.isEnabledFor(common.LOGLEVEL_FINE): - receive_start = time.time() - - raw_payload_bytes = self.receive_bytes(payload_length) - - if self._logger.isEnabledFor(common.LOGLEVEL_FINE): - self._logger.log( - common.LOGLEVEL_FINE, - 'Done receiving payload data at %s MB/s', - payload_length / (time.time() - receive_start) / 1000 / 1000) - self._logger.log(common.LOGLEVEL_FINE, 'Unmask payload data') - - if self._logger.isEnabledFor(common.LOGLEVEL_FINE): - unmask_start = time.time() - - bytes = masker.mask(raw_payload_bytes) + def _receive_bytes(length): + return self.receive_bytes(length) - if self._logger.isEnabledFor(common.LOGLEVEL_FINE): - self._logger.log( - common.LOGLEVEL_FINE, - 'Done unmasking payload data at %s MB/s', - payload_length / (time.time() - unmask_start) / 1000 / 1000) - - return opcode, bytes, fin, rsv1, rsv2, rsv3 + return parse_frame(receive_bytes=_receive_bytes, + logger=self._logger, + ws_version=self._request.ws_version, + unmask_receive=self._options.unmask_receive) def _receive_frame_as_frame_object(self): opcode, bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame() @@ -407,6 +468,32 @@ class Stream(StreamBase): return Frame(fin=fin, rsv1=rsv1, rsv2=rsv2, rsv3=rsv3, opcode=opcode, payload=bytes) + def receive_filtered_frame(self): + """Receives a frame and applies frame filters and message filters. + The frame to be received must satisfy following conditions: + - The frame is not fragmented. + - The opcode of the frame is TEXT or BINARY. + + DO NOT USE this method except for testing purpose. + """ + + frame = self._receive_frame_as_frame_object() + if not frame.fin: + raise InvalidFrameException( + 'Segmented frames must not be received via ' + 'receive_filtered_frame()') + if (frame.opcode != common.OPCODE_TEXT and + frame.opcode != common.OPCODE_BINARY): + raise InvalidFrameException( + 'Control frames must not be received via ' + 'receive_filtered_frame()') + + for frame_filter in self._options.incoming_frame_filters: + frame_filter.filter(frame) + for message_filter in self._options.incoming_message_filters: + frame.payload = message_filter.filter(frame.payload) + return frame + def send_message(self, message, end=True, binary=False): """Send message. @@ -428,11 +515,219 @@ class Stream(StreamBase): raise BadOperationException( 'Message for binary frame must be instance of str') + for message_filter in self._options.outgoing_message_filters: + message = message_filter.filter(message, end, binary) + try: - self._write(self._writer.build(message, end, binary)) + # Set this to any positive integer to limit maximum size of data in + # payload data of each frame. + MAX_PAYLOAD_DATA_SIZE = -1 + + if MAX_PAYLOAD_DATA_SIZE <= 0: + self._write(self._writer.build(message, end, binary)) + return + + bytes_written = 0 + while True: + end_for_this_frame = end + bytes_to_write = len(message) - bytes_written + if (MAX_PAYLOAD_DATA_SIZE > 0 and + bytes_to_write > MAX_PAYLOAD_DATA_SIZE): + end_for_this_frame = False + bytes_to_write = MAX_PAYLOAD_DATA_SIZE + + frame = self._writer.build( + message[bytes_written:bytes_written + bytes_to_write], + end_for_this_frame, + binary) + self._write(frame) + + bytes_written += bytes_to_write + + # This if must be placed here (the end of while block) so that + # at least one frame is sent. + if len(message) <= bytes_written: + break except ValueError, e: raise BadOperationException(e) + def _get_message_from_frame(self, frame): + """Gets a message from frame. If the message is composed of fragmented + frames and the frame is not the last fragmented frame, this method + returns None. The whole message will be returned when the last + fragmented frame is passed to this method. + + Raises: + InvalidFrameException: when the frame doesn't match defragmentation + context, or the frame contains invalid data. + """ + + if frame.opcode == common.OPCODE_CONTINUATION: + if not self._received_fragments: + if frame.fin: + raise InvalidFrameException( + 'Received a termination frame but fragmentation ' + 'not started') + else: + raise InvalidFrameException( + 'Received an intermediate frame but ' + 'fragmentation not started') + + if frame.fin: + # End of fragmentation frame + self._received_fragments.append(frame.payload) + message = ''.join(self._received_fragments) + self._received_fragments = [] + return message + else: + # Intermediate frame + self._received_fragments.append(frame.payload) + return None + else: + if self._received_fragments: + if frame.fin: + raise InvalidFrameException( + 'Received an unfragmented frame without ' + 'terminating existing fragmentation') + else: + raise InvalidFrameException( + 'New fragmentation started without terminating ' + 'existing fragmentation') + + if frame.fin: + # Unfragmented frame + + self._original_opcode = frame.opcode + return frame.payload + else: + # Start of fragmentation frame + + if (not self._options.allow_fragmented_control_frame and + common.is_control_opcode(frame.opcode)): + raise InvalidFrameException( + 'Control frames must not be fragmented') + + self._original_opcode = frame.opcode + self._received_fragments.append(frame.payload) + return None + + def _process_close_message(self, message): + """Processes close message. + + Args: + message: close message. + + Raises: + InvalidFrameException: when the message is invalid. + """ + + self._request.client_terminated = True + + # Status code is optional. We can have status reason only if we + # have status code. Status reason can be empty string. So, + # allowed cases are + # - no application data: no code no reason + # - 2 octet of application data: has code but no reason + # - 3 or more octet of application data: both code and reason + if len(message) == 0: + self._logger.debug('Received close frame (empty body)') + self._request.ws_close_code = ( + common.STATUS_NO_STATUS_RECEIVED) + elif len(message) == 1: + raise InvalidFrameException( + 'If a close frame has status code, the length of ' + 'status code must be 2 octet') + elif len(message) >= 2: + self._request.ws_close_code = struct.unpack( + '!H', message[0:2])[0] + self._request.ws_close_reason = message[2:].decode( + 'utf-8', 'replace') + self._logger.debug( + 'Received close frame (code=%d, reason=%r)', + self._request.ws_close_code, + self._request.ws_close_reason) + + # Drain junk data after the close frame if necessary. + self._drain_received_data() + + if self._request.server_terminated: + self._logger.debug( + 'Received ack for server-initiated closing handshake') + return + + self._logger.debug( + 'Received client-initiated closing handshake') + + code = common.STATUS_NORMAL_CLOSURE + reason = '' + if hasattr(self._request, '_dispatcher'): + dispatcher = self._request._dispatcher + code, reason = dispatcher.passive_closing_handshake( + self._request) + if code is None and reason is not None and len(reason) > 0: + self._logger.warning( + 'Handler specified reason despite code being None') + reason = '' + if reason is None: + reason = '' + self._send_closing_handshake(code, reason) + self._logger.debug( + 'Sent ack for client-initiated closing handshake ' + '(code=%r, reason=%r)', code, reason) + + def _process_ping_message(self, message): + """Processes ping message. + + Args: + message: ping message. + """ + + try: + handler = self._request.on_ping_handler + if handler: + handler(self._request, message) + return + except AttributeError, e: + pass + self._send_pong(message) + + def _process_pong_message(self, message): + """Processes pong message. + + Args: + message: pong message. + """ + + # TODO(tyoshino): Add ping timeout handling. + + inflight_pings = deque() + + while True: + try: + expected_body = self._ping_queue.popleft() + if expected_body == message: + # inflight_pings contains pings ignored by the + # other peer. Just forget them. + self._logger.debug( + 'Ping %r is acked (%d pings were ignored)', + expected_body, len(inflight_pings)) + break + else: + inflight_pings.append(expected_body) + except IndexError, e: + # The received pong was unsolicited pong. Keep the + # ping queue as is. + self._ping_queue = inflight_pings + self._logger.debug('Received a unsolicited pong') + break + + try: + handler = self._request.on_pong_handler + if handler: + handler(self._request, message) + except AttributeError, e: + pass + def receive_message(self): """Receive a WebSocket frame and return its payload as a text in unicode or a binary in str. @@ -482,52 +777,12 @@ class Stream(StreamBase): 'Unsupported flag is set (rsv = %d%d%d)' % (frame.rsv1, frame.rsv2, frame.rsv3)) - if frame.opcode == common.OPCODE_CONTINUATION: - if not self._received_fragments: - if frame.fin: - raise InvalidFrameException( - 'Received a termination frame but fragmentation ' - 'not started') - else: - raise InvalidFrameException( - 'Received an intermediate frame but ' - 'fragmentation not started') - - if frame.fin: - # End of fragmentation frame - self._received_fragments.append(frame.payload) - message = ''.join(self._received_fragments) - self._received_fragments = [] - else: - # Intermediate frame - self._received_fragments.append(frame.payload) - continue - else: - if self._received_fragments: - if frame.fin: - raise InvalidFrameException( - 'Received an unfragmented frame without ' - 'terminating existing fragmentation') - else: - raise InvalidFrameException( - 'New fragmentation started without terminating ' - 'existing fragmentation') - - if frame.fin: - # Unfragmented frame - - self._original_opcode = frame.opcode - message = frame.payload - else: - # Start of fragmentation frame - - if common.is_control_opcode(frame.opcode): - raise InvalidFrameException( - 'Control frames must not be fragmented') + message = self._get_message_from_frame(frame) + if message is None: + continue - self._original_opcode = frame.opcode - self._received_fragments.append(frame.payload) - continue + for message_filter in self._options.incoming_message_filters: + message = message_filter.filter(message) if self._original_opcode == common.OPCODE_TEXT: # The WebSocket protocol section 4.4 specifies that invalid @@ -540,124 +795,21 @@ class Stream(StreamBase): elif self._original_opcode == common.OPCODE_BINARY: return message elif self._original_opcode == common.OPCODE_CLOSE: - self._request.client_terminated = True - - # Status code is optional. We can have status reason only if we - # have status code. Status reason can be empty string. So, - # allowed cases are - # - no application data: no code no reason - # - 2 octet of application data: has code but no reason - # - 3 or more octet of application data: both code and reason - if len(message) == 0: - self._logger.debug('Received close frame (empty body)') - self._request.ws_close_code = ( - common.STATUS_NO_STATUS_RECEIVED) - elif len(message) == 1: - raise InvalidFrameException( - 'If a close frame has status code, the length of ' - 'status code must be 2 octet') - elif len(message) >= 2: - self._request.ws_close_code = struct.unpack( - '!H', message[0:2])[0] - self._request.ws_close_reason = message[2:].decode( - 'utf-8', 'replace') - self._logger.debug( - 'Received close frame (code=%d, reason=%r)', - self._request.ws_close_code, - self._request.ws_close_reason) - - # Drain junk data after the close frame if necessary. - self._drain_received_data() - - if self._request.server_terminated: - self._logger.debug( - 'Received ack for server-initiated closing handshake') - return None - - self._logger.debug( - 'Received client-initiated closing handshake') - - code = common.STATUS_NORMAL_CLOSURE - reason = '' - if hasattr(self._request, '_dispatcher'): - dispatcher = self._request._dispatcher - code, reason = dispatcher.passive_closing_handshake( - self._request) - if code is None and reason is not None and len(reason) > 0: - self._logger.warning( - 'Handler specified reason despite code being None') - reason = '' - if reason is None: - reason = '' - self._send_closing_handshake(code, reason) - self._logger.debug( - 'Sent ack for client-initiated closing handshake ' - '(code=%r, reason=%r)', code, reason) + self._process_close_message(message) return None elif self._original_opcode == common.OPCODE_PING: - try: - handler = self._request.on_ping_handler - if handler: - handler(self._request, message) - continue - except AttributeError, e: - pass - self._send_pong(message) + self._process_ping_message(message) elif self._original_opcode == common.OPCODE_PONG: - # TODO(tyoshino): Add ping timeout handling. - - inflight_pings = deque() - - while True: - try: - expected_body = self._ping_queue.popleft() - if expected_body == message: - # inflight_pings contains pings ignored by the - # other peer. Just forget them. - self._logger.debug( - 'Ping %r is acked (%d pings were ignored)', - expected_body, len(inflight_pings)) - break - else: - inflight_pings.append(expected_body) - except IndexError, e: - # The received pong was unsolicited pong. Keep the - # ping queue as is. - self._ping_queue = inflight_pings - self._logger.debug('Received a unsolicited pong') - break - - try: - handler = self._request.on_pong_handler - if handler: - handler(self._request, message) - continue - except AttributeError, e: - pass - - continue + self._process_pong_message(message) else: raise UnsupportedFrameException( 'Opcode %d is not supported' % self._original_opcode) def _send_closing_handshake(self, code, reason): - body = '' - if code is not None: - if (code > common.STATUS_USER_PRIVATE_MAX or - code < common.STATUS_NORMAL_CLOSURE): - raise BadOperationException('Status code is out of range') - if (code == common.STATUS_NO_STATUS_RECEIVED or - code == common.STATUS_ABNORMAL_CLOSURE or - code == common.STATUS_TLS_HANDSHAKE): - raise BadOperationException('Status code is reserved pseudo ' - 'code') - encoded_reason = reason.encode('utf-8') - body = struct.pack('!H', code) + encoded_reason - + body = create_closing_handshake_body(code, reason) frame = create_close_frame( - body, - self._options.mask_send, - self._options.outgoing_frame_filters) + body, mask=self._options.mask_send, + frame_filters=self._options.outgoing_frame_filters) self._request.server_terminated = True @@ -731,6 +883,14 @@ class Stream(StreamBase): self._options.outgoing_frame_filters) self._write(frame) + def get_last_received_opcode(self): + """Returns the opcode of the WebSocket message which the last received + frame belongs to. The return value is valid iff immediately after + receive_message call. + """ + + return self._original_opcode + def _drain_received_data(self): """Drains unread data in the receive buffer to avoid sending out TCP RST packet. This is because when deflate-stream is enabled, some diff --git a/module/lib/mod_pywebsocket/common.py b/module/lib/mod_pywebsocket/common.py index 710967c80..2388379c0 100644 --- a/module/lib/mod_pywebsocket/common.py +++ b/module/lib/mod_pywebsocket/common.py @@ -104,7 +104,10 @@ SEC_WEBSOCKET_LOCATION_HEADER = 'Sec-WebSocket-Location' DEFLATE_STREAM_EXTENSION = 'deflate-stream' DEFLATE_FRAME_EXTENSION = 'deflate-frame' PERFRAME_COMPRESSION_EXTENSION = 'perframe-compress' +PERMESSAGE_COMPRESSION_EXTENSION = 'permessage-compress' X_WEBKIT_DEFLATE_FRAME_EXTENSION = 'x-webkit-deflate-frame' +X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION = 'x-webkit-permessage-compress' +MUX_EXTENSION = 'mux_DO_NOT_USE' # Status codes # Code STATUS_NO_STATUS_RECEIVED, STATUS_ABNORMAL_CLOSURE, and @@ -125,7 +128,7 @@ STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007 STATUS_POLICY_VIOLATION = 1008 STATUS_MESSAGE_TOO_BIG = 1009 STATUS_MANDATORY_EXTENSION = 1010 -STATUS_INTERNAL_SERVER_ERROR = 1011 +STATUS_INTERNAL_ENDPOINT_ERROR = 1011 STATUS_TLS_HANDSHAKE = 1015 STATUS_USER_REGISTERED_BASE = 3000 STATUS_USER_REGISTERED_MAX = 3999 diff --git a/module/lib/mod_pywebsocket/dispatch.py b/module/lib/mod_pywebsocket/dispatch.py index ab1eb4fb3..25905f180 100644 --- a/module/lib/mod_pywebsocket/dispatch.py +++ b/module/lib/mod_pywebsocket/dispatch.py @@ -39,6 +39,7 @@ import re from mod_pywebsocket import common from mod_pywebsocket import handshake from mod_pywebsocket import msgutil +from mod_pywebsocket import mux from mod_pywebsocket import stream from mod_pywebsocket import util @@ -277,13 +278,18 @@ class Dispatcher(object): AbortedByUserException: when user handler abort connection """ - handler_suite = self.get_handler_suite(request.ws_resource) - if handler_suite is None: - raise DispatchException('No handler for: %r' % request.ws_resource) - transfer_data_ = handler_suite.transfer_data # TODO(tyoshino): Terminate underlying TCP connection if possible. try: - transfer_data_(request) + if mux.use_mux(request): + mux.start(request, self) + else: + handler_suite = self.get_handler_suite(request.ws_resource) + if handler_suite is None: + raise DispatchException('No handler for: %r' % + request.ws_resource) + transfer_data_ = handler_suite.transfer_data + transfer_data_(request) + if not request.server_terminated: request.ws_stream.close_connection() # Catch non-critical exceptions the handler didn't handle. diff --git a/module/lib/mod_pywebsocket/extensions.py b/module/lib/mod_pywebsocket/extensions.py index 52b7a4a19..03dbf9ee1 100644 --- a/module/lib/mod_pywebsocket/extensions.py +++ b/module/lib/mod_pywebsocket/extensions.py @@ -38,6 +38,9 @@ _available_processors = {} class ExtensionProcessorInterface(object): + def name(self): + return None + def get_extension_response(self): return None @@ -46,13 +49,21 @@ class ExtensionProcessorInterface(object): class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): - """WebSocket DEFLATE stream extension processor.""" + """WebSocket DEFLATE stream extension processor. + + Specification: + Section 9.2.1 in + http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 + """ def __init__(self, request): self._logger = util.get_class_logger(self) self._request = request + def name(self): + return common.DEFLATE_STREAM_EXTENSION + def get_extension_response(self): if len(self._request.get_parameter_names()) != 0: return None @@ -70,8 +81,40 @@ _available_processors[common.DEFLATE_STREAM_EXTENSION] = ( DeflateStreamExtensionProcessor) +def _log_compression_ratio(logger, original_bytes, total_original_bytes, + filtered_bytes, total_filtered_bytes): + # Print inf when ratio is not available. + ratio = float('inf') + average_ratio = float('inf') + if original_bytes != 0: + ratio = float(filtered_bytes) / original_bytes + if total_original_bytes != 0: + average_ratio = ( + float(total_filtered_bytes) / total_original_bytes) + logger.debug('Outgoing compress ratio: %f (average: %f)' % + (ratio, average_ratio)) + + +def _log_decompression_ratio(logger, received_bytes, total_received_bytes, + filtered_bytes, total_filtered_bytes): + # Print inf when ratio is not available. + ratio = float('inf') + average_ratio = float('inf') + if received_bytes != 0: + ratio = float(received_bytes) / filtered_bytes + if total_filtered_bytes != 0: + average_ratio = ( + float(total_received_bytes) / total_filtered_bytes) + logger.debug('Incoming compress ratio: %f (average: %f)' % + (ratio, average_ratio)) + + class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): - """WebSocket Per-frame DEFLATE extension processor.""" + """WebSocket Per-frame DEFLATE extension processor. + + Specification: + http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate + """ _WINDOW_BITS_PARAM = 'max_window_bits' _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover' @@ -83,6 +126,7 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): self._response_window_bits = None self._response_no_context_takeover = False + self._bfinal = False # Counters for statistics. @@ -96,6 +140,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): # Total number of incoming bytes obtained after applying this filter. self._total_filtered_incoming_payload_bytes = 0 + def name(self): + return common.DEFLATE_FRAME_EXTENSION + def get_extension_response(self): # Any unknown parameter will be just ignored. @@ -173,6 +220,9 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): def set_response_no_context_takeover(self, value): self._response_no_context_takeover = value + def set_bfinal(self, value): + self._bfinal = value + def enable_outgoing_compression(self): self._compress_outgoing = True @@ -193,24 +243,17 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): original_payload_size) return - frame.payload = self._deflater.filter(frame.payload) + frame.payload = self._deflater.filter( + frame.payload, bfinal=self._bfinal) frame.rsv1 = 1 filtered_payload_size = len(frame.payload) self._total_filtered_outgoing_payload_bytes += filtered_payload_size - # Print inf when ratio is not available. - ratio = float('inf') - average_ratio = float('inf') - if original_payload_size != 0: - ratio = float(filtered_payload_size) / original_payload_size - if self._total_outgoing_payload_bytes != 0: - average_ratio = ( - float(self._total_filtered_outgoing_payload_bytes) / - self._total_outgoing_payload_bytes) - self._logger.debug( - 'Outgoing compress ratio: %f (average: %f)' % - (ratio, average_ratio)) + _log_compression_ratio(self._logger, original_payload_size, + self._total_outgoing_payload_bytes, + filtered_payload_size, + self._total_filtered_outgoing_payload_bytes) def _incoming_filter(self, frame): """Transform incoming frames. This method is called only by @@ -231,18 +274,10 @@ class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): filtered_payload_size = len(frame.payload) self._total_filtered_incoming_payload_bytes += filtered_payload_size - # Print inf when ratio is not available. - ratio = float('inf') - average_ratio = float('inf') - if received_payload_size != 0: - ratio = float(received_payload_size) / filtered_payload_size - if self._total_filtered_incoming_payload_bytes != 0: - average_ratio = ( - float(self._total_incoming_payload_bytes) / - self._total_filtered_incoming_payload_bytes) - self._logger.debug( - 'Incoming compress ratio: %f (average: %f)' % - (ratio, average_ratio)) + _log_decompression_ratio(self._logger, received_payload_size, + self._total_incoming_payload_bytes, + filtered_payload_size, + self._total_filtered_incoming_payload_bytes) _available_processors[common.DEFLATE_FRAME_EXTENSION] = ( @@ -250,7 +285,7 @@ _available_processors[common.DEFLATE_FRAME_EXTENSION] = ( # Adding vendor-prefixed deflate-frame extension. -# TODO(bashi): Remove this after WebKit stops using vender prefix. +# TODO(bashi): Remove this after WebKit stops using vendor prefix. _available_processors[common.X_WEBKIT_DEFLATE_FRAME_EXTENSION] = ( DeflateFrameExtensionProcessor) @@ -270,21 +305,22 @@ def _create_accepted_method_desc(method_name, method_params): return common.format_extension(extension) -class PerFrameCompressionExtensionProcessor(ExtensionProcessorInterface): - """WebSocket Per-frame compression extension processor.""" +class CompressionExtensionProcessorBase(ExtensionProcessorInterface): + """Base class for Per-frame and Per-message compression extension.""" _METHOD_PARAM = 'method' - _DEFLATE_METHOD = 'deflate' def __init__(self, request): self._logger = util.get_class_logger(self) self._request = request self._compression_method_name = None self._compression_processor = None + self._compression_processor_hook = None + + def name(self): + return '' def _lookup_compression_processor(self, method_desc): - if method_desc.name() == self._DEFLATE_METHOD: - return DeflateFrameExtensionProcessor(method_desc) return None def _get_compression_processor_response(self): @@ -311,6 +347,10 @@ class PerFrameCompressionExtensionProcessor(ExtensionProcessorInterface): break if compression_processor is None: return None + + if self._compression_processor_hook: + self._compression_processor_hook(compression_processor) + processor_response = compression_processor.get_extension_response() if processor_response is None: return None @@ -337,14 +377,345 @@ class PerFrameCompressionExtensionProcessor(ExtensionProcessorInterface): return self._compression_processor.setup_stream_options(stream_options) + def set_compression_processor_hook(self, hook): + self._compression_processor_hook = hook + def get_compression_processor(self): return self._compression_processor +class PerFrameCompressionExtensionProcessor(CompressionExtensionProcessorBase): + """WebSocket Per-frame compression extension processor. + + Specification: + http://tools.ietf.org/html/draft-ietf-hybi-websocket-perframe-compression + """ + + _DEFLATE_METHOD = 'deflate' + + def __init__(self, request): + CompressionExtensionProcessorBase.__init__(self, request) + + def name(self): + return common.PERFRAME_COMPRESSION_EXTENSION + + def _lookup_compression_processor(self, method_desc): + if method_desc.name() == self._DEFLATE_METHOD: + return DeflateFrameExtensionProcessor(method_desc) + return None + + _available_processors[common.PERFRAME_COMPRESSION_EXTENSION] = ( PerFrameCompressionExtensionProcessor) +class DeflateMessageProcessor(ExtensionProcessorInterface): + """Per-message deflate processor.""" + + _S2C_MAX_WINDOW_BITS_PARAM = 's2c_max_window_bits' + _S2C_NO_CONTEXT_TAKEOVER_PARAM = 's2c_no_context_takeover' + _C2S_MAX_WINDOW_BITS_PARAM = 'c2s_max_window_bits' + _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover' + + def __init__(self, request): + self._request = request + self._logger = util.get_class_logger(self) + + self._c2s_max_window_bits = None + self._c2s_no_context_takeover = False + self._bfinal = False + + self._compress_outgoing_enabled = False + + # True if a message is fragmented and compression is ongoing. + self._compress_ongoing = False + + # Counters for statistics. + + # Total number of outgoing bytes supplied to this filter. + self._total_outgoing_payload_bytes = 0 + # Total number of bytes sent to the network after applying this filter. + self._total_filtered_outgoing_payload_bytes = 0 + + # Total number of bytes received from the network. + self._total_incoming_payload_bytes = 0 + # Total number of incoming bytes obtained after applying this filter. + self._total_filtered_incoming_payload_bytes = 0 + + def name(self): + return 'deflate' + + def get_extension_response(self): + # Any unknown parameter will be just ignored. + + s2c_max_window_bits = self._request.get_parameter_value( + self._S2C_MAX_WINDOW_BITS_PARAM) + if s2c_max_window_bits is not None: + try: + s2c_max_window_bits = int(s2c_max_window_bits) + except ValueError, e: + return None + if s2c_max_window_bits < 8 or s2c_max_window_bits > 15: + return None + + s2c_no_context_takeover = self._request.has_parameter( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM) + if (s2c_no_context_takeover and + self._request.get_parameter_value( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM) is not None): + return None + + self._deflater = util._RFC1979Deflater( + s2c_max_window_bits, s2c_no_context_takeover) + + self._inflater = util._RFC1979Inflater() + + self._compress_outgoing_enabled = True + + response = common.ExtensionParameter(self._request.name()) + + if s2c_max_window_bits is not None: + response.add_parameter( + self._S2C_MAX_WINDOW_BITS_PARAM, str(s2c_max_window_bits)) + + if s2c_no_context_takeover: + response.add_parameter( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM, None) + + if self._c2s_max_window_bits is not None: + response.add_parameter( + self._C2S_MAX_WINDOW_BITS_PARAM, + str(self._c2s_max_window_bits)) + if self._c2s_no_context_takeover: + response.add_parameter( + self._C2S_NO_CONTEXT_TAKEOVER_PARAM, None) + + self._logger.debug( + 'Enable %s extension (' + 'request: s2c_max_window_bits=%s; s2c_no_context_takeover=%r, ' + 'response: c2s_max_window_bits=%s; c2s_no_context_takeover=%r)' % + (self._request.name(), + s2c_max_window_bits, + s2c_no_context_takeover, + self._c2s_max_window_bits, + self._c2s_no_context_takeover)) + + return response + + def setup_stream_options(self, stream_options): + class _OutgoingMessageFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, message, end=True, binary=False): + return self._parent._process_outgoing_message( + message, end, binary) + + class _IncomingMessageFilter(object): + + def __init__(self, parent): + self._parent = parent + self._decompress_next_message = False + + def decompress_next_message(self): + self._decompress_next_message = True + + def filter(self, message): + message = self._parent._process_incoming_message( + message, self._decompress_next_message) + self._decompress_next_message = False + return message + + self._outgoing_message_filter = _OutgoingMessageFilter(self) + self._incoming_message_filter = _IncomingMessageFilter(self) + stream_options.outgoing_message_filters.append( + self._outgoing_message_filter) + stream_options.incoming_message_filters.append( + self._incoming_message_filter) + + class _OutgoingFrameFilter(object): + + def __init__(self, parent): + self._parent = parent + self._set_compression_bit = False + + def set_compression_bit(self): + self._set_compression_bit = True + + def filter(self, frame): + self._parent._process_outgoing_frame( + frame, self._set_compression_bit) + self._set_compression_bit = False + + class _IncomingFrameFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, frame): + self._parent._process_incoming_frame(frame) + + self._outgoing_frame_filter = _OutgoingFrameFilter(self) + self._incoming_frame_filter = _IncomingFrameFilter(self) + stream_options.outgoing_frame_filters.append( + self._outgoing_frame_filter) + stream_options.incoming_frame_filters.append( + self._incoming_frame_filter) + + stream_options.encode_text_message_to_utf8 = False + + def set_c2s_max_window_bits(self, value): + self._c2s_max_window_bits = value + + def set_c2s_no_context_takeover(self, value): + self._c2s_no_context_takeover = value + + def set_bfinal(self, value): + self._bfinal = value + + def enable_outgoing_compression(self): + self._compress_outgoing_enabled = True + + def disable_outgoing_compression(self): + self._compress_outgoing_enabled = False + + def _process_incoming_message(self, message, decompress): + if not decompress: + return message + + received_payload_size = len(message) + self._total_incoming_payload_bytes += received_payload_size + + message = self._inflater.filter(message) + + filtered_payload_size = len(message) + self._total_filtered_incoming_payload_bytes += filtered_payload_size + + _log_decompression_ratio(self._logger, received_payload_size, + self._total_incoming_payload_bytes, + filtered_payload_size, + self._total_filtered_incoming_payload_bytes) + + return message + + def _process_outgoing_message(self, message, end, binary): + if not binary: + message = message.encode('utf-8') + + if not self._compress_outgoing_enabled: + return message + + original_payload_size = len(message) + self._total_outgoing_payload_bytes += original_payload_size + + message = self._deflater.filter( + message, flush=end, bfinal=self._bfinal) + + filtered_payload_size = len(message) + self._total_filtered_outgoing_payload_bytes += filtered_payload_size + + _log_compression_ratio(self._logger, original_payload_size, + self._total_outgoing_payload_bytes, + filtered_payload_size, + self._total_filtered_outgoing_payload_bytes) + + if not self._compress_ongoing: + self._outgoing_frame_filter.set_compression_bit() + self._compress_ongoing = not end + return message + + def _process_incoming_frame(self, frame): + if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode): + self._incoming_message_filter.decompress_next_message() + frame.rsv1 = 0 + + def _process_outgoing_frame(self, frame, compression_bit): + if (not compression_bit or + common.is_control_opcode(frame.opcode)): + return + + frame.rsv1 = 1 + + +class PerMessageCompressionExtensionProcessor( + CompressionExtensionProcessorBase): + """WebSocket Per-message compression extension processor. + + Specification: + http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression + """ + + _DEFLATE_METHOD = 'deflate' + + def __init__(self, request): + CompressionExtensionProcessorBase.__init__(self, request) + + def name(self): + return common.PERMESSAGE_COMPRESSION_EXTENSION + + def _lookup_compression_processor(self, method_desc): + if method_desc.name() == self._DEFLATE_METHOD: + return DeflateMessageProcessor(method_desc) + return None + + +_available_processors[common.PERMESSAGE_COMPRESSION_EXTENSION] = ( + PerMessageCompressionExtensionProcessor) + + +# Adding vendor-prefixed permessage-compress extension. +# TODO(bashi): Remove this after WebKit stops using vendor prefix. +_available_processors[common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION] = ( + PerMessageCompressionExtensionProcessor) + + +class MuxExtensionProcessor(ExtensionProcessorInterface): + """WebSocket multiplexing extension processor.""" + + _QUOTA_PARAM = 'quota' + + def __init__(self, request): + self._request = request + + def name(self): + return common.MUX_EXTENSION + + def get_extension_response(self, ws_request, + logical_channel_extensions): + # Mux extension cannot be used after extensions that depend on + # frame boundary, extension data field, or any reserved bits + # which are attributed to each frame. + for extension in logical_channel_extensions: + name = extension.name() + if (name == common.PERFRAME_COMPRESSION_EXTENSION or + name == common.DEFLATE_FRAME_EXTENSION or + name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): + return None + + quota = self._request.get_parameter_value(self._QUOTA_PARAM) + if quota is None: + ws_request.mux_quota = 0 + else: + try: + quota = int(quota) + except ValueError, e: + return None + if quota < 0 or quota >= 2 ** 32: + return None + ws_request.mux_quota = quota + + ws_request.mux = True + ws_request.mux_extensions = logical_channel_extensions + return common.ExtensionParameter(common.MUX_EXTENSION) + + def setup_stream_options(self, stream_options): + pass + + +_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor + + def get_extension_processor(extension_request): global _available_processors processor_class = _available_processors.get(extension_request.name()) diff --git a/module/lib/mod_pywebsocket/handshake/__init__.py b/module/lib/mod_pywebsocket/handshake/__init__.py index 10a178314..194f6b395 100644 --- a/module/lib/mod_pywebsocket/handshake/__init__.py +++ b/module/lib/mod_pywebsocket/handshake/__init__.py @@ -37,7 +37,6 @@ successfully established. import logging from mod_pywebsocket import common -from mod_pywebsocket.handshake import draft75 from mod_pywebsocket.handshake import hybi00 from mod_pywebsocket.handshake import hybi # Export AbortedByUserException, HandshakeException, and VersionException @@ -56,10 +55,8 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False): Args: request: mod_python request. dispatcher: Dispatcher (dispatch.Dispatcher). - allowDraft75: allow draft 75 handshake protocol. - strict: Strictly check handshake request in draft 75. - Default: False. If True, request.connection must provide - get_memorized_lines method. + allowDraft75: obsolete argument. ignored. + strict: obsolete argument. ignored. Handshaker will add attributes such as ws_resource in performing handshake. @@ -86,9 +83,6 @@ def do_handshake(request, dispatcher, allowDraft75=False, strict=False): ('RFC 6455', hybi.Handshaker(request, dispatcher))) handshakers.append( ('HyBi 00', hybi00.Handshaker(request, dispatcher))) - if allowDraft75: - handshakers.append( - ('Hixie 75', draft75.Handshaker(request, dispatcher, strict))) for name, handshaker in handshakers: _LOGGER.debug('Trying protocol version %s', name) diff --git a/module/lib/mod_pywebsocket/handshake/_base.py b/module/lib/mod_pywebsocket/handshake/_base.py index bc095b129..e5c94ca90 100644 --- a/module/lib/mod_pywebsocket/handshake/_base.py +++ b/module/lib/mod_pywebsocket/handshake/_base.py @@ -85,13 +85,16 @@ def get_default_port(is_secure): def validate_subprotocol(subprotocol, hixie): - """Validate a value in subprotocol fields such as WebSocket-Protocol, - Sec-WebSocket-Protocol. + """Validate a value in the Sec-WebSocket-Protocol field. See - RFC 6455: Section 4.1., 4.2.2., and 4.3. - HyBi 00: Section 4.1. Opening handshake - - Hixie 75: Section 4.1. Handshake + + Args: + hixie: if True, checks if characters in subprotocol are in range + between U+0020 and U+007E. It's required by HyBi 00 but not by + RFC 6455. """ if not subprotocol: @@ -170,7 +173,11 @@ def check_request_line(request): # 5.1 1. The three character UTF-8 string "GET". # 5.1 2. A UTF-8-encoded U+0020 SPACE character (0x20 byte). if request.method != 'GET': - raise HandshakeException('Method is not GET') + raise HandshakeException('Method is not GET: %r' % request.method) + + if request.protocol != 'HTTP/1.1': + raise HandshakeException('Version is not HTTP/1.1: %r' % + request.protocol) def check_header_lines(request, mandatory_headers): diff --git a/module/lib/mod_pywebsocket/handshake/draft75.py b/module/lib/mod_pywebsocket/handshake/draft75.py deleted file mode 100644 index 802a31c9a..000000000 --- a/module/lib/mod_pywebsocket/handshake/draft75.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2011, Google Inc. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -"""WebSocket handshaking defined in draft-hixie-thewebsocketprotocol-75.""" - - -# Note: request.connection.write is used in this module, even though mod_python -# document says that it should be used only in connection handlers. -# Unfortunately, we have no other options. For example, request.write is not -# suitable because it doesn't allow direct raw bytes writing. - - -import logging -import re - -from mod_pywebsocket import common -from mod_pywebsocket.stream import StreamHixie75 -from mod_pywebsocket import util -from mod_pywebsocket.handshake._base import HandshakeException -from mod_pywebsocket.handshake._base import build_location -from mod_pywebsocket.handshake._base import validate_subprotocol - - -_MANDATORY_HEADERS = [ - # key, expected value or None - ['Upgrade', 'WebSocket'], - ['Connection', 'Upgrade'], - ['Host', None], - ['Origin', None], -] - -_FIRST_FIVE_LINES = map(re.compile, [ - r'^GET /[\S]* HTTP/1.1\r\n$', - r'^Upgrade: WebSocket\r\n$', - r'^Connection: Upgrade\r\n$', - r'^Host: [\S]+\r\n$', - r'^Origin: [\S]+\r\n$', -]) - -_SIXTH_AND_LATER = re.compile( - r'^' - r'(WebSocket-Protocol: [\x20-\x7e]+\r\n)?' - r'(Cookie: [^\r]*\r\n)*' - r'(Cookie2: [^\r]*\r\n)?' - r'(Cookie: [^\r]*\r\n)*' - r'\r\n') - - -class Handshaker(object): - """This class performs WebSocket handshake.""" - - def __init__(self, request, dispatcher, strict=False): - """Construct an instance. - - Args: - request: mod_python request. - dispatcher: Dispatcher (dispatch.Dispatcher). - strict: Strictly check handshake request. Default: False. - If True, request.connection must provide get_memorized_lines - method. - - Handshaker will add attributes such as ws_resource in performing - handshake. - """ - - self._logger = util.get_class_logger(self) - - self._request = request - self._dispatcher = dispatcher - self._strict = strict - - def do_handshake(self): - """Perform WebSocket Handshake. - - On _request, we set - ws_resource, ws_origin, ws_location, ws_protocol - ws_challenge_md5: WebSocket handshake information. - ws_stream: Frame generation/parsing class. - ws_version: Protocol version. - """ - - self._check_header_lines() - self._set_resource() - self._set_origin() - self._set_location() - self._set_subprotocol() - self._set_protocol_version() - - self._dispatcher.do_extra_handshake(self._request) - - self._send_handshake() - - self._logger.debug('Sent opening handshake response') - - def _set_resource(self): - self._request.ws_resource = self._request.uri - - def _set_origin(self): - self._request.ws_origin = self._request.headers_in['Origin'] - - def _set_location(self): - self._request.ws_location = build_location(self._request) - - def _set_subprotocol(self): - subprotocol = self._request.headers_in.get('WebSocket-Protocol') - if subprotocol is not None: - validate_subprotocol(subprotocol, hixie=True) - self._request.ws_protocol = subprotocol - - def _set_protocol_version(self): - self._logger.debug('IETF Hixie 75 protocol') - self._request.ws_version = common.VERSION_HIXIE75 - self._request.ws_stream = StreamHixie75(self._request) - - def _sendall(self, data): - self._request.connection.write(data) - - def _send_handshake(self): - self._sendall('HTTP/1.1 101 Web Socket Protocol Handshake\r\n') - self._sendall('Upgrade: WebSocket\r\n') - self._sendall('Connection: Upgrade\r\n') - self._sendall('WebSocket-Origin: %s\r\n' % self._request.ws_origin) - self._sendall('WebSocket-Location: %s\r\n' % self._request.ws_location) - if self._request.ws_protocol: - self._sendall( - 'WebSocket-Protocol: %s\r\n' % self._request.ws_protocol) - self._sendall('\r\n') - - def _check_header_lines(self): - for key, expected_value in _MANDATORY_HEADERS: - actual_value = self._request.headers_in.get(key) - if not actual_value: - raise HandshakeException('Header %s is not defined' % key) - if expected_value: - if actual_value != expected_value: - raise HandshakeException( - 'Expected %r for header %s but found %r' % - (expected_value, key, actual_value)) - if self._strict: - try: - lines = self._request.connection.get_memorized_lines() - except AttributeError, e: - raise AttributeError( - 'Strict handshake is specified but the connection ' - 'doesn\'t provide get_memorized_lines()') - self._check_first_lines(lines) - - def _check_first_lines(self, lines): - if len(lines) < len(_FIRST_FIVE_LINES): - raise HandshakeException('Too few header lines: %d' % len(lines)) - for line, regexp in zip(lines, _FIRST_FIVE_LINES): - if not regexp.search(line): - raise HandshakeException( - 'Unexpected header: %r doesn\'t match %r' - % (line, regexp.pattern)) - sixth_and_later = ''.join(lines[5:]) - if not _SIXTH_AND_LATER.search(sixth_and_later): - raise HandshakeException( - 'Unexpected header: %r doesn\'t match %r' - % (sixth_and_later, _SIXTH_AND_LATER.pattern)) - - -# vi:sts=4 sw=4 et diff --git a/module/lib/mod_pywebsocket/handshake/hybi.py b/module/lib/mod_pywebsocket/handshake/hybi.py index 2883acbf8..fc0e2a096 100644 --- a/module/lib/mod_pywebsocket/handshake/hybi.py +++ b/module/lib/mod_pywebsocket/handshake/hybi.py @@ -182,34 +182,60 @@ class Handshaker(object): # Extra handshake handler may modify/remove processors. self._dispatcher.do_extra_handshake(self._request) + processors = filter(lambda processor: processor is not None, + self._request.ws_extension_processors) + + accepted_extensions = [] + + # We need to take care of mux extension here. Extensions that + # are placed before mux should be applied to logical channels. + mux_index = -1 + for i, processor in enumerate(processors): + if processor.name() == common.MUX_EXTENSION: + mux_index = i + break + if mux_index >= 0: + mux_processor = processors[mux_index] + logical_channel_processors = processors[:mux_index] + processors = processors[mux_index+1:] + + for processor in logical_channel_processors: + extension_response = processor.get_extension_response() + if extension_response is None: + # Rejected. + continue + accepted_extensions.append(extension_response) + # Pass a shallow copy of accepted_extensions as extensions for + # logical channels. + mux_response = mux_processor.get_extension_response( + self._request, accepted_extensions[:]) + if mux_response is not None: + accepted_extensions.append(mux_response) stream_options = StreamOptions() - self._request.ws_extensions = None - for processor in self._request.ws_extension_processors: - if processor is None: - # Some processors may be removed by extra handshake - # handler. - continue + # When there is mux extension, here, |processors| contain only + # prosessors for extensions placed after mux. + for processor in processors: extension_response = processor.get_extension_response() if extension_response is None: # Rejected. continue - if self._request.ws_extensions is None: - self._request.ws_extensions = [] - self._request.ws_extensions.append(extension_response) + accepted_extensions.append(extension_response) processor.setup_stream_options(stream_options) - if self._request.ws_extensions is not None: + if len(accepted_extensions) > 0: + self._request.ws_extensions = accepted_extensions self._logger.debug( 'Extensions accepted: %r', - map(common.ExtensionParameter.name, - self._request.ws_extensions)) + map(common.ExtensionParameter.name, accepted_extensions)) + else: + self._request.ws_extensions = None - self._request.ws_stream = Stream(self._request, stream_options) + self._request.ws_stream = self._create_stream(stream_options) if self._request.ws_requested_protocols is not None: if self._request.ws_protocol is None: @@ -268,7 +294,7 @@ class Handshaker(object): protocol_header = self._request.headers_in.get( common.SEC_WEBSOCKET_PROTOCOL_HEADER) - if not protocol_header: + if protocol_header is None: self._request.ws_requested_protocols = None return @@ -341,7 +367,10 @@ class Handshaker(object): return key - def _send_handshake(self, accept): + def _create_stream(self, stream_options): + return Stream(self._request, stream_options) + + def _create_handshake_response(self, accept): response = [] response.append('HTTP/1.1 101 Switching Protocols\r\n') @@ -363,7 +392,10 @@ class Handshaker(object): common.format_extensions(self._request.ws_extensions))) response.append('\r\n') - raw_response = ''.join(response) + return ''.join(response) + + def _send_handshake(self, accept): + raw_response = self._create_handshake_response(accept) self._request.connection.write(raw_response) self._logger.debug('Sent server\'s opening handshake: %r', raw_response) diff --git a/module/lib/mod_pywebsocket/headerparserhandler.py b/module/lib/mod_pywebsocket/headerparserhandler.py index b68c240e1..2cc62de04 100644 --- a/module/lib/mod_pywebsocket/headerparserhandler.py +++ b/module/lib/mod_pywebsocket/headerparserhandler.py @@ -63,8 +63,9 @@ _PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT = ( _PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION = { 'off': False, 'no': False, 'on': True, 'yes': True} -# PythonOption to specify to allow draft75 handshake. -# The default is None (Off) +# (Obsolete option. Ignored.) +# PythonOption to specify to allow handshake defined in Hixie 75 version +# protocol. The default is None (Off) _PYOPT_ALLOW_DRAFT75 = 'mod_pywebsocket.allow_draft75' # Map from values to their meanings. _PYOPT_ALLOW_DRAFT75_DEFINITION = {'off': False, 'on': True} diff --git a/module/lib/mod_pywebsocket/msgutil.py b/module/lib/mod_pywebsocket/msgutil.py index 21ffdacf6..4c1a0114b 100644 --- a/module/lib/mod_pywebsocket/msgutil.py +++ b/module/lib/mod_pywebsocket/msgutil.py @@ -59,20 +59,20 @@ def close_connection(request): request.ws_stream.close_connection() -def send_message(request, message, end=True, binary=False): - """Send message. +def send_message(request, payload_data, end=True, binary=False): + """Send a message (or part of a message). Args: request: mod_python request. - message: unicode text or str binary to send. - end: False to send message as a fragment. All messages until the - first call with end=True (inclusive) will be delivered to the - client in separate frames but as one WebSocket message. - binary: send message as binary frame. + payload_data: unicode text or str binary to send. + end: True to terminate a message. + False to send payload_data as part of a message that is to be + terminated by next or later send_message call with end=True. + binary: send payload_data as binary frame(s). Raises: BadOperationException: when server already terminated. """ - request.ws_stream.send_message(message, end, binary) + request.ws_stream.send_message(payload_data, end, binary) def receive_message(request): diff --git a/module/lib/mod_pywebsocket/standalone.py b/module/lib/mod_pywebsocket/standalone.py index 850aa5cd4..07a33d9c9 100755 --- a/module/lib/mod_pywebsocket/standalone.py +++ b/module/lib/mod_pywebsocket/standalone.py @@ -32,27 +32,44 @@ """Standalone WebSocket server. +Use this file to launch pywebsocket without Apache HTTP Server. + + BASIC USAGE -Use this server to run mod_pywebsocket without Apache HTTP Server. +Go to the src directory and run -Usage: - python standalone.py [-p <ws_port>] [-w <websock_handlers>] - [-s <scan_dir>] - [-d <document_root>] - [-m <websock_handlers_map_file>] - ... for other options, see _main below ... + $ python mod_pywebsocket/standalone.py [-p <ws_port>] + [-w <websock_handlers>] + [-d <document_root>] <ws_port> is the port number to use for ws:// connection. <document_root> is the path to the root directory of HTML files. <websock_handlers> is the path to the root directory of WebSocket handlers. -See __init__.py for details of <websock_handlers> and how to write WebSocket -handlers. If this path is relative, <document_root> is used as the base. +If not specified, <document_root> will be used. See __init__.py (or +run $ pydoc mod_pywebsocket) for how to write WebSocket handlers. + +For more detail and other options, run + + $ python mod_pywebsocket/standalone.py --help + +or see _build_option_parser method below. + +For trouble shooting, adding "--log_level debug" might help you. + -<scan_dir> is a path under the root directory. If specified, only the -handlers under scan_dir are scanned. This is useful in saving scan time. +TRY DEMO + +Go to the src directory and run + + $ python standalone.py -d example + +to launch pywebsocket with the sample handler and html on port 80. Open +http://localhost/console.html, click the connect button, type something into +the text box next to the send button and click the send button. If everything +is working, you'll see the message you typed echoed by the server. SUPPORTING TLS @@ -63,10 +80,10 @@ To support TLS, run standalone.py with -t, -k, and -c options. SUPPORTING CLIENT AUTHENTICATION To support client authentication with TLS, run standalone.py with -t, -k, -c, -and --ca-certificate options. +and --tls-client-auth, and --tls-client-ca options. E.g., $./standalone.py -d ../example -p 10443 -t -c ../test/cert/cert.pem -k -../test/cert/key.pem --ca-certificate=../test/cert/cacert.pem +../test/cert/key.pem --tls-client-auth --tls-client-ca=../test/cert/cacert.pem CONFIGURATION FILE @@ -110,6 +127,7 @@ import CGIHTTPServer import SimpleHTTPServer import SocketServer import ConfigParser +import base64 import httplib import logging import logging.handlers @@ -224,6 +242,12 @@ class _StandaloneRequest(object): return self._request_handler.command method = property(get_method) + def get_protocol(self): + """Getter to mimic request.protocol.""" + + return self._request_handler.request_version + protocol = property(get_protocol) + def is_https(self): """Mimic request.is_https().""" @@ -264,6 +288,32 @@ class _StandaloneSSLConnection(object): return socket._fileobject(self._connection, mode, bufsize) +def _alias_handlers(dispatcher, websock_handlers_map_file): + """Set aliases specified in websock_handler_map_file in dispatcher. + + Args: + dispatcher: dispatch.Dispatcher instance + websock_handler_map_file: alias map file + """ + + fp = open(websock_handlers_map_file) + try: + for line in fp: + if line[0] == '#' or line.isspace(): + continue + m = re.match('(\S+)\s+(\S+)', line) + if not m: + logging.warning('Wrong format in map file:' + line) + continue + try: + dispatcher.add_resource_path_alias( + m.group(1), m.group(2)) + except dispatch.DispatchException, e: + logging.error(str(e)) + finally: + fp.close() + + class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): """HTTPServer specialized for WebSocket.""" @@ -278,6 +328,20 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): if necessary. """ + # Share a Dispatcher among request handlers to save time for + # instantiation. Dispatcher can be shared because it is thread-safe. + options.dispatcher = dispatch.Dispatcher( + options.websock_handlers, + options.scan_dir, + options.allow_handlers_outside_root_dir) + if options.websock_handlers_map_file: + _alias_handlers(options.dispatcher, + options.websock_handlers_map_file) + warnings = options.dispatcher.source_warnings() + if warnings: + for warning in warnings: + logging.warning('mod_pywebsocket: %s' % warning) + self._logger = util.get_class_logger(self) self.request_queue_size = options.request_queue_size @@ -325,7 +389,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): continue if self.websocket_server_options.use_tls: if _HAS_SSL: - if self.websocket_server_options.ca_certificate: + if self.websocket_server_options.tls_client_auth: client_cert_ = ssl.CERT_REQUIRED else: client_cert_ = ssl.CERT_NONE @@ -333,7 +397,7 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): keyfile=self.websocket_server_options.private_key, certfile=self.websocket_server_options.certificate, ssl_version=ssl.PROTOCOL_SSLv23, - ca_certs=self.websocket_server_options.ca_certificate, + ca_certs=self.websocket_server_options.tls_client_ca, cert_reqs=client_cert_) if _HAS_OPEN_SSL: ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) @@ -362,6 +426,15 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): self._logger.info('Skip by failure: %r', e) socket_.close() failed_sockets.append(socketinfo) + if self.server_address[1] == 0: + # The operating system assigns the actual port number for port + # number 0. This case, the second and later sockets should use + # the same port number. Also self.server_port is rewritten + # because it is exported, and will be used by external code. + self.server_address = ( + self.server_name, socket_.getsockname()[1]) + self.server_port = self.server_address[1] + self._logger.info('Port %r is assigned', self.server_port) for socketinfo in failed_sockets: self._sockets.remove(socketinfo) @@ -386,6 +459,10 @@ class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): for socketinfo in failed_sockets: self._sockets.remove(socketinfo) + if len(self._sockets) == 0: + self._logger.critical( + 'No sockets activated. Use info log level to see the reason.') + def server_close(self): """Override SocketServer.TCPServer.server_close to enable multiple sockets close. @@ -513,6 +590,17 @@ class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): # attributes). if not CGIHTTPServer.CGIHTTPRequestHandler.parse_request(self): return False + + if self._options.use_basic_auth: + auth = self.headers.getheader('Authorization') + if auth != self._options.basic_auth_credential: + self.send_response(401) + self.send_header('WWW-Authenticate', + 'Basic realm="Pywebsocket"') + self.end_headers() + self._logger.info('Request basic authentication') + return True + host, port, resource = http_header_util.parse_uri(self.path) if resource is None: self._logger.info('Invalid URI: %r', self.path) @@ -648,32 +736,6 @@ def _configure_logging(options): deflate_log_level_name) -def _alias_handlers(dispatcher, websock_handlers_map_file): - """Set aliases specified in websock_handler_map_file in dispatcher. - - Args: - dispatcher: dispatch.Dispatcher instance - websock_handler_map_file: alias map file - """ - - fp = open(websock_handlers_map_file) - try: - for line in fp: - if line[0] == '#' or line.isspace(): - continue - m = re.match('(\S+)\s+(\S+)', line) - if not m: - logging.warning('Wrong format in map file:' + line) - continue - try: - dispatcher.add_resource_path_alias( - m.group(1), m.group(2)) - except dispatch.DispatchException, e: - logging.error(str(e)) - finally: - fp.close() - - def _build_option_parser(): parser = optparse.OptionParser() @@ -700,7 +762,9 @@ def _build_option_parser(): parser.add_option('-w', '--websock-handlers', '--websock_handlers', dest='websock_handlers', default='.', - help='WebSocket handlers root directory.') + help=('The root directory of WebSocket handler files. ' + 'If the path is relative, --document-root is used ' + 'as the base.')) parser.add_option('-m', '--websock-handlers-map-file', '--websock_handlers_map_file', dest='websock_handlers_map_file', @@ -710,15 +774,20 @@ def _build_option_parser(): 'existing_resource_path, separated by spaces.')) parser.add_option('-s', '--scan-dir', '--scan_dir', dest='scan_dir', default=None, - help=('WebSocket handlers scan directory. ' - 'Must be a directory under websock_handlers.')) + help=('Must be a directory under --websock-handlers. ' + 'Only handlers under this directory are scanned ' + 'and registered to the server. ' + 'Useful for saving scan time when the handler ' + 'root directory contains lots of files that are ' + 'not handler file or are handler files but you ' + 'don\'t want them to be registered. ')) parser.add_option('--allow-handlers-outside-root-dir', '--allow_handlers_outside_root_dir', dest='allow_handlers_outside_root_dir', action='store_true', default=False, help=('Scans WebSocket handlers even if their canonical ' - 'path is not under websock_handlers.')) + 'path is not under --websock-handlers.')) parser.add_option('-d', '--document-root', '--document_root', dest='document_root', default='.', help='Document root directory.') @@ -735,9 +804,20 @@ def _build_option_parser(): default='', help='TLS private key file.') parser.add_option('-c', '--certificate', dest='certificate', default='', help='TLS certificate file.') - parser.add_option('--ca-certificate', dest='ca_certificate', default='', - help=('TLS CA certificate file for client ' - 'authentication.')) + parser.add_option('--tls-client-auth', dest='tls_client_auth', + action='store_true', default=False, + help='Requires TLS client auth on every connection.') + parser.add_option('--tls-client-ca', dest='tls_client_ca', default='', + help=('Specifies a pem file which contains a set of ' + 'concatenated CA certificates which are used to ' + 'validate certificates passed from clients')) + parser.add_option('--basic-auth', dest='use_basic_auth', + action='store_true', default=False, + help='Requires Basic authentication.') + parser.add_option('--basic-auth-credential', + dest='basic_auth_credential', default='test:test', + help='Specifies the credential of basic authentication ' + 'by username:password pair (e.g. test:test).') parser.add_option('-l', '--log-file', '--log_file', dest='log_file', default='', help='Log file.') # Custom log level: @@ -771,9 +851,9 @@ def _build_option_parser(): help='Log backup count') parser.add_option('--allow-draft75', dest='allow_draft75', action='store_true', default=False, - help='Allow draft 75 handshake') + help='Obsolete option. Ignored.') parser.add_option('--strict', dest='strict', action='store_true', - default=False, help='Strictly check handshake request') + default=False, help='Obsolete option. Ignored.') parser.add_option('-q', '--queue', dest='request_queue_size', type='int', default=_DEFAULT_REQUEST_QUEUE_SIZE, help='request queue size') @@ -841,6 +921,12 @@ def _parse_args_and_config(args): def _main(args=None): + """You can call this function from your own program, but please note that + this function has some side-effects that might affect your program. For + example, util.wrap_popen3_for_win use in this method replaces implementation + of os.popen3. + """ + options, args = _parse_args_and_config(args=args) os.chdir(options.document_root) @@ -877,7 +963,7 @@ def _main(args=None): 'To use TLS, specify private_key and certificate.') sys.exit(1) - if options.ca_certificate: + if options.tls_client_auth: if not options.use_tls: logging.critical('TLS must be enabled for client authentication.') sys.exit(1) @@ -887,26 +973,16 @@ def _main(args=None): if not options.scan_dir: options.scan_dir = options.websock_handlers + if options.use_basic_auth: + options.basic_auth_credential = 'Basic ' + base64.b64encode( + options.basic_auth_credential) + try: if options.thread_monitor_interval_in_sec > 0: # Run a thread monitor to show the status of server threads for # debugging. ThreadMonitor(options.thread_monitor_interval_in_sec).start() - # Share a Dispatcher among request handlers to save time for - # instantiation. Dispatcher can be shared because it is thread-safe. - options.dispatcher = dispatch.Dispatcher( - options.websock_handlers, - options.scan_dir, - options.allow_handlers_outside_root_dir) - if options.websock_handlers_map_file: - _alias_handlers(options.dispatcher, - options.websock_handlers_map_file) - warnings = options.dispatcher.source_warnings() - if warnings: - for warning in warnings: - logging.warning('mod_pywebsocket: %s' % warning) - server = WebSocketServer(options) server.serve_forever() except Exception, e: diff --git a/module/lib/mod_pywebsocket/stream.py b/module/lib/mod_pywebsocket/stream.py index d051eee20..edc533279 100644 --- a/module/lib/mod_pywebsocket/stream.py +++ b/module/lib/mod_pywebsocket/stream.py @@ -51,6 +51,7 @@ from mod_pywebsocket._stream_hybi import create_ping_frame from mod_pywebsocket._stream_hybi import create_pong_frame from mod_pywebsocket._stream_hybi import create_binary_frame from mod_pywebsocket._stream_hybi import create_text_frame +from mod_pywebsocket._stream_hybi import create_closing_handshake_body # vi:sts=4 sw=4 et diff --git a/module/lib/mod_pywebsocket/util.py b/module/lib/mod_pywebsocket/util.py index d60b53f75..7bb0b5d9e 100644 --- a/module/lib/mod_pywebsocket/util.py +++ b/module/lib/mod_pywebsocket/util.py @@ -31,6 +31,8 @@ """WebSocket utilities. """ + +import array import errno # Import hash classes from a module available and recommended for each Python @@ -160,6 +162,34 @@ class NoopMasker(object): return s +class RepeatedXorMasker(object): + """A masking object that applies XOR on the string given to mask method + with the masking bytes given to the constructor repeatedly. This object + remembers the position in the masking bytes the last mask method call + ended and resumes from that point on the next mask method call. + """ + + def __init__(self, mask): + self._mask = map(ord, mask) + self._mask_size = len(self._mask) + self._count = 0 + + def mask(self, s): + result = array.array('B') + result.fromstring(s) + # Use temporary local variables to eliminate the cost to access + # attributes + count = self._count + mask = self._mask + mask_size = self._mask_size + for i in xrange(len(result)): + result[i] ^= mask[count] + count = (count + 1) % mask_size + self._count = count + + return result.tostring() + + class DeflateRequest(object): """A wrapper class for request object to intercept send and recv to perform deflate compression and decompression transparently. @@ -202,6 +232,12 @@ class _Deflater(object): self._compress = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits) + def compress(self, bytes): + compressed_bytes = self._compress.compress(bytes) + self._logger.debug('Compress input %r', bytes) + self._logger.debug('Compress result %r', compressed_bytes) + return compressed_bytes + def compress_and_flush(self, bytes): compressed_bytes = self._compress.compress(bytes) compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH) @@ -209,6 +245,12 @@ class _Deflater(object): self._logger.debug('Compress result %r', compressed_bytes) return compressed_bytes + def compress_and_finish(self, bytes): + compressed_bytes = self._compress.compress(bytes) + compressed_bytes += self._compress.flush(zlib.Z_FINISH) + self._logger.debug('Compress input %r', bytes) + self._logger.debug('Compress result %r', compressed_bytes) + return compressed_bytes class _Inflater(object): @@ -288,14 +330,21 @@ class _RFC1979Deflater(object): self._window_bits = window_bits self._no_context_takeover = no_context_takeover - def filter(self, bytes): - if self._deflater is None or self._no_context_takeover: + def filter(self, bytes, flush=True, bfinal=False): + if self._deflater is None or (self._no_context_takeover and flush): self._deflater = _Deflater(self._window_bits) - # Strip last 4 octets which is LEN and NLEN field of a non-compressed - # block added for Z_SYNC_FLUSH. - return self._deflater.compress_and_flush(bytes)[:-4] - + if bfinal: + result = self._deflater.compress_and_finish(bytes) + # Add a padding block with BFINAL = 0 and BTYPE = 0. + result = result + chr(0) + self._deflater = None + return result + if flush: + # Strip last 4 octets which is LEN and NLEN field of a + # non-compressed block added for Z_SYNC_FLUSH. + return self._deflater.compress_and_flush(bytes)[:-4] + return self._deflater.compress(bytes) class _RFC1979Inflater(object): """A decompressor class for byte sequence compressed and flushed following |