# Copyright 2012, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """This file provides classes and helper functions for multiplexing extension. Specification: http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-06 """ import collections import copy import email import email.parser import logging import math import struct import threading import traceback from mod_pywebsocket import common from mod_pywebsocket import handshake from mod_pywebsocket import util from mod_pywebsocket._stream_base import BadOperationException from mod_pywebsocket._stream_base import ConnectionTerminatedException from mod_pywebsocket._stream_hybi import Frame from mod_pywebsocket._stream_hybi import Stream from mod_pywebsocket._stream_hybi import StreamOptions from mod_pywebsocket._stream_hybi import create_binary_frame from mod_pywebsocket._stream_hybi import create_closing_handshake_body from mod_pywebsocket._stream_hybi import create_header from mod_pywebsocket._stream_hybi import create_length_header from mod_pywebsocket._stream_hybi import parse_frame from mod_pywebsocket.handshake import hybi _CONTROL_CHANNEL_ID = 0 _DEFAULT_CHANNEL_ID = 1 _MUX_OPCODE_ADD_CHANNEL_REQUEST = 0 _MUX_OPCODE_ADD_CHANNEL_RESPONSE = 1 _MUX_OPCODE_FLOW_CONTROL = 2 _MUX_OPCODE_DROP_CHANNEL = 3 _MUX_OPCODE_NEW_CHANNEL_SLOT = 4 _MAX_CHANNEL_ID = 2 ** 29 - 1 _INITIAL_NUMBER_OF_CHANNEL_SLOTS = 64 _INITIAL_QUOTA_FOR_CLIENT = 8 * 1024 _HANDSHAKE_ENCODING_IDENTITY = 0 _HANDSHAKE_ENCODING_DELTA = 1 # We need only these status code for now. _HTTP_BAD_RESPONSE_MESSAGES = { common.HTTP_STATUS_BAD_REQUEST: 'Bad Request', } # DropChannel reason code # TODO(bashi): Define all reason code defined in -05 draft. _DROP_CODE_NORMAL_CLOSURE = 1000 _DROP_CODE_INVALID_ENCAPSULATING_MESSAGE = 2001 _DROP_CODE_CHANNEL_ID_TRUNCATED = 2002 _DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED = 2003 _DROP_CODE_UNKNOWN_MUX_OPCODE = 2004 _DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005 _DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006 _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007 _DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002 _DROP_CODE_SEND_QUOTA_VIOLATION = 3005 _DROP_CODE_ACKNOWLEDGED = 3008 class MuxUnexpectedException(Exception): """Exception in handling multiplexing extension.""" pass # Temporary class MuxNotImplementedException(Exception): """Raised when a flow enters unimplemented code path.""" pass class LogicalConnectionClosedException(Exception): """Raised when logical connection is gracefully closed.""" pass class PhysicalConnectionError(Exception): """Raised when there is a physical connection error.""" def __init__(self, drop_code, message=''): super(PhysicalConnectionError, self).__init__( 'code=%d, message=%r' % (drop_code, message)) self.drop_code = drop_code self.message = message class LogicalChannelError(Exception): """Raised when there is a logical channel error.""" def __init__(self, channel_id, drop_code, message=''): super(LogicalChannelError, self).__init__( 'channel_id=%d, code=%d, message=%r' % ( channel_id, drop_code, message)) self.channel_id = channel_id self.drop_code = drop_code self.message = message def _encode_channel_id(channel_id): if channel_id < 0: raise ValueError('Channel id %d must not be negative' % channel_id) if channel_id < 2 ** 7: return chr(channel_id) if channel_id < 2 ** 14: return struct.pack('!H', 0x8000 + channel_id) if channel_id < 2 ** 21: first = chr(0xc0 + (channel_id >> 16)) return first + struct.pack('!H', channel_id & 0xffff) if channel_id < 2 ** 29: return struct.pack('!L', 0xe0000000 + channel_id) raise ValueError('Channel id %d is too large' % channel_id) def _encode_number(number): return create_length_header(number, False) def _create_add_channel_response(channel_id, encoded_handshake, encoding=0, rejected=False, outer_frame_mask=False): if encoding != 0 and encoding != 1: raise ValueError('Invalid encoding %d' % encoding) first_byte = ((_MUX_OPCODE_ADD_CHANNEL_RESPONSE << 5) | (rejected << 4) | encoding) block = (chr(first_byte) + _encode_channel_id(channel_id) + _encode_number(len(encoded_handshake)) + encoded_handshake) payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=outer_frame_mask) def _create_drop_channel(channel_id, code=None, message='', outer_frame_mask=False): if len(message) > 0 and code is None: raise ValueError('Code must be specified if message is specified') first_byte = _MUX_OPCODE_DROP_CHANNEL << 5 block = chr(first_byte) + _encode_channel_id(channel_id) if code is None: block += _encode_number(0) # Reason size else: reason = struct.pack('!H', code) + message reason_size = _encode_number(len(reason)) block += reason_size + reason payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=outer_frame_mask) def _create_flow_control(channel_id, replenished_quota, outer_frame_mask=False): first_byte = _MUX_OPCODE_FLOW_CONTROL << 5 block = (chr(first_byte) + _encode_channel_id(channel_id) + _encode_number(replenished_quota)) payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=outer_frame_mask) def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False): if slots < 0 or send_quota < 0: raise ValueError('slots and send_quota must be non-negative.') first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5 block = (chr(first_byte) + _encode_number(slots) + _encode_number(send_quota)) payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=outer_frame_mask) def _create_fallback_new_channel_slot(outer_frame_mask=False): first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag block = (chr(first_byte) + _encode_number(0) + _encode_number(0)) payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block return create_binary_frame(payload, mask=outer_frame_mask) def _parse_request_text(request_text): request_line, header_lines = request_text.split('\r\n', 1) words = request_line.split(' ') if len(words) != 3: raise ValueError('Bad Request-Line syntax %r' % request_line) [command, path, version] = words if version != 'HTTP/1.1': raise ValueError('Bad request version %r' % version) # email.parser.Parser() parses RFC 2822 (RFC 822) style headers. # RFC 6455 refers RFC 2616 for handshake parsing, and RFC 2616 refers # RFC 822. headers = email.parser.Parser().parsestr(header_lines) return command, path, version, headers class _ControlBlock(object): """A structure that holds parsing result of multiplexing control block. Control block specific attributes will be added by _MuxFramePayloadParser. (e.g. encoded_handshake will be added for AddChannelRequest and AddChannelResponse) """ def __init__(self, opcode): self.opcode = opcode class _MuxFramePayloadParser(object): """A class that parses multiplexed frame payload.""" def __init__(self, payload): self._data = payload self._read_position = 0 self._logger = util.get_class_logger(self) def read_channel_id(self): """Reads channel id. Raises: ValueError: when the payload doesn't contain valid channel id. """ remaining_length = len(self._data) - self._read_position pos = self._read_position if remaining_length == 0: raise ValueError('Invalid channel id format') channel_id = ord(self._data[pos]) channel_id_length = 1 if channel_id & 0xe0 == 0xe0: if remaining_length < 4: raise ValueError('Invalid channel id format') channel_id = struct.unpack('!L', self._data[pos:pos+4])[0] & 0x1fffffff channel_id_length = 4 elif channel_id & 0xc0 == 0xc0: if remaining_length < 3: raise ValueError('Invalid channel id format') channel_id = (((channel_id & 0x1f) << 16) + struct.unpack('!H', self._data[pos+1:pos+3])[0]) channel_id_length = 3 elif channel_id & 0x80 == 0x80: if remaining_length < 2: raise ValueError('Invalid channel id format') channel_id = struct.unpack('!H', self._data[pos:pos+2])[0] & 0x3fff channel_id_length = 2 self._read_position += channel_id_length return channel_id def read_inner_frame(self): """Reads an inner frame. Raises: PhysicalConnectionError: when the inner frame is invalid. """ if len(self._data) == self._read_position: raise PhysicalConnectionError( _DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED) bits = ord(self._data[self._read_position]) self._read_position += 1 fin = (bits & 0x80) == 0x80 rsv1 = (bits & 0x40) == 0x40 rsv2 = (bits & 0x20) == 0x20 rsv3 = (bits & 0x10) == 0x10 opcode = bits & 0xf payload = self.remaining_data() # Consume rest of the message which is payload data of the original # frame. self._read_position = len(self._data) return fin, rsv1, rsv2, rsv3, opcode, payload def _read_number(self): if self._read_position + 1 > len(self._data): raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Cannot read the first byte of number field') number = ord(self._data[self._read_position]) if number & 0x80 == 0x80: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'The most significant bit of the first byte of number should ' 'be unset') self._read_position += 1 pos = self._read_position if number == 127: if pos + 8 > len(self._data): raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Invalid number field') self._read_position += 8 number = struct.unpack('!Q', self._data[pos:pos+8])[0] if number > 0x7FFFFFFFFFFFFFFF: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Encoded number >= 2^63') if number <= 0xFFFF: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, '%d should not be encoded by 9 bytes encoding' % number) return number if number == 126: if pos + 2 > len(self._data): raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Invalid number field') self._read_position += 2 number = struct.unpack('!H', self._data[pos:pos+2])[0] if number <= 125: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, '%d should not be encoded by 3 bytes encoding' % number) return number def _read_size_and_contents(self): """Reads data that consists of followings: - the size of the contents encoded the same way as payload length of the WebSocket Protocol with 1 bit padding at the head. - the contents. """ size = self._read_number() pos = self._read_position if pos + size > len(self._data): raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Cannot read %d bytes data' % size) self._read_position += size return self._data[pos:pos+size] def _read_add_channel_request(self, first_byte, control_block): reserved = (first_byte >> 2) & 0x7 if reserved != 0: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') # Invalid encoding will be handled by MuxHandler. encoding = first_byte & 0x3 try: control_block.channel_id = self.read_channel_id() except ValueError, e: raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) control_block.encoding = encoding encoded_handshake = self._read_size_and_contents() control_block.encoded_handshake = encoded_handshake return control_block def _read_add_channel_response(self, first_byte, control_block): reserved = (first_byte >> 2) & 0x3 if reserved != 0: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') control_block.accepted = (first_byte >> 4) & 1 control_block.encoding = first_byte & 0x3 try: control_block.channel_id = self.read_channel_id() except ValueError, e: raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) control_block.encoded_handshake = self._read_size_and_contents() return control_block def _read_flow_control(self, first_byte, control_block): reserved = first_byte & 0x1f if reserved != 0: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') try: control_block.channel_id = self.read_channel_id() except ValueError, e: raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) control_block.send_quota = self._read_number() return control_block def _read_drop_channel(self, first_byte, control_block): reserved = first_byte & 0x1f if reserved != 0: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') try: control_block.channel_id = self.read_channel_id() except ValueError, e: raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) reason = self._read_size_and_contents() if len(reason) == 0: control_block.drop_code = None control_block.drop_message = '' elif len(reason) >= 2: control_block.drop_code = struct.unpack('!H', reason[:2])[0] control_block.drop_message = reason[2:] else: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Received DropChannel that conains only 1-byte reason') return control_block def _read_new_channel_slot(self, first_byte, control_block): reserved = first_byte & 0x1e if reserved != 0: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') control_block.fallback = first_byte & 1 control_block.slots = self._read_number() control_block.send_quota = self._read_number() return control_block def read_control_blocks(self): """Reads control block(s). Raises: PhysicalConnectionError: when the payload contains invalid control block(s). StopIteration: when no control blocks left. """ while self._read_position < len(self._data): first_byte = ord(self._data[self._read_position]) self._read_position += 1 opcode = (first_byte >> 5) & 0x7 control_block = _ControlBlock(opcode=opcode) if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST: yield self._read_add_channel_request(first_byte, control_block) elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE: yield self._read_add_channel_response( first_byte, control_block) elif opcode == _MUX_OPCODE_FLOW_CONTROL: yield self._read_flow_control(first_byte, control_block) elif opcode == _MUX_OPCODE_DROP_CHANNEL: yield self._read_drop_channel(first_byte, control_block) elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT: yield self._read_new_channel_slot(first_byte, control_block) else: raise PhysicalConnectionError( _DROP_CODE_UNKNOWN_MUX_OPCODE, 'Invalid opcode %d' % opcode) assert self._read_position == len(self._data) raise StopIteration def remaining_data(self): """Returns remaining data.""" return self._data[self._read_position:] class _LogicalRequest(object): """Mimics mod_python request.""" def __init__(self, channel_id, command, path, protocol, headers, connection): """Constructs an instance. Args: channel_id: the channel id of the logical channel. command: HTTP request command. path: HTTP request path. headers: HTTP headers. connection: _LogicalConnection instance. """ self.channel_id = channel_id self.method = command self.uri = path self.protocol = protocol self.headers_in = headers self.connection = connection self.server_terminated = False self.client_terminated = False def is_https(self): """Mimics request.is_https(). Returns False because this method is used only by old protocols (hixie and hybi00). """ return False class _LogicalConnection(object): """Mimics mod_python mp_conn.""" # For details, see the comment of set_read_state(). STATE_ACTIVE = 1 STATE_GRACEFULLY_CLOSED = 2 STATE_TERMINATED = 3 def __init__(self, mux_handler, channel_id): """Constructs an instance. Args: mux_handler: _MuxHandler instance. channel_id: channel id of this connection. """ self._mux_handler = mux_handler self._channel_id = channel_id self._incoming_data = '' self._write_condition = threading.Condition() self._waiting_write_completion = False self._read_condition = threading.Condition() self._read_state = self.STATE_ACTIVE def get_local_addr(self): """Getter to mimic mp_conn.local_addr.""" return self._mux_handler.physical_connection.get_local_addr() local_addr = property(get_local_addr) def get_remote_addr(self): """Getter to mimic mp_conn.remote_addr.""" return self._mux_handler.physical_connection.get_remote_addr() remote_addr = property(get_remote_addr) def get_memorized_lines(self): """Gets memorized lines. Not supported.""" raise MuxUnexpectedException('_LogicalConnection does not support ' 'get_memorized_lines') def write(self, data): """Writes data. mux_handler sends data asynchronously. The caller will be suspended until write done. Args: data: data to be written. Raises: MuxUnexpectedException: when called before finishing the previous write. """ try: self._write_condition.acquire() if self._waiting_write_completion: raise MuxUnexpectedException( 'Logical connection %d is already waiting the completion ' 'of write' % self._channel_id) self._waiting_write_completion = True self._mux_handler.send_data(self._channel_id, data) self._write_condition.wait() finally: self._write_condition.release() def write_control_data(self, data): """Writes data via the control channel. Don't wait finishing write because this method can be called by mux dispatcher. Args: data: data to be written. """ self._mux_handler.send_control_data(data) def notify_write_done(self): """Called when sending data is completed.""" try: self._write_condition.acquire() if not self._waiting_write_completion: raise MuxUnexpectedException( 'Invalid call of notify_write_done for logical connection' ' %d' % self._channel_id) self._waiting_write_completion = False self._write_condition.notify() finally: self._write_condition.release() def append_frame_data(self, frame_data): """Appends incoming frame data. Called when mux_handler dispatches frame data to the corresponding application. Args: frame_data: incoming frame data. """ self._read_condition.acquire() self._incoming_data += frame_data self._read_condition.notify() self._read_condition.release() def read(self, length): """Reads data. Blocks until enough data has arrived via physical connection. Args: length: length of data to be read. Raises: LogicalConnectionClosedException: when closing handshake for this logical channel has been received. ConnectionTerminatedException: when the physical connection has closed, or an error is caused on the reader thread. """ self._read_condition.acquire() while (self._read_state == self.STATE_ACTIVE and len(self._incoming_data) < length): self._read_condition.wait() try: if self._read_state == self.STATE_GRACEFULLY_CLOSED: raise LogicalConnectionClosedException( 'Logical channel %d has closed.' % self._channel_id) elif self._read_state == self.STATE_TERMINATED: raise ConnectionTerminatedException( 'Receiving %d byte failed. Logical channel (%d) closed' % (length, self._channel_id)) value = self._incoming_data[:length] self._incoming_data = self._incoming_data[length:] finally: self._read_condition.release() return value def set_read_state(self, new_state): """Sets the state of this connection. Called when an event for this connection has occurred. Args: new_state: state to be set. new_state must be one of followings: - STATE_GRACEFULLY_CLOSED: when closing handshake for this connection has been received. - STATE_TERMINATED: when the physical connection has closed or DropChannel of this connection has received. """ self._read_condition.acquire() self._read_state = new_state self._read_condition.notify() self._read_condition.release() class _LogicalStream(Stream): """Mimics the Stream class. This class interprets multiplexed WebSocket frames. """ def __init__(self, request, send_quota, receive_quota): """Constructs an instance. Args: request: _LogicalRequest instance. send_quota: Initial send quota. receive_quota: Initial receive quota. """ # TODO(bashi): Support frame filters. stream_options = StreamOptions() # Physical stream is responsible for masking. stream_options.unmask_receive = False # Control frames can be fragmented on logical channel. stream_options.allow_fragmented_control_frame = True Stream.__init__(self, request, stream_options) self._send_quota = send_quota self._send_quota_condition = threading.Condition() self._receive_quota = receive_quota self._write_inner_frame_semaphore = threading.Semaphore() def _create_inner_frame(self, opcode, payload, end=True): # TODO(bashi): Support extensions that use reserved bits. first_byte = (end << 7) | opcode return (_encode_channel_id(self._request.channel_id) + chr(first_byte) + payload) def _write_inner_frame(self, opcode, payload, end=True): payload_length = len(payload) write_position = 0 try: # An inner frame will be fragmented if there is no enough send # quota. This semaphore ensures that fragmented inner frames are # sent in order on the logical channel. # Note that frames that come from other logical channels or # multiplexing control blocks can be inserted between fragmented # inner frames on the physical channel. self._write_inner_frame_semaphore.acquire() while write_position < payload_length: try: self._send_quota_condition.acquire() while self._send_quota == 0: self._logger.debug( 'No quota. Waiting FlowControl message for %d.' % self._request.channel_id) self._send_quota_condition.wait() remaining = payload_length - write_position write_length = min(self._send_quota, remaining) inner_frame_end = ( end and (write_position + write_length == payload_length)) inner_frame = self._create_inner_frame( opcode, payload[write_position:write_position+write_length], inner_frame_end) frame_data = self._writer.build( inner_frame, end=True, binary=True) self._send_quota -= write_length self._logger.debug('Consumed quota=%d, remaining=%d' % (write_length, self._send_quota)) finally: self._send_quota_condition.release() # Writing data will block the worker so we need to release # _send_quota_condition before writing. self._logger.debug('Sending inner frame: %r' % frame_data) self._request.connection.write(frame_data) write_position += write_length opcode = common.OPCODE_CONTINUATION except ValueError, e: raise BadOperationException(e) finally: self._write_inner_frame_semaphore.release() def replenish_send_quota(self, send_quota): """Replenish send quota.""" self._send_quota_condition.acquire() self._send_quota += send_quota self._logger.debug('Replenished send quota for channel id %d: %d' % (self._request.channel_id, self._send_quota)) self._send_quota_condition.notify() self._send_quota_condition.release() def consume_receive_quota(self, amount): """Consumes receive quota. Returns False on failure.""" if self._receive_quota < amount: self._logger.debug('Violate quota on channel id %d: %d < %d' % (self._request.channel_id, self._receive_quota, amount)) return False self._receive_quota -= amount return True def send_message(self, message, end=True, binary=False): """Override Stream.send_message.""" if self._request.server_terminated: raise BadOperationException( 'Requested send_message after sending out a closing handshake') if binary and isinstance(message, unicode): raise BadOperationException( 'Message for binary frame must be instance of str') if binary: opcode = common.OPCODE_BINARY else: opcode = common.OPCODE_TEXT message = message.encode('utf-8') self._write_inner_frame(opcode, message, end) def _receive_frame(self): """Overrides Stream._receive_frame. In addition to call Stream._receive_frame, this method adds the amount of payload to receiving quota and sends FlowControl to the client. We need to do it here because Stream.receive_message() handles control frames internally. """ opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self) amount = len(payload) self._receive_quota += amount frame_data = _create_flow_control(self._request.channel_id, amount) self._logger.debug('Sending flow control for %d, replenished=%d' % (self._request.channel_id, amount)) self._request.connection.write_control_data(frame_data) return opcode, payload, fin, rsv1, rsv2, rsv3 def receive_message(self): """Overrides Stream.receive_message.""" # Just call Stream.receive_message(), but catch # LogicalConnectionClosedException, which is raised when the logical # connection has closed gracefully. try: return Stream.receive_message(self) except LogicalConnectionClosedException, e: self._logger.debug('%s', e) return None def _send_closing_handshake(self, code, reason): """Overrides Stream._send_closing_handshake.""" body = create_closing_handshake_body(code, reason) self._logger.debug('Sending closing handshake for %d: (%r, %r)' % (self._request.channel_id, code, reason)) self._write_inner_frame(common.OPCODE_CLOSE, body, end=True) self._request.server_terminated = True def send_ping(self, body=''): """Overrides Stream.send_ping""" self._logger.debug('Sending ping on logical channel %d: %r' % (self._request.channel_id, body)) self._write_inner_frame(common.OPCODE_PING, body, end=True) self._ping_queue.append(body) def _send_pong(self, body): """Overrides Stream._send_pong""" self._logger.debug('Sending pong on logical channel %d: %r' % (self._request.channel_id, body)) self._write_inner_frame(common.OPCODE_PONG, body, end=True) def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''): """Overrides Stream.close_connection.""" # TODO(bashi): Implement self._logger.debug('Closing logical connection %d' % self._request.channel_id) self._request.server_terminated = True def _drain_received_data(self): """Overrides Stream._drain_received_data. Nothing need to be done for logical channel. """ pass class _OutgoingData(object): """A structure that holds data to be sent via physical connection and origin of the data. """ def __init__(self, channel_id, data): self.channel_id = channel_id self.data = data class _PhysicalConnectionWriter(threading.Thread): """A thread that is responsible for writing data to physical connection. TODO(bashi): Make sure there is no thread-safety problem when the reader thread reads data from the same socket at a time. """ def __init__(self, mux_handler): """Constructs an instance. Args: mux_handler: _MuxHandler instance. """ threading.Thread.__init__(self) self._logger = util.get_class_logger(self) self._mux_handler = mux_handler self.setDaemon(True) self._stop_requested = False self._deque = collections.deque() self._deque_condition = threading.Condition() def put_outgoing_data(self, data): """Puts outgoing data. Args: data: _OutgoingData instance. Raises: BadOperationException: when the thread has been requested to terminate. """ try: self._deque_condition.acquire() if self._stop_requested: raise BadOperationException('Cannot write data anymore') self._deque.append(data) self._deque_condition.notify() finally: self._deque_condition.release() def _write_data(self, outgoing_data): try: self._mux_handler.physical_connection.write(outgoing_data.data) except Exception, e: util.prepend_message_to_exception( 'Failed to send message to %r: ' % (self._mux_handler.physical_connection.remote_addr,), e) raise # TODO(bashi): It would be better to block the thread that sends # control data as well. if outgoing_data.channel_id != _CONTROL_CHANNEL_ID: self._mux_handler.notify_write_done(outgoing_data.channel_id) def run(self): self._deque_condition.acquire() while not self._stop_requested: if len(self._deque) == 0: self._deque_condition.wait() continue outgoing_data = self._deque.popleft() self._deque_condition.release() self._write_data(outgoing_data) self._deque_condition.acquire() # Flush deque try: while len(self._deque) > 0: outgoing_data = self._deque.popleft() self._write_data(outgoing_data) finally: self._deque_condition.release() def stop(self): """Stops the writer thread.""" self._deque_condition.acquire() self._stop_requested = True self._deque_condition.notify() self._deque_condition.release() class _PhysicalConnectionReader(threading.Thread): """A thread that is responsible for reading data from physical connection. """ def __init__(self, mux_handler): """Constructs an instance. Args: mux_handler: _MuxHandler instance. """ threading.Thread.__init__(self) self._logger = util.get_class_logger(self) self._mux_handler = mux_handler self.setDaemon(True) def run(self): while True: try: physical_stream = self._mux_handler.physical_stream message = physical_stream.receive_message() if message is None: break # Below happens only when a data message is received. opcode = physical_stream.get_last_received_opcode() if opcode != common.OPCODE_BINARY: self._mux_handler.fail_physical_connection( _DROP_CODE_INVALID_ENCAPSULATING_MESSAGE, 'Received a text message on physical connection') break except ConnectionTerminatedException, e: self._logger.debug('%s', e) break try: self._mux_handler.dispatch_message(message) except PhysicalConnectionError, e: self._mux_handler.fail_physical_connection( e.drop_code, e.message) break except LogicalChannelError, e: self._mux_handler.fail_logical_channel( e.channel_id, e.drop_code, e.message) except Exception, e: self._logger.debug(traceback.format_exc()) break self._mux_handler.notify_reader_done() class _Worker(threading.Thread): """A thread that is responsible for running the corresponding application handler. """ def __init__(self, mux_handler, request): """Constructs an instance. Args: mux_handler: _MuxHandler instance. request: _LogicalRequest instance. """ threading.Thread.__init__(self) self._logger = util.get_class_logger(self) self._mux_handler = mux_handler self._request = request self.setDaemon(True) def run(self): self._logger.debug('Logical channel worker started. (id=%d)' % self._request.channel_id) try: # Non-critical exceptions will be handled by dispatcher. self._mux_handler.dispatcher.transfer_data(self._request) finally: self._mux_handler.notify_worker_done(self._request.channel_id) class _MuxHandshaker(hybi.Handshaker): """Opening handshake processor for multiplexing.""" _DUMMY_WEBSOCKET_KEY = 'dGhlIHNhbXBsZSBub25jZQ==' def __init__(self, request, dispatcher, send_quota, receive_quota): """Constructs an instance. Args: request: _LogicalRequest instance. dispatcher: Dispatcher instance (dispatch.Dispatcher). send_quota: Initial send quota. receive_quota: Initial receive quota. """ hybi.Handshaker.__init__(self, request, dispatcher) self._send_quota = send_quota self._receive_quota = receive_quota # Append headers which should not be included in handshake field of # AddChannelRequest. # TODO(bashi): Make sure whether we should raise exception when # these headers are included already. request.headers_in[common.UPGRADE_HEADER] = ( common.WEBSOCKET_UPGRADE_TYPE) request.headers_in[common.CONNECTION_HEADER] = ( common.UPGRADE_CONNECTION_TYPE) request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = ( str(common.VERSION_HYBI_LATEST)) request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = ( self._DUMMY_WEBSOCKET_KEY) def _create_stream(self, stream_options): """Override hybi.Handshaker._create_stream.""" self._logger.debug('Creating logical stream for %d' % self._request.channel_id) return _LogicalStream(self._request, self._send_quota, self._receive_quota) def _create_handshake_response(self, accept): """Override hybi._create_handshake_response.""" response = [] response.append('HTTP/1.1 101 Switching Protocols\r\n') # Upgrade, Connection and Sec-WebSocket-Accept should be excluded. if self._request.ws_protocol is not None: response.append('%s: %s\r\n' % ( common.SEC_WEBSOCKET_PROTOCOL_HEADER, self._request.ws_protocol)) if (self._request.ws_extensions is not None and len(self._request.ws_extensions) != 0): response.append('%s: %s\r\n' % ( common.SEC_WEBSOCKET_EXTENSIONS_HEADER, common.format_extensions(self._request.ws_extensions))) response.append('\r\n') return ''.join(response) def _send_handshake(self, accept): """Override hybi.Handshaker._send_handshake.""" # Don't send handshake response for the default channel if self._request.channel_id == _DEFAULT_CHANNEL_ID: return handshake_response = self._create_handshake_response(accept) frame_data = _create_add_channel_response( self._request.channel_id, handshake_response) self._logger.debug('Sending handshake response for %d: %r' % (self._request.channel_id, frame_data)) self._request.connection.write_control_data(frame_data) class _LogicalChannelData(object): """A structure that holds information about logical channel. """ def __init__(self, request, worker): self.request = request self.worker = worker self.drop_code = _DROP_CODE_NORMAL_CLOSURE self.drop_message = '' class _HandshakeDeltaBase(object): """A class that holds information for delta-encoded handshake.""" def __init__(self, headers): self._headers = headers def create_headers(self, delta=None): """Creates request headers for an AddChannelRequest that has delta-encoded handshake. Args: delta: headers should be overridden. """ headers = copy.copy(self._headers) if delta: for key, value in delta.items(): # The spec requires that a header with an empty value is # removed from the delta base. if len(value) == 0 and headers.has_key(key): del headers[key] else: headers[key] = value # TODO(bashi): Support extensions headers['Sec-WebSocket-Extensions'] = '' return headers class _MuxHandler(object): """Multiplexing handler. When a handler starts, it launches three threads; the reader thread, the writer thread, and a worker thread. The reader thread reads data from the physical stream, i.e., the ws_stream object of the underlying websocket connection. The reader thread interprets multiplexed frames and dispatches them to logical channels. Methods of this class are mostly called by the reader thread. The writer thread sends multiplexed frames which are created by logical channels via the physical connection. The worker thread launched at the starting point handles the "Implicitly Opened Connection". If multiplexing handler receives an AddChannelRequest and accepts it, the handler will launch a new worker thread and dispatch the request to it. """ def __init__(self, request, dispatcher): """Constructs an instance. Args: request: mod_python request of the physical connection. dispatcher: Dispatcher instance (dispatch.Dispatcher). """ self.original_request = request self.dispatcher = dispatcher self.physical_connection = request.connection self.physical_stream = request.ws_stream self._logger = util.get_class_logger(self) self._logical_channels = {} self._logical_channels_condition = threading.Condition() # Holds client's initial quota self._channel_slots = collections.deque() self._handshake_base = None self._worker_done_notify_received = False self._reader = None self._writer = None def start(self): """Starts the handler. Raises: MuxUnexpectedException: when the handler already started, or when opening handshake of the default channel fails. """ if self._reader or self._writer: raise MuxUnexpectedException('MuxHandler already started') self._reader = _PhysicalConnectionReader(self) self._writer = _PhysicalConnectionWriter(self) self._reader.start() self._writer.start() # Create "Implicitly Opened Connection". logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID) self._handshake_base = _HandshakeDeltaBase( self.original_request.headers_in) logical_request = _LogicalRequest( _DEFAULT_CHANNEL_ID, self.original_request.method, self.original_request.uri, self.original_request.protocol, self._handshake_base.create_headers(), logical_connection) # Client's send quota for the implicitly opened connection is zero, # but we will send FlowControl later so set the initial quota to # _INITIAL_QUOTA_FOR_CLIENT. self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT) if not self._do_handshake_for_logical_request( logical_request, send_quota=self.original_request.mux_quota): raise MuxUnexpectedException( 'Failed handshake on the default channel id') self._add_logical_channel(logical_request) # Send FlowControl for the implicitly opened connection. frame_data = _create_flow_control(_DEFAULT_CHANNEL_ID, _INITIAL_QUOTA_FOR_CLIENT) logical_request.connection.write_control_data(frame_data) def add_channel_slots(self, slots, send_quota): """Adds channel slots. Args: slots: number of slots to be added. send_quota: initial send quota for slots. """ self._channel_slots.extend([send_quota] * slots) # Send NewChannelSlot to client. frame_data = _create_new_channel_slot(slots, send_quota) self.send_control_data(frame_data) def wait_until_done(self, timeout=None): """Waits until all workers are done. Returns False when timeout has occurred. Returns True on success. Args: timeout: timeout in sec. """ self._logical_channels_condition.acquire() try: while len(self._logical_channels) > 0: self._logger.debug('Waiting workers(%d)...' % len(self._logical_channels)) self._worker_done_notify_received = False self._logical_channels_condition.wait(timeout) if not self._worker_done_notify_received: self._logger.debug('Waiting worker(s) timed out') return False finally: self._logical_channels_condition.release() # Flush pending outgoing data self._writer.stop() self._writer.join() return True def notify_write_done(self, channel_id): """Called by the writer thread when a write operation has done. Args: channel_id: objective channel id. """ try: self._logical_channels_condition.acquire() if channel_id in self._logical_channels: channel_data = self._logical_channels[channel_id] channel_data.request.connection.notify_write_done() else: self._logger.debug('Seems that logical channel for %d has gone' % channel_id) finally: self._logical_channels_condition.release() def send_control_data(self, data): """Sends data via the control channel. Args: data: data to be sent. """ self._writer.put_outgoing_data(_OutgoingData( channel_id=_CONTROL_CHANNEL_ID, data=data)) def send_data(self, channel_id, data): """Sends data via given logical channel. This method is called by worker threads. Args: channel_id: objective channel id. data: data to be sent. """ self._writer.put_outgoing_data(_OutgoingData( channel_id=channel_id, data=data)) def _send_drop_channel(self, channel_id, code=None, message=''): frame_data = _create_drop_channel(channel_id, code, message) self._logger.debug( 'Sending drop channel for channel id %d' % channel_id) self.send_control_data(frame_data) def _send_error_add_channel_response(self, channel_id, status=None): if status is None: status = common.HTTP_STATUS_BAD_REQUEST if status in _HTTP_BAD_RESPONSE_MESSAGES: message = _HTTP_BAD_RESPONSE_MESSAGES[status] else: self._logger.debug('Response message for %d is not found' % status) message = '???' response = 'HTTP/1.1 %d %s\r\n\r\n' % (status, message) frame_data = _create_add_channel_response(channel_id, encoded_handshake=response, encoding=0, rejected=True) self.send_control_data(frame_data) def _create_logical_request(self, block): if block.channel_id == _CONTROL_CHANNEL_ID: # TODO(bashi): Raise PhysicalConnectionError with code 2006 # instead of MuxUnexpectedException. raise MuxUnexpectedException( 'Received the control channel id (0) as objective channel ' 'id for AddChannel') if block.encoding > _HANDSHAKE_ENCODING_DELTA: raise PhysicalConnectionError( _DROP_CODE_UNKNOWN_REQUEST_ENCODING) method, path, version, headers = _parse_request_text( block.encoded_handshake) if block.encoding == _HANDSHAKE_ENCODING_DELTA: headers = self._handshake_base.create_headers(headers) connection = _LogicalConnection(self, block.channel_id) request = _LogicalRequest(block.channel_id, method, path, version, headers, connection) return request def _do_handshake_for_logical_request(self, request, send_quota=0): try: receive_quota = self._channel_slots.popleft() except IndexError: raise LogicalChannelError( request.channel_id, _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION) handshaker = _MuxHandshaker(request, self.dispatcher, send_quota, receive_quota) try: handshaker.do_handshake() except handshake.VersionException, e: self._logger.info('%s', e) self._send_error_add_channel_response( request.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) return False except handshake.HandshakeException, e: # TODO(bashi): Should we _Fail the Logical Channel_ with 3001 # instead? self._logger.info('%s', e) self._send_error_add_channel_response(request.channel_id, status=e.status) return False except handshake.AbortedByUserException, e: self._logger.info('%s', e) self._send_error_add_channel_response(request.channel_id) return False return True def _add_logical_channel(self, logical_request): try: self._logical_channels_condition.acquire() if logical_request.channel_id in self._logical_channels: self._logger.debug('Channel id %d already exists' % logical_request.channel_id) raise PhysicalConnectionError( _DROP_CODE_CHANNEL_ALREADY_EXISTS, 'Channel id %d already exists' % logical_request.channel_id) worker = _Worker(self, logical_request) channel_data = _LogicalChannelData(logical_request, worker) self._logical_channels[logical_request.channel_id] = channel_data worker.start() finally: self._logical_channels_condition.release() def _process_add_channel_request(self, block): try: logical_request = self._create_logical_request(block) except ValueError, e: self._logger.debug('Failed to create logical request: %r' % e) self._send_error_add_channel_response( block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) return if self._do_handshake_for_logical_request(logical_request): if block.encoding == _HANDSHAKE_ENCODING_IDENTITY: # Update handshake base. # TODO(bashi): Make sure this is the right place to update # handshake base. self._handshake_base = _HandshakeDeltaBase( logical_request.headers_in) self._add_logical_channel(logical_request) else: self._send_error_add_channel_response( block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) def _process_flow_control(self, block): try: self._logical_channels_condition.acquire() if not block.channel_id in self._logical_channels: return channel_data = self._logical_channels[block.channel_id] channel_data.request.ws_stream.replenish_send_quota( block.send_quota) finally: self._logical_channels_condition.release() def _process_drop_channel(self, block): self._logger.debug( 'DropChannel received for %d: code=%r, reason=%r' % (block.channel_id, block.drop_code, block.drop_message)) try: self._logical_channels_condition.acquire() if not block.channel_id in self._logical_channels: return channel_data = self._logical_channels[block.channel_id] channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED # Close the logical channel channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) finally: self._logical_channels_condition.release() def _process_control_blocks(self, parser): for control_block in parser.read_control_blocks(): opcode = control_block.opcode self._logger.debug('control block received, opcode: %d' % opcode) if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST: self._process_add_channel_request(control_block) elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Received AddChannelResponse') elif opcode == _MUX_OPCODE_FLOW_CONTROL: self._process_flow_control(control_block) elif opcode == _MUX_OPCODE_DROP_CHANNEL: self._process_drop_channel(control_block) elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT: raise PhysicalConnectionError( _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Received NewChannelSlot') else: raise MuxUnexpectedException( 'Unexpected opcode %r' % opcode) def _process_logical_frame(self, channel_id, parser): self._logger.debug('Received a frame. channel id=%d' % channel_id) try: self._logical_channels_condition.acquire() if not channel_id in self._logical_channels: # We must ignore the message for an inactive channel. return channel_data = self._logical_channels[channel_id] fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame() if not channel_data.request.ws_stream.consume_receive_quota( len(payload)): # The client violates quota. Close logical channel. raise LogicalChannelError( channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION) header = create_header(opcode, len(payload), fin, rsv1, rsv2, rsv3, mask=False) frame_data = header + payload channel_data.request.connection.append_frame_data(frame_data) finally: self._logical_channels_condition.release() def dispatch_message(self, message): """Dispatches message. The reader thread calls this method. Args: message: a message that contains encapsulated frame. Raises: PhysicalConnectionError: if the message contains physical connection level errors. LogicalChannelError: if the message contains logical channel level errors. """ parser = _MuxFramePayloadParser(message) try: channel_id = parser.read_channel_id() except ValueError, e: raise PhysicalConnectionError(_DROP_CODE_CHANNEL_ID_TRUNCATED) if channel_id == _CONTROL_CHANNEL_ID: self._process_control_blocks(parser) else: self._process_logical_frame(channel_id, parser) def notify_worker_done(self, channel_id): """Called when a worker has finished. Args: channel_id: channel id corresponded with the worker. """ self._logger.debug('Worker for channel id %d terminated' % channel_id) try: self._logical_channels_condition.acquire() if not channel_id in self._logical_channels: raise MuxUnexpectedException( 'Channel id %d not found' % channel_id) channel_data = self._logical_channels.pop(channel_id) finally: self._worker_done_notify_received = True self._logical_channels_condition.notify() self._logical_channels_condition.release() if not channel_data.request.server_terminated: self._send_drop_channel( channel_id, code=channel_data.drop_code, message=channel_data.drop_message) def notify_reader_done(self): """This method is called by the reader thread when the reader has finished. """ # Terminate all logical connections self._logger.debug('termiating all logical connections...') self._logical_channels_condition.acquire() for channel_data in self._logical_channels.values(): try: channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) except Exception: pass self._logical_channels_condition.release() def fail_physical_connection(self, code, message): """Fail the physical connection. Args: code: drop reason code. message: drop message. """ self._logger.debug('Failing the physical connection...') self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message) self.physical_stream.close_connection( common.STATUS_INTERNAL_ENDPOINT_ERROR) def fail_logical_channel(self, channel_id, code, message): """Fail a logical channel. Args: channel_id: channel id. code: drop reason code. message: drop message. """ self._logger.debug('Failing logical channel %d...' % channel_id) try: self._logical_channels_condition.acquire() if channel_id in self._logical_channels: channel_data = self._logical_channels[channel_id] # Close the logical channel. notify_worker_done() will be # called later and it will send DropChannel. channel_data.drop_code = code channel_data.drop_message = message channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) else: self._send_drop_channel(channel_id, code, message) finally: self._logical_channels_condition.release() def use_mux(request): return hasattr(request, 'mux') and request.mux def start(request, dispatcher): mux_handler = _MuxHandler(request, dispatcher) mux_handler.start() mux_handler.add_channel_slots(_INITIAL_NUMBER_OF_CHANNEL_SLOTS, _INITIAL_QUOTA_FOR_CLIENT) mux_handler.wait_until_done() # vi:sts=4 sw=4 et