diff options
Diffstat (limited to 'pyload/lib/mod_pywebsocket/mux.py')
-rw-r--r-- | pyload/lib/mod_pywebsocket/mux.py | 470 |
1 files changed, 365 insertions, 105 deletions
diff --git a/pyload/lib/mod_pywebsocket/mux.py b/pyload/lib/mod_pywebsocket/mux.py index f0bdd2461..7923fb211 100644 --- a/pyload/lib/mod_pywebsocket/mux.py +++ b/pyload/lib/mod_pywebsocket/mux.py @@ -50,6 +50,7 @@ from mod_pywebsocket import handshake from mod_pywebsocket import util from mod_pywebsocket._stream_base import BadOperationException from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException from mod_pywebsocket._stream_hybi import Frame from mod_pywebsocket._stream_hybi import Stream from mod_pywebsocket._stream_hybi import StreamOptions @@ -94,10 +95,12 @@ _DROP_CODE_UNKNOWN_MUX_OPCODE = 2004 _DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005 _DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006 _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007 +_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 2010 -_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002 _DROP_CODE_SEND_QUOTA_VIOLATION = 3005 +_DROP_CODE_SEND_QUOTA_OVERFLOW = 3006 _DROP_CODE_ACKNOWLEDGED = 3008 +_DROP_CODE_BAD_FRAGMENTATION = 3009 class MuxUnexpectedException(Exception): @@ -158,8 +161,7 @@ def _encode_number(number): def _create_add_channel_response(channel_id, encoded_handshake, - encoding=0, rejected=False, - outer_frame_mask=False): + encoding=0, rejected=False): if encoding != 0 and encoding != 1: raise ValueError('Invalid encoding %d' % encoding) @@ -169,12 +171,10 @@ def _create_add_channel_response(channel_id, encoded_handshake, _encode_channel_id(channel_id) + _encode_number(len(encoded_handshake)) + encoded_handshake) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_drop_channel(channel_id, code=None, message='', - outer_frame_mask=False): +def _create_drop_channel(channel_id, code=None, message=''): if len(message) > 0 and code is None: raise ValueError('Code must be specified if message is specified') @@ -187,36 +187,31 @@ def _create_drop_channel(channel_id, code=None, message='', reason_size = _encode_number(len(reason)) block += reason_size + reason - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_flow_control(channel_id, replenished_quota, - outer_frame_mask=False): +def _create_flow_control(channel_id, replenished_quota): first_byte = _MUX_OPCODE_FLOW_CONTROL << 5 block = (chr(first_byte) + _encode_channel_id(channel_id) + _encode_number(replenished_quota)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False): +def _create_new_channel_slot(slots, send_quota): if slots < 0 or send_quota < 0: raise ValueError('slots and send_quota must be non-negative.') first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5 block = (chr(first_byte) + _encode_number(slots) + _encode_number(send_quota)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block -def _create_fallback_new_channel_slot(outer_frame_mask=False): +def _create_fallback_new_channel_slot(): first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag block = (chr(first_byte) + _encode_number(0) + _encode_number(0)) - payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block - return create_binary_frame(payload, mask=outer_frame_mask) + return block def _parse_request_text(request_text): @@ -318,44 +313,34 @@ class _MuxFramePayloadParser(object): def _read_number(self): if self._read_position + 1 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( 'Cannot read the first byte of number field') number = ord(self._data[self._read_position]) if number & 0x80 == 0x80: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( 'The most significant bit of the first byte of number should ' 'be unset') self._read_position += 1 pos = self._read_position if number == 127: if pos + 8 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Invalid number field') + raise ValueError('Invalid number field') self._read_position += 8 number = struct.unpack('!Q', self._data[pos:pos+8])[0] if number > 0x7FFFFFFFFFFFFFFF: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Encoded number >= 2^63') + raise ValueError('Encoded number(%d) >= 2^63' % number) if number <= 0xFFFF: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( '%d should not be encoded by 9 bytes encoding' % number) return number if number == 126: if pos + 2 > len(self._data): - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, - 'Invalid number field') + raise ValueError('Invalid number field') self._read_position += 2 number = struct.unpack('!H', self._data[pos:pos+2])[0] if number <= 125: - raise PhysicalConnectionError( - _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + raise ValueError( '%d should not be encoded by 3 bytes encoding' % number) return number @@ -366,7 +351,11 @@ class _MuxFramePayloadParser(object): - the contents. """ - size = self._read_number() + try: + size = self._read_number() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) pos = self._read_position if pos + size > len(self._data): raise PhysicalConnectionError( @@ -419,9 +408,11 @@ class _MuxFramePayloadParser(object): try: control_block.channel_id = self.read_channel_id() + control_block.send_quota = self._read_number() except ValueError, e: - raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) - control_block.send_quota = self._read_number() + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) + return control_block def _read_drop_channel(self, first_byte, control_block): @@ -455,8 +446,12 @@ class _MuxFramePayloadParser(object): _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, 'Reserved bits must be unset') control_block.fallback = first_byte & 1 - control_block.slots = self._read_number() - control_block.send_quota = self._read_number() + try: + control_block.slots = self._read_number() + control_block.send_quota = self._read_number() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + str(e)) return control_block def read_control_blocks(self): @@ -549,8 +544,12 @@ class _LogicalConnection(object): self._mux_handler = mux_handler self._channel_id = channel_id self._incoming_data = '' + + # - Protects _waiting_write_completion + # - Signals the thread waiting for completion of write by mux handler self._write_condition = threading.Condition() self._waiting_write_completion = False + self._read_condition = threading.Condition() self._read_state = self.STATE_ACTIVE @@ -594,6 +593,7 @@ class _LogicalConnection(object): self._waiting_write_completion = True self._mux_handler.send_data(self._channel_id, data) self._write_condition.wait() + # TODO(tyoshino): Raise an exception if woke up by on_writer_done. finally: self._write_condition.release() @@ -607,20 +607,31 @@ class _LogicalConnection(object): self._mux_handler.send_control_data(data) - def notify_write_done(self): + def on_write_data_done(self): """Called when sending data is completed.""" try: self._write_condition.acquire() if not self._waiting_write_completion: raise MuxUnexpectedException( - 'Invalid call of notify_write_done for logical connection' - ' %d' % self._channel_id) + 'Invalid call of on_write_data_done for logical ' + 'connection %d' % self._channel_id) + self._waiting_write_completion = False + self._write_condition.notify() + finally: + self._write_condition.release() + + def on_writer_done(self): + """Called by the mux handler when the writer thread has finished.""" + + try: + self._write_condition.acquire() self._waiting_write_completion = False self._write_condition.notify() finally: self._write_condition.release() + def append_frame_data(self, frame_data): """Appends incoming frame data. Called when mux_handler dispatches frame data to the corresponding application. @@ -686,37 +697,162 @@ class _LogicalConnection(object): self._read_condition.release() +class _InnerMessage(object): + """Holds the result of _InnerMessageBuilder.build(). + """ + + def __init__(self, opcode, payload): + self.opcode = opcode + self.payload = payload + + +class _InnerMessageBuilder(object): + """A class that holds the context of inner message fragmentation and + builds a message from fragmented inner frame(s). + """ + + def __init__(self): + self._control_opcode = None + self._pending_control_fragments = [] + self._message_opcode = None + self._pending_message_fragments = [] + self._frame_handler = self._handle_first + + def _handle_first(self, frame): + if frame.opcode == common.OPCODE_CONTINUATION: + raise InvalidFrameException('Sending invalid continuation opcode') + + if common.is_control_opcode(frame.opcode): + return self._process_first_fragmented_control(frame) + else: + return self._process_first_fragmented_message(frame) + + def _process_first_fragmented_control(self, frame): + self._control_opcode = frame.opcode + self._pending_control_fragments.append(frame.payload) + if not frame.fin: + self._frame_handler = self._handle_fragmented_control + return None + return self._reassemble_fragmented_control() + + def _process_first_fragmented_message(self, frame): + self._message_opcode = frame.opcode + self._pending_message_fragments.append(frame.payload) + if not frame.fin: + self._frame_handler = self._handle_fragmented_message + return None + return self._reassemble_fragmented_message() + + def _handle_fragmented_control(self, frame): + if frame.opcode != common.OPCODE_CONTINUATION: + raise InvalidFrameException( + 'Sending invalid opcode %d while sending fragmented control ' + 'message' % frame.opcode) + self._pending_control_fragments.append(frame.payload) + if not frame.fin: + return None + return self._reassemble_fragmented_control() + + def _reassemble_fragmented_control(self): + opcode = self._control_opcode + payload = ''.join(self._pending_control_fragments) + self._control_opcode = None + self._pending_control_fragments = [] + if self._message_opcode is not None: + self._frame_handler = self._handle_fragmented_message + else: + self._frame_handler = self._handle_first + return _InnerMessage(opcode, payload) + + def _handle_fragmented_message(self, frame): + # Sender can interleave a control message while sending fragmented + # messages. + if common.is_control_opcode(frame.opcode): + if self._control_opcode is not None: + raise MuxUnexpectedException( + 'Should not reach here(Bug in builder)') + return self._process_first_fragmented_control(frame) + + if frame.opcode != common.OPCODE_CONTINUATION: + raise InvalidFrameException( + 'Sending invalid opcode %d while sending fragmented message' % + frame.opcode) + self._pending_message_fragments.append(frame.payload) + if not frame.fin: + return None + return self._reassemble_fragmented_message() + + def _reassemble_fragmented_message(self): + opcode = self._message_opcode + payload = ''.join(self._pending_message_fragments) + self._message_opcode = None + self._pending_message_fragments = [] + self._frame_handler = self._handle_first + return _InnerMessage(opcode, payload) + + def build(self, frame): + """Build an inner message. Returns an _InnerMessage instance when + the given frame is the last fragmented frame. Returns None otherwise. + + Args: + frame: an inner frame. + Raises: + InvalidFrameException: when received invalid opcode. (e.g. + receiving non continuation data opcode but the fin flag of + the previous inner frame was not set.) + """ + + return self._frame_handler(frame) + + class _LogicalStream(Stream): """Mimics the Stream class. This class interprets multiplexed WebSocket frames. """ - def __init__(self, request, send_quota, receive_quota): + def __init__(self, request, stream_options, send_quota, receive_quota): """Constructs an instance. Args: request: _LogicalRequest instance. + stream_options: StreamOptions instance. send_quota: Initial send quota. receive_quota: Initial receive quota. """ - # TODO(bashi): Support frame filters. - stream_options = StreamOptions() # Physical stream is responsible for masking. stream_options.unmask_receive = False - # Control frames can be fragmented on logical channel. - stream_options.allow_fragmented_control_frame = True Stream.__init__(self, request, stream_options) + + self._send_closed = False self._send_quota = send_quota - self._send_quota_condition = threading.Condition() + # - Protects _send_closed and _send_quota + # - Signals the thread waiting for send quota replenished + self._send_condition = threading.Condition() + + # The opcode of the first frame in messages. + self._message_opcode = common.OPCODE_TEXT + # True when the last message was fragmented. + self._last_message_was_fragmented = False + self._receive_quota = receive_quota self._write_inner_frame_semaphore = threading.Semaphore() + self._inner_message_builder = _InnerMessageBuilder() + def _create_inner_frame(self, opcode, payload, end=True): - # TODO(bashi): Support extensions that use reserved bits. - first_byte = (end << 7) | opcode - return (_encode_channel_id(self._request.channel_id) + - chr(first_byte) + payload) + frame = Frame(fin=end, opcode=opcode, payload=payload) + for frame_filter in self._options.outgoing_frame_filters: + frame_filter.filter(frame) + + if len(payload) != len(frame.payload): + raise MuxUnexpectedException( + 'Mux extension must not be used after extensions which change ' + ' frame boundary') + + first_byte = ((frame.fin << 7) | (frame.rsv1 << 6) | + (frame.rsv2 << 5) | (frame.rsv3 << 4) | frame.opcode) + return chr(first_byte) + frame.payload def _write_inner_frame(self, opcode, payload, end=True): payload_length = len(payload) @@ -730,14 +866,36 @@ class _LogicalStream(Stream): # multiplexing control blocks can be inserted between fragmented # inner frames on the physical channel. self._write_inner_frame_semaphore.acquire() + + # Consume an octet quota when this is the first fragmented frame. + if opcode != common.OPCODE_CONTINUATION: + try: + self._send_condition.acquire() + while (not self._send_closed) and self._send_quota == 0: + self._send_condition.wait() + + if self._send_closed: + raise BadOperationException( + 'Logical connection %d is closed' % + self._request.channel_id) + + self._send_quota -= 1 + finally: + self._send_condition.release() + while write_position < payload_length: try: - self._send_quota_condition.acquire() - while self._send_quota == 0: + self._send_condition.acquire() + while (not self._send_closed) and self._send_quota == 0: self._logger.debug( 'No quota. Waiting FlowControl message for %d.' % self._request.channel_id) - self._send_quota_condition.wait() + self._send_condition.wait() + + if self._send_closed: + raise BadOperationException( + 'Logical connection %d is closed' % + self.request._channel_id) remaining = payload_length - write_position write_length = min(self._send_quota, remaining) @@ -749,18 +907,16 @@ class _LogicalStream(Stream): opcode, payload[write_position:write_position+write_length], inner_frame_end) - frame_data = self._writer.build( - inner_frame, end=True, binary=True) self._send_quota -= write_length self._logger.debug('Consumed quota=%d, remaining=%d' % (write_length, self._send_quota)) finally: - self._send_quota_condition.release() + self._send_condition.release() # Writing data will block the worker so we need to release - # _send_quota_condition before writing. - self._logger.debug('Sending inner frame: %r' % frame_data) - self._request.connection.write(frame_data) + # _send_condition before writing. + self._logger.debug('Sending inner frame: %r' % inner_frame) + self._request.connection.write(inner_frame) write_position += write_length opcode = common.OPCODE_CONTINUATION @@ -773,12 +929,18 @@ class _LogicalStream(Stream): def replenish_send_quota(self, send_quota): """Replenish send quota.""" - self._send_quota_condition.acquire() - self._send_quota += send_quota - self._logger.debug('Replenished send quota for channel id %d: %d' % - (self._request.channel_id, self._send_quota)) - self._send_quota_condition.notify() - self._send_quota_condition.release() + try: + self._send_condition.acquire() + if self._send_quota + send_quota > 0x7FFFFFFFFFFFFFFF: + self._send_quota = 0 + raise LogicalChannelError( + self._request.channel_id, _DROP_CODE_SEND_QUOTA_OVERFLOW) + self._send_quota += send_quota + self._logger.debug('Replenished send quota for channel id %d: %d' % + (self._request.channel_id, self._send_quota)) + finally: + self._send_condition.notify() + self._send_condition.release() def consume_receive_quota(self, amount): """Consumes receive quota. Returns False on failure.""" @@ -808,7 +970,19 @@ class _LogicalStream(Stream): opcode = common.OPCODE_TEXT message = message.encode('utf-8') + for message_filter in self._options.outgoing_message_filters: + message = message_filter.filter(message, end, binary) + + if self._last_message_was_fragmented: + if opcode != self._message_opcode: + raise BadOperationException('Message types are different in ' + 'frames for the same message') + opcode = common.OPCODE_CONTINUATION + else: + self._message_opcode = opcode + self._write_inner_frame(opcode, message, end) + self._last_message_was_fragmented = not end def _receive_frame(self): """Overrides Stream._receive_frame. @@ -821,6 +995,9 @@ class _LogicalStream(Stream): opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self) amount = len(payload) + # Replenish extra one octet when receiving the first fragmented frame. + if opcode != common.OPCODE_CONTINUATION: + amount += 1 self._receive_quota += amount frame_data = _create_flow_control(self._request.channel_id, amount) @@ -829,6 +1006,21 @@ class _LogicalStream(Stream): self._request.connection.write_control_data(frame_data) return opcode, payload, fin, rsv1, rsv2, rsv3 + def _get_message_from_frame(self, frame): + """Overrides Stream._get_message_from_frame. + """ + + try: + inner_message = self._inner_message_builder.build(frame) + except InvalidFrameException: + raise LogicalChannelError( + self._request.channel_id, _DROP_CODE_BAD_FRAGMENTATION) + + if inner_message is None: + return None + self._original_opcode = inner_message.opcode + return inner_message.payload + def receive_message(self): """Overrides Stream.receive_message.""" @@ -882,6 +1074,14 @@ class _LogicalStream(Stream): pass + def stop_sending(self): + """Stops accepting new send operation (_write_inner_frame).""" + + self._send_condition.acquire() + self._send_closed = True + self._send_condition.notify() + self._send_condition.release() + class _OutgoingData(object): """A structure that holds data to be sent via physical connection and @@ -911,8 +1111,17 @@ class _PhysicalConnectionWriter(threading.Thread): self._logger = util.get_class_logger(self) self._mux_handler = mux_handler self.setDaemon(True) + + # When set, make this thread stop accepting new data, flush pending + # data and exit. self._stop_requested = False + # The close code of the physical connection. + self._close_code = common.STATUS_NORMAL_CLOSURE + # Deque for passing write data. It's protected by _deque_condition + # until _stop_requested is set. self._deque = collections.deque() + # - Protects _deque, _stop_requested and _close_code + # - Signals threads waiting for them to be available self._deque_condition = threading.Condition() def put_outgoing_data(self, data): @@ -937,8 +1146,11 @@ class _PhysicalConnectionWriter(threading.Thread): self._deque_condition.release() def _write_data(self, outgoing_data): + message = (_encode_channel_id(outgoing_data.channel_id) + + outgoing_data.data) try: - self._mux_handler.physical_connection.write(outgoing_data.data) + self._mux_handler.physical_stream.send_message( + message=message, end=True, binary=True) except Exception, e: util.prepend_message_to_exception( 'Failed to send message to %r: ' % @@ -948,33 +1160,51 @@ class _PhysicalConnectionWriter(threading.Thread): # TODO(bashi): It would be better to block the thread that sends # control data as well. if outgoing_data.channel_id != _CONTROL_CHANNEL_ID: - self._mux_handler.notify_write_done(outgoing_data.channel_id) + self._mux_handler.notify_write_data_done(outgoing_data.channel_id) def run(self): - self._deque_condition.acquire() - while not self._stop_requested: - if len(self._deque) == 0: - self._deque_condition.wait() - continue - - outgoing_data = self._deque.popleft() - self._deque_condition.release() - self._write_data(outgoing_data) + try: self._deque_condition.acquire() + while not self._stop_requested: + if len(self._deque) == 0: + self._deque_condition.wait() + continue - # Flush deque - try: - while len(self._deque) > 0: outgoing_data = self._deque.popleft() + + self._deque_condition.release() self._write_data(outgoing_data) + self._deque_condition.acquire() + + # Flush deque. + # + # At this point, self._deque_condition is always acquired. + try: + while len(self._deque) > 0: + outgoing_data = self._deque.popleft() + self._write_data(outgoing_data) + finally: + self._deque_condition.release() + + # Close physical connection. + try: + # Don't wait the response here. The response will be read + # by the reader thread. + self._mux_handler.physical_stream.close_connection( + self._close_code, wait_response=False) + except Exception, e: + util.prepend_message_to_exception( + 'Failed to close the physical connection: %r' % e) + raise finally: - self._deque_condition.release() + self._mux_handler.notify_writer_done() - def stop(self): + def stop(self, close_code=common.STATUS_NORMAL_CLOSURE): """Stops the writer thread.""" self._deque_condition.acquire() self._stop_requested = True + self._close_code = close_code self._deque_condition.notify() self._deque_condition.release() @@ -1055,6 +1285,9 @@ class _Worker(threading.Thread): try: # Non-critical exceptions will be handled by dispatcher. self._mux_handler.dispatcher.transfer_data(self._request) + except LogicalChannelError, e: + self._mux_handler.fail_logical_channel( + e.channel_id, e.drop_code, e.message) finally: self._mux_handler.notify_worker_done(self._request.channel_id) @@ -1083,8 +1316,6 @@ class _MuxHandshaker(hybi.Handshaker): # these headers are included already. request.headers_in[common.UPGRADE_HEADER] = ( common.WEBSOCKET_UPGRADE_TYPE) - request.headers_in[common.CONNECTION_HEADER] = ( - common.UPGRADE_CONNECTION_TYPE) request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = ( str(common.VERSION_HYBI_LATEST)) request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = ( @@ -1095,8 +1326,9 @@ class _MuxHandshaker(hybi.Handshaker): self._logger.debug('Creating logical stream for %d' % self._request.channel_id) - return _LogicalStream(self._request, self._send_quota, - self._receive_quota) + return _LogicalStream( + self._request, stream_options, self._send_quota, + self._receive_quota) def _create_handshake_response(self, accept): """Override hybi._create_handshake_response.""" @@ -1105,7 +1337,9 @@ class _MuxHandshaker(hybi.Handshaker): response.append('HTTP/1.1 101 Switching Protocols\r\n') - # Upgrade, Connection and Sec-WebSocket-Accept should be excluded. + # Upgrade and Sec-WebSocket-Accept should be excluded. + response.append('%s: %s\r\n' % ( + common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) if self._request.ws_protocol is not None: response.append('%s: %s\r\n' % ( common.SEC_WEBSOCKET_PROTOCOL_HEADER, @@ -1169,8 +1403,6 @@ class _HandshakeDeltaBase(object): del headers[key] else: headers[key] = value - # TODO(bashi): Support extensions - headers['Sec-WebSocket-Extensions'] = '' return headers @@ -1232,8 +1464,12 @@ class _MuxHandler(object): # Create "Implicitly Opened Connection". logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID) - self._handshake_base = _HandshakeDeltaBase( - self.original_request.headers_in) + headers = copy.copy(self.original_request.headers_in) + # Add extensions for logical channel. + headers[common.SEC_WEBSOCKET_EXTENSIONS_HEADER] = ( + common.format_extensions( + self.original_request.mux_processor.extensions())) + self._handshake_base = _HandshakeDeltaBase(headers) logical_request = _LogicalRequest( _DEFAULT_CHANNEL_ID, self.original_request.method, @@ -1245,8 +1481,9 @@ class _MuxHandler(object): # but we will send FlowControl later so set the initial quota to # _INITIAL_QUOTA_FOR_CLIENT. self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT) + send_quota = self.original_request.mux_processor.quota() if not self._do_handshake_for_logical_request( - logical_request, send_quota=self.original_request.mux_quota): + logical_request, send_quota=send_quota): raise MuxUnexpectedException( 'Failed handshake on the default channel id') self._add_logical_channel(logical_request) @@ -1287,7 +1524,6 @@ class _MuxHandler(object): if not self._worker_done_notify_received: self._logger.debug('Waiting worker(s) timed out') return False - finally: self._logical_channels_condition.release() @@ -1297,7 +1533,7 @@ class _MuxHandler(object): return True - def notify_write_done(self, channel_id): + def notify_write_data_done(self, channel_id): """Called by the writer thread when a write operation has done. Args: @@ -1308,7 +1544,7 @@ class _MuxHandler(object): self._logical_channels_condition.acquire() if channel_id in self._logical_channels: channel_data = self._logical_channels[channel_id] - channel_data.request.connection.notify_write_done() + channel_data.request.connection.on_write_data_done() else: self._logger.debug('Seems that logical channel for %d has gone' % channel_id) @@ -1469,9 +1705,11 @@ class _MuxHandler(object): return channel_data = self._logical_channels[block.channel_id] channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED + # Close the logical channel channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) + channel_data.request.ws_stream.stop_sending() finally: self._logical_channels_condition.release() @@ -1506,8 +1744,11 @@ class _MuxHandler(object): return channel_data = self._logical_channels[channel_id] fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame() + consuming_byte = len(payload) + if opcode != common.OPCODE_CONTINUATION: + consuming_byte += 1 if not channel_data.request.ws_stream.consume_receive_quota( - len(payload)): + consuming_byte): # The client violates quota. Close logical channel. raise LogicalChannelError( channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION) @@ -1569,15 +1810,32 @@ class _MuxHandler(object): finished. """ - # Terminate all logical connections - self._logger.debug('termiating all logical connections...') + self._logger.debug( + 'Termiating all logical connections waiting for incoming data ' + '...') self._logical_channels_condition.acquire() for channel_data in self._logical_channels.values(): try: channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) except Exception: - pass + self._logger.debug(traceback.format_exc()) + self._logical_channels_condition.release() + + def notify_writer_done(self): + """This method is called by the writer thread when the writer has + finished. + """ + + self._logger.debug( + 'Termiating all logical connections waiting for write ' + 'completion ...') + self._logical_channels_condition.acquire() + for channel_data in self._logical_channels.values(): + try: + channel_data.request.connection.on_writer_done() + except Exception: + self._logger.debug(traceback.format_exc()) self._logical_channels_condition.release() def fail_physical_connection(self, code, message): @@ -1590,8 +1848,7 @@ class _MuxHandler(object): self._logger.debug('Failing the physical connection...') self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message) - self.physical_stream.close_connection( - common.STATUS_INTERNAL_ENDPOINT_ERROR) + self._writer.stop(common.STATUS_INTERNAL_ENDPOINT_ERROR) def fail_logical_channel(self, channel_id, code, message): """Fail a logical channel. @@ -1611,8 +1868,10 @@ class _MuxHandler(object): # called later and it will send DropChannel. channel_data.drop_code = code channel_data.drop_message = message + channel_data.request.connection.set_read_state( _LogicalConnection.STATE_TERMINATED) + channel_data.request.ws_stream.stop_sending() else: self._send_drop_channel(channel_id, code, message) finally: @@ -1620,7 +1879,8 @@ class _MuxHandler(object): def use_mux(request): - return hasattr(request, 'mux') and request.mux + return hasattr(request, 'mux_processor') and ( + request.mux_processor.is_active()) def start(request, dispatcher): |