# Copyright 2011, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """WebSocket utilities. """ import errno # Import hash classes from a module available and recommended for each Python # version and re-export those symbol. Use sha and md5 module in Python 2.4, and # hashlib module in Python 2.6. try: import hashlib md5_hash = hashlib.md5 sha1_hash = hashlib.sha1 except ImportError: import md5 import sha md5_hash = md5.md5 sha1_hash = sha.sha import StringIO import logging import os import re import socket import traceback import zlib def get_stack_trace(): """Get the current stack trace as string. This is needed to support Python 2.3. TODO: Remove this when we only support Python 2.4 and above. Use traceback.format_exc instead. """ out = StringIO.StringIO() traceback.print_exc(file=out) return out.getvalue() def prepend_message_to_exception(message, exc): """Prepend message to the exception.""" exc.args = (message + str(exc),) return def __translate_interp(interp, cygwin_path): """Translate interp program path for Win32 python to run cygwin program (e.g. perl). Note that it doesn't support path that contains space, which is typically true for Unix, where #!-script is written. For Win32 python, cygwin_path is a directory of cygwin binaries. Args: interp: interp command line cygwin_path: directory name of cygwin binary, or None Returns: translated interp command line. """ if not cygwin_path: return interp m = re.match('^[^ ]*/([^ ]+)( .*)?', interp) if m: cmd = os.path.join(cygwin_path, m.group(1)) return cmd + m.group(2) return interp def get_script_interp(script_path, cygwin_path=None): """Gets #!-interpreter command line from the script. It also fixes command path. When Cygwin Python is used, e.g. in WebKit, it could run "/usr/bin/perl -wT hello.pl". When Win32 Python is used, e.g. in Chromium, it couldn't. So, fix "/usr/bin/perl" to "\perl.exe". Args: script_path: pathname of the script cygwin_path: directory name of cygwin binary, or None Returns: #!-interpreter command line, or None if it is not #!-script. """ fp = open(script_path) line = fp.readline() fp.close() m = re.match('^#!(.*)', line) if m: return __translate_interp(m.group(1), cygwin_path) return None def wrap_popen3_for_win(cygwin_path): """Wrap popen3 to support #!-script on Windows. Args: cygwin_path: path for cygwin binary if command path is needed to be translated. None if no translation required. """ __orig_popen3 = os.popen3 def __wrap_popen3(cmd, mode='t', bufsize=-1): cmdline = cmd.split(' ') interp = get_script_interp(cmdline[0], cygwin_path) if interp: cmd = interp + ' ' + cmd return __orig_popen3(cmd, mode, bufsize) os.popen3 = __wrap_popen3 def hexify(s): return ' '.join(map(lambda x: '%02x' % ord(x), s)) def get_class_logger(o): return logging.getLogger( '%s.%s' % (o.__class__.__module__, o.__class__.__name__)) class NoopMasker(object): """A masking object that has the same interface as RepeatedXorMasker but just returns the string passed in without making any change. """ def __init__(self): pass def mask(self, s): return s class DeflateRequest(object): """A wrapper class for request object to intercept send and recv to perform deflate compression and decompression transparently. """ def __init__(self, request): self._request = request self.connection = DeflateConnection(request.connection) def __getattribute__(self, name): if name in ('_request', 'connection'): return object.__getattribute__(self, name) return self._request.__getattribute__(name) def __setattr__(self, name, value): if name in ('_request', 'connection'): return object.__setattr__(self, name, value) return self._request.__setattr__(name, value) # By making wbits option negative, we can suppress CMF/FLG (2 octet) and # ADLER32 (4 octet) fields of zlib so that we can use zlib module just as # deflate library. DICTID won't be added as far as we don't set dictionary. # LZ77 window of 32K will be used for both compression and decompression. # For decompression, we can just use 32K to cover any windows size. For # compression, we use 32K so receivers must use 32K. # # Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level # to decode. # # See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of # Python. See also RFC1950 (ZLIB 3.3). class _Deflater(object): def __init__(self, window_bits): self._logger = get_class_logger(self) self._compress = zlib.compressobj( zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits) def compress_and_flush(self, bytes): compressed_bytes = self._compress.compress(bytes) compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH) self._logger.debug('Compress input %r', bytes) self._logger.debug('Compress result %r', compressed_bytes) return compressed_bytes class _Inflater(object): def __init__(self): self._logger = get_class_logger(self) self._unconsumed = '' self.reset() def decompress(self, size): if not (size == -1 or size > 0): raise Exception('size must be -1 or positive') data = '' while True: if size == -1: data += self._decompress.decompress(self._unconsumed) # See Python bug http://bugs.python.org/issue12050 to # understand why the same code cannot be used for updating # self._unconsumed for here and else block. self._unconsumed = '' else: data += self._decompress.decompress( self._unconsumed, size - len(data)) self._unconsumed = self._decompress.unconsumed_tail if self._decompress.unused_data: # Encountered a last block (i.e. a block with BFINAL = 1) and # found a new stream (unused_data). We cannot use the same # zlib.Decompress object for the new stream. Create a new # Decompress object to decompress the new one. # # It's fine to ignore unconsumed_tail if unused_data is not # empty. self._unconsumed = self._decompress.unused_data self.reset() if size >= 0 and len(data) == size: # data is filled. Don't call decompress again. break else: # Re-invoke Decompress.decompress to try to decompress all # available bytes before invoking read which blocks until # any new byte is available. continue else: # Here, since unused_data is empty, even if unconsumed_tail is # not empty, bytes of requested length are already in data. We # don't have to "continue" here. break if data: self._logger.debug('Decompressed %r', data) return data def append(self, data): self._logger.debug('Appended %r', data) self._unconsumed += data def reset(self): self._logger.debug('Reset') self._decompress = zlib.decompressobj(-zlib.MAX_WBITS) # Compresses/decompresses given octets using the method introduced in RFC1979. class _RFC1979Deflater(object): """A compressor class that applies DEFLATE to given byte sequence and flushes using the algorithm described in the RFC1979 section 2.1. """ def __init__(self, window_bits, no_context_takeover): self._deflater = None if window_bits is None: window_bits = zlib.MAX_WBITS self._window_bits = window_bits self._no_context_takeover = no_context_takeover def filter(self, bytes): if self._deflater is None or self._no_context_takeover: self._deflater = _Deflater(self._window_bits) # Strip last 4 octets which is LEN and NLEN field of a non-compressed # block added for Z_SYNC_FLUSH. return self._deflater.compress_and_flush(bytes)[:-4] class _RFC1979Inflater(object): """A decompressor class for byte sequence compressed and flushed following the algorithm described in the RFC1979 section 2.1. """ def __init__(self): self._inflater = _Inflater() def filter(self, bytes): # Restore stripped LEN and NLEN field of a non-compressed block added # for Z_SYNC_FLUSH. self._inflater.append(bytes + '\x00\x00\xff\xff') return self._inflater.decompress(-1) class DeflateSocket(object): """A wrapper class for socket object to intercept send and recv to perform deflate compression and decompression transparently. """ # Size of the buffer passed to recv to receive compressed data. _RECV_SIZE = 4096 def __init__(self, socket): self._socket = socket self._logger = get_class_logger(self) self._deflater = _Deflater(zlib.MAX_WBITS) self._inflater = _Inflater() def recv(self, size): """Receives data from the socket specified on the construction up to the specified size. Once any data is available, returns it even if it's smaller than the specified size. """ # TODO(tyoshino): Allow call with size=0. It should block until any # decompressed data is available. if size <= 0: raise Exception('Non-positive size passed') while True: data = self._inflater.decompress(size) if len(data) != 0: return data read_data = self._socket.recv(DeflateSocket._RECV_SIZE) if not read_data: return '' self._inflater.append(read_data) def sendall(self, bytes): self.send(bytes) def send(self, bytes): self._socket.sendall(self._deflater.compress_and_flush(bytes)) return len(bytes) class DeflateConnection(object): """A wrapper class for request object to intercept write and read to perform deflate compression and decompression transparently. """ def __init__(self, connection): self._connection = connection self._logger = get_class_logger(self) self._deflater = _Deflater(zlib.MAX_WBITS) self._inflater = _Inflater() def get_remote_addr(self): return self._connection.remote_addr remote_addr = property(get_remote_addr) def put_bytes(self, bytes): self.write(bytes) def read(self, size=-1): """Reads at most size bytes. Blocks until there's at least one byte available. """ # TODO(tyoshino): Allow call with size=0. if not (size == -1 or size > 0): raise Exception('size must be -1 or positive') data = '' while True: if size == -1: data += self._inflater.decompress(-1) else: data += self._inflater.decompress(size - len(data)) if size >= 0 and len(data) != 0: break # TODO(tyoshino): Make this read efficient by some workaround. # # In 3.0.3 and prior of mod_python, read blocks until length bytes # was read. We don't know the exact size to read while using # deflate, so read byte-by-byte. # # _StandaloneRequest.read that ultimately performs # socket._fileobject.read also blocks until length bytes was read read_data = self._connection.read(1) if not read_data: break self._inflater.append(read_data) return data def write(self, bytes): self._connection.write(self._deflater.compress_and_flush(bytes)) def _is_ewouldblock_errno(error_number): """Returns True iff error_number indicates that receive operation would block. To make this portable, we check availability of errno and then compare them. """ for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']: if (error_name in dir(errno) and error_number == getattr(errno, error_name)): return True return False def drain_received_data(raw_socket): # Set the socket non-blocking. original_timeout = raw_socket.gettimeout() raw_socket.settimeout(0.0) drained_data = [] # Drain until the socket is closed or no data is immediately # available for read. while True: try: data = raw_socket.recv(1) if not data: break drained_data.append(data) except socket.error, e: # e can be either a pair (errno, string) or just a string (or # something else) telling what went wrong. We suppress only # the errors that indicates that the socket blocks. Those # exceptions can be parsed as a pair (errno, string). try: error_number, message = e except: # Failed to parse socket.error. raise e if _is_ewouldblock_errno(error_number): break else: raise e # Rollback timeout value. raw_socket.settimeout(original_timeout) return ''.join(drained_data) # vi:sts=4 sw=4 et