diff options
Diffstat (limited to 'pyload/lib')
56 files changed, 19226 insertions, 0 deletions
diff --git a/pyload/lib/Getch.py b/pyload/lib/Getch.py new file mode 100644 index 000000000..22b7ea7f8 --- /dev/null +++ b/pyload/lib/Getch.py @@ -0,0 +1,76 @@ +class Getch: + """ + Gets a single character from standard input. Does not echo to + the screen. + """ + + def __init__(self): + try: + self.impl = _GetchWindows() + except ImportError: + try: + self.impl = _GetchMacCarbon() + except(AttributeError, ImportError): + self.impl = _GetchUnix() + + def __call__(self): return self.impl() + + +class _GetchUnix: + def __init__(self): + import tty + import sys + + def __call__(self): + import sys + import tty + import termios + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +class _GetchWindows: + def __init__(self): + import msvcrt + + def __call__(self): + import msvcrt + + return msvcrt.getch() + +class _GetchMacCarbon: + """ + A function which returns the current ASCII key that is down; + if no ASCII key is down, the null string is returned. The + page http://www.mactech.com/macintosh-c/chap02-1.html was + very helpful in figuring out how to do this. + """ + + def __init__(self): + import Carbon + Carbon.Evt #see if it has this (in Unix, it doesn't) + + def __call__(self): + import Carbon + + if Carbon.Evt.EventAvail(0x0008)[0] == 0: # 0x0008 is the keyDownMask + return '' + else: + # + # The event contains the following info: + # (what,msg,when,where,mod)=Carbon.Evt.GetNextEvent(0x0008)[1] + # + # The message (msg) contains the ASCII char which is + # extracted with the 0x000000FF charCodeMask; this + # number is converted to an ASCII character with chr() and + # returned + # + (what, msg, when, where, mod) = Carbon.Evt.GetNextEvent(0x0008)[1] + return chr(msg)
\ No newline at end of file diff --git a/pyload/lib/ReadWriteLock.py b/pyload/lib/ReadWriteLock.py new file mode 100644 index 000000000..cc82f3d48 --- /dev/null +++ b/pyload/lib/ReadWriteLock.py @@ -0,0 +1,232 @@ +# -*- coding: iso-8859-15 -*- +"""locks.py - Read-Write lock thread lock implementation + +See the class documentation for more info. + +Copyright (C) 2007, Heiko Wundram. +Released under the BSD-license. + +http://code.activestate.com/recipes/502283-read-write-lock-class-rlock-like/ +""" + +# Imports +# ------- + +from threading import Condition, Lock, currentThread +from time import time + + +# Read write lock +# --------------- + +class ReadWriteLock(object): + """Read-Write lock class. A read-write lock differs from a standard + threading.RLock() by allowing multiple threads to simultaneously hold a + read lock, while allowing only a single thread to hold a write lock at the + same point of time. + + When a read lock is requested while a write lock is held, the reader + is blocked; when a write lock is requested while another write lock is + held or there are read locks, the writer is blocked. + + Writers are always preferred by this implementation: if there are blocked + threads waiting for a write lock, current readers may request more read + locks (which they eventually should free, as they starve the waiting + writers otherwise), but a new thread requesting a read lock will not + be granted one, and block. This might mean starvation for readers if + two writer threads interweave their calls to acquireWrite() without + leaving a window only for readers. + + In case a current reader requests a write lock, this can and will be + satisfied without giving up the read locks first, but, only one thread + may perform this kind of lock upgrade, as a deadlock would otherwise + occur. After the write lock has been granted, the thread will hold a + full write lock, and not be downgraded after the upgrading call to + acquireWrite() has been match by a corresponding release(). + """ + + def __init__(self): + """Initialize this read-write lock.""" + + # Condition variable, used to signal waiters of a change in object + # state. + self.__condition = Condition(Lock()) + + # Initialize with no writers. + self.__writer = None + self.__upgradewritercount = 0 + self.__pendingwriters = [] + + # Initialize with no readers. + self.__readers = {} + + def acquire(self, blocking=True, timeout=None, shared=False): + if shared: + self.acquireRead(timeout) + else: + self.acquireWrite(timeout) + + def acquireRead(self, timeout=None): + """Acquire a read lock for the current thread, waiting at most + timeout seconds or doing a non-blocking check in case timeout is <= 0. + + In case timeout is None, the call to acquireRead blocks until the + lock request can be serviced. + + In case the timeout expires before the lock could be serviced, a + RuntimeError is thrown.""" + + if timeout is not None: + endtime = time() + timeout + me = currentThread() + self.__condition.acquire() + try: + if self.__writer is me: + # If we are the writer, grant a new read lock, always. + self.__writercount += 1 + return + while True: + if self.__writer is None: + # Only test anything if there is no current writer. + if self.__upgradewritercount or self.__pendingwriters: + if me in self.__readers: + # Only grant a read lock if we already have one + # in case writers are waiting for their turn. + # This means that writers can't easily get starved + # (but see below, readers can). + self.__readers[me] += 1 + return + # No, we aren't a reader (yet), wait for our turn. + else: + # Grant a new read lock, always, in case there are + # no pending writers (and no writer). + self.__readers[me] = self.__readers.get(me, 0) + 1 + return + if timeout is not None: + remaining = endtime - time() + if remaining <= 0: + # Timeout has expired, signal caller of this. + raise RuntimeError("Acquiring read lock timed out") + self.__condition.wait(remaining) + else: + self.__condition.wait() + finally: + self.__condition.release() + + def acquireWrite(self, timeout=None): + """Acquire a write lock for the current thread, waiting at most + timeout seconds or doing a non-blocking check in case timeout is <= 0. + + In case the write lock cannot be serviced due to the deadlock + condition mentioned above, a ValueError is raised. + + In case timeout is None, the call to acquireWrite blocks until the + lock request can be serviced. + + In case the timeout expires before the lock could be serviced, a + RuntimeError is thrown.""" + + if timeout is not None: + endtime = time() + timeout + me, upgradewriter = currentThread(), False + self.__condition.acquire() + try: + if self.__writer is me: + # If we are the writer, grant a new write lock, always. + self.__writercount += 1 + return + elif me in self.__readers: + # If we are a reader, no need to add us to pendingwriters, + # we get the upgradewriter slot. + if self.__upgradewritercount: + # If we are a reader and want to upgrade, and someone + # else also wants to upgrade, there is no way we can do + # this except if one of us releases all his read locks. + # Signal this to user. + raise ValueError( + "Inevitable dead lock, denying write lock" + ) + upgradewriter = True + self.__upgradewritercount = self.__readers.pop(me) + else: + # We aren't a reader, so add us to the pending writers queue + # for synchronization with the readers. + self.__pendingwriters.append(me) + while True: + if not self.__readers and self.__writer is None: + # Only test anything if there are no readers and writers. + if self.__upgradewritercount: + if upgradewriter: + # There is a writer to upgrade, and it's us. Take + # the write lock. + self.__writer = me + self.__writercount = self.__upgradewritercount + 1 + self.__upgradewritercount = 0 + return + # There is a writer to upgrade, but it's not us. + # Always leave the upgrade writer the advance slot, + # because he presumes he'll get a write lock directly + # from a previously held read lock. + elif self.__pendingwriters[0] is me: + # If there are no readers and writers, it's always + # fine for us to take the writer slot, removing us + # from the pending writers queue. + # This might mean starvation for readers, though. + self.__writer = me + self.__writercount = 1 + self.__pendingwriters = self.__pendingwriters[1:] + return + if timeout is not None: + remaining = endtime - time() + if remaining <= 0: + # Timeout has expired, signal caller of this. + if upgradewriter: + # Put us back on the reader queue. No need to + # signal anyone of this change, because no other + # writer could've taken our spot before we got + # here (because of remaining readers), as the test + # for proper conditions is at the start of the + # loop, not at the end. + self.__readers[me] = self.__upgradewritercount + self.__upgradewritercount = 0 + else: + # We were a simple pending writer, just remove us + # from the FIFO list. + self.__pendingwriters.remove(me) + raise RuntimeError("Acquiring write lock timed out") + self.__condition.wait(remaining) + else: + self.__condition.wait() + finally: + self.__condition.release() + + def release(self): + """Release the currently held lock. + + In case the current thread holds no lock, a ValueError is thrown.""" + + me = currentThread() + self.__condition.acquire() + try: + if self.__writer is me: + # We are the writer, take one nesting depth away. + self.__writercount -= 1 + if not self.__writercount: + # No more write locks; take our writer position away and + # notify waiters of the new circumstances. + self.__writer = None + self.__condition.notifyAll() + elif me in self.__readers: + # We are a reader currently, take one nesting depth away. + self.__readers[me] -= 1 + if not self.__readers[me]: + # No more read locks, take our reader position away. + del self.__readers[me] + if not self.__readers: + # No more readers, notify waiters of the new + # circumstances. + self.__condition.notifyAll() + else: + raise ValueError("Trying to release unheld lock") + finally: + self.__condition.release() diff --git a/pyload/lib/SafeEval.py b/pyload/lib/SafeEval.py new file mode 100644 index 000000000..8fc57f261 --- /dev/null +++ b/pyload/lib/SafeEval.py @@ -0,0 +1,47 @@ +## {{{ http://code.activestate.com/recipes/286134/ (r3) (modified) +import dis + +_const_codes = map(dis.opmap.__getitem__, [ + 'POP_TOP','ROT_TWO','ROT_THREE','ROT_FOUR','DUP_TOP', + 'BUILD_LIST','BUILD_MAP','BUILD_TUPLE', + 'LOAD_CONST','RETURN_VALUE','STORE_SUBSCR' + ]) + + +_load_names = ['False', 'True', 'null', 'true', 'false'] + +_locals = {'null': None, 'true': True, 'false': False} + +def _get_opcodes(codeobj): + i = 0 + opcodes = [] + s = codeobj.co_code + names = codeobj.co_names + while i < len(s): + code = ord(s[i]) + opcodes.append(code) + if code >= dis.HAVE_ARGUMENT: + i += 3 + else: + i += 1 + return opcodes, names + +def test_expr(expr, allowed_codes): + try: + c = compile(expr, "", "eval") + except: + raise ValueError, "%s is not a valid expression" % expr + codes, names = _get_opcodes(c) + for code in codes: + if code not in allowed_codes: + for n in names: + if n not in _load_names: + raise ValueError, "opcode %s not allowed" % dis.opname[code] + return c + + +def const_eval(expr): + c = test_expr(expr, _const_codes) + return eval(c, None, _locals) + +## end of http://code.activestate.com/recipes/286134/ }}} diff --git a/pyload/lib/__init__.py b/pyload/lib/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/pyload/lib/__init__.py diff --git a/pyload/lib/beaker/__init__.py b/pyload/lib/beaker/__init__.py new file mode 100644 index 000000000..792d60054 --- /dev/null +++ b/pyload/lib/beaker/__init__.py @@ -0,0 +1 @@ +# diff --git a/pyload/lib/beaker/cache.py b/pyload/lib/beaker/cache.py new file mode 100644 index 000000000..4a96537ff --- /dev/null +++ b/pyload/lib/beaker/cache.py @@ -0,0 +1,459 @@ +"""Cache object + +The Cache object is used to manage a set of cache files and their +associated backend. The backends can be rotated on the fly by +specifying an alternate type when used. + +Advanced users can add new backends in beaker.backends + +""" + +import warnings + +import beaker.container as container +import beaker.util as util +from beaker.exceptions import BeakerException, InvalidCacheBackendError + +import beaker.ext.memcached as memcached +import beaker.ext.database as database +import beaker.ext.sqla as sqla +import beaker.ext.google as google + +# Initialize the basic available backends +clsmap = { + 'memory':container.MemoryNamespaceManager, + 'dbm':container.DBMNamespaceManager, + 'file':container.FileNamespaceManager, + 'ext:memcached':memcached.MemcachedNamespaceManager, + 'ext:database':database.DatabaseNamespaceManager, + 'ext:sqla': sqla.SqlaNamespaceManager, + 'ext:google': google.GoogleNamespaceManager, + } + +# Initialize the cache region dict +cache_regions = {} +cache_managers = {} + +try: + import pkg_resources + + # Load up the additional entry point defined backends + for entry_point in pkg_resources.iter_entry_points('beaker.backends'): + try: + NamespaceManager = entry_point.load() + name = entry_point.name + if name in clsmap: + raise BeakerException("NamespaceManager name conflict,'%s' " + "already loaded" % name) + clsmap[name] = NamespaceManager + except (InvalidCacheBackendError, SyntaxError): + # Ignore invalid backends + pass + except: + import sys + from pkg_resources import DistributionNotFound + # Warn when there's a problem loading a NamespaceManager + if not isinstance(sys.exc_info()[1], DistributionNotFound): + import traceback + from StringIO import StringIO + tb = StringIO() + traceback.print_exc(file=tb) + warnings.warn("Unable to load NamespaceManager entry point: '%s': " + "%s" % (entry_point, tb.getvalue()), RuntimeWarning, + 2) +except ImportError: + pass + + + + +def cache_region(region, *deco_args): + """Decorate a function to cache itself using a cache region + + The region decorator requires arguments if there are more than + 2 of the same named function, in the same module. This is + because the namespace used for the functions cache is based on + the functions name and the module. + + + Example:: + + # Add cache region settings to beaker: + beaker.cache.cache_regions.update(dict_of_config_region_options)) + + @cache_region('short_term', 'some_data') + def populate_things(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + return load('rabbits', 20, 0) + + .. note:: + + The function being decorated must only be called with + positional arguments. + + """ + cache = [None] + + def decorate(func): + namespace = util.func_namespace(func) + def cached(*args): + reg = cache_regions[region] + if not reg.get('enabled', True): + return func(*args) + + if not cache[0]: + if region not in cache_regions: + raise BeakerException('Cache region not configured: %s' % region) + cache[0] = Cache._get_cache(namespace, reg) + + cache_key = " ".join(map(str, deco_args + args)) + def go(): + return func(*args) + + return cache[0].get_value(cache_key, createfunc=go) + cached._arg_namespace = namespace + cached._arg_region = region + return cached + return decorate + + +def region_invalidate(namespace, region, *args): + """Invalidate a cache region namespace or decorated function + + This function only invalidates cache spaces created with the + cache_region decorator. + + :param namespace: Either the namespace of the result to invalidate, or the + cached function reference + + :param region: The region the function was cached to. If the function was + cached to a single region then this argument can be None + + :param args: Arguments that were used to differentiate the cached + function as well as the arguments passed to the decorated + function + + Example:: + + # Add cache region settings to beaker: + beaker.cache.cache_regions.update(dict_of_config_region_options)) + + def populate_things(invalidate=False): + + @cache_region('short_term', 'some_data') + def load(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + # If the results should be invalidated first + if invalidate: + region_invalidate(load, None, 'some_data', + 'rabbits', 20, 0) + return load('rabbits', 20, 0) + + """ + if callable(namespace): + if not region: + region = namespace._arg_region + namespace = namespace._arg_namespace + + if not region: + raise BeakerException("Region or callable function " + "namespace is required") + else: + region = cache_regions[region] + + cache = Cache._get_cache(namespace, region) + cache_key = " ".join(str(x) for x in args) + cache.remove_value(cache_key) + + +class Cache(object): + """Front-end to the containment API implementing a data cache. + + :param namespace: the namespace of this Cache + + :param type: type of cache to use + + :param expire: seconds to keep cached data + + :param expiretime: seconds to keep cached data (legacy support) + + :param starttime: time when cache was cache was + + """ + def __init__(self, namespace, type='memory', expiretime=None, + starttime=None, expire=None, **nsargs): + try: + cls = clsmap[type] + if isinstance(cls, InvalidCacheBackendError): + raise cls + except KeyError: + raise TypeError("Unknown cache implementation %r" % type) + + self.namespace = cls(namespace, **nsargs) + self.expiretime = expiretime or expire + self.starttime = starttime + self.nsargs = nsargs + + @classmethod + def _get_cache(cls, namespace, kw): + key = namespace + str(kw) + try: + return cache_managers[key] + except KeyError: + cache_managers[key] = cache = cls(namespace, **kw) + return cache + + def put(self, key, value, **kw): + self._get_value(key, **kw).set_value(value) + set_value = put + + def get(self, key, **kw): + """Retrieve a cached value from the container""" + return self._get_value(key, **kw).get_value() + get_value = get + + def remove_value(self, key, **kw): + mycontainer = self._get_value(key, **kw) + if mycontainer.has_current_value(): + mycontainer.clear_value() + remove = remove_value + + def _get_value(self, key, **kw): + if isinstance(key, unicode): + key = key.encode('ascii', 'backslashreplace') + + if 'type' in kw: + return self._legacy_get_value(key, **kw) + + kw.setdefault('expiretime', self.expiretime) + kw.setdefault('starttime', self.starttime) + + return container.Value(key, self.namespace, **kw) + + @util.deprecated("Specifying a " + "'type' and other namespace configuration with cache.get()/put()/etc. " + "is deprecated. Specify 'type' and other namespace configuration to " + "cache_manager.get_cache() and/or the Cache constructor instead.") + def _legacy_get_value(self, key, type, **kw): + expiretime = kw.pop('expiretime', self.expiretime) + starttime = kw.pop('starttime', None) + createfunc = kw.pop('createfunc', None) + kwargs = self.nsargs.copy() + kwargs.update(kw) + c = Cache(self.namespace.namespace, type=type, **kwargs) + return c._get_value(key, expiretime=expiretime, createfunc=createfunc, + starttime=starttime) + + def clear(self): + """Clear all the values from the namespace""" + self.namespace.remove() + + # dict interface + def __getitem__(self, key): + return self.get(key) + + def __contains__(self, key): + return self._get_value(key).has_current_value() + + def has_key(self, key): + return key in self + + def __delitem__(self, key): + self.remove_value(key) + + def __setitem__(self, key, value): + self.put(key, value) + + +class CacheManager(object): + def __init__(self, **kwargs): + """Initialize a CacheManager object with a set of options + + Options should be parsed with the + :func:`~beaker.util.parse_cache_config_options` function to + ensure only valid options are used. + + """ + self.kwargs = kwargs + self.regions = kwargs.pop('cache_regions', {}) + + # Add these regions to the module global + cache_regions.update(self.regions) + + def get_cache(self, name, **kwargs): + kw = self.kwargs.copy() + kw.update(kwargs) + return Cache._get_cache(name, kw) + + def get_cache_region(self, name, region): + if region not in self.regions: + raise BeakerException('Cache region not configured: %s' % region) + kw = self.regions[region] + return Cache._get_cache(name, kw) + + def region(self, region, *args): + """Decorate a function to cache itself using a cache region + + The region decorator requires arguments if there are more than + 2 of the same named function, in the same module. This is + because the namespace used for the functions cache is based on + the functions name and the module. + + + Example:: + + # Assuming a cache object is available like: + cache = CacheManager(dict_of_config_options) + + + def populate_things(): + + @cache.region('short_term', 'some_data') + def load(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + return load('rabbits', 20, 0) + + .. note:: + + The function being decorated must only be called with + positional arguments. + + """ + return cache_region(region, *args) + + def region_invalidate(self, namespace, region, *args): + """Invalidate a cache region namespace or decorated function + + This function only invalidates cache spaces created with the + cache_region decorator. + + :param namespace: Either the namespace of the result to invalidate, or the + name of the cached function + + :param region: The region the function was cached to. If the function was + cached to a single region then this argument can be None + + :param args: Arguments that were used to differentiate the cached + function as well as the arguments passed to the decorated + function + + Example:: + + # Assuming a cache object is available like: + cache = CacheManager(dict_of_config_options) + + def populate_things(invalidate=False): + + @cache.region('short_term', 'some_data') + def load(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + # If the results should be invalidated first + if invalidate: + cache.region_invalidate(load, None, 'some_data', + 'rabbits', 20, 0) + return load('rabbits', 20, 0) + + + """ + return region_invalidate(namespace, region, *args) + if callable(namespace): + if not region: + region = namespace._arg_region + namespace = namespace._arg_namespace + + if not region: + raise BeakerException("Region or callable function " + "namespace is required") + else: + region = self.regions[region] + + cache = self.get_cache(namespace, **region) + cache_key = " ".join(str(x) for x in args) + cache.remove_value(cache_key) + + def cache(self, *args, **kwargs): + """Decorate a function to cache itself with supplied parameters + + :param args: Used to make the key unique for this function, as in region() + above. + + :param kwargs: Parameters to be passed to get_cache(), will override defaults + + Example:: + + # Assuming a cache object is available like: + cache = CacheManager(dict_of_config_options) + + + def populate_things(): + + @cache.cache('mycache', expire=15) + def load(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + return load('rabbits', 20, 0) + + .. note:: + + The function being decorated must only be called with + positional arguments. + + """ + cache = [None] + key = " ".join(str(x) for x in args) + + def decorate(func): + namespace = util.func_namespace(func) + def cached(*args): + if not cache[0]: + cache[0] = self.get_cache(namespace, **kwargs) + cache_key = key + " " + " ".join(str(x) for x in args) + def go(): + return func(*args) + return cache[0].get_value(cache_key, createfunc=go) + cached._arg_namespace = namespace + return cached + return decorate + + def invalidate(self, func, *args, **kwargs): + """Invalidate a cache decorated function + + This function only invalidates cache spaces created with the + cache decorator. + + :param func: Decorated function to invalidate + + :param args: Used to make the key unique for this function, as in region() + above. + + :param kwargs: Parameters that were passed for use by get_cache(), note that + this is only required if a ``type`` was specified for the + function + + Example:: + + # Assuming a cache object is available like: + cache = CacheManager(dict_of_config_options) + + + def populate_things(invalidate=False): + + @cache.cache('mycache', type="file", expire=15) + def load(search_term, limit, offset): + return load_the_data(search_term, limit, offset) + + # If the results should be invalidated first + if invalidate: + cache.invalidate(load, 'mycache', 'rabbits', 20, 0, type="file") + return load('rabbits', 20, 0) + + """ + namespace = func._arg_namespace + + cache = self.get_cache(namespace, **kwargs) + cache_key = " ".join(str(x) for x in args) + cache.remove_value(cache_key) diff --git a/pyload/lib/beaker/container.py b/pyload/lib/beaker/container.py new file mode 100644 index 000000000..515e97af6 --- /dev/null +++ b/pyload/lib/beaker/container.py @@ -0,0 +1,633 @@ +"""Container and Namespace classes""" +import anydbm +import cPickle +import logging +import os +import time + +import beaker.util as util +from beaker.exceptions import CreationAbortedError, MissingCacheParameter +from beaker.synchronization import _threading, file_synchronizer, \ + mutex_synchronizer, NameLock, null_synchronizer + +__all__ = ['Value', 'Container', 'ContainerContext', + 'MemoryContainer', 'DBMContainer', 'NamespaceManager', + 'MemoryNamespaceManager', 'DBMNamespaceManager', 'FileContainer', + 'OpenResourceNamespaceManager', + 'FileNamespaceManager', 'CreationAbortedError'] + + +logger = logging.getLogger('beaker.container') +if logger.isEnabledFor(logging.DEBUG): + debug = logger.debug +else: + def debug(message, *args): + pass + + +class NamespaceManager(object): + """Handles dictionary operations and locking for a namespace of + values. + + The implementation for setting and retrieving the namespace data is + handled by subclasses. + + NamespaceManager may be used alone, or may be privately accessed by + one or more Container objects. Container objects provide per-key + services like expiration times and automatic recreation of values. + + Multiple NamespaceManagers created with a particular name will all + share access to the same underlying datasource and will attempt to + synchronize against a common mutex object. The scope of this + sharing may be within a single process or across multiple + processes, depending on the type of NamespaceManager used. + + The NamespaceManager itself is generally threadsafe, except in the + case of the DBMNamespaceManager in conjunction with the gdbm dbm + implementation. + + """ + + @classmethod + def _init_dependencies(cls): + pass + + def __init__(self, namespace): + self._init_dependencies() + self.namespace = namespace + + def get_creation_lock(self, key): + raise NotImplementedError() + + def do_remove(self): + raise NotImplementedError() + + def acquire_read_lock(self): + pass + + def release_read_lock(self): + pass + + def acquire_write_lock(self, wait=True): + return True + + def release_write_lock(self): + pass + + def has_key(self, key): + return self.__contains__(key) + + def __getitem__(self, key): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError() + + def set_value(self, key, value, expiretime=None): + """Optional set_value() method called by Value. + + Allows an expiretime to be passed, for namespace + implementations which can prune their collections + using expiretime. + + """ + self[key] = value + + def __contains__(self, key): + raise NotImplementedError() + + def __delitem__(self, key): + raise NotImplementedError() + + def keys(self): + raise NotImplementedError() + + def remove(self): + self.do_remove() + + +class OpenResourceNamespaceManager(NamespaceManager): + """A NamespaceManager where read/write operations require opening/ + closing of a resource which is possibly mutexed. + + """ + def __init__(self, namespace): + NamespaceManager.__init__(self, namespace) + self.access_lock = self.get_access_lock() + self.openers = 0 + self.mutex = _threading.Lock() + + def get_access_lock(self): + raise NotImplementedError() + + def do_open(self, flags): + raise NotImplementedError() + + def do_close(self): + raise NotImplementedError() + + def acquire_read_lock(self): + self.access_lock.acquire_read_lock() + try: + self.open('r', checkcount = True) + except: + self.access_lock.release_read_lock() + raise + + def release_read_lock(self): + try: + self.close(checkcount = True) + finally: + self.access_lock.release_read_lock() + + def acquire_write_lock(self, wait=True): + r = self.access_lock.acquire_write_lock(wait) + try: + if (wait or r): + self.open('c', checkcount = True) + return r + except: + self.access_lock.release_write_lock() + raise + + def release_write_lock(self): + try: + self.close(checkcount=True) + finally: + self.access_lock.release_write_lock() + + def open(self, flags, checkcount=False): + self.mutex.acquire() + try: + if checkcount: + if self.openers == 0: + self.do_open(flags) + self.openers += 1 + else: + self.do_open(flags) + self.openers = 1 + finally: + self.mutex.release() + + def close(self, checkcount=False): + self.mutex.acquire() + try: + if checkcount: + self.openers -= 1 + if self.openers == 0: + self.do_close() + else: + if self.openers > 0: + self.do_close() + self.openers = 0 + finally: + self.mutex.release() + + def remove(self): + self.access_lock.acquire_write_lock() + try: + self.close(checkcount=False) + self.do_remove() + finally: + self.access_lock.release_write_lock() + +class Value(object): + __slots__ = 'key', 'createfunc', 'expiretime', 'expire_argument', 'starttime', 'storedtime',\ + 'namespace' + + def __init__(self, key, namespace, createfunc=None, expiretime=None, starttime=None): + self.key = key + self.createfunc = createfunc + self.expire_argument = expiretime + self.starttime = starttime + self.storedtime = -1 + self.namespace = namespace + + def has_value(self): + """return true if the container has a value stored. + + This is regardless of it being expired or not. + + """ + self.namespace.acquire_read_lock() + try: + return self.namespace.has_key(self.key) + finally: + self.namespace.release_read_lock() + + def can_have_value(self): + return self.has_current_value() or self.createfunc is not None + + def has_current_value(self): + self.namespace.acquire_read_lock() + try: + has_value = self.namespace.has_key(self.key) + if has_value: + try: + stored, expired, value = self._get_value() + return not self._is_expired(stored, expired) + except KeyError: + pass + return False + finally: + self.namespace.release_read_lock() + + def _is_expired(self, storedtime, expiretime): + """Return true if this container's value is expired.""" + return ( + ( + self.starttime is not None and + storedtime < self.starttime + ) + or + ( + expiretime is not None and + time.time() >= expiretime + storedtime + ) + ) + + def get_value(self): + self.namespace.acquire_read_lock() + try: + has_value = self.has_value() + if has_value: + try: + stored, expired, value = self._get_value() + if not self._is_expired(stored, expired): + return value + except KeyError: + # guard against un-mutexed backends raising KeyError + has_value = False + + if not self.createfunc: + raise KeyError(self.key) + finally: + self.namespace.release_read_lock() + + has_createlock = False + creation_lock = self.namespace.get_creation_lock(self.key) + if has_value: + if not creation_lock.acquire(wait=False): + debug("get_value returning old value while new one is created") + return value + else: + debug("lock_creatfunc (didnt wait)") + has_createlock = True + + if not has_createlock: + debug("lock_createfunc (waiting)") + creation_lock.acquire() + debug("lock_createfunc (waited)") + + try: + # see if someone created the value already + self.namespace.acquire_read_lock() + try: + if self.has_value(): + try: + stored, expired, value = self._get_value() + if not self._is_expired(stored, expired): + return value + except KeyError: + # guard against un-mutexed backends raising KeyError + pass + finally: + self.namespace.release_read_lock() + + debug("get_value creating new value") + v = self.createfunc() + self.set_value(v) + return v + finally: + creation_lock.release() + debug("released create lock") + + def _get_value(self): + value = self.namespace[self.key] + try: + stored, expired, value = value + except ValueError: + if not len(value) == 2: + raise + # Old format: upgrade + stored, value = value + expired = self.expire_argument + debug("get_value upgrading time %r expire time %r", stored, self.expire_argument) + self.namespace.release_read_lock() + self.set_value(value, stored) + self.namespace.acquire_read_lock() + except TypeError: + # occurs when the value is None. memcached + # may yank the rug from under us in which case + # that's the result + raise KeyError(self.key) + return stored, expired, value + + def set_value(self, value, storedtime=None): + self.namespace.acquire_write_lock() + try: + if storedtime is None: + storedtime = time.time() + debug("set_value stored time %r expire time %r", storedtime, self.expire_argument) + self.namespace.set_value(self.key, (storedtime, self.expire_argument, value)) + finally: + self.namespace.release_write_lock() + + def clear_value(self): + self.namespace.acquire_write_lock() + try: + debug("clear_value") + if self.namespace.has_key(self.key): + try: + del self.namespace[self.key] + except KeyError: + # guard against un-mutexed backends raising KeyError + pass + self.storedtime = -1 + finally: + self.namespace.release_write_lock() + +class AbstractDictionaryNSManager(NamespaceManager): + """A subclassable NamespaceManager that places data in a dictionary. + + Subclasses should provide a "dictionary" attribute or descriptor + which returns a dict-like object. The dictionary will store keys + that are local to the "namespace" attribute of this manager, so + ensure that the dictionary will not be used by any other namespace. + + e.g.:: + + import collections + cached_data = collections.defaultdict(dict) + + class MyDictionaryManager(AbstractDictionaryNSManager): + def __init__(self, namespace): + AbstractDictionaryNSManager.__init__(self, namespace) + self.dictionary = cached_data[self.namespace] + + The above stores data in a global dictionary called "cached_data", + which is structured as a dictionary of dictionaries, keyed + first on namespace name to a sub-dictionary, then on actual + cache key to value. + + """ + + def get_creation_lock(self, key): + return NameLock( + identifier="memorynamespace/funclock/%s/%s" % (self.namespace, key), + reentrant=True + ) + + def __getitem__(self, key): + return self.dictionary[key] + + def __contains__(self, key): + return self.dictionary.__contains__(key) + + def has_key(self, key): + return self.dictionary.__contains__(key) + + def __setitem__(self, key, value): + self.dictionary[key] = value + + def __delitem__(self, key): + del self.dictionary[key] + + def do_remove(self): + self.dictionary.clear() + + def keys(self): + return self.dictionary.keys() + +class MemoryNamespaceManager(AbstractDictionaryNSManager): + namespaces = util.SyncDict() + + def __init__(self, namespace, **kwargs): + AbstractDictionaryNSManager.__init__(self, namespace) + self.dictionary = MemoryNamespaceManager.namespaces.get(self.namespace, + dict) + +class DBMNamespaceManager(OpenResourceNamespaceManager): + def __init__(self, namespace, dbmmodule=None, data_dir=None, + dbm_dir=None, lock_dir=None, digest_filenames=True, **kwargs): + self.digest_filenames = digest_filenames + + if not dbm_dir and not data_dir: + raise MissingCacheParameter("data_dir or dbm_dir is required") + elif dbm_dir: + self.dbm_dir = dbm_dir + else: + self.dbm_dir = data_dir + "/container_dbm" + util.verify_directory(self.dbm_dir) + + if not lock_dir and not data_dir: + raise MissingCacheParameter("data_dir or lock_dir is required") + elif lock_dir: + self.lock_dir = lock_dir + else: + self.lock_dir = data_dir + "/container_dbm_lock" + util.verify_directory(self.lock_dir) + + self.dbmmodule = dbmmodule or anydbm + + self.dbm = None + OpenResourceNamespaceManager.__init__(self, namespace) + + self.file = util.encoded_path(root= self.dbm_dir, + identifiers=[self.namespace], + extension='.dbm', + digest_filenames=self.digest_filenames) + + debug("data file %s", self.file) + self._checkfile() + + def get_access_lock(self): + return file_synchronizer(identifier=self.namespace, + lock_dir=self.lock_dir) + + def get_creation_lock(self, key): + return file_synchronizer( + identifier = "dbmcontainer/funclock/%s" % self.namespace, + lock_dir=self.lock_dir + ) + + def file_exists(self, file): + if os.access(file, os.F_OK): + return True + else: + for ext in ('db', 'dat', 'pag', 'dir'): + if os.access(file + os.extsep + ext, os.F_OK): + return True + + return False + + def _checkfile(self): + if not self.file_exists(self.file): + g = self.dbmmodule.open(self.file, 'c') + g.close() + + def get_filenames(self): + list = [] + if os.access(self.file, os.F_OK): + list.append(self.file) + + for ext in ('pag', 'dir', 'db', 'dat'): + if os.access(self.file + os.extsep + ext, os.F_OK): + list.append(self.file + os.extsep + ext) + return list + + def do_open(self, flags): + debug("opening dbm file %s", self.file) + try: + self.dbm = self.dbmmodule.open(self.file, flags) + except: + self._checkfile() + self.dbm = self.dbmmodule.open(self.file, flags) + + def do_close(self): + if self.dbm is not None: + debug("closing dbm file %s", self.file) + self.dbm.close() + + def do_remove(self): + for f in self.get_filenames(): + os.remove(f) + + def __getitem__(self, key): + return cPickle.loads(self.dbm[key]) + + def __contains__(self, key): + return self.dbm.has_key(key) + + def __setitem__(self, key, value): + self.dbm[key] = cPickle.dumps(value) + + def __delitem__(self, key): + del self.dbm[key] + + def keys(self): + return self.dbm.keys() + + +class FileNamespaceManager(OpenResourceNamespaceManager): + def __init__(self, namespace, data_dir=None, file_dir=None, lock_dir=None, + digest_filenames=True, **kwargs): + self.digest_filenames = digest_filenames + + if not file_dir and not data_dir: + raise MissingCacheParameter("data_dir or file_dir is required") + elif file_dir: + self.file_dir = file_dir + else: + self.file_dir = data_dir + "/container_file" + util.verify_directory(self.file_dir) + + if not lock_dir and not data_dir: + raise MissingCacheParameter("data_dir or lock_dir is required") + elif lock_dir: + self.lock_dir = lock_dir + else: + self.lock_dir = data_dir + "/container_file_lock" + util.verify_directory(self.lock_dir) + OpenResourceNamespaceManager.__init__(self, namespace) + + self.file = util.encoded_path(root=self.file_dir, + identifiers=[self.namespace], + extension='.cache', + digest_filenames=self.digest_filenames) + self.hash = {} + + debug("data file %s", self.file) + + def get_access_lock(self): + return file_synchronizer(identifier=self.namespace, + lock_dir=self.lock_dir) + + def get_creation_lock(self, key): + return file_synchronizer( + identifier = "filecontainer/funclock/%s" % self.namespace, + lock_dir = self.lock_dir + ) + + def file_exists(self, file): + return os.access(file, os.F_OK) + + def do_open(self, flags): + if self.file_exists(self.file): + fh = open(self.file, 'rb') + try: + self.hash = cPickle.load(fh) + except (IOError, OSError, EOFError, cPickle.PickleError, ValueError): + pass + fh.close() + + self.flags = flags + + def do_close(self): + if self.flags == 'c' or self.flags == 'w': + fh = open(self.file, 'wb') + cPickle.dump(self.hash, fh) + fh.close() + + self.hash = {} + self.flags = None + + def do_remove(self): + try: + os.remove(self.file) + except OSError, err: + # for instance, because we haven't yet used this cache, + # but client code has asked for a clear() operation... + pass + self.hash = {} + + def __getitem__(self, key): + return self.hash[key] + + def __contains__(self, key): + return self.hash.has_key(key) + + def __setitem__(self, key, value): + self.hash[key] = value + + def __delitem__(self, key): + del self.hash[key] + + def keys(self): + return self.hash.keys() + + +#### legacy stuff to support the old "Container" class interface + +namespace_classes = {} + +ContainerContext = dict + +class ContainerMeta(type): + def __init__(cls, classname, bases, dict_): + namespace_classes[cls] = cls.namespace_class + return type.__init__(cls, classname, bases, dict_) + def __call__(self, key, context, namespace, createfunc=None, + expiretime=None, starttime=None, **kwargs): + if namespace in context: + ns = context[namespace] + else: + nscls = namespace_classes[self] + context[namespace] = ns = nscls(namespace, **kwargs) + return Value(key, ns, createfunc=createfunc, + expiretime=expiretime, starttime=starttime) + +class Container(object): + __metaclass__ = ContainerMeta + namespace_class = NamespaceManager + +class FileContainer(Container): + namespace_class = FileNamespaceManager + +class MemoryContainer(Container): + namespace_class = MemoryNamespaceManager + +class DBMContainer(Container): + namespace_class = DBMNamespaceManager + +DbmContainer = DBMContainer diff --git a/pyload/lib/beaker/converters.py b/pyload/lib/beaker/converters.py new file mode 100644 index 000000000..f0ad34963 --- /dev/null +++ b/pyload/lib/beaker/converters.py @@ -0,0 +1,26 @@ +# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org) +# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php +def asbool(obj): + if isinstance(obj, (str, unicode)): + obj = obj.strip().lower() + if obj in ['true', 'yes', 'on', 'y', 't', '1']: + return True + elif obj in ['false', 'no', 'off', 'n', 'f', '0']: + return False + else: + raise ValueError( + "String is not true/false: %r" % obj) + return bool(obj) + +def aslist(obj, sep=None, strip=True): + if isinstance(obj, (str, unicode)): + lst = obj.split(sep) + if strip: + lst = [v.strip() for v in lst] + return lst + elif isinstance(obj, (list, tuple)): + return obj + elif obj is None: + return [] + else: + return [obj] diff --git a/pyload/lib/beaker/crypto/__init__.py b/pyload/lib/beaker/crypto/__init__.py new file mode 100644 index 000000000..3e26b0c13 --- /dev/null +++ b/pyload/lib/beaker/crypto/__init__.py @@ -0,0 +1,40 @@ +from warnings import warn + +from beaker.crypto.pbkdf2 import PBKDF2, strxor +from beaker.crypto.util import hmac, sha1, hmac_sha1, md5 +from beaker import util + +keyLength = None + +if util.jython: + try: + from beaker.crypto.jcecrypto import getKeyLength, aesEncrypt + keyLength = getKeyLength() + except ImportError: + pass +else: + try: + from beaker.crypto.pycrypto import getKeyLength, aesEncrypt, aesDecrypt + keyLength = getKeyLength() + except ImportError: + pass + +if not keyLength: + has_aes = False +else: + has_aes = True + +if has_aes and keyLength < 32: + warn('Crypto implementation only supports key lengths up to %d bits. ' + 'Generated session cookies may be incompatible with other ' + 'environments' % (keyLength * 8)) + + +def generateCryptoKeys(master_key, salt, iterations): + # NB: We XOR parts of the keystream into the randomly-generated parts, just + # in case os.urandom() isn't as random as it should be. Note that if + # os.urandom() returns truly random data, this will have no effect on the + # overall security. + keystream = PBKDF2(master_key, salt, iterations=iterations) + cipher_key = keystream.read(keyLength) + return cipher_key diff --git a/pyload/lib/beaker/crypto/jcecrypto.py b/pyload/lib/beaker/crypto/jcecrypto.py new file mode 100644 index 000000000..4062d513e --- /dev/null +++ b/pyload/lib/beaker/crypto/jcecrypto.py @@ -0,0 +1,30 @@ +""" +Encryption module that uses the Java Cryptography Extensions (JCE). + +Note that in default installations of the Java Runtime Environment, the +maximum key length is limited to 128 bits due to US export +restrictions. This makes the generated keys incompatible with the ones +generated by pycryptopp, which has no such restrictions. To fix this, +download the "Unlimited Strength Jurisdiction Policy Files" from Sun, +which will allow encryption using 256 bit AES keys. +""" +from javax.crypto import Cipher +from javax.crypto.spec import SecretKeySpec, IvParameterSpec + +import jarray + +# Initialization vector filled with zeros +_iv = IvParameterSpec(jarray.zeros(16, 'b')) + +def aesEncrypt(data, key): + cipher = Cipher.getInstance('AES/CTR/NoPadding') + skeySpec = SecretKeySpec(key, 'AES') + cipher.init(Cipher.ENCRYPT_MODE, skeySpec, _iv) + return cipher.doFinal(data).tostring() + +# magic. +aesDecrypt = aesEncrypt + +def getKeyLength(): + maxlen = Cipher.getMaxAllowedKeyLength('AES/CTR/NoPadding') + return min(maxlen, 256) / 8 diff --git a/pyload/lib/beaker/crypto/pbkdf2.py b/pyload/lib/beaker/crypto/pbkdf2.py new file mode 100644 index 000000000..96dc5fbb2 --- /dev/null +++ b/pyload/lib/beaker/crypto/pbkdf2.py @@ -0,0 +1,342 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +########################################################################### +# PBKDF2.py - PKCS#5 v2.0 Password-Based Key Derivation +# +# Copyright (C) 2007 Dwayne C. Litzenberger <dlitz@dlitz.net> +# All rights reserved. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose and without fee is hereby granted, +# provided that the above copyright notice appear in all copies and that +# both that copyright notice and this permission notice appear in +# supporting documentation. +# +# THE AUTHOR PROVIDES THIS SOFTWARE ``AS IS'' AND ANY EXPRESSED OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES +# OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. +# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT +# NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Country of origin: Canada +# +########################################################################### +# Sample PBKDF2 usage: +# from Crypto.Cipher import AES +# from PBKDF2 import PBKDF2 +# import os +# +# salt = os.urandom(8) # 64-bit salt +# key = PBKDF2("This passphrase is a secret.", salt).read(32) # 256-bit key +# iv = os.urandom(16) # 128-bit IV +# cipher = AES.new(key, AES.MODE_CBC, iv) +# ... +# +# Sample crypt() usage: +# from PBKDF2 import crypt +# pwhash = crypt("secret") +# alleged_pw = raw_input("Enter password: ") +# if pwhash == crypt(alleged_pw, pwhash): +# print "Password good" +# else: +# print "Invalid password" +# +########################################################################### +# History: +# +# 2007-07-27 Dwayne C. Litzenberger <dlitz@dlitz.net> +# - Initial Release (v1.0) +# +# 2007-07-31 Dwayne C. Litzenberger <dlitz@dlitz.net> +# - Bugfix release (v1.1) +# - SECURITY: The PyCrypto XOR cipher (used, if available, in the _strxor +# function in the previous release) silently truncates all keys to 64 +# bytes. The way it was used in the previous release, this would only be +# problem if the pseudorandom function that returned values larger than +# 64 bytes (so SHA1, SHA256 and SHA512 are fine), but I don't like +# anything that silently reduces the security margin from what is +# expected. +# +########################################################################### + +__version__ = "1.1" + +from struct import pack +from binascii import b2a_hex +from random import randint + +from base64 import b64encode + +from beaker.crypto.util import hmac as HMAC, hmac_sha1 as SHA1 + +def strxor(a, b): + return "".join([chr(ord(x) ^ ord(y)) for (x, y) in zip(a, b)]) + +class PBKDF2(object): + """PBKDF2.py : PKCS#5 v2.0 Password-Based Key Derivation + + This implementation takes a passphrase and a salt (and optionally an + iteration count, a digest module, and a MAC module) and provides a + file-like object from which an arbitrarily-sized key can be read. + + If the passphrase and/or salt are unicode objects, they are encoded as + UTF-8 before they are processed. + + The idea behind PBKDF2 is to derive a cryptographic key from a + passphrase and a salt. + + PBKDF2 may also be used as a strong salted password hash. The + 'crypt' function is provided for that purpose. + + Remember: Keys generated using PBKDF2 are only as strong as the + passphrases they are derived from. + """ + + def __init__(self, passphrase, salt, iterations=1000, + digestmodule=SHA1, macmodule=HMAC): + if not callable(macmodule): + macmodule = macmodule.new + self.__macmodule = macmodule + self.__digestmodule = digestmodule + self._setup(passphrase, salt, iterations, self._pseudorandom) + + def _pseudorandom(self, key, msg): + """Pseudorandom function. e.g. HMAC-SHA1""" + return self.__macmodule(key=key, msg=msg, + digestmod=self.__digestmodule).digest() + + def read(self, bytes): + """Read the specified number of key bytes.""" + if self.closed: + raise ValueError("file-like object is closed") + + size = len(self.__buf) + blocks = [self.__buf] + i = self.__blockNum + while size < bytes: + i += 1 + if i > 0xffffffff: + # We could return "" here, but + raise OverflowError("derived key too long") + block = self.__f(i) + blocks.append(block) + size += len(block) + buf = "".join(blocks) + retval = buf[:bytes] + self.__buf = buf[bytes:] + self.__blockNum = i + return retval + + def __f(self, i): + # i must fit within 32 bits + assert (1 <= i <= 0xffffffff) + U = self.__prf(self.__passphrase, self.__salt + pack("!L", i)) + result = U + for j in xrange(2, 1+self.__iterations): + U = self.__prf(self.__passphrase, U) + result = strxor(result, U) + return result + + def hexread(self, octets): + """Read the specified number of octets. Return them as hexadecimal. + + Note that len(obj.hexread(n)) == 2*n. + """ + return b2a_hex(self.read(octets)) + + def _setup(self, passphrase, salt, iterations, prf): + # Sanity checks: + + # passphrase and salt must be str or unicode (in the latter + # case, we convert to UTF-8) + if isinstance(passphrase, unicode): + passphrase = passphrase.encode("UTF-8") + if not isinstance(passphrase, str): + raise TypeError("passphrase must be str or unicode") + if isinstance(salt, unicode): + salt = salt.encode("UTF-8") + if not isinstance(salt, str): + raise TypeError("salt must be str or unicode") + + # iterations must be an integer >= 1 + if not isinstance(iterations, (int, long)): + raise TypeError("iterations must be an integer") + if iterations < 1: + raise ValueError("iterations must be at least 1") + + # prf must be callable + if not callable(prf): + raise TypeError("prf must be callable") + + self.__passphrase = passphrase + self.__salt = salt + self.__iterations = iterations + self.__prf = prf + self.__blockNum = 0 + self.__buf = "" + self.closed = False + + def close(self): + """Close the stream.""" + if not self.closed: + del self.__passphrase + del self.__salt + del self.__iterations + del self.__prf + del self.__blockNum + del self.__buf + self.closed = True + +def crypt(word, salt=None, iterations=None): + """PBKDF2-based unix crypt(3) replacement. + + The number of iterations specified in the salt overrides the 'iterations' + parameter. + + The effective hash length is 192 bits. + """ + + # Generate a (pseudo-)random salt if the user hasn't provided one. + if salt is None: + salt = _makesalt() + + # salt must be a string or the us-ascii subset of unicode + if isinstance(salt, unicode): + salt = salt.encode("us-ascii") + if not isinstance(salt, str): + raise TypeError("salt must be a string") + + # word must be a string or unicode (in the latter case, we convert to UTF-8) + if isinstance(word, unicode): + word = word.encode("UTF-8") + if not isinstance(word, str): + raise TypeError("word must be a string or unicode") + + # Try to extract the real salt and iteration count from the salt + if salt.startswith("$p5k2$"): + (iterations, salt, dummy) = salt.split("$")[2:5] + if iterations == "": + iterations = 400 + else: + converted = int(iterations, 16) + if iterations != "%x" % converted: # lowercase hex, minimum digits + raise ValueError("Invalid salt") + iterations = converted + if not (iterations >= 1): + raise ValueError("Invalid salt") + + # Make sure the salt matches the allowed character set + allowed = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789./" + for ch in salt: + if ch not in allowed: + raise ValueError("Illegal character %r in salt" % (ch,)) + + if iterations is None or iterations == 400: + iterations = 400 + salt = "$p5k2$$" + salt + else: + salt = "$p5k2$%x$%s" % (iterations, salt) + rawhash = PBKDF2(word, salt, iterations).read(24) + return salt + "$" + b64encode(rawhash, "./") + +# Add crypt as a static method of the PBKDF2 class +# This makes it easier to do "from PBKDF2 import PBKDF2" and still use +# crypt. +PBKDF2.crypt = staticmethod(crypt) + +def _makesalt(): + """Return a 48-bit pseudorandom salt for crypt(). + + This function is not suitable for generating cryptographic secrets. + """ + binarysalt = "".join([pack("@H", randint(0, 0xffff)) for i in range(3)]) + return b64encode(binarysalt, "./") + +def test_pbkdf2(): + """Module self-test""" + from binascii import a2b_hex + + # + # Test vectors from RFC 3962 + # + + # Test 1 + result = PBKDF2("password", "ATHENA.MIT.EDUraeburn", 1).read(16) + expected = a2b_hex("cdedb5281bb2f801565a1122b2563515") + if result != expected: + raise RuntimeError("self-test failed") + + # Test 2 + result = PBKDF2("password", "ATHENA.MIT.EDUraeburn", 1200).hexread(32) + expected = ("5c08eb61fdf71e4e4ec3cf6ba1f5512b" + "a7e52ddbc5e5142f708a31e2e62b1e13") + if result != expected: + raise RuntimeError("self-test failed") + + # Test 3 + result = PBKDF2("X"*64, "pass phrase equals block size", 1200).hexread(32) + expected = ("139c30c0966bc32ba55fdbf212530ac9" + "c5ec59f1a452f5cc9ad940fea0598ed1") + if result != expected: + raise RuntimeError("self-test failed") + + # Test 4 + result = PBKDF2("X"*65, "pass phrase exceeds block size", 1200).hexread(32) + expected = ("9ccad6d468770cd51b10e6a68721be61" + "1a8b4d282601db3b36be9246915ec82a") + if result != expected: + raise RuntimeError("self-test failed") + + # + # Other test vectors + # + + # Chunked read + f = PBKDF2("kickstart", "workbench", 256) + result = f.read(17) + result += f.read(17) + result += f.read(1) + result += f.read(2) + result += f.read(3) + expected = PBKDF2("kickstart", "workbench", 256).read(40) + if result != expected: + raise RuntimeError("self-test failed") + + # + # crypt() test vectors + # + + # crypt 1 + result = crypt("cloadm", "exec") + expected = '$p5k2$$exec$r1EWMCMk7Rlv3L/RNcFXviDefYa0hlql' + if result != expected: + raise RuntimeError("self-test failed") + + # crypt 2 + result = crypt("gnu", '$p5k2$c$u9HvcT4d$.....') + expected = '$p5k2$c$u9HvcT4d$Sd1gwSVCLZYAuqZ25piRnbBEoAesaa/g' + if result != expected: + raise RuntimeError("self-test failed") + + # crypt 3 + result = crypt("dcl", "tUsch7fU", iterations=13) + expected = "$p5k2$d$tUsch7fU$nqDkaxMDOFBeJsTSfABsyn.PYUXilHwL" + if result != expected: + raise RuntimeError("self-test failed") + + # crypt 4 (unicode) + result = crypt(u'\u0399\u03c9\u03b1\u03bd\u03bd\u03b7\u03c2', + '$p5k2$$KosHgqNo$9mjN8gqjt02hDoP0c2J0ABtLIwtot8cQ') + expected = '$p5k2$$KosHgqNo$9mjN8gqjt02hDoP0c2J0ABtLIwtot8cQ' + if result != expected: + raise RuntimeError("self-test failed") + +if __name__ == '__main__': + test_pbkdf2() + +# vim:set ts=4 sw=4 sts=4 expandtab: diff --git a/pyload/lib/beaker/crypto/pycrypto.py b/pyload/lib/beaker/crypto/pycrypto.py new file mode 100644 index 000000000..a3eb4d9db --- /dev/null +++ b/pyload/lib/beaker/crypto/pycrypto.py @@ -0,0 +1,31 @@ +"""Encryption module that uses pycryptopp or pycrypto""" +try: + # Pycryptopp is preferred over Crypto because Crypto has had + # various periods of not being maintained, and pycryptopp uses + # the Crypto++ library which is generally considered the 'gold standard' + # of crypto implementations + from pycryptopp.cipher import aes + + def aesEncrypt(data, key): + cipher = aes.AES(key) + return cipher.process(data) + + # magic. + aesDecrypt = aesEncrypt + +except ImportError: + from Crypto.Cipher import AES + + def aesEncrypt(data, key): + cipher = AES.new(key) + + data = data + (" " * (16 - (len(data) % 16))) + return cipher.encrypt(data) + + def aesDecrypt(data, key): + cipher = AES.new(key) + + return cipher.decrypt(data).rstrip() + +def getKeyLength(): + return 32 diff --git a/pyload/lib/beaker/crypto/util.py b/pyload/lib/beaker/crypto/util.py new file mode 100644 index 000000000..d97e8ce6f --- /dev/null +++ b/pyload/lib/beaker/crypto/util.py @@ -0,0 +1,30 @@ +from warnings import warn +from beaker import util + + +try: + # Use PyCrypto (if available) + from Crypto.Hash import HMAC as hmac, SHA as hmac_sha1 + sha1 = hmac_sha1.new + +except ImportError: + + # PyCrypto not available. Use the Python standard library. + import hmac + + # When using the stdlib, we have to make sure the hmac version and sha + # version are compatible + if util.py24: + from sha import sha as sha1 + import sha as hmac_sha1 + else: + # NOTE: We have to use the callable with hashlib (hashlib.sha1), + # otherwise hmac only accepts the sha module object itself + from hashlib import sha1 + hmac_sha1 = sha1 + + +if util.py24: + from md5 import md5 +else: + from hashlib import md5 diff --git a/pyload/lib/beaker/exceptions.py b/pyload/lib/beaker/exceptions.py new file mode 100644 index 000000000..cc0eed286 --- /dev/null +++ b/pyload/lib/beaker/exceptions.py @@ -0,0 +1,24 @@ +"""Beaker exception classes""" + +class BeakerException(Exception): + pass + + +class CreationAbortedError(Exception): + """Deprecated.""" + + +class InvalidCacheBackendError(BeakerException, ImportError): + pass + + +class MissingCacheParameter(BeakerException): + pass + + +class LockError(BeakerException): + pass + + +class InvalidCryptoBackendError(BeakerException): + pass diff --git a/pyload/lib/beaker/ext/__init__.py b/pyload/lib/beaker/ext/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/pyload/lib/beaker/ext/__init__.py diff --git a/pyload/lib/beaker/ext/database.py b/pyload/lib/beaker/ext/database.py new file mode 100644 index 000000000..701e6f7d2 --- /dev/null +++ b/pyload/lib/beaker/ext/database.py @@ -0,0 +1,165 @@ +import cPickle +import logging +import pickle +from datetime import datetime + +from beaker.container import OpenResourceNamespaceManager, Container +from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter +from beaker.synchronization import file_synchronizer, null_synchronizer +from beaker.util import verify_directory, SyncDict + +log = logging.getLogger(__name__) + +sa = None +pool = None +types = None + +class DatabaseNamespaceManager(OpenResourceNamespaceManager): + metadatas = SyncDict() + tables = SyncDict() + + @classmethod + def _init_dependencies(cls): + global sa, pool, types + if sa is not None: + return + try: + import sqlalchemy as sa + import sqlalchemy.pool as pool + from sqlalchemy import types + except ImportError: + raise InvalidCacheBackendError("Database cache backend requires " + "the 'sqlalchemy' library") + + def __init__(self, namespace, url=None, sa_opts=None, optimistic=False, + table_name='beaker_cache', data_dir=None, lock_dir=None, + **params): + """Creates a database namespace manager + + ``url`` + SQLAlchemy compliant db url + ``sa_opts`` + A dictionary of SQLAlchemy keyword options to initialize the engine + with. + ``optimistic`` + Use optimistic session locking, note that this will result in an + additional select when updating a cache value to compare version + numbers. + ``table_name`` + The table name to use in the database for the cache. + """ + OpenResourceNamespaceManager.__init__(self, namespace) + + if sa_opts is None: + sa_opts = params + + if lock_dir: + self.lock_dir = lock_dir + elif data_dir: + self.lock_dir = data_dir + "/container_db_lock" + if self.lock_dir: + verify_directory(self.lock_dir) + + # Check to see if the table's been created before + url = url or sa_opts['sa.url'] + table_key = url + table_name + def make_cache(): + # Check to see if we have a connection pool open already + meta_key = url + table_name + def make_meta(): + # SQLAlchemy pops the url, this ensures it sticks around + # later + sa_opts['sa.url'] = url + engine = sa.engine_from_config(sa_opts, 'sa.') + meta = sa.MetaData() + meta.bind = engine + return meta + meta = DatabaseNamespaceManager.metadatas.get(meta_key, make_meta) + # Create the table object and cache it now + cache = sa.Table(table_name, meta, + sa.Column('id', types.Integer, primary_key=True), + sa.Column('namespace', types.String(255), nullable=False), + sa.Column('accessed', types.DateTime, nullable=False), + sa.Column('created', types.DateTime, nullable=False), + sa.Column('data', types.PickleType, nullable=False), + sa.UniqueConstraint('namespace') + ) + cache.create(checkfirst=True) + return cache + self.hash = {} + self._is_new = False + self.loaded = False + self.cache = DatabaseNamespaceManager.tables.get(table_key, make_cache) + + def get_access_lock(self): + return null_synchronizer() + + def get_creation_lock(self, key): + return file_synchronizer( + identifier ="databasecontainer/funclock/%s" % self.namespace, + lock_dir = self.lock_dir) + + def do_open(self, flags): + # If we already loaded the data, don't bother loading it again + if self.loaded: + self.flags = flags + return + + cache = self.cache + result = sa.select([cache.c.data], + cache.c.namespace==self.namespace + ).execute().fetchone() + if not result: + self._is_new = True + self.hash = {} + else: + self._is_new = False + try: + self.hash = result['data'] + except (IOError, OSError, EOFError, cPickle.PickleError, + pickle.PickleError): + log.debug("Couln't load pickle data, creating new storage") + self.hash = {} + self._is_new = True + self.flags = flags + self.loaded = True + + def do_close(self): + if self.flags is not None and (self.flags == 'c' or self.flags == 'w'): + cache = self.cache + if self._is_new: + cache.insert().execute(namespace=self.namespace, data=self.hash, + accessed=datetime.now(), + created=datetime.now()) + self._is_new = False + else: + cache.update(cache.c.namespace==self.namespace).execute( + data=self.hash, accessed=datetime.now()) + self.flags = None + + def do_remove(self): + cache = self.cache + cache.delete(cache.c.namespace==self.namespace).execute() + self.hash = {} + + # We can retain the fact that we did a load attempt, but since the + # file is gone this will be a new namespace should it be saved. + self._is_new = True + + def __getitem__(self, key): + return self.hash[key] + + def __contains__(self, key): + return self.hash.has_key(key) + + def __setitem__(self, key, value): + self.hash[key] = value + + def __delitem__(self, key): + del self.hash[key] + + def keys(self): + return self.hash.keys() + +class DatabaseContainer(Container): + namespace_manager = DatabaseNamespaceManager diff --git a/pyload/lib/beaker/ext/google.py b/pyload/lib/beaker/ext/google.py new file mode 100644 index 000000000..dd8380d7f --- /dev/null +++ b/pyload/lib/beaker/ext/google.py @@ -0,0 +1,120 @@ +import cPickle +import logging +from datetime import datetime + +from beaker.container import OpenResourceNamespaceManager, Container +from beaker.exceptions import InvalidCacheBackendError +from beaker.synchronization import null_synchronizer + +log = logging.getLogger(__name__) + +db = None + +class GoogleNamespaceManager(OpenResourceNamespaceManager): + tables = {} + + @classmethod + def _init_dependencies(cls): + global db + if db is not None: + return + try: + db = __import__('google.appengine.ext.db').appengine.ext.db + except ImportError: + raise InvalidCacheBackendError("Datastore cache backend requires the " + "'google.appengine.ext' library") + + def __init__(self, namespace, table_name='beaker_cache', **params): + """Creates a datastore namespace manager""" + OpenResourceNamespaceManager.__init__(self, namespace) + + def make_cache(): + table_dict = dict(created=db.DateTimeProperty(), + accessed=db.DateTimeProperty(), + data=db.BlobProperty()) + table = type(table_name, (db.Model,), table_dict) + return table + self.table_name = table_name + self.cache = GoogleNamespaceManager.tables.setdefault(table_name, make_cache()) + self.hash = {} + self._is_new = False + self.loaded = False + self.log_debug = logging.DEBUG >= log.getEffectiveLevel() + + # Google wants namespaces to start with letters, change the namespace + # to start with a letter + self.namespace = 'p%s' % self.namespace + + def get_access_lock(self): + return null_synchronizer() + + def get_creation_lock(self, key): + # this is weird, should probably be present + return null_synchronizer() + + def do_open(self, flags): + # If we already loaded the data, don't bother loading it again + if self.loaded: + self.flags = flags + return + + item = self.cache.get_by_key_name(self.namespace) + + if not item: + self._is_new = True + self.hash = {} + else: + self._is_new = False + try: + self.hash = cPickle.loads(str(item.data)) + except (IOError, OSError, EOFError, cPickle.PickleError): + if self.log_debug: + log.debug("Couln't load pickle data, creating new storage") + self.hash = {} + self._is_new = True + self.flags = flags + self.loaded = True + + def do_close(self): + if self.flags is not None and (self.flags == 'c' or self.flags == 'w'): + if self._is_new: + item = self.cache(key_name=self.namespace) + item.data = cPickle.dumps(self.hash) + item.created = datetime.now() + item.accessed = datetime.now() + item.put() + self._is_new = False + else: + item = self.cache.get_by_key_name(self.namespace) + item.data = cPickle.dumps(self.hash) + item.accessed = datetime.now() + item.put() + self.flags = None + + def do_remove(self): + item = self.cache.get_by_key_name(self.namespace) + item.delete() + self.hash = {} + + # We can retain the fact that we did a load attempt, but since the + # file is gone this will be a new namespace should it be saved. + self._is_new = True + + def __getitem__(self, key): + return self.hash[key] + + def __contains__(self, key): + return self.hash.has_key(key) + + def __setitem__(self, key, value): + self.hash[key] = value + + def __delitem__(self, key): + del self.hash[key] + + def keys(self): + return self.hash.keys() + + +class GoogleContainer(Container): + namespace_class = GoogleNamespaceManager diff --git a/pyload/lib/beaker/ext/memcached.py b/pyload/lib/beaker/ext/memcached.py new file mode 100644 index 000000000..96516953f --- /dev/null +++ b/pyload/lib/beaker/ext/memcached.py @@ -0,0 +1,82 @@ +from beaker.container import NamespaceManager, Container +from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter +from beaker.synchronization import file_synchronizer, null_synchronizer +from beaker.util import verify_directory, SyncDict +import warnings + +memcache = None + +class MemcachedNamespaceManager(NamespaceManager): + clients = SyncDict() + + @classmethod + def _init_dependencies(cls): + global memcache + if memcache is not None: + return + try: + import pylibmc as memcache + except ImportError: + try: + import cmemcache as memcache + warnings.warn("cmemcache is known to have serious " + "concurrency issues; consider using 'memcache' or 'pylibmc'") + except ImportError: + try: + import memcache + except ImportError: + raise InvalidCacheBackendError("Memcached cache backend requires either " + "the 'memcache' or 'cmemcache' library") + + def __init__(self, namespace, url=None, data_dir=None, lock_dir=None, **params): + NamespaceManager.__init__(self, namespace) + + if not url: + raise MissingCacheParameter("url is required") + + if lock_dir: + self.lock_dir = lock_dir + elif data_dir: + self.lock_dir = data_dir + "/container_mcd_lock" + if self.lock_dir: + verify_directory(self.lock_dir) + + self.mc = MemcachedNamespaceManager.clients.get(url, memcache.Client, url.split(';')) + + def get_creation_lock(self, key): + return file_synchronizer( + identifier="memcachedcontainer/funclock/%s" % self.namespace,lock_dir = self.lock_dir) + + def _format_key(self, key): + return self.namespace + '_' + key.replace(' ', '\302\267') + + def __getitem__(self, key): + return self.mc.get(self._format_key(key)) + + def __contains__(self, key): + value = self.mc.get(self._format_key(key)) + return value is not None + + def has_key(self, key): + return key in self + + def set_value(self, key, value, expiretime=None): + if expiretime: + self.mc.set(self._format_key(key), value, time=expiretime) + else: + self.mc.set(self._format_key(key), value) + + def __setitem__(self, key, value): + self.set_value(key, value) + + def __delitem__(self, key): + self.mc.delete(self._format_key(key)) + + def do_remove(self): + self.mc.flush_all() + + def keys(self): + raise NotImplementedError("Memcache caching does not support iteration of all cache keys") + +class MemcachedContainer(Container): + namespace_class = MemcachedNamespaceManager diff --git a/pyload/lib/beaker/ext/sqla.py b/pyload/lib/beaker/ext/sqla.py new file mode 100644 index 000000000..8c79633c1 --- /dev/null +++ b/pyload/lib/beaker/ext/sqla.py @@ -0,0 +1,133 @@ +import cPickle +import logging +import pickle +from datetime import datetime + +from beaker.container import OpenResourceNamespaceManager, Container +from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter +from beaker.synchronization import file_synchronizer, null_synchronizer +from beaker.util import verify_directory, SyncDict + + +log = logging.getLogger(__name__) + +sa = None + +class SqlaNamespaceManager(OpenResourceNamespaceManager): + binds = SyncDict() + tables = SyncDict() + + @classmethod + def _init_dependencies(cls): + global sa + if sa is not None: + return + try: + import sqlalchemy as sa + except ImportError: + raise InvalidCacheBackendError("SQLAlchemy, which is required by " + "this backend, is not installed") + + def __init__(self, namespace, bind, table, data_dir=None, lock_dir=None, + **kwargs): + """Create a namespace manager for use with a database table via + SQLAlchemy. + + ``bind`` + SQLAlchemy ``Engine`` or ``Connection`` object + + ``table`` + SQLAlchemy ``Table`` object in which to store namespace data. + This should usually be something created by ``make_cache_table``. + """ + OpenResourceNamespaceManager.__init__(self, namespace) + + if lock_dir: + self.lock_dir = lock_dir + elif data_dir: + self.lock_dir = data_dir + "/container_db_lock" + if self.lock_dir: + verify_directory(self.lock_dir) + + self.bind = self.__class__.binds.get(str(bind.url), lambda: bind) + self.table = self.__class__.tables.get('%s:%s' % (bind.url, table.name), + lambda: table) + self.hash = {} + self._is_new = False + self.loaded = False + + def get_access_lock(self): + return null_synchronizer() + + def get_creation_lock(self, key): + return file_synchronizer( + identifier ="databasecontainer/funclock/%s" % self.namespace, + lock_dir=self.lock_dir) + + def do_open(self, flags): + if self.loaded: + self.flags = flags + return + select = sa.select([self.table.c.data], + (self.table.c.namespace == self.namespace)) + result = self.bind.execute(select).fetchone() + if not result: + self._is_new = True + self.hash = {} + else: + self._is_new = False + try: + self.hash = result['data'] + except (IOError, OSError, EOFError, cPickle.PickleError, + pickle.PickleError): + log.debug("Couln't load pickle data, creating new storage") + self.hash = {} + self._is_new = True + self.flags = flags + self.loaded = True + + def do_close(self): + if self.flags is not None and (self.flags == 'c' or self.flags == 'w'): + if self._is_new: + insert = self.table.insert() + self.bind.execute(insert, namespace=self.namespace, data=self.hash, + accessed=datetime.now(), created=datetime.now()) + self._is_new = False + else: + update = self.table.update(self.table.c.namespace == self.namespace) + self.bind.execute(update, data=self.hash, accessed=datetime.now()) + self.flags = None + + def do_remove(self): + delete = self.table.delete(self.table.c.namespace == self.namespace) + self.bind.execute(delete) + self.hash = {} + self._is_new = True + + def __getitem__(self, key): + return self.hash[key] + + def __contains__(self, key): + return self.hash.has_key(key) + + def __setitem__(self, key, value): + self.hash[key] = value + + def __delitem__(self, key): + del self.hash[key] + + def keys(self): + return self.hash.keys() + + +class SqlaContainer(Container): + namespace_manager = SqlaNamespaceManager + +def make_cache_table(metadata, table_name='beaker_cache'): + """Return a ``Table`` object suitable for storing cached values for the + namespace manager. Do not create the table.""" + return sa.Table(table_name, metadata, + sa.Column('namespace', sa.String(255), primary_key=True), + sa.Column('accessed', sa.DateTime, nullable=False), + sa.Column('created', sa.DateTime, nullable=False), + sa.Column('data', sa.PickleType, nullable=False)) diff --git a/pyload/lib/beaker/middleware.py b/pyload/lib/beaker/middleware.py new file mode 100644 index 000000000..7ba88b37d --- /dev/null +++ b/pyload/lib/beaker/middleware.py @@ -0,0 +1,165 @@ +import warnings + +try: + from paste.registry import StackedObjectProxy + beaker_session = StackedObjectProxy(name="Beaker Session") + beaker_cache = StackedObjectProxy(name="Cache Manager") +except: + beaker_cache = None + beaker_session = None + +from beaker.cache import CacheManager +from beaker.session import Session, SessionObject +from beaker.util import coerce_cache_params, coerce_session_params, \ + parse_cache_config_options + + +class CacheMiddleware(object): + cache = beaker_cache + + def __init__(self, app, config=None, environ_key='beaker.cache', **kwargs): + """Initialize the Cache Middleware + + The Cache middleware will make a Cache instance available + every request under the ``environ['beaker.cache']`` key by + default. The location in environ can be changed by setting + ``environ_key``. + + ``config`` + dict All settings should be prefixed by 'cache.'. This + method of passing variables is intended for Paste and other + setups that accumulate multiple component settings in a + single dictionary. If config contains *no cache. prefixed + args*, then *all* of the config options will be used to + intialize the Cache objects. + + ``environ_key`` + Location where the Cache instance will keyed in the WSGI + environ + + ``**kwargs`` + All keyword arguments are assumed to be cache settings and + will override any settings found in ``config`` + + """ + self.app = app + config = config or {} + + self.options = {} + + # Update the options with the parsed config + self.options.update(parse_cache_config_options(config)) + + # Add any options from kwargs, but leave out the defaults this + # time + self.options.update( + parse_cache_config_options(kwargs, include_defaults=False)) + + # Assume all keys are intended for cache if none are prefixed with + # 'cache.' + if not self.options and config: + self.options = config + + self.options.update(kwargs) + self.cache_manager = CacheManager(**self.options) + self.environ_key = environ_key + + def __call__(self, environ, start_response): + if environ.get('paste.registry'): + if environ['paste.registry'].reglist: + environ['paste.registry'].register(self.cache, + self.cache_manager) + environ[self.environ_key] = self.cache_manager + return self.app(environ, start_response) + + +class SessionMiddleware(object): + session = beaker_session + + def __init__(self, wrap_app, config=None, environ_key='beaker.session', + **kwargs): + """Initialize the Session Middleware + + The Session middleware will make a lazy session instance + available every request under the ``environ['beaker.session']`` + key by default. The location in environ can be changed by + setting ``environ_key``. + + ``config`` + dict All settings should be prefixed by 'session.'. This + method of passing variables is intended for Paste and other + setups that accumulate multiple component settings in a + single dictionary. If config contains *no cache. prefixed + args*, then *all* of the config options will be used to + intialize the Cache objects. + + ``environ_key`` + Location where the Session instance will keyed in the WSGI + environ + + ``**kwargs`` + All keyword arguments are assumed to be session settings and + will override any settings found in ``config`` + + """ + config = config or {} + + # Load up the default params + self.options = dict(invalidate_corrupt=True, type=None, + data_dir=None, key='beaker.session.id', + timeout=None, secret=None, log_file=None) + + # Pull out any config args meant for beaker session. if there are any + for dct in [config, kwargs]: + for key, val in dct.iteritems(): + if key.startswith('beaker.session.'): + self.options[key[15:]] = val + if key.startswith('session.'): + self.options[key[8:]] = val + if key.startswith('session_'): + warnings.warn('Session options should start with session. ' + 'instead of session_.', DeprecationWarning, 2) + self.options[key[8:]] = val + + # Coerce and validate session params + coerce_session_params(self.options) + + # Assume all keys are intended for cache if none are prefixed with + # 'cache.' + if not self.options and config: + self.options = config + + self.options.update(kwargs) + self.wrap_app = wrap_app + self.environ_key = environ_key + + def __call__(self, environ, start_response): + session = SessionObject(environ, **self.options) + if environ.get('paste.registry'): + if environ['paste.registry'].reglist: + environ['paste.registry'].register(self.session, session) + environ[self.environ_key] = session + environ['beaker.get_session'] = self._get_session + + def session_start_response(status, headers, exc_info = None): + if session.accessed(): + session.persist() + if session.__dict__['_headers']['set_cookie']: + cookie = session.__dict__['_headers']['cookie_out'] + if cookie: + headers.append(('Set-cookie', cookie)) + return start_response(status, headers, exc_info) + return self.wrap_app(environ, session_start_response) + + def _get_session(self): + return Session({}, use_cookies=False, **self.options) + + +def session_filter_factory(global_conf, **kwargs): + def filter(app): + return SessionMiddleware(app, global_conf, **kwargs) + return filter + + +def session_filter_app_factory(app, global_conf, **kwargs): + return SessionMiddleware(app, global_conf, **kwargs) diff --git a/pyload/lib/beaker/session.py b/pyload/lib/beaker/session.py new file mode 100644 index 000000000..7d465530b --- /dev/null +++ b/pyload/lib/beaker/session.py @@ -0,0 +1,618 @@ +import Cookie +import os +import random +import time +from datetime import datetime, timedelta + +from beaker.crypto import hmac as HMAC, hmac_sha1 as SHA1, md5 +from beaker.util import pickle + +from beaker import crypto +from beaker.cache import clsmap +from beaker.exceptions import BeakerException, InvalidCryptoBackendError +from base64 import b64encode, b64decode + + +__all__ = ['SignedCookie', 'Session'] + +getpid = hasattr(os, 'getpid') and os.getpid or (lambda : '') + +class SignedCookie(Cookie.BaseCookie): + """Extends python cookie to give digital signature support""" + def __init__(self, secret, input=None): + self.secret = secret + Cookie.BaseCookie.__init__(self, input) + + def value_decode(self, val): + val = val.strip('"') + sig = HMAC.new(self.secret, val[40:], SHA1).hexdigest() + + # Avoid timing attacks + invalid_bits = 0 + input_sig = val[:40] + if len(sig) != len(input_sig): + return None, val + + for a, b in zip(sig, input_sig): + invalid_bits += a != b + + if invalid_bits: + return None, val + else: + return val[40:], val + + def value_encode(self, val): + sig = HMAC.new(self.secret, val, SHA1).hexdigest() + return str(val), ("%s%s" % (sig, val)) + + +class Session(dict): + """Session object that uses container package for storage. + + ``key`` + The name the cookie should be set to. + ``timeout`` + How long session data is considered valid. This is used + regardless of the cookie being present or not to determine + whether session data is still valid. + ``cookie_domain`` + Domain to use for the cookie. + ``secure`` + Whether or not the cookie should only be sent over SSL. + """ + def __init__(self, request, id=None, invalidate_corrupt=False, + use_cookies=True, type=None, data_dir=None, + key='beaker.session.id', timeout=None, cookie_expires=True, + cookie_domain=None, secret=None, secure=False, + namespace_class=None, **namespace_args): + if not type: + if data_dir: + self.type = 'file' + else: + self.type = 'memory' + else: + self.type = type + + self.namespace_class = namespace_class or clsmap[self.type] + + self.namespace_args = namespace_args + + self.request = request + self.data_dir = data_dir + self.key = key + + self.timeout = timeout + self.use_cookies = use_cookies + self.cookie_expires = cookie_expires + + # Default cookie domain/path + self._domain = cookie_domain + self._path = '/' + self.was_invalidated = False + self.secret = secret + self.secure = secure + self.id = id + self.accessed_dict = {} + + if self.use_cookies: + cookieheader = request.get('cookie', '') + if secret: + try: + self.cookie = SignedCookie(secret, input=cookieheader) + except Cookie.CookieError: + self.cookie = SignedCookie(secret, input=None) + else: + self.cookie = Cookie.SimpleCookie(input=cookieheader) + + if not self.id and self.key in self.cookie: + self.id = self.cookie[self.key].value + + self.is_new = self.id is None + if self.is_new: + self._create_id() + self['_accessed_time'] = self['_creation_time'] = time.time() + else: + try: + self.load() + except: + if invalidate_corrupt: + self.invalidate() + else: + raise + + def _create_id(self): + self.id = md5( + md5("%f%s%f%s" % (time.time(), id({}), random.random(), + getpid())).hexdigest(), + ).hexdigest() + self.is_new = True + self.last_accessed = None + if self.use_cookies: + self.cookie[self.key] = self.id + if self._domain: + self.cookie[self.key]['domain'] = self._domain + if self.secure: + self.cookie[self.key]['secure'] = True + self.cookie[self.key]['path'] = self._path + if self.cookie_expires is not True: + if self.cookie_expires is False: + expires = datetime.fromtimestamp( 0x7FFFFFFF ) + elif isinstance(self.cookie_expires, timedelta): + expires = datetime.today() + self.cookie_expires + elif isinstance(self.cookie_expires, datetime): + expires = self.cookie_expires + else: + raise ValueError("Invalid argument for cookie_expires: %s" + % repr(self.cookie_expires)) + self.cookie[self.key]['expires'] = \ + expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = False + + def created(self): + return self['_creation_time'] + created = property(created) + + def _set_domain(self, domain): + self['_domain'] = domain + self.cookie[self.key]['domain'] = domain + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = True + + def _get_domain(self): + return self._domain + + domain = property(_get_domain, _set_domain) + + def _set_path(self, path): + self['_path'] = path + self.cookie[self.key]['path'] = path + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = True + + def _get_path(self): + return self._path + + path = property(_get_path, _set_path) + + def _delete_cookie(self): + self.request['set_cookie'] = True + self.cookie[self.key] = self.id + if self._domain: + self.cookie[self.key]['domain'] = self._domain + if self.secure: + self.cookie[self.key]['secure'] = True + self.cookie[self.key]['path'] = '/' + expires = datetime.today().replace(year=2003) + self.cookie[self.key]['expires'] = \ + expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = True + + def delete(self): + """Deletes the session from the persistent storage, and sends + an expired cookie out""" + if self.use_cookies: + self._delete_cookie() + self.clear() + + def invalidate(self): + """Invalidates this session, creates a new session id, returns + to the is_new state""" + self.clear() + self.was_invalidated = True + self._create_id() + self.load() + + def load(self): + "Loads the data from this session from persistent storage" + self.namespace = self.namespace_class(self.id, + data_dir=self.data_dir, digest_filenames=False, + **self.namespace_args) + now = time.time() + self.request['set_cookie'] = True + + self.namespace.acquire_read_lock() + timed_out = False + try: + self.clear() + try: + session_data = self.namespace['session'] + + # Memcached always returns a key, its None when its not + # present + if session_data is None: + session_data = { + '_creation_time':now, + '_accessed_time':now + } + self.is_new = True + except (KeyError, TypeError): + session_data = { + '_creation_time':now, + '_accessed_time':now + } + self.is_new = True + + if self.timeout is not None and \ + now - session_data['_accessed_time'] > self.timeout: + timed_out= True + else: + # Properly set the last_accessed time, which is different + # than the *currently* _accessed_time + if self.is_new or '_accessed_time' not in session_data: + self.last_accessed = None + else: + self.last_accessed = session_data['_accessed_time'] + + # Update the current _accessed_time + session_data['_accessed_time'] = now + self.update(session_data) + self.accessed_dict = session_data.copy() + finally: + self.namespace.release_read_lock() + if timed_out: + self.invalidate() + + def save(self, accessed_only=False): + """Saves the data for this session to persistent storage + + If accessed_only is True, then only the original data loaded + at the beginning of the request will be saved, with the updated + last accessed time. + + """ + # Look to see if its a new session that was only accessed + # Don't save it under that case + if accessed_only and self.is_new: + return None + + if not hasattr(self, 'namespace'): + self.namespace = self.namespace_class( + self.id, + data_dir=self.data_dir, + digest_filenames=False, + **self.namespace_args) + + self.namespace.acquire_write_lock() + try: + if accessed_only: + data = dict(self.accessed_dict.items()) + else: + data = dict(self.items()) + + # Save the data + if not data and 'session' in self.namespace: + del self.namespace['session'] + else: + self.namespace['session'] = data + finally: + self.namespace.release_write_lock() + if self.is_new: + self.request['set_cookie'] = True + + def revert(self): + """Revert the session to its original state from its first + access in the request""" + self.clear() + self.update(self.accessed_dict) + + # TODO: I think both these methods should be removed. They're from + # the original mod_python code i was ripping off but they really + # have no use here. + def lock(self): + """Locks this session against other processes/threads. This is + automatic when load/save is called. + + ***use with caution*** and always with a corresponding 'unlock' + inside a "finally:" block, as a stray lock typically cannot be + unlocked without shutting down the whole application. + + """ + self.namespace.acquire_write_lock() + + def unlock(self): + """Unlocks this session against other processes/threads. This + is automatic when load/save is called. + + ***use with caution*** and always within a "finally:" block, as + a stray lock typically cannot be unlocked without shutting down + the whole application. + + """ + self.namespace.release_write_lock() + +class CookieSession(Session): + """Pure cookie-based session + + Options recognized when using cookie-based sessions are slightly + more restricted than general sessions. + + ``key`` + The name the cookie should be set to. + ``timeout`` + How long session data is considered valid. This is used + regardless of the cookie being present or not to determine + whether session data is still valid. + ``encrypt_key`` + The key to use for the session encryption, if not provided the + session will not be encrypted. + ``validate_key`` + The key used to sign the encrypted session + ``cookie_domain`` + Domain to use for the cookie. + ``secure`` + Whether or not the cookie should only be sent over SSL. + + """ + def __init__(self, request, key='beaker.session.id', timeout=None, + cookie_expires=True, cookie_domain=None, encrypt_key=None, + validate_key=None, secure=False, **kwargs): + + if not crypto.has_aes and encrypt_key: + raise InvalidCryptoBackendError("No AES library is installed, can't generate " + "encrypted cookie-only Session.") + + self.request = request + self.key = key + self.timeout = timeout + self.cookie_expires = cookie_expires + self.encrypt_key = encrypt_key + self.validate_key = validate_key + self.request['set_cookie'] = False + self.secure = secure + self._domain = cookie_domain + self._path = '/' + + try: + cookieheader = request['cookie'] + except KeyError: + cookieheader = '' + + if validate_key is None: + raise BeakerException("No validate_key specified for Cookie only " + "Session.") + + try: + self.cookie = SignedCookie(validate_key, input=cookieheader) + except Cookie.CookieError: + self.cookie = SignedCookie(validate_key, input=None) + + self['_id'] = self._make_id() + self.is_new = True + + # If we have a cookie, load it + if self.key in self.cookie and self.cookie[self.key].value is not None: + self.is_new = False + try: + self.update(self._decrypt_data()) + except: + pass + if self.timeout is not None and time.time() - \ + self['_accessed_time'] > self.timeout: + self.clear() + self.accessed_dict = self.copy() + self._create_cookie() + + def created(self): + return self['_creation_time'] + created = property(created) + + def id(self): + return self['_id'] + id = property(id) + + def _set_domain(self, domain): + self['_domain'] = domain + self._domain = domain + + def _get_domain(self): + return self._domain + + domain = property(_get_domain, _set_domain) + + def _set_path(self, path): + self['_path'] = path + self._path = path + + def _get_path(self): + return self._path + + path = property(_get_path, _set_path) + + def _encrypt_data(self): + """Serialize, encipher, and base64 the session dict""" + if self.encrypt_key: + nonce = b64encode(os.urandom(40))[:8] + encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, + self.validate_key + nonce, 1) + data = pickle.dumps(self.copy(), 2) + return nonce + b64encode(crypto.aesEncrypt(data, encrypt_key)) + else: + data = pickle.dumps(self.copy(), 2) + return b64encode(data) + + def _decrypt_data(self): + """Bas64, decipher, then un-serialize the data for the session + dict""" + if self.encrypt_key: + nonce = self.cookie[self.key].value[:8] + encrypt_key = crypto.generateCryptoKeys(self.encrypt_key, + self.validate_key + nonce, 1) + payload = b64decode(self.cookie[self.key].value[8:]) + data = crypto.aesDecrypt(payload, encrypt_key) + return pickle.loads(data) + else: + data = b64decode(self.cookie[self.key].value) + return pickle.loads(data) + + def _make_id(self): + return md5(md5( + "%f%s%f%s" % (time.time(), id({}), random.random(), getpid()) + ).hexdigest() + ).hexdigest() + + def save(self, accessed_only=False): + """Saves the data for this session to persistent storage""" + if accessed_only and self.is_new: + return + if accessed_only: + self.clear() + self.update(self.accessed_dict) + self._create_cookie() + + def expire(self): + """Delete the 'expires' attribute on this Session, if any.""" + + self.pop('_expires', None) + + def _create_cookie(self): + if '_creation_time' not in self: + self['_creation_time'] = time.time() + if '_id' not in self: + self['_id'] = self._make_id() + self['_accessed_time'] = time.time() + + if self.cookie_expires is not True: + if self.cookie_expires is False: + expires = datetime.fromtimestamp( 0x7FFFFFFF ) + elif isinstance(self.cookie_expires, timedelta): + expires = datetime.today() + self.cookie_expires + elif isinstance(self.cookie_expires, datetime): + expires = self.cookie_expires + else: + raise ValueError("Invalid argument for cookie_expires: %s" + % repr(self.cookie_expires)) + self['_expires'] = expires + elif '_expires' in self: + expires = self['_expires'] + else: + expires = None + + val = self._encrypt_data() + if len(val) > 4064: + raise BeakerException("Cookie value is too long to store") + + self.cookie[self.key] = val + if '_domain' in self: + self.cookie[self.key]['domain'] = self['_domain'] + elif self._domain: + self.cookie[self.key]['domain'] = self._domain + if self.secure: + self.cookie[self.key]['secure'] = True + + self.cookie[self.key]['path'] = self.get('_path', '/') + + if expires: + self.cookie[self.key]['expires'] = \ + expires.strftime("%a, %d-%b-%Y %H:%M:%S GMT" ) + self.request['cookie_out'] = self.cookie[self.key].output(header='') + self.request['set_cookie'] = True + + def delete(self): + """Delete the cookie, and clear the session""" + # Send a delete cookie request + self._delete_cookie() + self.clear() + + def invalidate(self): + """Clear the contents and start a new session""" + self.delete() + self['_id'] = self._make_id() + + +class SessionObject(object): + """Session proxy/lazy creator + + This object proxies access to the actual session object, so that in + the case that the session hasn't been used before, it will be + setup. This avoid creating and loading the session from persistent + storage unless its actually used during the request. + + """ + def __init__(self, environ, **params): + self.__dict__['_params'] = params + self.__dict__['_environ'] = environ + self.__dict__['_sess'] = None + self.__dict__['_headers'] = [] + + def _session(self): + """Lazy initial creation of session object""" + if self.__dict__['_sess'] is None: + params = self.__dict__['_params'] + environ = self.__dict__['_environ'] + self.__dict__['_headers'] = req = {'cookie_out':None} + req['cookie'] = environ.get('HTTP_COOKIE') + if params.get('type') == 'cookie': + self.__dict__['_sess'] = CookieSession(req, **params) + else: + self.__dict__['_sess'] = Session(req, use_cookies=True, + **params) + return self.__dict__['_sess'] + + def __getattr__(self, attr): + return getattr(self._session(), attr) + + def __setattr__(self, attr, value): + setattr(self._session(), attr, value) + + def __delattr__(self, name): + self._session().__delattr__(name) + + def __getitem__(self, key): + return self._session()[key] + + def __setitem__(self, key, value): + self._session()[key] = value + + def __delitem__(self, key): + self._session().__delitem__(key) + + def __repr__(self): + return self._session().__repr__() + + def __iter__(self): + """Only works for proxying to a dict""" + return iter(self._session().keys()) + + def __contains__(self, key): + return self._session().has_key(key) + + def get_by_id(self, id): + """Loads a session given a session ID""" + params = self.__dict__['_params'] + session = Session({}, use_cookies=False, id=id, **params) + if session.is_new: + return None + return session + + def save(self): + self.__dict__['_dirty'] = True + + def delete(self): + self.__dict__['_dirty'] = True + self._session().delete() + + def persist(self): + """Persist the session to the storage + + If its set to autosave, then the entire session will be saved + regardless of if save() has been called. Otherwise, just the + accessed time will be updated if save() was not called, or + the session will be saved if save() was called. + + """ + if self.__dict__['_params'].get('auto'): + self._session().save() + else: + if self.__dict__.get('_dirty'): + self._session().save() + else: + self._session().save(accessed_only=True) + + def dirty(self): + return self.__dict__.get('_dirty', False) + + def accessed(self): + """Returns whether or not the session has been accessed""" + return self.__dict__['_sess'] is not None diff --git a/pyload/lib/beaker/synchronization.py b/pyload/lib/beaker/synchronization.py new file mode 100644 index 000000000..761303707 --- /dev/null +++ b/pyload/lib/beaker/synchronization.py @@ -0,0 +1,381 @@ +"""Synchronization functions. + +File- and mutex-based mutual exclusion synchronizers are provided, +as well as a name-based mutex which locks within an application +based on a string name. + +""" + +import os +import sys +import tempfile + +try: + import threading as _threading +except ImportError: + import dummy_threading as _threading + +# check for fcntl module +try: + sys.getwindowsversion() + has_flock = False +except: + try: + import fcntl + has_flock = True + except ImportError: + has_flock = False + +from beaker import util +from beaker.exceptions import LockError + +__all__ = ["file_synchronizer", "mutex_synchronizer", "null_synchronizer", + "NameLock", "_threading"] + + +class NameLock(object): + """a proxy for an RLock object that is stored in a name based + registry. + + Multiple threads can get a reference to the same RLock based on the + name alone, and synchronize operations related to that name. + + """ + locks = util.WeakValuedRegistry() + + class NLContainer(object): + def __init__(self, reentrant): + if reentrant: + self.lock = _threading.RLock() + else: + self.lock = _threading.Lock() + def __call__(self): + return self.lock + + def __init__(self, identifier = None, reentrant = False): + if identifier is None: + self._lock = NameLock.NLContainer(reentrant) + else: + self._lock = NameLock.locks.get(identifier, NameLock.NLContainer, + reentrant) + + def acquire(self, wait = True): + return self._lock().acquire(wait) + + def release(self): + self._lock().release() + + +_synchronizers = util.WeakValuedRegistry() +def _synchronizer(identifier, cls, **kwargs): + return _synchronizers.sync_get((identifier, cls), cls, identifier, **kwargs) + + +def file_synchronizer(identifier, **kwargs): + if not has_flock or 'lock_dir' not in kwargs: + return mutex_synchronizer(identifier) + else: + return _synchronizer(identifier, FileSynchronizer, **kwargs) + + +def mutex_synchronizer(identifier, **kwargs): + return _synchronizer(identifier, ConditionSynchronizer, **kwargs) + + +class null_synchronizer(object): + def acquire_write_lock(self, wait=True): + return True + def acquire_read_lock(self): + pass + def release_write_lock(self): + pass + def release_read_lock(self): + pass + acquire = acquire_write_lock + release = release_write_lock + + +class SynchronizerImpl(object): + def __init__(self): + self._state = util.ThreadLocal() + + class SyncState(object): + __slots__ = 'reentrantcount', 'writing', 'reading' + + def __init__(self): + self.reentrantcount = 0 + self.writing = False + self.reading = False + + def state(self): + if not self._state.has(): + state = SynchronizerImpl.SyncState() + self._state.put(state) + return state + else: + return self._state.get() + state = property(state) + + def release_read_lock(self): + state = self.state + + if state.writing: + raise LockError("lock is in writing state") + if not state.reading: + raise LockError("lock is not in reading state") + + if state.reentrantcount == 1: + self.do_release_read_lock() + state.reading = False + + state.reentrantcount -= 1 + + def acquire_read_lock(self, wait = True): + state = self.state + + if state.writing: + raise LockError("lock is in writing state") + + if state.reentrantcount == 0: + x = self.do_acquire_read_lock(wait) + if (wait or x): + state.reentrantcount += 1 + state.reading = True + return x + elif state.reading: + state.reentrantcount += 1 + return True + + def release_write_lock(self): + state = self.state + + if state.reading: + raise LockError("lock is in reading state") + if not state.writing: + raise LockError("lock is not in writing state") + + if state.reentrantcount == 1: + self.do_release_write_lock() + state.writing = False + + state.reentrantcount -= 1 + + release = release_write_lock + + def acquire_write_lock(self, wait = True): + state = self.state + + if state.reading: + raise LockError("lock is in reading state") + + if state.reentrantcount == 0: + x = self.do_acquire_write_lock(wait) + if (wait or x): + state.reentrantcount += 1 + state.writing = True + return x + elif state.writing: + state.reentrantcount += 1 + return True + + acquire = acquire_write_lock + + def do_release_read_lock(self): + raise NotImplementedError() + + def do_acquire_read_lock(self): + raise NotImplementedError() + + def do_release_write_lock(self): + raise NotImplementedError() + + def do_acquire_write_lock(self): + raise NotImplementedError() + + +class FileSynchronizer(SynchronizerImpl): + """a synchronizer which locks using flock(). + + Adapted for Python/multithreads from Apache::Session::Lock::File, + http://search.cpan.org/src/CWEST/Apache-Session-1.81/Session/Lock/File.pm + + This module does not unlink temporary files, + because it interferes with proper locking. This can cause + problems on certain systems (Linux) whose file systems (ext2) do not + perform well with lots of files in one directory. To prevent this + you should use a script to clean out old files from your lock directory. + + """ + def __init__(self, identifier, lock_dir): + super(FileSynchronizer, self).__init__() + self._filedescriptor = util.ThreadLocal() + + if lock_dir is None: + lock_dir = tempfile.gettempdir() + else: + lock_dir = lock_dir + + self.filename = util.encoded_path( + lock_dir, + [identifier], + extension='.lock' + ) + + def _filedesc(self): + return self._filedescriptor.get() + _filedesc = property(_filedesc) + + def _open(self, mode): + filedescriptor = self._filedesc + if filedescriptor is None: + filedescriptor = os.open(self.filename, mode) + self._filedescriptor.put(filedescriptor) + return filedescriptor + + def do_acquire_read_lock(self, wait): + filedescriptor = self._open(os.O_CREAT | os.O_RDONLY) + if not wait: + try: + fcntl.flock(filedescriptor, fcntl.LOCK_SH | fcntl.LOCK_NB) + return True + except IOError: + os.close(filedescriptor) + self._filedescriptor.remove() + return False + else: + fcntl.flock(filedescriptor, fcntl.LOCK_SH) + return True + + def do_acquire_write_lock(self, wait): + filedescriptor = self._open(os.O_CREAT | os.O_WRONLY) + if not wait: + try: + fcntl.flock(filedescriptor, fcntl.LOCK_EX | fcntl.LOCK_NB) + return True + except IOError: + os.close(filedescriptor) + self._filedescriptor.remove() + return False + else: + fcntl.flock(filedescriptor, fcntl.LOCK_EX) + return True + + def do_release_read_lock(self): + self._release_all_locks() + + def do_release_write_lock(self): + self._release_all_locks() + + def _release_all_locks(self): + filedescriptor = self._filedesc + if filedescriptor is not None: + fcntl.flock(filedescriptor, fcntl.LOCK_UN) + os.close(filedescriptor) + self._filedescriptor.remove() + + +class ConditionSynchronizer(SynchronizerImpl): + """a synchronizer using a Condition.""" + + def __init__(self, identifier): + super(ConditionSynchronizer, self).__init__() + + # counts how many asynchronous methods are executing + self.async = 0 + + # pointer to thread that is the current sync operation + self.current_sync_operation = None + + # condition object to lock on + self.condition = _threading.Condition(_threading.Lock()) + + def do_acquire_read_lock(self, wait = True): + self.condition.acquire() + try: + # see if a synchronous operation is waiting to start + # or is already running, in which case we wait (or just + # give up and return) + if wait: + while self.current_sync_operation is not None: + self.condition.wait() + else: + if self.current_sync_operation is not None: + return False + + self.async += 1 + finally: + self.condition.release() + + if not wait: + return True + + def do_release_read_lock(self): + self.condition.acquire() + try: + self.async -= 1 + + # check if we are the last asynchronous reader thread + # out the door. + if self.async == 0: + # yes. so if a sync operation is waiting, notifyAll to wake + # it up + if self.current_sync_operation is not None: + self.condition.notifyAll() + elif self.async < 0: + raise LockError("Synchronizer error - too many " + "release_read_locks called") + finally: + self.condition.release() + + def do_acquire_write_lock(self, wait = True): + self.condition.acquire() + try: + # here, we are not a synchronous reader, and after returning, + # assuming waiting or immediate availability, we will be. + + if wait: + # if another sync is working, wait + while self.current_sync_operation is not None: + self.condition.wait() + else: + # if another sync is working, + # we dont want to wait, so forget it + if self.current_sync_operation is not None: + return False + + # establish ourselves as the current sync + # this indicates to other read/write operations + # that they should wait until this is None again + self.current_sync_operation = _threading.currentThread() + + # now wait again for asyncs to finish + if self.async > 0: + if wait: + # wait + self.condition.wait() + else: + # we dont want to wait, so forget it + self.current_sync_operation = None + return False + finally: + self.condition.release() + + if not wait: + return True + + def do_release_write_lock(self): + self.condition.acquire() + try: + if self.current_sync_operation is not _threading.currentThread(): + raise LockError("Synchronizer error - current thread doesnt " + "have the write lock") + + # reset the current sync operation so + # another can get it + self.current_sync_operation = None + + # tell everyone to get ready + self.condition.notifyAll() + finally: + # everyone go !! + self.condition.release() diff --git a/pyload/lib/beaker/util.py b/pyload/lib/beaker/util.py new file mode 100644 index 000000000..04c9617c5 --- /dev/null +++ b/pyload/lib/beaker/util.py @@ -0,0 +1,302 @@ +"""Beaker utilities""" + +try: + import thread as _thread + import threading as _threading +except ImportError: + import dummy_thread as _thread + import dummy_threading as _threading + +from datetime import datetime, timedelta +import os +import string +import types +import weakref +import warnings +import sys + +py3k = getattr(sys, 'py3kwarning', False) or sys.version_info >= (3, 0) +py24 = sys.version_info < (2,5) +jython = sys.platform.startswith('java') + +if py3k or jython: + import pickle +else: + import cPickle as pickle + +from beaker.converters import asbool +from threading import local as _tlocal + + +__all__ = ["ThreadLocal", "Registry", "WeakValuedRegistry", "SyncDict", + "encoded_path", "verify_directory"] + + +def verify_directory(dir): + """verifies and creates a directory. tries to + ignore collisions with other threads and processes.""" + + tries = 0 + while not os.access(dir, os.F_OK): + try: + tries += 1 + os.makedirs(dir) + except: + if tries > 5: + raise + + +def deprecated(message): + def wrapper(fn): + def deprecated_method(*args, **kargs): + warnings.warn(message, DeprecationWarning, 2) + return fn(*args, **kargs) + # TODO: use decorator ? functools.wrapper ? + deprecated_method.__name__ = fn.__name__ + deprecated_method.__doc__ = "%s\n\n%s" % (message, fn.__doc__) + return deprecated_method + return wrapper + +class ThreadLocal(object): + """stores a value on a per-thread basis""" + + __slots__ = '_tlocal' + + def __init__(self): + self._tlocal = _tlocal() + + def put(self, value): + self._tlocal.value = value + + def has(self): + return hasattr(self._tlocal, 'value') + + def get(self, default=None): + return getattr(self._tlocal, 'value', default) + + def remove(self): + del self._tlocal.value + +class SyncDict(object): + """ + An efficient/threadsafe singleton map algorithm, a.k.a. + "get a value based on this key, and create if not found or not + valid" paradigm: + + exists && isvalid ? get : create + + Designed to work with weakref dictionaries to expect items + to asynchronously disappear from the dictionary. + + Use python 2.3.3 or greater ! a major bug was just fixed in Nov. + 2003 that was driving me nuts with garbage collection/weakrefs in + this section. + + """ + def __init__(self): + self.mutex = _thread.allocate_lock() + self.dict = {} + + def get(self, key, createfunc, *args, **kwargs): + try: + if self.has_key(key): + return self.dict[key] + else: + return self.sync_get(key, createfunc, *args, **kwargs) + except KeyError: + return self.sync_get(key, createfunc, *args, **kwargs) + + def sync_get(self, key, createfunc, *args, **kwargs): + self.mutex.acquire() + try: + try: + if self.has_key(key): + return self.dict[key] + else: + return self._create(key, createfunc, *args, **kwargs) + except KeyError: + return self._create(key, createfunc, *args, **kwargs) + finally: + self.mutex.release() + + def _create(self, key, createfunc, *args, **kwargs): + self[key] = obj = createfunc(*args, **kwargs) + return obj + + def has_key(self, key): + return self.dict.has_key(key) + + def __contains__(self, key): + return self.dict.__contains__(key) + def __getitem__(self, key): + return self.dict.__getitem__(key) + def __setitem__(self, key, value): + self.dict.__setitem__(key, value) + def __delitem__(self, key): + return self.dict.__delitem__(key) + def clear(self): + self.dict.clear() + + +class WeakValuedRegistry(SyncDict): + def __init__(self): + self.mutex = _threading.RLock() + self.dict = weakref.WeakValueDictionary() + +sha1 = None +def encoded_path(root, identifiers, extension = ".enc", depth = 3, + digest_filenames=True): + + """Generate a unique file-accessible path from the given list of + identifiers starting at the given root directory.""" + ident = "_".join(identifiers) + + global sha1 + if sha1 is None: + from beaker.crypto import sha1 + + if digest_filenames: + if py3k: + ident = sha1(ident.encode('utf-8')).hexdigest() + else: + ident = sha1(ident).hexdigest() + + ident = os.path.basename(ident) + + tokens = [] + for d in range(1, depth): + tokens.append(ident[0:d]) + + dir = os.path.join(root, *tokens) + verify_directory(dir) + + return os.path.join(dir, ident + extension) + + +def verify_options(opt, types, error): + if not isinstance(opt, types): + if not isinstance(types, tuple): + types = (types,) + coerced = False + for typ in types: + try: + if typ in (list, tuple): + opt = [x.strip() for x in opt.split(',')] + else: + if typ == bool: + typ = asbool + opt = typ(opt) + coerced = True + except: + pass + if coerced: + break + if not coerced: + raise Exception(error) + elif isinstance(opt, str) and not opt.strip(): + raise Exception("Empty strings are invalid for: %s" % error) + return opt + + +def verify_rules(params, ruleset): + for key, types, message in ruleset: + if key in params: + params[key] = verify_options(params[key], types, message) + return params + + +def coerce_session_params(params): + rules = [ + ('data_dir', (str, types.NoneType), "data_dir must be a string " + "referring to a directory."), + ('lock_dir', (str, types.NoneType), "lock_dir must be a string referring to a " + "directory."), + ('type', (str, types.NoneType), "Session type must be a string."), + ('cookie_expires', (bool, datetime, timedelta), "Cookie expires was " + "not a boolean, datetime, or timedelta instance."), + ('cookie_domain', (str, types.NoneType), "Cookie domain must be a " + "string."), + ('id', (str,), "Session id must be a string."), + ('key', (str,), "Session key must be a string."), + ('secret', (str, types.NoneType), "Session secret must be a string."), + ('validate_key', (str, types.NoneType), "Session encrypt_key must be " + "a string."), + ('encrypt_key', (str, types.NoneType), "Session validate_key must be " + "a string."), + ('secure', (bool, types.NoneType), "Session secure must be a boolean."), + ('timeout', (int, types.NoneType), "Session timeout must be an " + "integer."), + ('auto', (bool, types.NoneType), "Session is created if accessed."), + ] + return verify_rules(params, rules) + + +def coerce_cache_params(params): + rules = [ + ('data_dir', (str, types.NoneType), "data_dir must be a string " + "referring to a directory."), + ('lock_dir', (str, types.NoneType), "lock_dir must be a string referring to a " + "directory."), + ('type', (str,), "Cache type must be a string."), + ('enabled', (bool, types.NoneType), "enabled must be true/false " + "if present."), + ('expire', (int, types.NoneType), "expire must be an integer representing " + "how many seconds the cache is valid for"), + ('regions', (list, tuple, types.NoneType), "Regions must be a " + "comma seperated list of valid regions") + ] + return verify_rules(params, rules) + + +def parse_cache_config_options(config, include_defaults=True): + """Parse configuration options and validate for use with the + CacheManager""" + + # Load default cache options + if include_defaults: + options= dict(type='memory', data_dir=None, expire=None, + log_file=None) + else: + options = {} + for key, val in config.iteritems(): + if key.startswith('beaker.cache.'): + options[key[13:]] = val + if key.startswith('cache.'): + options[key[6:]] = val + coerce_cache_params(options) + + # Set cache to enabled if not turned off + if 'enabled' not in options: + options['enabled'] = True + + # Configure region dict if regions are available + regions = options.pop('regions', None) + if regions: + region_configs = {} + for region in regions: + # Setup the default cache options + region_options = dict(data_dir=options.get('data_dir'), + lock_dir=options.get('lock_dir'), + type=options.get('type'), + enabled=options['enabled'], + expire=options.get('expire')) + region_len = len(region) + 1 + for key in options.keys(): + if key.startswith('%s.' % region): + region_options[key[region_len:]] = options.pop(key) + coerce_cache_params(region_options) + region_configs[region] = region_options + options['cache_regions'] = region_configs + return options + +def func_namespace(func): + """Generates a unique namespace for a function""" + kls = None + if hasattr(func, 'im_func'): + kls = func.im_class + func = func.im_func + + if kls: + return '%s.%s' % (kls.__module__, kls.__name__) + else: + return '%s.%s' % (func.__module__, func.__name__) diff --git a/pyload/lib/bottle.py b/pyload/lib/bottle.py new file mode 100644 index 000000000..b00bda1c9 --- /dev/null +++ b/pyload/lib/bottle.py @@ -0,0 +1,3251 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Bottle is a fast and simple micro-framework for small web applications. It +offers request dispatching (Routes) with url parameter support, templates, +a built-in HTTP Server and adapters for many third party WSGI/HTTP-server and +template engines - all in a single file and with no dependencies other than the +Python Standard Library. + +Homepage and documentation: http://bottlepy.org/ + +Copyright (c) 2012, Marcel Hellkamp. +License: MIT (see LICENSE for details) +""" + +from __future__ import with_statement + +__author__ = 'Marcel Hellkamp' +__version__ = '0.11.4' +__license__ = 'MIT' + +# The gevent server adapter needs to patch some modules before they are imported +# This is why we parse the commandline parameters here but handle them later +if __name__ == '__main__': + from optparse import OptionParser + _cmd_parser = OptionParser(usage="usage: %prog [options] package.module:app") + _opt = _cmd_parser.add_option + _opt("--version", action="store_true", help="show version number.") + _opt("-b", "--bind", metavar="ADDRESS", help="bind socket to ADDRESS.") + _opt("-s", "--server", default='wsgiref', help="use SERVER as backend.") + _opt("-p", "--plugin", action="append", help="install additional plugin/s.") + _opt("--debug", action="store_true", help="start server in debug mode.") + _opt("--reload", action="store_true", help="auto-reload on file changes.") + _cmd_options, _cmd_args = _cmd_parser.parse_args() + if _cmd_options.server and _cmd_options.server.startswith('gevent'): + import gevent.monkey; gevent.monkey.patch_all() + +import base64, cgi, email.utils, functools, hmac, imp, itertools, mimetypes,\ + os, re, subprocess, sys, tempfile, threading, time, urllib, warnings + +from datetime import date as datedate, datetime, timedelta +from tempfile import TemporaryFile +from traceback import format_exc, print_exc + +try: from json import dumps as json_dumps, loads as json_lds +except ImportError: # pragma: no cover + try: from simplejson import dumps as json_dumps, loads as json_lds + except ImportError: + try: from django.utils.simplejson import dumps as json_dumps, loads as json_lds + except ImportError: + def json_dumps(data): + raise ImportError("JSON support requires Python 2.6 or simplejson.") + json_lds = json_dumps + + + +# We now try to fix 2.5/2.6/3.1/3.2 incompatibilities. +# It ain't pretty but it works... Sorry for the mess. + +py = sys.version_info +py3k = py >= (3,0,0) +py25 = py < (2,6,0) +py31 = (3,1,0) <= py < (3,2,0) + +# Workaround for the missing "as" keyword in py3k. +def _e(): return sys.exc_info()[1] + +# Workaround for the "print is a keyword/function" Python 2/3 dilemma +# and a fallback for mod_wsgi (resticts stdout/err attribute access) +try: + _stdout, _stderr = sys.stdout.write, sys.stderr.write +except IOError: + _stdout = lambda x: sys.stdout.write(x) + _stderr = lambda x: sys.stderr.write(x) + +# Lots of stdlib and builtin differences. +if py3k: + import http.client as httplib + import _thread as thread + from urllib.parse import urljoin, SplitResult as UrlSplitResult + from urllib.parse import urlencode, quote as urlquote, unquote as urlunquote + urlunquote = functools.partial(urlunquote, encoding='latin1') + from http.cookies import SimpleCookie + from collections import MutableMapping as DictMixin + import pickle + from io import BytesIO + basestring = str + unicode = str + json_loads = lambda s: json_lds(touni(s)) + callable = lambda x: hasattr(x, '__call__') + imap = map +else: # 2.x + import httplib + import thread + from urlparse import urljoin, SplitResult as UrlSplitResult + from urllib import urlencode, quote as urlquote, unquote as urlunquote + from Cookie import SimpleCookie + from itertools import imap + import cPickle as pickle + from StringIO import StringIO as BytesIO + if py25: + from UserDict import DictMixin + def next(it): return it.next() + bytes = str + else: # 2.6, 2.7 + from collections import MutableMapping as DictMixin + json_loads = json_lds + +# Some helpers for string/byte handling +def tob(s, enc='utf8'): + return s.encode(enc) if isinstance(s, unicode) else bytes(s) +def touni(s, enc='utf8', err='strict'): + return s.decode(enc, err) if isinstance(s, bytes) else unicode(s) +tonat = touni if py3k else tob + +# 3.2 fixes cgi.FieldStorage to accept bytes (which makes a lot of sense). +# 3.1 needs a workaround. +if py31: + from io import TextIOWrapper + class NCTextIOWrapper(TextIOWrapper): + def close(self): pass # Keep wrapped buffer open. + +# File uploads (which are implemented as empty FiledStorage instances...) +# have a negative truth value. That makes no sense, here is a fix. +class FieldStorage(cgi.FieldStorage): + def __nonzero__(self): return bool(self.list or self.file) + if py3k: __bool__ = __nonzero__ + +# A bug in functools causes it to break if the wrapper is an instance method +def update_wrapper(wrapper, wrapped, *a, **ka): + try: functools.update_wrapper(wrapper, wrapped, *a, **ka) + except AttributeError: pass + + + +# These helpers are used at module level and need to be defined first. +# And yes, I know PEP-8, but sometimes a lower-case classname makes more sense. + +def depr(message): + warnings.warn(message, DeprecationWarning, stacklevel=3) + +def makelist(data): # This is just to handy + if isinstance(data, (tuple, list, set, dict)): return list(data) + elif data: return [data] + else: return [] + + +class DictProperty(object): + ''' Property that maps to a key in a local dict-like attribute. ''' + def __init__(self, attr, key=None, read_only=False): + self.attr, self.key, self.read_only = attr, key, read_only + + def __call__(self, func): + functools.update_wrapper(self, func, updated=[]) + self.getter, self.key = func, self.key or func.__name__ + return self + + def __get__(self, obj, cls): + if obj is None: return self + key, storage = self.key, getattr(obj, self.attr) + if key not in storage: storage[key] = self.getter(obj) + return storage[key] + + def __set__(self, obj, value): + if self.read_only: raise AttributeError("Read-Only property.") + getattr(obj, self.attr)[self.key] = value + + def __delete__(self, obj): + if self.read_only: raise AttributeError("Read-Only property.") + del getattr(obj, self.attr)[self.key] + + +class cached_property(object): + ''' A property that is only computed once per instance and then replaces + itself with an ordinary attribute. Deleting the attribute resets the + property. ''' + + def __init__(self, func): + self.func = func + + def __get__(self, obj, cls): + if obj is None: return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value + + +class lazy_attribute(object): + ''' A property that caches itself to the class object. ''' + def __init__(self, func): + functools.update_wrapper(self, func, updated=[]) + self.getter = func + + def __get__(self, obj, cls): + value = self.getter(cls) + setattr(cls, self.__name__, value) + return value + + + + + + +############################################################################### +# Exceptions and Events ######################################################## +############################################################################### + + +class BottleException(Exception): + """ A base class for exceptions used by bottle. """ + pass + + + + + + +############################################################################### +# Routing ###################################################################### +############################################################################### + + +class RouteError(BottleException): + """ This is a base class for all routing related exceptions """ + + +class RouteReset(BottleException): + """ If raised by a plugin or request handler, the route is reset and all + plugins are re-applied. """ + +class RouterUnknownModeError(RouteError): pass + + +class RouteSyntaxError(RouteError): + """ The route parser found something not supported by this router """ + + +class RouteBuildError(RouteError): + """ The route could not been built """ + + +class Router(object): + ''' A Router is an ordered collection of route->target pairs. It is used to + efficiently match WSGI requests against a number of routes and return + the first target that satisfies the request. The target may be anything, + usually a string, ID or callable object. A route consists of a path-rule + and a HTTP method. + + The path-rule is either a static path (e.g. `/contact`) or a dynamic + path that contains wildcards (e.g. `/wiki/<page>`). The wildcard syntax + and details on the matching order are described in docs:`routing`. + ''' + + default_pattern = '[^/]+' + default_filter = 're' + #: Sorry for the mess. It works. Trust me. + rule_syntax = re.compile('(\\\\*)'\ + '(?:(?::([a-zA-Z_][a-zA-Z_0-9]*)?()(?:#(.*?)#)?)'\ + '|(?:<([a-zA-Z_][a-zA-Z_0-9]*)?(?::([a-zA-Z_]*)'\ + '(?::((?:\\\\.|[^\\\\>]+)+)?)?)?>))') + + def __init__(self, strict=False): + self.rules = {} # A {rule: Rule} mapping + self.builder = {} # A rule/name->build_info mapping + self.static = {} # Cache for static routes: {path: {method: target}} + self.dynamic = [] # Cache for dynamic routes. See _compile() + #: If true, static routes are no longer checked first. + self.strict_order = strict + self.filters = {'re': self.re_filter, 'int': self.int_filter, + 'float': self.float_filter, 'path': self.path_filter} + + def re_filter(self, conf): + return conf or self.default_pattern, None, None + + def int_filter(self, conf): + return r'-?\d+', int, lambda x: str(int(x)) + + def float_filter(self, conf): + return r'-?[\d.]+', float, lambda x: str(float(x)) + + def path_filter(self, conf): + return r'.+?', None, None + + def add_filter(self, name, func): + ''' Add a filter. The provided function is called with the configuration + string as parameter and must return a (regexp, to_python, to_url) tuple. + The first element is a string, the last two are callables or None. ''' + self.filters[name] = func + + def parse_rule(self, rule): + ''' Parses a rule into a (name, filter, conf) token stream. If mode is + None, name contains a static rule part. ''' + offset, prefix = 0, '' + for match in self.rule_syntax.finditer(rule): + prefix += rule[offset:match.start()] + g = match.groups() + if len(g[0])%2: # Escaped wildcard + prefix += match.group(0)[len(g[0]):] + offset = match.end() + continue + if prefix: yield prefix, None, None + name, filtr, conf = g[1:4] if not g[2] is None else g[4:7] + if not filtr: filtr = self.default_filter + yield name, filtr, conf or None + offset, prefix = match.end(), '' + if offset <= len(rule) or prefix: + yield prefix+rule[offset:], None, None + + def add(self, rule, method, target, name=None): + ''' Add a new route or replace the target for an existing route. ''' + if rule in self.rules: + self.rules[rule][method] = target + if name: self.builder[name] = self.builder[rule] + return + + target = self.rules[rule] = {method: target} + + # Build pattern and other structures for dynamic routes + anons = 0 # Number of anonymous wildcards + pattern = '' # Regular expression pattern + filters = [] # Lists of wildcard input filters + builder = [] # Data structure for the URL builder + is_static = True + for key, mode, conf in self.parse_rule(rule): + if mode: + is_static = False + mask, in_filter, out_filter = self.filters[mode](conf) + if key: + pattern += '(?P<%s>%s)' % (key, mask) + else: + pattern += '(?:%s)' % mask + key = 'anon%d' % anons; anons += 1 + if in_filter: filters.append((key, in_filter)) + builder.append((key, out_filter or str)) + elif key: + pattern += re.escape(key) + builder.append((None, key)) + self.builder[rule] = builder + if name: self.builder[name] = builder + + if is_static and not self.strict_order: + self.static[self.build(rule)] = target + return + + def fpat_sub(m): + return m.group(0) if len(m.group(1)) % 2 else m.group(1) + '(?:' + flat_pattern = re.sub(r'(\\*)(\(\?P<[^>]*>|\((?!\?))', fpat_sub, pattern) + + try: + re_match = re.compile('^(%s)$' % pattern).match + except re.error: + raise RouteSyntaxError("Could not add Route: %s (%s)" % (rule, _e())) + + def match(path): + """ Return an url-argument dictionary. """ + url_args = re_match(path).groupdict() + for name, wildcard_filter in filters: + try: + url_args[name] = wildcard_filter(url_args[name]) + except ValueError: + raise HTTPError(400, 'Path has wrong format.') + return url_args + + try: + combined = '%s|(^%s$)' % (self.dynamic[-1][0].pattern, flat_pattern) + self.dynamic[-1] = (re.compile(combined), self.dynamic[-1][1]) + self.dynamic[-1][1].append((match, target)) + except (AssertionError, IndexError): # AssertionError: Too many groups + self.dynamic.append((re.compile('(^%s$)' % flat_pattern), + [(match, target)])) + return match + + def build(self, _name, *anons, **query): + ''' Build an URL by filling the wildcards in a rule. ''' + builder = self.builder.get(_name) + if not builder: raise RouteBuildError("No route with that name.", _name) + try: + for i, value in enumerate(anons): query['anon%d'%i] = value + url = ''.join([f(query.pop(n)) if n else f for (n,f) in builder]) + return url if not query else url+'?'+urlencode(query) + except KeyError: + raise RouteBuildError('Missing URL argument: %r' % _e().args[0]) + + def match(self, environ): + ''' Return a (target, url_agrs) tuple or raise HTTPError(400/404/405). ''' + path, targets, urlargs = environ['PATH_INFO'] or '/', None, {} + if path in self.static: + targets = self.static[path] + else: + for combined, rules in self.dynamic: + match = combined.match(path) + if not match: continue + getargs, targets = rules[match.lastindex - 1] + urlargs = getargs(path) if getargs else {} + break + + if not targets: + raise HTTPError(404, "Not found: " + repr(environ['PATH_INFO'])) + method = environ['REQUEST_METHOD'].upper() + if method in targets: + return targets[method], urlargs + if method == 'HEAD' and 'GET' in targets: + return targets['GET'], urlargs + if 'ANY' in targets: + return targets['ANY'], urlargs + allowed = [verb for verb in targets if verb != 'ANY'] + if 'GET' in allowed and 'HEAD' not in allowed: + allowed.append('HEAD') + raise HTTPError(405, "Method not allowed.", Allow=",".join(allowed)) + + +class Route(object): + ''' This class wraps a route callback along with route specific metadata and + configuration and applies Plugins on demand. It is also responsible for + turing an URL path rule into a regular expression usable by the Router. + ''' + + def __init__(self, app, rule, method, callback, name=None, + plugins=None, skiplist=None, **config): + #: The application this route is installed to. + self.app = app + #: The path-rule string (e.g. ``/wiki/:page``). + self.rule = rule + #: The HTTP method as a string (e.g. ``GET``). + self.method = method + #: The original callback with no plugins applied. Useful for introspection. + self.callback = callback + #: The name of the route (if specified) or ``None``. + self.name = name or None + #: A list of route-specific plugins (see :meth:`Bottle.route`). + self.plugins = plugins or [] + #: A list of plugins to not apply to this route (see :meth:`Bottle.route`). + self.skiplist = skiplist or [] + #: Additional keyword arguments passed to the :meth:`Bottle.route` + #: decorator are stored in this dictionary. Used for route-specific + #: plugin configuration and meta-data. + self.config = ConfigDict(config) + + def __call__(self, *a, **ka): + depr("Some APIs changed to return Route() instances instead of"\ + " callables. Make sure to use the Route.call method and not to"\ + " call Route instances directly.") + return self.call(*a, **ka) + + @cached_property + def call(self): + ''' The route callback with all plugins applied. This property is + created on demand and then cached to speed up subsequent requests.''' + return self._make_callback() + + def reset(self): + ''' Forget any cached values. The next time :attr:`call` is accessed, + all plugins are re-applied. ''' + self.__dict__.pop('call', None) + + def prepare(self): + ''' Do all on-demand work immediately (useful for debugging).''' + self.call + + @property + def _context(self): + depr('Switch to Plugin API v2 and access the Route object directly.') + return dict(rule=self.rule, method=self.method, callback=self.callback, + name=self.name, app=self.app, config=self.config, + apply=self.plugins, skip=self.skiplist) + + def all_plugins(self): + ''' Yield all Plugins affecting this route. ''' + unique = set() + for p in reversed(self.app.plugins + self.plugins): + if True in self.skiplist: break + name = getattr(p, 'name', False) + if name and (name in self.skiplist or name in unique): continue + if p in self.skiplist or type(p) in self.skiplist: continue + if name: unique.add(name) + yield p + + def _make_callback(self): + callback = self.callback + for plugin in self.all_plugins(): + try: + if hasattr(plugin, 'apply'): + api = getattr(plugin, 'api', 1) + context = self if api > 1 else self._context + callback = plugin.apply(callback, context) + else: + callback = plugin(callback) + except RouteReset: # Try again with changed configuration. + return self._make_callback() + if not callback is self.callback: + update_wrapper(callback, self.callback) + return callback + + def __repr__(self): + return '<%s %r %r>' % (self.method, self.rule, self.callback) + + + + + + +############################################################################### +# Application Object ########################################################### +############################################################################### + + +class Bottle(object): + """ Each Bottle object represents a single, distinct web application and + consists of routes, callbacks, plugins, resources and configuration. + Instances are callable WSGI applications. + + :param catchall: If true (default), handle all exceptions. Turn off to + let debugging middleware handle exceptions. + """ + + def __init__(self, catchall=True, autojson=True): + #: If true, most exceptions are caught and returned as :exc:`HTTPError` + self.catchall = catchall + + #: A :class:`ResourceManager` for application files + self.resources = ResourceManager() + + #: A :class:`ConfigDict` for app specific configuration. + self.config = ConfigDict() + self.config.autojson = autojson + + self.routes = [] # List of installed :class:`Route` instances. + self.router = Router() # Maps requests to :class:`Route` instances. + self.error_handler = {} + + # Core plugins + self.plugins = [] # List of installed plugins. + self.hooks = HooksPlugin() + self.install(self.hooks) + if self.config.autojson: + self.install(JSONPlugin()) + self.install(TemplatePlugin()) + + + def mount(self, prefix, app, **options): + ''' Mount an application (:class:`Bottle` or plain WSGI) to a specific + URL prefix. Example:: + + root_app.mount('/admin/', admin_app) + + :param prefix: path prefix or `mount-point`. If it ends in a slash, + that slash is mandatory. + :param app: an instance of :class:`Bottle` or a WSGI application. + + All other parameters are passed to the underlying :meth:`route` call. + ''' + if isinstance(app, basestring): + prefix, app = app, prefix + depr('Parameter order of Bottle.mount() changed.') # 0.10 + + segments = [p for p in prefix.split('/') if p] + if not segments: raise ValueError('Empty path prefix.') + path_depth = len(segments) + + def mountpoint_wrapper(): + try: + request.path_shift(path_depth) + rs = BaseResponse([], 200) + def start_response(status, header): + rs.status = status + for name, value in header: rs.add_header(name, value) + return rs.body.append + body = app(request.environ, start_response) + body = itertools.chain(rs.body, body) + return HTTPResponse(body, rs.status_code, **rs.headers) + finally: + request.path_shift(-path_depth) + + options.setdefault('skip', True) + options.setdefault('method', 'ANY') + options.setdefault('mountpoint', {'prefix': prefix, 'target': app}) + options['callback'] = mountpoint_wrapper + + self.route('/%s/<:re:.*>' % '/'.join(segments), **options) + if not prefix.endswith('/'): + self.route('/' + '/'.join(segments), **options) + + def merge(self, routes): + ''' Merge the routes of another :class:`Bottle` application or a list of + :class:`Route` objects into this application. The routes keep their + 'owner', meaning that the :data:`Route.app` attribute is not + changed. ''' + if isinstance(routes, Bottle): + routes = routes.routes + for route in routes: + self.add_route(route) + + def install(self, plugin): + ''' Add a plugin to the list of plugins and prepare it for being + applied to all routes of this application. A plugin may be a simple + decorator or an object that implements the :class:`Plugin` API. + ''' + if hasattr(plugin, 'setup'): plugin.setup(self) + if not callable(plugin) and not hasattr(plugin, 'apply'): + raise TypeError("Plugins must be callable or implement .apply()") + self.plugins.append(plugin) + self.reset() + return plugin + + def uninstall(self, plugin): + ''' Uninstall plugins. Pass an instance to remove a specific plugin, a type + object to remove all plugins that match that type, a string to remove + all plugins with a matching ``name`` attribute or ``True`` to remove all + plugins. Return the list of removed plugins. ''' + removed, remove = [], plugin + for i, plugin in list(enumerate(self.plugins))[::-1]: + if remove is True or remove is plugin or remove is type(plugin) \ + or getattr(plugin, 'name', True) == remove: + removed.append(plugin) + del self.plugins[i] + if hasattr(plugin, 'close'): plugin.close() + if removed: self.reset() + return removed + + def run(self, **kwargs): + ''' Calls :func:`run` with the same parameters. ''' + run(self, **kwargs) + + def reset(self, route=None): + ''' Reset all routes (force plugins to be re-applied) and clear all + caches. If an ID or route object is given, only that specific route + is affected. ''' + if route is None: routes = self.routes + elif isinstance(route, Route): routes = [route] + else: routes = [self.routes[route]] + for route in routes: route.reset() + if DEBUG: + for route in routes: route.prepare() + self.hooks.trigger('app_reset') + + def close(self): + ''' Close the application and all installed plugins. ''' + for plugin in self.plugins: + if hasattr(plugin, 'close'): plugin.close() + self.stopped = True + + def match(self, environ): + """ Search for a matching route and return a (:class:`Route` , urlargs) + tuple. The second value is a dictionary with parameters extracted + from the URL. Raise :exc:`HTTPError` (404/405) on a non-match.""" + return self.router.match(environ) + + def get_url(self, routename, **kargs): + """ Return a string that matches a named route """ + scriptname = request.environ.get('SCRIPT_NAME', '').strip('/') + '/' + location = self.router.build(routename, **kargs).lstrip('/') + return urljoin(urljoin('/', scriptname), location) + + def add_route(self, route): + ''' Add a route object, but do not change the :data:`Route.app` + attribute.''' + self.routes.append(route) + self.router.add(route.rule, route.method, route, name=route.name) + if DEBUG: route.prepare() + + def route(self, path=None, method='GET', callback=None, name=None, + apply=None, skip=None, **config): + """ A decorator to bind a function to a request URL. Example:: + + @app.route('/hello/:name') + def hello(name): + return 'Hello %s' % name + + The ``:name`` part is a wildcard. See :class:`Router` for syntax + details. + + :param path: Request path or a list of paths to listen to. If no + path is specified, it is automatically generated from the + signature of the function. + :param method: HTTP method (`GET`, `POST`, `PUT`, ...) or a list of + methods to listen to. (default: `GET`) + :param callback: An optional shortcut to avoid the decorator + syntax. ``route(..., callback=func)`` equals ``route(...)(func)`` + :param name: The name for this route. (default: None) + :param apply: A decorator or plugin or a list of plugins. These are + applied to the route callback in addition to installed plugins. + :param skip: A list of plugins, plugin classes or names. Matching + plugins are not installed to this route. ``True`` skips all. + + Any additional keyword arguments are stored as route-specific + configuration and passed to plugins (see :meth:`Plugin.apply`). + """ + if callable(path): path, callback = None, path + plugins = makelist(apply) + skiplist = makelist(skip) + def decorator(callback): + # TODO: Documentation and tests + if isinstance(callback, basestring): callback = load(callback) + for rule in makelist(path) or yieldroutes(callback): + for verb in makelist(method): + verb = verb.upper() + route = Route(self, rule, verb, callback, name=name, + plugins=plugins, skiplist=skiplist, **config) + self.add_route(route) + return callback + return decorator(callback) if callback else decorator + + def get(self, path=None, method='GET', **options): + """ Equals :meth:`route`. """ + return self.route(path, method, **options) + + def post(self, path=None, method='POST', **options): + """ Equals :meth:`route` with a ``POST`` method parameter. """ + return self.route(path, method, **options) + + def put(self, path=None, method='PUT', **options): + """ Equals :meth:`route` with a ``PUT`` method parameter. """ + return self.route(path, method, **options) + + def delete(self, path=None, method='DELETE', **options): + """ Equals :meth:`route` with a ``DELETE`` method parameter. """ + return self.route(path, method, **options) + + def error(self, code=500): + """ Decorator: Register an output handler for a HTTP error code""" + def wrapper(handler): + self.error_handler[int(code)] = handler + return handler + return wrapper + + def hook(self, name): + """ Return a decorator that attaches a callback to a hook. Three hooks + are currently implemented: + + - before_request: Executed once before each request + - after_request: Executed once after each request + - app_reset: Called whenever :meth:`reset` is called. + """ + def wrapper(func): + self.hooks.add(name, func) + return func + return wrapper + + def handle(self, path, method='GET'): + """ (deprecated) Execute the first matching route callback and return + the result. :exc:`HTTPResponse` exceptions are caught and returned. + If :attr:`Bottle.catchall` is true, other exceptions are caught as + well and returned as :exc:`HTTPError` instances (500). + """ + depr("This method will change semantics in 0.10. Try to avoid it.") + if isinstance(path, dict): + return self._handle(path) + return self._handle({'PATH_INFO': path, 'REQUEST_METHOD': method.upper()}) + + def default_error_handler(self, res): + return tob(template(ERROR_PAGE_TEMPLATE, e=res)) + + def _handle(self, environ): + try: + environ['bottle.app'] = self + request.bind(environ) + response.bind() + route, args = self.router.match(environ) + environ['route.handle'] = route + environ['bottle.route'] = route + environ['route.url_args'] = args + return route.call(**args) + except HTTPResponse: + return _e() + except RouteReset: + route.reset() + return self._handle(environ) + except (KeyboardInterrupt, SystemExit, MemoryError): + raise + except Exception: + if not self.catchall: raise + stacktrace = format_exc() + environ['wsgi.errors'].write(stacktrace) + return HTTPError(500, "Internal Server Error", _e(), stacktrace) + + def _cast(self, out, peek=None): + """ Try to convert the parameter into something WSGI compatible and set + correct HTTP headers when possible. + Support: False, str, unicode, dict, HTTPResponse, HTTPError, file-like, + iterable of strings and iterable of unicodes + """ + + # Empty output is done here + if not out: + if 'Content-Length' not in response: + response['Content-Length'] = 0 + return [] + # Join lists of byte or unicode strings. Mixed lists are NOT supported + if isinstance(out, (tuple, list))\ + and isinstance(out[0], (bytes, unicode)): + out = out[0][0:0].join(out) # b'abc'[0:0] -> b'' + # Encode unicode strings + if isinstance(out, unicode): + out = out.encode(response.charset) + # Byte Strings are just returned + if isinstance(out, bytes): + if 'Content-Length' not in response: + response['Content-Length'] = len(out) + return [out] + # HTTPError or HTTPException (recursive, because they may wrap anything) + # TODO: Handle these explicitly in handle() or make them iterable. + if isinstance(out, HTTPError): + out.apply(response) + out = self.error_handler.get(out.status_code, self.default_error_handler)(out) + return self._cast(out) + if isinstance(out, HTTPResponse): + out.apply(response) + return self._cast(out.body) + + # File-like objects. + if hasattr(out, 'read'): + if 'wsgi.file_wrapper' in request.environ: + return request.environ['wsgi.file_wrapper'](out) + elif hasattr(out, 'close') or not hasattr(out, '__iter__'): + return WSGIFileWrapper(out) + + # Handle Iterables. We peek into them to detect their inner type. + try: + out = iter(out) + first = next(out) + while not first: + first = next(out) + except StopIteration: + return self._cast('') + except HTTPResponse: + first = _e() + except (KeyboardInterrupt, SystemExit, MemoryError): + raise + except Exception: + if not self.catchall: raise + first = HTTPError(500, 'Unhandled exception', _e(), format_exc()) + + # These are the inner types allowed in iterator or generator objects. + if isinstance(first, HTTPResponse): + return self._cast(first) + if isinstance(first, bytes): + return itertools.chain([first], out) + if isinstance(first, unicode): + return imap(lambda x: x.encode(response.charset), + itertools.chain([first], out)) + return self._cast(HTTPError(500, 'Unsupported response type: %s'\ + % type(first))) + + def wsgi(self, environ, start_response): + """ The bottle WSGI-interface. """ + try: + out = self._cast(self._handle(environ)) + # rfc2616 section 4.3 + if response._status_code in (100, 101, 204, 304)\ + or environ['REQUEST_METHOD'] == 'HEAD': + if hasattr(out, 'close'): out.close() + out = [] + start_response(response._status_line, response.headerlist) + return out + except (KeyboardInterrupt, SystemExit, MemoryError): + raise + except Exception: + if not self.catchall: raise + err = '<h1>Critical error while processing request: %s</h1>' \ + % html_escape(environ.get('PATH_INFO', '/')) + if DEBUG: + err += '<h2>Error:</h2>\n<pre>\n%s\n</pre>\n' \ + '<h2>Traceback:</h2>\n<pre>\n%s\n</pre>\n' \ + % (html_escape(repr(_e())), html_escape(format_exc())) + environ['wsgi.errors'].write(err) + headers = [('Content-Type', 'text/html; charset=UTF-8')] + start_response('500 INTERNAL SERVER ERROR', headers) + return [tob(err)] + + def __call__(self, environ, start_response): + ''' Each instance of :class:'Bottle' is a WSGI application. ''' + return self.wsgi(environ, start_response) + + + + + + +############################################################################### +# HTTP and WSGI Tools ########################################################## +############################################################################### + + +class BaseRequest(object): + """ A wrapper for WSGI environment dictionaries that adds a lot of + convenient access methods and properties. Most of them are read-only. + + Adding new attributes to a request actually adds them to the environ + dictionary (as 'bottle.request.ext.<name>'). This is the recommended + way to store and access request-specific data. + """ + + __slots__ = ('environ') + + #: Maximum size of memory buffer for :attr:`body` in bytes. + MEMFILE_MAX = 102400 + #: Maximum number pr GET or POST parameters per request + MAX_PARAMS = 100 + + def __init__(self, environ=None): + """ Wrap a WSGI environ dictionary. """ + #: The wrapped WSGI environ dictionary. This is the only real attribute. + #: All other attributes actually are read-only properties. + self.environ = {} if environ is None else environ + self.environ['bottle.request'] = self + + @DictProperty('environ', 'bottle.app', read_only=True) + def app(self): + ''' Bottle application handling this request. ''' + raise RuntimeError('This request is not connected to an application.') + + @property + def path(self): + ''' The value of ``PATH_INFO`` with exactly one prefixed slash (to fix + broken clients and avoid the "empty path" edge case). ''' + return '/' + self.environ.get('PATH_INFO','').lstrip('/') + + @property + def method(self): + ''' The ``REQUEST_METHOD`` value as an uppercase string. ''' + return self.environ.get('REQUEST_METHOD', 'GET').upper() + + @DictProperty('environ', 'bottle.request.headers', read_only=True) + def headers(self): + ''' A :class:`WSGIHeaderDict` that provides case-insensitive access to + HTTP request headers. ''' + return WSGIHeaderDict(self.environ) + + def get_header(self, name, default=None): + ''' Return the value of a request header, or a given default value. ''' + return self.headers.get(name, default) + + @DictProperty('environ', 'bottle.request.cookies', read_only=True) + def cookies(self): + """ Cookies parsed into a :class:`FormsDict`. Signed cookies are NOT + decoded. Use :meth:`get_cookie` if you expect signed cookies. """ + cookies = SimpleCookie(self.environ.get('HTTP_COOKIE','')) + cookies = list(cookies.values())[:self.MAX_PARAMS] + return FormsDict((c.key, c.value) for c in cookies) + + def get_cookie(self, key, default=None, secret=None): + """ Return the content of a cookie. To read a `Signed Cookie`, the + `secret` must match the one used to create the cookie (see + :meth:`BaseResponse.set_cookie`). If anything goes wrong (missing + cookie or wrong signature), return a default value. """ + value = self.cookies.get(key) + if secret and value: + dec = cookie_decode(value, secret) # (key, value) tuple or None + return dec[1] if dec and dec[0] == key else default + return value or default + + @DictProperty('environ', 'bottle.request.query', read_only=True) + def query(self): + ''' The :attr:`query_string` parsed into a :class:`FormsDict`. These + values are sometimes called "URL arguments" or "GET parameters", but + not to be confused with "URL wildcards" as they are provided by the + :class:`Router`. ''' + get = self.environ['bottle.get'] = FormsDict() + pairs = _parse_qsl(self.environ.get('QUERY_STRING', '')) + for key, value in pairs[:self.MAX_PARAMS]: + get[key] = value + return get + + @DictProperty('environ', 'bottle.request.forms', read_only=True) + def forms(self): + """ Form values parsed from an `url-encoded` or `multipart/form-data` + encoded POST or PUT request body. The result is retuned as a + :class:`FormsDict`. All keys and values are strings. File uploads + are stored separately in :attr:`files`. """ + forms = FormsDict() + for name, item in self.POST.allitems(): + if not hasattr(item, 'filename'): + forms[name] = item + return forms + + @DictProperty('environ', 'bottle.request.params', read_only=True) + def params(self): + """ A :class:`FormsDict` with the combined values of :attr:`query` and + :attr:`forms`. File uploads are stored in :attr:`files`. """ + params = FormsDict() + for key, value in self.query.allitems(): + params[key] = value + for key, value in self.forms.allitems(): + params[key] = value + return params + + @DictProperty('environ', 'bottle.request.files', read_only=True) + def files(self): + """ File uploads parsed from an `url-encoded` or `multipart/form-data` + encoded POST or PUT request body. The values are instances of + :class:`cgi.FieldStorage`. The most important attributes are: + + filename + The filename, if specified; otherwise None; this is the client + side filename, *not* the file name on which it is stored (that's + a temporary file you don't deal with) + file + The file(-like) object from which you can read the data. + value + The value as a *string*; for file uploads, this transparently + reads the file every time you request the value. Do not do this + on big files. + """ + files = FormsDict() + for name, item in self.POST.allitems(): + if hasattr(item, 'filename'): + files[name] = item + return files + + @DictProperty('environ', 'bottle.request.json', read_only=True) + def json(self): + ''' If the ``Content-Type`` header is ``application/json``, this + property holds the parsed content of the request body. Only requests + smaller than :attr:`MEMFILE_MAX` are processed to avoid memory + exhaustion. ''' + if 'application/json' in self.environ.get('CONTENT_TYPE', '') \ + and 0 < self.content_length < self.MEMFILE_MAX: + return json_loads(self.body.read(self.MEMFILE_MAX)) + return None + + @DictProperty('environ', 'bottle.request.body', read_only=True) + def _body(self): + maxread = max(0, self.content_length) + stream = self.environ['wsgi.input'] + body = BytesIO() if maxread < self.MEMFILE_MAX else TemporaryFile(mode='w+b') + while maxread > 0: + part = stream.read(min(maxread, self.MEMFILE_MAX)) + if not part: break + body.write(part) + maxread -= len(part) + self.environ['wsgi.input'] = body + body.seek(0) + return body + + @property + def body(self): + """ The HTTP request body as a seek-able file-like object. Depending on + :attr:`MEMFILE_MAX`, this is either a temporary file or a + :class:`io.BytesIO` instance. Accessing this property for the first + time reads and replaces the ``wsgi.input`` environ variable. + Subsequent accesses just do a `seek(0)` on the file object. """ + self._body.seek(0) + return self._body + + #: An alias for :attr:`query`. + GET = query + + @DictProperty('environ', 'bottle.request.post', read_only=True) + def POST(self): + """ The values of :attr:`forms` and :attr:`files` combined into a single + :class:`FormsDict`. Values are either strings (form values) or + instances of :class:`cgi.FieldStorage` (file uploads). + """ + post = FormsDict() + # We default to application/x-www-form-urlencoded for everything that + # is not multipart and take the fast path (also: 3.1 workaround) + if not self.content_type.startswith('multipart/'): + maxlen = max(0, min(self.content_length, self.MEMFILE_MAX)) + pairs = _parse_qsl(tonat(self.body.read(maxlen), 'latin1')) + for key, value in pairs[:self.MAX_PARAMS]: + post[key] = value + return post + + safe_env = {'QUERY_STRING':''} # Build a safe environment for cgi + for key in ('REQUEST_METHOD', 'CONTENT_TYPE', 'CONTENT_LENGTH'): + if key in self.environ: safe_env[key] = self.environ[key] + args = dict(fp=self.body, environ=safe_env, keep_blank_values=True) + if py31: + args['fp'] = NCTextIOWrapper(args['fp'], encoding='ISO-8859-1', + newline='\n') + elif py3k: + args['encoding'] = 'ISO-8859-1' + data = FieldStorage(**args) + for item in (data.list or [])[:self.MAX_PARAMS]: + post[item.name] = item if item.filename else item.value + return post + + @property + def COOKIES(self): + ''' Alias for :attr:`cookies` (deprecated). ''' + depr('BaseRequest.COOKIES was renamed to BaseRequest.cookies (lowercase).') + return self.cookies + + @property + def url(self): + """ The full request URI including hostname and scheme. If your app + lives behind a reverse proxy or load balancer and you get confusing + results, make sure that the ``X-Forwarded-Host`` header is set + correctly. """ + return self.urlparts.geturl() + + @DictProperty('environ', 'bottle.request.urlparts', read_only=True) + def urlparts(self): + ''' The :attr:`url` string as an :class:`urlparse.SplitResult` tuple. + The tuple contains (scheme, host, path, query_string and fragment), + but the fragment is always empty because it is not visible to the + server. ''' + env = self.environ + http = env.get('HTTP_X_FORWARDED_PROTO') or env.get('wsgi.url_scheme', 'http') + host = env.get('HTTP_X_FORWARDED_HOST') or env.get('HTTP_HOST') + if not host: + # HTTP 1.1 requires a Host-header. This is for HTTP/1.0 clients. + host = env.get('SERVER_NAME', '127.0.0.1') + port = env.get('SERVER_PORT') + if port and port != ('80' if http == 'http' else '443'): + host += ':' + port + path = urlquote(self.fullpath) + return UrlSplitResult(http, host, path, env.get('QUERY_STRING'), '') + + @property + def fullpath(self): + """ Request path including :attr:`script_name` (if present). """ + return urljoin(self.script_name, self.path.lstrip('/')) + + @property + def query_string(self): + """ The raw :attr:`query` part of the URL (everything in between ``?`` + and ``#``) as a string. """ + return self.environ.get('QUERY_STRING', '') + + @property + def script_name(self): + ''' The initial portion of the URL's `path` that was removed by a higher + level (server or routing middleware) before the application was + called. This script path is returned with leading and tailing + slashes. ''' + script_name = self.environ.get('SCRIPT_NAME', '').strip('/') + return '/' + script_name + '/' if script_name else '/' + + def path_shift(self, shift=1): + ''' Shift path segments from :attr:`path` to :attr:`script_name` and + vice versa. + + :param shift: The number of path segments to shift. May be negative + to change the shift direction. (default: 1) + ''' + script = self.environ.get('SCRIPT_NAME','/') + self['SCRIPT_NAME'], self['PATH_INFO'] = path_shift(script, self.path, shift) + + @property + def content_length(self): + ''' The request body length as an integer. The client is responsible to + set this header. Otherwise, the real length of the body is unknown + and -1 is returned. In this case, :attr:`body` will be empty. ''' + return int(self.environ.get('CONTENT_LENGTH') or -1) + + @property + def content_type(self): + ''' The Content-Type header as a lowercase-string (default: empty). ''' + return self.environ.get('CONTENT_TYPE', '').lower() + + @property + def is_xhr(self): + ''' True if the request was triggered by a XMLHttpRequest. This only + works with JavaScript libraries that support the `X-Requested-With` + header (most of the popular libraries do). ''' + requested_with = self.environ.get('HTTP_X_REQUESTED_WITH','') + return requested_with.lower() == 'xmlhttprequest' + + @property + def is_ajax(self): + ''' Alias for :attr:`is_xhr`. "Ajax" is not the right term. ''' + return self.is_xhr + + @property + def auth(self): + """ HTTP authentication data as a (user, password) tuple. This + implementation currently supports basic (not digest) authentication + only. If the authentication happened at a higher level (e.g. in the + front web-server or a middleware), the password field is None, but + the user field is looked up from the ``REMOTE_USER`` environ + variable. On any errors, None is returned. """ + basic = parse_auth(self.environ.get('HTTP_AUTHORIZATION','')) + if basic: return basic + ruser = self.environ.get('REMOTE_USER') + if ruser: return (ruser, None) + return None + + @property + def remote_route(self): + """ A list of all IPs that were involved in this request, starting with + the client IP and followed by zero or more proxies. This does only + work if all proxies support the ```X-Forwarded-For`` header. Note + that this information can be forged by malicious clients. """ + proxy = self.environ.get('HTTP_X_FORWARDED_FOR') + if proxy: return [ip.strip() for ip in proxy.split(',')] + remote = self.environ.get('REMOTE_ADDR') + return [remote] if remote else [] + + @property + def remote_addr(self): + """ The client IP as a string. Note that this information can be forged + by malicious clients. """ + route = self.remote_route + return route[0] if route else None + + def copy(self): + """ Return a new :class:`Request` with a shallow :attr:`environ` copy. """ + return Request(self.environ.copy()) + + def get(self, value, default=None): return self.environ.get(value, default) + def __getitem__(self, key): return self.environ[key] + def __delitem__(self, key): self[key] = ""; del(self.environ[key]) + def __iter__(self): return iter(self.environ) + def __len__(self): return len(self.environ) + def keys(self): return self.environ.keys() + def __setitem__(self, key, value): + """ Change an environ value and clear all caches that depend on it. """ + + if self.environ.get('bottle.request.readonly'): + raise KeyError('The environ dictionary is read-only.') + + self.environ[key] = value + todelete = () + + if key == 'wsgi.input': + todelete = ('body', 'forms', 'files', 'params', 'post', 'json') + elif key == 'QUERY_STRING': + todelete = ('query', 'params') + elif key.startswith('HTTP_'): + todelete = ('headers', 'cookies') + + for key in todelete: + self.environ.pop('bottle.request.'+key, None) + + def __repr__(self): + return '<%s: %s %s>' % (self.__class__.__name__, self.method, self.url) + + def __getattr__(self, name): + ''' Search in self.environ for additional user defined attributes. ''' + try: + var = self.environ['bottle.request.ext.%s'%name] + return var.__get__(self) if hasattr(var, '__get__') else var + except KeyError: + raise AttributeError('Attribute %r not defined.' % name) + + def __setattr__(self, name, value): + if name == 'environ': return object.__setattr__(self, name, value) + self.environ['bottle.request.ext.%s'%name] = value + + + + +def _hkey(s): + return s.title().replace('_','-') + + +class HeaderProperty(object): + def __init__(self, name, reader=None, writer=str, default=''): + self.name, self.default = name, default + self.reader, self.writer = reader, writer + self.__doc__ = 'Current value of the %r header.' % name.title() + + def __get__(self, obj, cls): + if obj is None: return self + value = obj.headers.get(self.name, self.default) + return self.reader(value) if self.reader else value + + def __set__(self, obj, value): + obj.headers[self.name] = self.writer(value) + + def __delete__(self, obj): + del obj.headers[self.name] + + +class BaseResponse(object): + """ Storage class for a response body as well as headers and cookies. + + This class does support dict-like case-insensitive item-access to + headers, but is NOT a dict. Most notably, iterating over a response + yields parts of the body and not the headers. + """ + + default_status = 200 + default_content_type = 'text/html; charset=UTF-8' + + # Header blacklist for specific response codes + # (rfc2616 section 10.2.3 and 10.3.5) + bad_headers = { + 204: set(('Content-Type',)), + 304: set(('Allow', 'Content-Encoding', 'Content-Language', + 'Content-Length', 'Content-Range', 'Content-Type', + 'Content-Md5', 'Last-Modified'))} + + def __init__(self, body='', status=None, **headers): + self._cookies = None + self._headers = {'Content-Type': [self.default_content_type]} + self.body = body + self.status = status or self.default_status + if headers: + for name, value in headers.items(): + self[name] = value + + def copy(self): + ''' Returns a copy of self. ''' + copy = Response() + copy.status = self.status + copy._headers = dict((k, v[:]) for (k, v) in self._headers.items()) + return copy + + def __iter__(self): + return iter(self.body) + + def close(self): + if hasattr(self.body, 'close'): + self.body.close() + + @property + def status_line(self): + ''' The HTTP status line as a string (e.g. ``404 Not Found``).''' + return self._status_line + + @property + def status_code(self): + ''' The HTTP status code as an integer (e.g. 404).''' + return self._status_code + + def _set_status(self, status): + if isinstance(status, int): + code, status = status, _HTTP_STATUS_LINES.get(status) + elif ' ' in status: + status = status.strip() + code = int(status.split()[0]) + else: + raise ValueError('String status line without a reason phrase.') + if not 100 <= code <= 999: raise ValueError('Status code out of range.') + self._status_code = code + self._status_line = str(status or ('%d Unknown' % code)) + + def _get_status(self): + return self._status_line + + status = property(_get_status, _set_status, None, + ''' A writeable property to change the HTTP response status. It accepts + either a numeric code (100-999) or a string with a custom reason + phrase (e.g. "404 Brain not found"). Both :data:`status_line` and + :data:`status_code` are updated accordingly. The return value is + always a status string. ''') + del _get_status, _set_status + + @property + def headers(self): + ''' An instance of :class:`HeaderDict`, a case-insensitive dict-like + view on the response headers. ''' + hdict = HeaderDict() + hdict.dict = self._headers + return hdict + + def __contains__(self, name): return _hkey(name) in self._headers + def __delitem__(self, name): del self._headers[_hkey(name)] + def __getitem__(self, name): return self._headers[_hkey(name)][-1] + def __setitem__(self, name, value): self._headers[_hkey(name)] = [str(value)] + + def get_header(self, name, default=None): + ''' Return the value of a previously defined header. If there is no + header with that name, return a default value. ''' + return self._headers.get(_hkey(name), [default])[-1] + + def set_header(self, name, value): + ''' Create a new response header, replacing any previously defined + headers with the same name. ''' + self._headers[_hkey(name)] = [str(value)] + + def add_header(self, name, value): + ''' Add an additional response header, not removing duplicates. ''' + self._headers.setdefault(_hkey(name), []).append(str(value)) + + def iter_headers(self): + ''' Yield (header, value) tuples, skipping headers that are not + allowed with the current response status code. ''' + return self.headerlist + + def wsgiheader(self): + depr('The wsgiheader method is deprecated. See headerlist.') #0.10 + return self.headerlist + + @property + def headerlist(self): + ''' WSGI conform list of (header, value) tuples. ''' + out = [] + headers = self._headers.items() + if self._status_code in self.bad_headers: + bad_headers = self.bad_headers[self._status_code] + headers = [h for h in headers if h[0] not in bad_headers] + out += [(name, val) for name, vals in headers for val in vals] + if self._cookies: + for c in self._cookies.values(): + out.append(('Set-Cookie', c.OutputString())) + return out + + content_type = HeaderProperty('Content-Type') + content_length = HeaderProperty('Content-Length', reader=int) + + @property + def charset(self): + """ Return the charset specified in the content-type header (default: utf8). """ + if 'charset=' in self.content_type: + return self.content_type.split('charset=')[-1].split(';')[0].strip() + return 'UTF-8' + + @property + def COOKIES(self): + """ A dict-like SimpleCookie instance. This should not be used directly. + See :meth:`set_cookie`. """ + depr('The COOKIES dict is deprecated. Use `set_cookie()` instead.') # 0.10 + if not self._cookies: + self._cookies = SimpleCookie() + return self._cookies + + def set_cookie(self, name, value, secret=None, **options): + ''' Create a new cookie or replace an old one. If the `secret` parameter is + set, create a `Signed Cookie` (described below). + + :param name: the name of the cookie. + :param value: the value of the cookie. + :param secret: a signature key required for signed cookies. + + Additionally, this method accepts all RFC 2109 attributes that are + supported by :class:`cookie.Morsel`, including: + + :param max_age: maximum age in seconds. (default: None) + :param expires: a datetime object or UNIX timestamp. (default: None) + :param domain: the domain that is allowed to read the cookie. + (default: current domain) + :param path: limits the cookie to a given path (default: current path) + :param secure: limit the cookie to HTTPS connections (default: off). + :param httponly: prevents client-side javascript to read this cookie + (default: off, requires Python 2.6 or newer). + + If neither `expires` nor `max_age` is set (default), the cookie will + expire at the end of the browser session (as soon as the browser + window is closed). + + Signed cookies may store any pickle-able object and are + cryptographically signed to prevent manipulation. Keep in mind that + cookies are limited to 4kb in most browsers. + + Warning: Signed cookies are not encrypted (the client can still see + the content) and not copy-protected (the client can restore an old + cookie). The main intention is to make pickling and unpickling + save, not to store secret information at client side. + ''' + if not self._cookies: + self._cookies = SimpleCookie() + + if secret: + value = touni(cookie_encode((name, value), secret)) + elif not isinstance(value, basestring): + raise TypeError('Secret key missing for non-string Cookie.') + + if len(value) > 4096: raise ValueError('Cookie value to long.') + self._cookies[name] = value + + for key, value in options.items(): + if key == 'max_age': + if isinstance(value, timedelta): + value = value.seconds + value.days * 24 * 3600 + if key == 'expires': + if isinstance(value, (datedate, datetime)): + value = value.timetuple() + elif isinstance(value, (int, float)): + value = time.gmtime(value) + value = time.strftime("%a, %d %b %Y %H:%M:%S GMT", value) + self._cookies[name][key.replace('_', '-')] = value + + def delete_cookie(self, key, **kwargs): + ''' Delete a cookie. Be sure to use the same `domain` and `path` + settings as used to create the cookie. ''' + kwargs['max_age'] = -1 + kwargs['expires'] = 0 + self.set_cookie(key, '', **kwargs) + + def __repr__(self): + out = '' + for name, value in self.headerlist: + out += '%s: %s\n' % (name.title(), value.strip()) + return out + +#: Thread-local storage for :class:`LocalRequest` and :class:`LocalResponse` +#: attributes. +_lctx = threading.local() + +def local_property(name): + def fget(self): + try: + return getattr(_lctx, name) + except AttributeError: + raise RuntimeError("Request context not initialized.") + def fset(self, value): setattr(_lctx, name, value) + def fdel(self): delattr(_lctx, name) + return property(fget, fset, fdel, + 'Thread-local property stored in :data:`_lctx.%s`' % name) + + +class LocalRequest(BaseRequest): + ''' A thread-local subclass of :class:`BaseRequest` with a different + set of attribues for each thread. There is usually only one global + instance of this class (:data:`request`). If accessed during a + request/response cycle, this instance always refers to the *current* + request (even on a multithreaded server). ''' + bind = BaseRequest.__init__ + environ = local_property('request_environ') + + +class LocalResponse(BaseResponse): + ''' A thread-local subclass of :class:`BaseResponse` with a different + set of attribues for each thread. There is usually only one global + instance of this class (:data:`response`). Its attributes are used + to build the HTTP response at the end of the request/response cycle. + ''' + bind = BaseResponse.__init__ + _status_line = local_property('response_status_line') + _status_code = local_property('response_status_code') + _cookies = local_property('response_cookies') + _headers = local_property('response_headers') + body = local_property('response_body') + +Request = BaseRequest +Response = BaseResponse + +class HTTPResponse(Response, BottleException): + def __init__(self, body='', status=None, header=None, **headers): + if header or 'output' in headers: + depr('Call signature changed (for the better)') + if header: headers.update(header) + if 'output' in headers: body = headers.pop('output') + super(HTTPResponse, self).__init__(body, status, **headers) + + def apply(self, response): + response._status_code = self._status_code + response._status_line = self._status_line + response._headers = self._headers + response._cookies = self._cookies + response.body = self.body + + def _output(self, value=None): + depr('Use HTTPResponse.body instead of HTTPResponse.output') + if value is None: return self.body + self.body = value + + output = property(_output, _output, doc='Alias for .body') + +class HTTPError(HTTPResponse): + default_status = 500 + def __init__(self, status=None, body=None, exception=None, traceback=None, header=None, **headers): + self.exception = exception + self.traceback = traceback + super(HTTPError, self).__init__(body, status, header, **headers) + + + + + +############################################################################### +# Plugins ###################################################################### +############################################################################### + +class PluginError(BottleException): pass + +class JSONPlugin(object): + name = 'json' + api = 2 + + def __init__(self, json_dumps=json_dumps): + self.json_dumps = json_dumps + + def apply(self, callback, route): + dumps = self.json_dumps + if not dumps: return callback + def wrapper(*a, **ka): + rv = callback(*a, **ka) + if isinstance(rv, dict): + #Attempt to serialize, raises exception on failure + json_response = dumps(rv) + #Set content type only if serialization succesful + response.content_type = 'application/json' + return json_response + return rv + return wrapper + + +class HooksPlugin(object): + name = 'hooks' + api = 2 + + _names = 'before_request', 'after_request', 'app_reset' + + def __init__(self): + self.hooks = dict((name, []) for name in self._names) + self.app = None + + def _empty(self): + return not (self.hooks['before_request'] or self.hooks['after_request']) + + def setup(self, app): + self.app = app + + def add(self, name, func): + ''' Attach a callback to a hook. ''' + was_empty = self._empty() + self.hooks.setdefault(name, []).append(func) + if self.app and was_empty and not self._empty(): self.app.reset() + + def remove(self, name, func): + ''' Remove a callback from a hook. ''' + was_empty = self._empty() + if name in self.hooks and func in self.hooks[name]: + self.hooks[name].remove(func) + if self.app and not was_empty and self._empty(): self.app.reset() + + def trigger(self, name, *a, **ka): + ''' Trigger a hook and return a list of results. ''' + hooks = self.hooks[name] + if ka.pop('reversed', False): hooks = hooks[::-1] + return [hook(*a, **ka) for hook in hooks] + + def apply(self, callback, route): + if self._empty(): return callback + def wrapper(*a, **ka): + self.trigger('before_request') + rv = callback(*a, **ka) + self.trigger('after_request', reversed=True) + return rv + return wrapper + + +class TemplatePlugin(object): + ''' This plugin applies the :func:`view` decorator to all routes with a + `template` config parameter. If the parameter is a tuple, the second + element must be a dict with additional options (e.g. `template_engine`) + or default variables for the template. ''' + name = 'template' + api = 2 + + def apply(self, callback, route): + conf = route.config.get('template') + if isinstance(conf, (tuple, list)) and len(conf) == 2: + return view(conf[0], **conf[1])(callback) + elif isinstance(conf, str) and 'template_opts' in route.config: + depr('The `template_opts` parameter is deprecated.') #0.9 + return view(conf, **route.config['template_opts'])(callback) + elif isinstance(conf, str): + return view(conf)(callback) + else: + return callback + + +#: Not a plugin, but part of the plugin API. TODO: Find a better place. +class _ImportRedirect(object): + def __init__(self, name, impmask): + ''' Create a virtual package that redirects imports (see PEP 302). ''' + self.name = name + self.impmask = impmask + self.module = sys.modules.setdefault(name, imp.new_module(name)) + self.module.__dict__.update({'__file__': __file__, '__path__': [], + '__all__': [], '__loader__': self}) + sys.meta_path.append(self) + + def find_module(self, fullname, path=None): + if '.' not in fullname: return + packname, modname = fullname.rsplit('.', 1) + if packname != self.name: return + return self + + def load_module(self, fullname): + if fullname in sys.modules: return sys.modules[fullname] + packname, modname = fullname.rsplit('.', 1) + realname = self.impmask % modname + __import__(realname) + module = sys.modules[fullname] = sys.modules[realname] + setattr(self.module, modname, module) + module.__loader__ = self + return module + + + + + + +############################################################################### +# Common Utilities ############################################################# +############################################################################### + + +class MultiDict(DictMixin): + """ This dict stores multiple values per key, but behaves exactly like a + normal dict in that it returns only the newest value for any given key. + There are special methods available to access the full list of values. + """ + + def __init__(self, *a, **k): + self.dict = dict((k, [v]) for (k, v) in dict(*a, **k).items()) + + def __len__(self): return len(self.dict) + def __iter__(self): return iter(self.dict) + def __contains__(self, key): return key in self.dict + def __delitem__(self, key): del self.dict[key] + def __getitem__(self, key): return self.dict[key][-1] + def __setitem__(self, key, value): self.append(key, value) + def keys(self): return self.dict.keys() + + if py3k: + def values(self): return (v[-1] for v in self.dict.values()) + def items(self): return ((k, v[-1]) for k, v in self.dict.items()) + def allitems(self): + return ((k, v) for k, vl in self.dict.items() for v in vl) + iterkeys = keys + itervalues = values + iteritems = items + iterallitems = allitems + + else: + def values(self): return [v[-1] for v in self.dict.values()] + def items(self): return [(k, v[-1]) for k, v in self.dict.items()] + def iterkeys(self): return self.dict.iterkeys() + def itervalues(self): return (v[-1] for v in self.dict.itervalues()) + def iteritems(self): + return ((k, v[-1]) for k, v in self.dict.iteritems()) + def iterallitems(self): + return ((k, v) for k, vl in self.dict.iteritems() for v in vl) + def allitems(self): + return [(k, v) for k, vl in self.dict.iteritems() for v in vl] + + def get(self, key, default=None, index=-1, type=None): + ''' Return the most recent value for a key. + + :param default: The default value to be returned if the key is not + present or the type conversion fails. + :param index: An index for the list of available values. + :param type: If defined, this callable is used to cast the value + into a specific type. Exception are suppressed and result in + the default value to be returned. + ''' + try: + val = self.dict[key][index] + return type(val) if type else val + except Exception: + pass + return default + + def append(self, key, value): + ''' Add a new value to the list of values for this key. ''' + self.dict.setdefault(key, []).append(value) + + def replace(self, key, value): + ''' Replace the list of values with a single value. ''' + self.dict[key] = [value] + + def getall(self, key): + ''' Return a (possibly empty) list of values for a key. ''' + return self.dict.get(key) or [] + + #: Aliases for WTForms to mimic other multi-dict APIs (Django) + getone = get + getlist = getall + + + +class FormsDict(MultiDict): + ''' This :class:`MultiDict` subclass is used to store request form data. + Additionally to the normal dict-like item access methods (which return + unmodified data as native strings), this container also supports + attribute-like access to its values. Attributes are automatically de- + or recoded to match :attr:`input_encoding` (default: 'utf8'). Missing + attributes default to an empty string. ''' + + #: Encoding used for attribute values. + input_encoding = 'utf8' + #: If true (default), unicode strings are first encoded with `latin1` + #: and then decoded to match :attr:`input_encoding`. + recode_unicode = True + + def _fix(self, s, encoding=None): + if isinstance(s, unicode) and self.recode_unicode: # Python 3 WSGI + s = s.encode('latin1') + if isinstance(s, bytes): # Python 2 WSGI + return s.decode(encoding or self.input_encoding) + return s + + def decode(self, encoding=None): + ''' Returns a copy with all keys and values de- or recoded to match + :attr:`input_encoding`. Some libraries (e.g. WTForms) want a + unicode dictionary. ''' + copy = FormsDict() + enc = copy.input_encoding = encoding or self.input_encoding + copy.recode_unicode = False + for key, value in self.allitems(): + copy.append(self._fix(key, enc), self._fix(value, enc)) + return copy + + def getunicode(self, name, default=None, encoding=None): + try: + return self._fix(self[name], encoding) + except (UnicodeError, KeyError): + return default + + def __getattr__(self, name, default=unicode()): + # Without this guard, pickle generates a cryptic TypeError: + if name.startswith('__') and name.endswith('__'): + return super(FormsDict, self).__getattr__(name) + return self.getunicode(name, default=default) + + +class HeaderDict(MultiDict): + """ A case-insensitive version of :class:`MultiDict` that defaults to + replace the old value instead of appending it. """ + + def __init__(self, *a, **ka): + self.dict = {} + if a or ka: self.update(*a, **ka) + + def __contains__(self, key): return _hkey(key) in self.dict + def __delitem__(self, key): del self.dict[_hkey(key)] + def __getitem__(self, key): return self.dict[_hkey(key)][-1] + def __setitem__(self, key, value): self.dict[_hkey(key)] = [str(value)] + def append(self, key, value): + self.dict.setdefault(_hkey(key), []).append(str(value)) + def replace(self, key, value): self.dict[_hkey(key)] = [str(value)] + def getall(self, key): return self.dict.get(_hkey(key)) or [] + def get(self, key, default=None, index=-1): + return MultiDict.get(self, _hkey(key), default, index) + def filter(self, names): + for name in [_hkey(n) for n in names]: + if name in self.dict: + del self.dict[name] + + +class WSGIHeaderDict(DictMixin): + ''' This dict-like class wraps a WSGI environ dict and provides convenient + access to HTTP_* fields. Keys and values are native strings + (2.x bytes or 3.x unicode) and keys are case-insensitive. If the WSGI + environment contains non-native string values, these are de- or encoded + using a lossless 'latin1' character set. + + The API will remain stable even on changes to the relevant PEPs. + Currently PEP 333, 444 and 3333 are supported. (PEP 444 is the only one + that uses non-native strings.) + ''' + #: List of keys that do not have a ``HTTP_`` prefix. + cgikeys = ('CONTENT_TYPE', 'CONTENT_LENGTH') + + def __init__(self, environ): + self.environ = environ + + def _ekey(self, key): + ''' Translate header field name to CGI/WSGI environ key. ''' + key = key.replace('-','_').upper() + if key in self.cgikeys: + return key + return 'HTTP_' + key + + def raw(self, key, default=None): + ''' Return the header value as is (may be bytes or unicode). ''' + return self.environ.get(self._ekey(key), default) + + def __getitem__(self, key): + return tonat(self.environ[self._ekey(key)], 'latin1') + + def __setitem__(self, key, value): + raise TypeError("%s is read-only." % self.__class__) + + def __delitem__(self, key): + raise TypeError("%s is read-only." % self.__class__) + + def __iter__(self): + for key in self.environ: + if key[:5] == 'HTTP_': + yield key[5:].replace('_', '-').title() + elif key in self.cgikeys: + yield key.replace('_', '-').title() + + def keys(self): return [x for x in self] + def __len__(self): return len(self.keys()) + def __contains__(self, key): return self._ekey(key) in self.environ + + +class ConfigDict(dict): + ''' A dict-subclass with some extras: You can access keys like attributes. + Uppercase attributes create new ConfigDicts and act as name-spaces. + Other missing attributes return None. Calling a ConfigDict updates its + values and returns itself. + + >>> cfg = ConfigDict() + >>> cfg.Namespace.value = 5 + >>> cfg.OtherNamespace(a=1, b=2) + >>> cfg + {'Namespace': {'value': 5}, 'OtherNamespace': {'a': 1, 'b': 2}} + ''' + + def __getattr__(self, key): + if key not in self and key[0].isupper(): + self[key] = ConfigDict() + return self.get(key) + + def __setattr__(self, key, value): + if hasattr(dict, key): + raise AttributeError('Read-only attribute.') + if key in self and self[key] and isinstance(self[key], ConfigDict): + raise AttributeError('Non-empty namespace attribute.') + self[key] = value + + def __delattr__(self, key): + if key in self: del self[key] + + def __call__(self, *a, **ka): + for key, value in dict(*a, **ka).items(): setattr(self, key, value) + return self + + +class AppStack(list): + """ A stack-like list. Calling it returns the head of the stack. """ + + def __call__(self): + """ Return the current default application. """ + return self[-1] + + def push(self, value=None): + """ Add a new :class:`Bottle` instance to the stack """ + if not isinstance(value, Bottle): + value = Bottle() + self.append(value) + return value + + +class WSGIFileWrapper(object): + + def __init__(self, fp, buffer_size=1024*64): + self.fp, self.buffer_size = fp, buffer_size + for attr in ('fileno', 'close', 'read', 'readlines', 'tell', 'seek'): + if hasattr(fp, attr): setattr(self, attr, getattr(fp, attr)) + + def __iter__(self): + buff, read = self.buffer_size, self.read + while True: + part = read(buff) + if not part: return + yield part + + +class ResourceManager(object): + ''' This class manages a list of search paths and helps to find and open + application-bound resources (files). + + :param base: default value for :meth:`add_path` calls. + :param opener: callable used to open resources. + :param cachemode: controls which lookups are cached. One of 'all', + 'found' or 'none'. + ''' + + def __init__(self, base='./', opener=open, cachemode='all'): + self.opener = open + self.base = base + self.cachemode = cachemode + + #: A list of search paths. See :meth:`add_path` for details. + self.path = [] + #: A cache for resolved paths. ``res.cache.clear()`` clears the cache. + self.cache = {} + + def add_path(self, path, base=None, index=None, create=False): + ''' Add a new path to the list of search paths. Return False if the + path does not exist. + + :param path: The new search path. Relative paths are turned into + an absolute and normalized form. If the path looks like a file + (not ending in `/`), the filename is stripped off. + :param base: Path used to absolutize relative search paths. + Defaults to :attr:`base` which defaults to ``os.getcwd()``. + :param index: Position within the list of search paths. Defaults + to last index (appends to the list). + + The `base` parameter makes it easy to reference files installed + along with a python module or package:: + + res.add_path('./resources/', __file__) + ''' + base = os.path.abspath(os.path.dirname(base or self.base)) + path = os.path.abspath(os.path.join(base, os.path.dirname(path))) + path += os.sep + if path in self.path: + self.path.remove(path) + if create and not os.path.isdir(path): + os.makedirs(path) + if index is None: + self.path.append(path) + else: + self.path.insert(index, path) + self.cache.clear() + return os.path.exists(path) + + def __iter__(self): + ''' Iterate over all existing files in all registered paths. ''' + search = self.path[:] + while search: + path = search.pop() + if not os.path.isdir(path): continue + for name in os.listdir(path): + full = os.path.join(path, name) + if os.path.isdir(full): search.append(full) + else: yield full + + def lookup(self, name): + ''' Search for a resource and return an absolute file path, or `None`. + + The :attr:`path` list is searched in order. The first match is + returend. Symlinks are followed. The result is cached to speed up + future lookups. ''' + if name not in self.cache or DEBUG: + for path in self.path: + fpath = os.path.join(path, name) + if os.path.isfile(fpath): + if self.cachemode in ('all', 'found'): + self.cache[name] = fpath + return fpath + if self.cachemode == 'all': + self.cache[name] = None + return self.cache[name] + + def open(self, name, mode='r', *args, **kwargs): + ''' Find a resource and return a file object, or raise IOError. ''' + fname = self.lookup(name) + if not fname: raise IOError("Resource %r not found." % name) + return self.opener(name, mode=mode, *args, **kwargs) + + + + + + +############################################################################### +# Application Helper ########################################################### +############################################################################### + + +def abort(code=500, text='Unknown Error: Application stopped.'): + """ Aborts execution and causes a HTTP error. """ + raise HTTPError(code, text) + + +def redirect(url, code=None): + """ Aborts execution and causes a 303 or 302 redirect, depending on + the HTTP protocol version. """ + if code is None: + code = 303 if request.get('SERVER_PROTOCOL') == "HTTP/1.1" else 302 + location = urljoin(request.url, url) + res = HTTPResponse("", status=code, Location=location) + if response._cookies: + res._cookies = response._cookies + raise res + + +def _file_iter_range(fp, offset, bytes, maxread=1024*1024): + ''' Yield chunks from a range in a file. No chunk is bigger than maxread.''' + fp.seek(offset) + while bytes > 0: + part = fp.read(min(bytes, maxread)) + if not part: break + bytes -= len(part) + yield part + + +def static_file(filename, root, mimetype='auto', download=False): + """ Open a file in a safe way and return :exc:`HTTPResponse` with status + code 200, 305, 401 or 404. Set Content-Type, Content-Encoding, + Content-Length and Last-Modified header. Obey If-Modified-Since header + and HEAD requests. + """ + root = os.path.abspath(root) + os.sep + filename = os.path.abspath(os.path.join(root, filename.strip('/\\'))) + headers = dict() + + if not filename.startswith(root): + return HTTPError(403, "Access denied.") + if not os.path.exists(filename) or not os.path.isfile(filename): + return HTTPError(404, "File does not exist.") + if not os.access(filename, os.R_OK): + return HTTPError(403, "You do not have permission to access this file.") + + if mimetype == 'auto': + mimetype, encoding = mimetypes.guess_type(filename) + if mimetype: headers['Content-Type'] = mimetype + if encoding: headers['Content-Encoding'] = encoding + elif mimetype: + headers['Content-Type'] = mimetype + + if download: + download = os.path.basename(filename if download == True else download) + headers['Content-Disposition'] = 'attachment; filename="%s"' % download + + stats = os.stat(filename) + headers['Content-Length'] = clen = stats.st_size + lm = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(stats.st_mtime)) + headers['Last-Modified'] = lm + + ims = request.environ.get('HTTP_IF_MODIFIED_SINCE') + if ims: + ims = parse_date(ims.split(";")[0].strip()) + if ims is not None and ims >= int(stats.st_mtime): + headers['Date'] = time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()) + return HTTPResponse(status=304, **headers) + + body = '' if request.method == 'HEAD' else open(filename, 'rb') + + headers["Accept-Ranges"] = "bytes" + ranges = request.environ.get('HTTP_RANGE') + if 'HTTP_RANGE' in request.environ: + ranges = list(parse_range_header(request.environ['HTTP_RANGE'], clen)) + if not ranges: + return HTTPError(416, "Requested Range Not Satisfiable") + offset, end = ranges[0] + headers["Content-Range"] = "bytes %d-%d/%d" % (offset, end-1, clen) + headers["Content-Length"] = str(end-offset) + if body: body = _file_iter_range(body, offset, end-offset) + return HTTPResponse(body, status=206, **headers) + return HTTPResponse(body, **headers) + + + + + + +############################################################################### +# HTTP Utilities and MISC (TODO) ############################################### +############################################################################### + + +def debug(mode=True): + """ Change the debug level. + There is only one debug level supported at the moment.""" + global DEBUG + DEBUG = bool(mode) + + +def parse_date(ims): + """ Parse rfc1123, rfc850 and asctime timestamps and return UTC epoch. """ + try: + ts = email.utils.parsedate_tz(ims) + return time.mktime(ts[:8] + (0,)) - (ts[9] or 0) - time.timezone + except (TypeError, ValueError, IndexError, OverflowError): + return None + + +def parse_auth(header): + """ Parse rfc2617 HTTP authentication header string (basic) and return (user,pass) tuple or None""" + try: + method, data = header.split(None, 1) + if method.lower() == 'basic': + user, pwd = touni(base64.b64decode(tob(data))).split(':',1) + return user, pwd + except (KeyError, ValueError): + return None + +def parse_range_header(header, maxlen=0): + ''' Yield (start, end) ranges parsed from a HTTP Range header. Skip + unsatisfiable ranges. The end index is non-inclusive.''' + if not header or header[:6] != 'bytes=': return + ranges = [r.split('-', 1) for r in header[6:].split(',') if '-' in r] + for start, end in ranges: + try: + if not start: # bytes=-100 -> last 100 bytes + start, end = max(0, maxlen-int(end)), maxlen + elif not end: # bytes=100- -> all but the first 99 bytes + start, end = int(start), maxlen + else: # bytes=100-200 -> bytes 100-200 (inclusive) + start, end = int(start), min(int(end)+1, maxlen) + if 0 <= start < end <= maxlen: + yield start, end + except ValueError: + pass + +def _parse_qsl(qs): + r = [] + for pair in qs.replace(';','&').split('&'): + if not pair: continue + nv = pair.split('=', 1) + if len(nv) != 2: nv.append('') + key = urlunquote(nv[0].replace('+', ' ')) + value = urlunquote(nv[1].replace('+', ' ')) + r.append((key, value)) + return r + +def _lscmp(a, b): + ''' Compares two strings in a cryptographically safe way: + Runtime is not affected by length of common prefix. ''' + return not sum(0 if x==y else 1 for x, y in zip(a, b)) and len(a) == len(b) + + +def cookie_encode(data, key): + ''' Encode and sign a pickle-able object. Return a (byte) string ''' + msg = base64.b64encode(pickle.dumps(data, -1)) + sig = base64.b64encode(hmac.new(tob(key), msg).digest()) + return tob('!') + sig + tob('?') + msg + + +def cookie_decode(data, key): + ''' Verify and decode an encoded string. Return an object or None.''' + data = tob(data) + if cookie_is_encoded(data): + sig, msg = data.split(tob('?'), 1) + if _lscmp(sig[1:], base64.b64encode(hmac.new(tob(key), msg).digest())): + return pickle.loads(base64.b64decode(msg)) + return None + + +def cookie_is_encoded(data): + ''' Return True if the argument looks like a encoded cookie.''' + return bool(data.startswith(tob('!')) and tob('?') in data) + + +def html_escape(string): + ''' Escape HTML special characters ``&<>`` and quotes ``'"``. ''' + return string.replace('&','&').replace('<','<').replace('>','>')\ + .replace('"','"').replace("'",''') + + +def html_quote(string): + ''' Escape and quote a string to be used as an HTTP attribute.''' + return '"%s"' % html_escape(string).replace('\n','%#10;')\ + .replace('\r',' ').replace('\t','	') + + +def yieldroutes(func): + """ Return a generator for routes that match the signature (name, args) + of the func parameter. This may yield more than one route if the function + takes optional keyword arguments. The output is best described by example:: + + a() -> '/a' + b(x, y) -> '/b/:x/:y' + c(x, y=5) -> '/c/:x' and '/c/:x/:y' + d(x=5, y=6) -> '/d' and '/d/:x' and '/d/:x/:y' + """ + import inspect # Expensive module. Only import if necessary. + path = '/' + func.__name__.replace('__','/').lstrip('/') + spec = inspect.getargspec(func) + argc = len(spec[0]) - len(spec[3] or []) + path += ('/:%s' * argc) % tuple(spec[0][:argc]) + yield path + for arg in spec[0][argc:]: + path += '/:%s' % arg + yield path + + +def path_shift(script_name, path_info, shift=1): + ''' Shift path fragments from PATH_INFO to SCRIPT_NAME and vice versa. + + :return: The modified paths. + :param script_name: The SCRIPT_NAME path. + :param script_name: The PATH_INFO path. + :param shift: The number of path fragments to shift. May be negative to + change the shift direction. (default: 1) + ''' + if shift == 0: return script_name, path_info + pathlist = path_info.strip('/').split('/') + scriptlist = script_name.strip('/').split('/') + if pathlist and pathlist[0] == '': pathlist = [] + if scriptlist and scriptlist[0] == '': scriptlist = [] + if shift > 0 and shift <= len(pathlist): + moved = pathlist[:shift] + scriptlist = scriptlist + moved + pathlist = pathlist[shift:] + elif shift < 0 and shift >= -len(scriptlist): + moved = scriptlist[shift:] + pathlist = moved + pathlist + scriptlist = scriptlist[:shift] + else: + empty = 'SCRIPT_NAME' if shift < 0 else 'PATH_INFO' + raise AssertionError("Cannot shift. Nothing left from %s" % empty) + new_script_name = '/' + '/'.join(scriptlist) + new_path_info = '/' + '/'.join(pathlist) + if path_info.endswith('/') and pathlist: new_path_info += '/' + return new_script_name, new_path_info + + +def validate(**vkargs): + """ + Validates and manipulates keyword arguments by user defined callables. + Handles ValueError and missing arguments by raising HTTPError(403). + """ + depr('Use route wildcard filters instead.') + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kargs): + for key, value in vkargs.items(): + if key not in kargs: + abort(403, 'Missing parameter: %s' % key) + try: + kargs[key] = value(kargs[key]) + except ValueError: + abort(403, 'Wrong parameter format for: %s' % key) + return func(*args, **kargs) + return wrapper + return decorator + + +def auth_basic(check, realm="private", text="Access denied"): + ''' Callback decorator to require HTTP auth (basic). + TODO: Add route(check_auth=...) parameter. ''' + def decorator(func): + def wrapper(*a, **ka): + user, password = request.auth or (None, None) + if user is None or not check(user, password): + response.headers['WWW-Authenticate'] = 'Basic realm="%s"' % realm + return HTTPError(401, text) + return func(*a, **ka) + return wrapper + return decorator + + +# Shortcuts for common Bottle methods. +# They all refer to the current default application. + +def make_default_app_wrapper(name): + ''' Return a callable that relays calls to the current default app. ''' + @functools.wraps(getattr(Bottle, name)) + def wrapper(*a, **ka): + return getattr(app(), name)(*a, **ka) + return wrapper + +route = make_default_app_wrapper('route') +get = make_default_app_wrapper('get') +post = make_default_app_wrapper('post') +put = make_default_app_wrapper('put') +delete = make_default_app_wrapper('delete') +error = make_default_app_wrapper('error') +mount = make_default_app_wrapper('mount') +hook = make_default_app_wrapper('hook') +install = make_default_app_wrapper('install') +uninstall = make_default_app_wrapper('uninstall') +url = make_default_app_wrapper('get_url') + + + + + + + +############################################################################### +# Server Adapter ############################################################### +############################################################################### + + +class ServerAdapter(object): + quiet = False + def __init__(self, host='127.0.0.1', port=8080, **config): + self.options = config + self.host = host + self.port = int(port) + + def run(self, handler): # pragma: no cover + pass + + def __repr__(self): + args = ', '.join(['%s=%s'%(k,repr(v)) for k, v in self.options.items()]) + return "%s(%s)" % (self.__class__.__name__, args) + + +class CGIServer(ServerAdapter): + quiet = True + def run(self, handler): # pragma: no cover + from wsgiref.handlers import CGIHandler + def fixed_environ(environ, start_response): + environ.setdefault('PATH_INFO', '') + return handler(environ, start_response) + CGIHandler().run(fixed_environ) + + +class FlupFCGIServer(ServerAdapter): + def run(self, handler): # pragma: no cover + import flup.server.fcgi + self.options.setdefault('bindAddress', (self.host, self.port)) + flup.server.fcgi.WSGIServer(handler, **self.options).run() + + +class WSGIRefServer(ServerAdapter): + def run(self, handler): # pragma: no cover + from wsgiref.simple_server import make_server, WSGIRequestHandler + if self.quiet: + class QuietHandler(WSGIRequestHandler): + def log_request(*args, **kw): pass + self.options['handler_class'] = QuietHandler + srv = make_server(self.host, self.port, handler, **self.options) + srv.serve_forever() + + +class CherryPyServer(ServerAdapter): + def run(self, handler): # pragma: no cover + from cherrypy import wsgiserver + server = wsgiserver.CherryPyWSGIServer((self.host, self.port), handler) + try: + server.start() + finally: + server.stop() + + +class WaitressServer(ServerAdapter): + def run(self, handler): + from waitress import serve + serve(handler, host=self.host, port=self.port) + + +class PasteServer(ServerAdapter): + def run(self, handler): # pragma: no cover + from paste import httpserver + if not self.quiet: + from paste.translogger import TransLogger + handler = TransLogger(handler) + httpserver.serve(handler, host=self.host, port=str(self.port), + **self.options) + + +class MeinheldServer(ServerAdapter): + def run(self, handler): + from meinheld import server + server.listen((self.host, self.port)) + server.run(handler) + + +class FapwsServer(ServerAdapter): + """ Extremely fast webserver using libev. See http://www.fapws.org/ """ + def run(self, handler): # pragma: no cover + import fapws._evwsgi as evwsgi + from fapws import base, config + port = self.port + if float(config.SERVER_IDENT[-2:]) > 0.4: + # fapws3 silently changed its API in 0.5 + port = str(port) + evwsgi.start(self.host, port) + # fapws3 never releases the GIL. Complain upstream. I tried. No luck. + if 'BOTTLE_CHILD' in os.environ and not self.quiet: + _stderr("WARNING: Auto-reloading does not work with Fapws3.\n") + _stderr(" (Fapws3 breaks python thread support)\n") + evwsgi.set_base_module(base) + def app(environ, start_response): + environ['wsgi.multiprocess'] = False + return handler(environ, start_response) + evwsgi.wsgi_cb(('', app)) + evwsgi.run() + + +class TornadoServer(ServerAdapter): + """ The super hyped asynchronous server by facebook. Untested. """ + def run(self, handler): # pragma: no cover + import tornado.wsgi, tornado.httpserver, tornado.ioloop + container = tornado.wsgi.WSGIContainer(handler) + server = tornado.httpserver.HTTPServer(container) + server.listen(port=self.port) + tornado.ioloop.IOLoop.instance().start() + + +class AppEngineServer(ServerAdapter): + """ Adapter for Google App Engine. """ + quiet = True + def run(self, handler): + from google.appengine.ext.webapp import util + # A main() function in the handler script enables 'App Caching'. + # Lets makes sure it is there. This _really_ improves performance. + module = sys.modules.get('__main__') + if module and not hasattr(module, 'main'): + module.main = lambda: util.run_wsgi_app(handler) + util.run_wsgi_app(handler) + + +class TwistedServer(ServerAdapter): + """ Untested. """ + def run(self, handler): + from twisted.web import server, wsgi + from twisted.python.threadpool import ThreadPool + from twisted.internet import reactor + thread_pool = ThreadPool() + thread_pool.start() + reactor.addSystemEventTrigger('after', 'shutdown', thread_pool.stop) + factory = server.Site(wsgi.WSGIResource(reactor, thread_pool, handler)) + reactor.listenTCP(self.port, factory, interface=self.host) + reactor.run() + + +class DieselServer(ServerAdapter): + """ Untested. """ + def run(self, handler): + from diesel.protocols.wsgi import WSGIApplication + app = WSGIApplication(handler, port=self.port) + app.run() + + +class GeventServer(ServerAdapter): + """ Untested. Options: + + * `fast` (default: False) uses libevent's http server, but has some + issues: No streaming, no pipelining, no SSL. + """ + def run(self, handler): + from gevent import wsgi, pywsgi, local + if not isinstance(_lctx, local.local): + msg = "Bottle requires gevent.monkey.patch_all() (before import)" + raise RuntimeError(msg) + if not self.options.get('fast'): wsgi = pywsgi + log = None if self.quiet else 'default' + wsgi.WSGIServer((self.host, self.port), handler, log=log).serve_forever() + + +class GunicornServer(ServerAdapter): + """ Untested. See http://gunicorn.org/configure.html for options. """ + def run(self, handler): + from gunicorn.app.base import Application + + config = {'bind': "%s:%d" % (self.host, int(self.port))} + config.update(self.options) + + class GunicornApplication(Application): + def init(self, parser, opts, args): + return config + + def load(self): + return handler + + GunicornApplication().run() + + +class EventletServer(ServerAdapter): + """ Untested """ + def run(self, handler): + from eventlet import wsgi, listen + try: + wsgi.server(listen((self.host, self.port)), handler, + log_output=(not self.quiet)) + except TypeError: + # Fallback, if we have old version of eventlet + wsgi.server(listen((self.host, self.port)), handler) + + +class RocketServer(ServerAdapter): + """ Untested. """ + def run(self, handler): + from rocket import Rocket + server = Rocket((self.host, self.port), 'wsgi', { 'wsgi_app' : handler }) + server.start() + + +class BjoernServer(ServerAdapter): + """ Fast server written in C: https://github.com/jonashaag/bjoern """ + def run(self, handler): + from bjoern import run + run(handler, self.host, self.port) + + +class AutoServer(ServerAdapter): + """ Untested. """ + adapters = [WaitressServer, PasteServer, TwistedServer, CherryPyServer, WSGIRefServer] + def run(self, handler): + for sa in self.adapters: + try: + return sa(self.host, self.port, **self.options).run(handler) + except ImportError: + pass + +server_names = { + 'cgi': CGIServer, + 'flup': FlupFCGIServer, + 'wsgiref': WSGIRefServer, + 'waitress': WaitressServer, + 'cherrypy': CherryPyServer, + 'paste': PasteServer, + 'fapws3': FapwsServer, + 'tornado': TornadoServer, + 'gae': AppEngineServer, + 'twisted': TwistedServer, + 'diesel': DieselServer, + 'meinheld': MeinheldServer, + 'gunicorn': GunicornServer, + 'eventlet': EventletServer, + 'gevent': GeventServer, + 'rocket': RocketServer, + 'bjoern' : BjoernServer, + 'auto': AutoServer, +} + + + + + + +############################################################################### +# Application Control ########################################################## +############################################################################### + + +def load(target, **namespace): + """ Import a module or fetch an object from a module. + + * ``package.module`` returns `module` as a module object. + * ``pack.mod:name`` returns the module variable `name` from `pack.mod`. + * ``pack.mod:func()`` calls `pack.mod.func()` and returns the result. + + The last form accepts not only function calls, but any type of + expression. Keyword arguments passed to this function are available as + local variables. Example: ``import_string('re:compile(x)', x='[a-z]')`` + """ + module, target = target.split(":", 1) if ':' in target else (target, None) + if module not in sys.modules: __import__(module) + if not target: return sys.modules[module] + if target.isalnum(): return getattr(sys.modules[module], target) + package_name = module.split('.')[0] + namespace[package_name] = sys.modules[package_name] + return eval('%s.%s' % (module, target), namespace) + + +def load_app(target): + """ Load a bottle application from a module and make sure that the import + does not affect the current default application, but returns a separate + application object. See :func:`load` for the target parameter. """ + global NORUN; NORUN, nr_old = True, NORUN + try: + tmp = default_app.push() # Create a new "default application" + rv = load(target) # Import the target module + return rv if callable(rv) else tmp + finally: + default_app.remove(tmp) # Remove the temporary added default application + NORUN = nr_old + +_debug = debug +def run(app=None, server='wsgiref', host='127.0.0.1', port=8080, + interval=1, reloader=False, quiet=False, plugins=None, + debug=False, **kargs): + """ Start a server instance. This method blocks until the server terminates. + + :param app: WSGI application or target string supported by + :func:`load_app`. (default: :func:`default_app`) + :param server: Server adapter to use. See :data:`server_names` keys + for valid names or pass a :class:`ServerAdapter` subclass. + (default: `wsgiref`) + :param host: Server address to bind to. Pass ``0.0.0.0`` to listens on + all interfaces including the external one. (default: 127.0.0.1) + :param port: Server port to bind to. Values below 1024 require root + privileges. (default: 8080) + :param reloader: Start auto-reloading server? (default: False) + :param interval: Auto-reloader interval in seconds (default: 1) + :param quiet: Suppress output to stdout and stderr? (default: False) + :param options: Options passed to the server adapter. + """ + if NORUN: return + if reloader and not os.environ.get('BOTTLE_CHILD'): + try: + lockfile = None + fd, lockfile = tempfile.mkstemp(prefix='bottle.', suffix='.lock') + os.close(fd) # We only need this file to exist. We never write to it + while os.path.exists(lockfile): + args = [sys.executable] + sys.argv + environ = os.environ.copy() + environ['BOTTLE_CHILD'] = 'true' + environ['BOTTLE_LOCKFILE'] = lockfile + p = subprocess.Popen(args, env=environ) + while p.poll() is None: # Busy wait... + os.utime(lockfile, None) # I am alive! + time.sleep(interval) + if p.poll() != 3: + if os.path.exists(lockfile): os.unlink(lockfile) + sys.exit(p.poll()) + except KeyboardInterrupt: + pass + finally: + if os.path.exists(lockfile): + os.unlink(lockfile) + return + + try: + _debug(debug) + app = app or default_app() + if isinstance(app, basestring): + app = load_app(app) + if not callable(app): + raise ValueError("Application is not callable: %r" % app) + + for plugin in plugins or []: + app.install(plugin) + + if server in server_names: + server = server_names.get(server) + if isinstance(server, basestring): + server = load(server) + if isinstance(server, type): + server = server(host=host, port=port, **kargs) + if not isinstance(server, ServerAdapter): + raise ValueError("Unknown or unsupported server: %r" % server) + + server.quiet = server.quiet or quiet + if not server.quiet: + _stderr("Bottle v%s server starting up (using %s)...\n" % (__version__, repr(server))) + _stderr("Listening on http://%s:%d/\n" % (server.host, server.port)) + _stderr("Hit Ctrl-C to quit.\n\n") + + if reloader: + lockfile = os.environ.get('BOTTLE_LOCKFILE') + bgcheck = FileCheckerThread(lockfile, interval) + with bgcheck: + server.run(app) + if bgcheck.status == 'reload': + sys.exit(3) + else: + server.run(app) + except KeyboardInterrupt: + pass + except (SystemExit, MemoryError): + raise + except: + if not reloader: raise + if not getattr(server, 'quiet', quiet): + print_exc() + time.sleep(interval) + sys.exit(3) + + + +class FileCheckerThread(threading.Thread): + ''' Interrupt main-thread as soon as a changed module file is detected, + the lockfile gets deleted or gets to old. ''' + + def __init__(self, lockfile, interval): + threading.Thread.__init__(self) + self.lockfile, self.interval = lockfile, interval + #: Is one of 'reload', 'error' or 'exit' + self.status = None + + def run(self): + exists = os.path.exists + mtime = lambda path: os.stat(path).st_mtime + files = dict() + + for module in list(sys.modules.values()): + path = getattr(module, '__file__', '') + if path[-4:] in ('.pyo', '.pyc'): path = path[:-1] + if path and exists(path): files[path] = mtime(path) + + while not self.status: + if not exists(self.lockfile)\ + or mtime(self.lockfile) < time.time() - self.interval - 5: + self.status = 'error' + thread.interrupt_main() + for path, lmtime in list(files.items()): + if not exists(path) or mtime(path) > lmtime: + self.status = 'reload' + thread.interrupt_main() + break + time.sleep(self.interval) + + def __enter__(self): + self.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.status: self.status = 'exit' # silent exit + self.join() + return exc_type is not None and issubclass(exc_type, KeyboardInterrupt) + + + + + +############################################################################### +# Template Adapters ############################################################ +############################################################################### + + +class TemplateError(HTTPError): + def __init__(self, message): + HTTPError.__init__(self, 500, message) + + +class BaseTemplate(object): + """ Base class and minimal API for template adapters """ + extensions = ['tpl','html','thtml','stpl'] + settings = {} #used in prepare() + defaults = {} #used in render() + + def __init__(self, source=None, name=None, lookup=[], encoding='utf8', **settings): + """ Create a new template. + If the source parameter (str or buffer) is missing, the name argument + is used to guess a template filename. Subclasses can assume that + self.source and/or self.filename are set. Both are strings. + The lookup, encoding and settings parameters are stored as instance + variables. + The lookup parameter stores a list containing directory paths. + The encoding parameter should be used to decode byte strings or files. + The settings parameter contains a dict for engine-specific settings. + """ + self.name = name + self.source = source.read() if hasattr(source, 'read') else source + self.filename = source.filename if hasattr(source, 'filename') else None + self.lookup = [os.path.abspath(x) for x in lookup] + self.encoding = encoding + self.settings = self.settings.copy() # Copy from class variable + self.settings.update(settings) # Apply + if not self.source and self.name: + self.filename = self.search(self.name, self.lookup) + if not self.filename: + raise TemplateError('Template %s not found.' % repr(name)) + if not self.source and not self.filename: + raise TemplateError('No template specified.') + self.prepare(**self.settings) + + @classmethod + def search(cls, name, lookup=[]): + """ Search name in all directories specified in lookup. + First without, then with common extensions. Return first hit. """ + if not lookup: + depr('The template lookup path list should not be empty.') + lookup = ['.'] + + if os.path.isabs(name) and os.path.isfile(name): + depr('Absolute template path names are deprecated.') + return os.path.abspath(name) + + for spath in lookup: + spath = os.path.abspath(spath) + os.sep + fname = os.path.abspath(os.path.join(spath, name)) + if not fname.startswith(spath): continue + if os.path.isfile(fname): return fname + for ext in cls.extensions: + if os.path.isfile('%s.%s' % (fname, ext)): + return '%s.%s' % (fname, ext) + + @classmethod + def global_config(cls, key, *args): + ''' This reads or sets the global settings stored in class.settings. ''' + if args: + cls.settings = cls.settings.copy() # Make settings local to class + cls.settings[key] = args[0] + else: + return cls.settings[key] + + def prepare(self, **options): + """ Run preparations (parsing, caching, ...). + It should be possible to call this again to refresh a template or to + update settings. + """ + raise NotImplementedError + + def render(self, *args, **kwargs): + """ Render the template with the specified local variables and return + a single byte or unicode string. If it is a byte string, the encoding + must match self.encoding. This method must be thread-safe! + Local variables may be provided in dictionaries (*args) + or directly, as keywords (**kwargs). + """ + raise NotImplementedError + + +class MakoTemplate(BaseTemplate): + def prepare(self, **options): + from mako.template import Template + from mako.lookup import TemplateLookup + options.update({'input_encoding':self.encoding}) + options.setdefault('format_exceptions', bool(DEBUG)) + lookup = TemplateLookup(directories=self.lookup, **options) + if self.source: + self.tpl = Template(self.source, lookup=lookup, **options) + else: + self.tpl = Template(uri=self.name, filename=self.filename, lookup=lookup, **options) + + def render(self, *args, **kwargs): + for dictarg in args: kwargs.update(dictarg) + _defaults = self.defaults.copy() + _defaults.update(kwargs) + return self.tpl.render(**_defaults) + + +class CheetahTemplate(BaseTemplate): + def prepare(self, **options): + from Cheetah.Template import Template + self.context = threading.local() + self.context.vars = {} + options['searchList'] = [self.context.vars] + if self.source: + self.tpl = Template(source=self.source, **options) + else: + self.tpl = Template(file=self.filename, **options) + + def render(self, *args, **kwargs): + for dictarg in args: kwargs.update(dictarg) + self.context.vars.update(self.defaults) + self.context.vars.update(kwargs) + out = str(self.tpl) + self.context.vars.clear() + return out + + +class Jinja2Template(BaseTemplate): + def prepare(self, filters=None, tests=None, **kwargs): + from jinja2 import Environment, FunctionLoader + if 'prefix' in kwargs: # TODO: to be removed after a while + raise RuntimeError('The keyword argument `prefix` has been removed. ' + 'Use the full jinja2 environment name line_statement_prefix instead.') + self.env = Environment(loader=FunctionLoader(self.loader), **kwargs) + if filters: self.env.filters.update(filters) + if tests: self.env.tests.update(tests) + if self.source: + self.tpl = self.env.from_string(self.source) + else: + self.tpl = self.env.get_template(self.filename) + + def render(self, *args, **kwargs): + for dictarg in args: kwargs.update(dictarg) + _defaults = self.defaults.copy() + _defaults.update(kwargs) + return self.tpl.render(**_defaults) + + def loader(self, name): + fname = self.search(name, self.lookup) + if not fname: return + with open(fname, "rb") as f: + return f.read().decode(self.encoding) + + +class SimpleTALTemplate(BaseTemplate): + ''' Deprecated, do not use. ''' + def prepare(self, **options): + depr('The SimpleTAL template handler is deprecated'\ + ' and will be removed in 0.12') + from simpletal import simpleTAL + if self.source: + self.tpl = simpleTAL.compileHTMLTemplate(self.source) + else: + with open(self.filename, 'rb') as fp: + self.tpl = simpleTAL.compileHTMLTemplate(tonat(fp.read())) + + def render(self, *args, **kwargs): + from simpletal import simpleTALES + for dictarg in args: kwargs.update(dictarg) + context = simpleTALES.Context() + for k,v in self.defaults.items(): + context.addGlobal(k, v) + for k,v in kwargs.items(): + context.addGlobal(k, v) + output = StringIO() + self.tpl.expand(context, output) + return output.getvalue() + + +class SimpleTemplate(BaseTemplate): + blocks = ('if', 'elif', 'else', 'try', 'except', 'finally', 'for', 'while', + 'with', 'def', 'class') + dedent_blocks = ('elif', 'else', 'except', 'finally') + + @lazy_attribute + def re_pytokens(cls): + ''' This matches comments and all kinds of quoted strings but does + NOT match comments (#...) within quoted strings. (trust me) ''' + return re.compile(r''' + (''(?!')|""(?!")|'{6}|"{6} # Empty strings (all 4 types) + |'(?:[^\\']|\\.)+?' # Single quotes (') + |"(?:[^\\"]|\\.)+?" # Double quotes (") + |'{3}(?:[^\\]|\\.|\n)+?'{3} # Triple-quoted strings (') + |"{3}(?:[^\\]|\\.|\n)+?"{3} # Triple-quoted strings (") + |\#.* # Comments + )''', re.VERBOSE) + + def prepare(self, escape_func=html_escape, noescape=False, **kwargs): + self.cache = {} + enc = self.encoding + self._str = lambda x: touni(x, enc) + self._escape = lambda x: escape_func(touni(x, enc)) + if noescape: + self._str, self._escape = self._escape, self._str + + @classmethod + def split_comment(cls, code): + """ Removes comments (#...) from python code. """ + if '#' not in code: return code + #: Remove comments only (leave quoted strings as they are) + subf = lambda m: '' if m.group(0)[0]=='#' else m.group(0) + return re.sub(cls.re_pytokens, subf, code) + + @cached_property + def co(self): + return compile(self.code, self.filename or '<string>', 'exec') + + @cached_property + def code(self): + stack = [] # Current Code indentation + lineno = 0 # Current line of code + ptrbuffer = [] # Buffer for printable strings and token tuple instances + codebuffer = [] # Buffer for generated python code + multiline = dedent = oneline = False + template = self.source or open(self.filename, 'rb').read() + + def yield_tokens(line): + for i, part in enumerate(re.split(r'\{\{(.*?)\}\}', line)): + if i % 2: + if part.startswith('!'): yield 'RAW', part[1:] + else: yield 'CMD', part + else: yield 'TXT', part + + def flush(): # Flush the ptrbuffer + if not ptrbuffer: return + cline = '' + for line in ptrbuffer: + for token, value in line: + if token == 'TXT': cline += repr(value) + elif token == 'RAW': cline += '_str(%s)' % value + elif token == 'CMD': cline += '_escape(%s)' % value + cline += ', ' + cline = cline[:-2] + '\\\n' + cline = cline[:-2] + if cline[:-1].endswith('\\\\\\\\\\n'): + cline = cline[:-7] + cline[-1] # 'nobr\\\\\n' --> 'nobr' + cline = '_printlist([' + cline + '])' + del ptrbuffer[:] # Do this before calling code() again + code(cline) + + def code(stmt): + for line in stmt.splitlines(): + codebuffer.append(' ' * len(stack) + line.strip()) + + for line in template.splitlines(True): + lineno += 1 + line = touni(line, self.encoding) + sline = line.lstrip() + if lineno <= 2: + m = re.match(r"%\s*#.*coding[:=]\s*([-\w.]+)", sline) + if m: self.encoding = m.group(1) + if m: line = line.replace('coding','coding (removed)') + if sline and sline[0] == '%' and sline[:2] != '%%': + line = line.split('%',1)[1].lstrip() # Full line following the % + cline = self.split_comment(line).strip() + cmd = re.split(r'[^a-zA-Z0-9_]', cline)[0] + flush() # You are actually reading this? Good luck, it's a mess :) + if cmd in self.blocks or multiline: + cmd = multiline or cmd + dedent = cmd in self.dedent_blocks # "else:" + if dedent and not oneline and not multiline: + cmd = stack.pop() + code(line) + oneline = not cline.endswith(':') # "if 1: pass" + multiline = cmd if cline.endswith('\\') else False + if not oneline and not multiline: + stack.append(cmd) + elif cmd == 'end' and stack: + code('#end(%s) %s' % (stack.pop(), line.strip()[3:])) + elif cmd == 'include': + p = cline.split(None, 2)[1:] + if len(p) == 2: + code("_=_include(%s, _stdout, %s)" % (repr(p[0]), p[1])) + elif p: + code("_=_include(%s, _stdout)" % repr(p[0])) + else: # Empty %include -> reverse of %rebase + code("_printlist(_base)") + elif cmd == 'rebase': + p = cline.split(None, 2)[1:] + if len(p) == 2: + code("globals()['_rebase']=(%s, dict(%s))" % (repr(p[0]), p[1])) + elif p: + code("globals()['_rebase']=(%s, {})" % repr(p[0])) + else: + code(line) + else: # Line starting with text (not '%') or '%%' (escaped) + if line.strip().startswith('%%'): + line = line.replace('%%', '%', 1) + ptrbuffer.append(yield_tokens(line)) + flush() + return '\n'.join(codebuffer) + '\n' + + def subtemplate(self, _name, _stdout, *args, **kwargs): + for dictarg in args: kwargs.update(dictarg) + if _name not in self.cache: + self.cache[_name] = self.__class__(name=_name, lookup=self.lookup) + return self.cache[_name].execute(_stdout, kwargs) + + def execute(self, _stdout, *args, **kwargs): + for dictarg in args: kwargs.update(dictarg) + env = self.defaults.copy() + env.update({'_stdout': _stdout, '_printlist': _stdout.extend, + '_include': self.subtemplate, '_str': self._str, + '_escape': self._escape, 'get': env.get, + 'setdefault': env.setdefault, 'defined': env.__contains__}) + env.update(kwargs) + eval(self.co, env) + if '_rebase' in env: + subtpl, rargs = env['_rebase'] + rargs['_base'] = _stdout[:] #copy stdout + del _stdout[:] # clear stdout + return self.subtemplate(subtpl,_stdout,rargs) + return env + + def render(self, *args, **kwargs): + """ Render the template using keyword arguments as local variables. """ + for dictarg in args: kwargs.update(dictarg) + stdout = [] + self.execute(stdout, kwargs) + return ''.join(stdout) + + +def template(*args, **kwargs): + ''' + Get a rendered template as a string iterator. + You can use a name, a filename or a template string as first parameter. + Template rendering arguments can be passed as dictionaries + or directly (as keyword arguments). + ''' + tpl = args[0] if args else None + adapter = kwargs.pop('template_adapter', SimpleTemplate) + lookup = kwargs.pop('template_lookup', TEMPLATE_PATH) + tplid = (id(lookup), tpl) + if tplid not in TEMPLATES or DEBUG: + settings = kwargs.pop('template_settings', {}) + if isinstance(tpl, adapter): + TEMPLATES[tplid] = tpl + if settings: TEMPLATES[tplid].prepare(**settings) + elif "\n" in tpl or "{" in tpl or "%" in tpl or '$' in tpl: + TEMPLATES[tplid] = adapter(source=tpl, lookup=lookup, **settings) + else: + TEMPLATES[tplid] = adapter(name=tpl, lookup=lookup, **settings) + if not TEMPLATES[tplid]: + abort(500, 'Template (%s) not found' % tpl) + for dictarg in args[1:]: kwargs.update(dictarg) + return TEMPLATES[tplid].render(kwargs) + +mako_template = functools.partial(template, template_adapter=MakoTemplate) +cheetah_template = functools.partial(template, template_adapter=CheetahTemplate) +jinja2_template = functools.partial(template, template_adapter=Jinja2Template) +simpletal_template = functools.partial(template, template_adapter=SimpleTALTemplate) + + +def view(tpl_name, **defaults): + ''' Decorator: renders a template for a handler. + The handler can control its behavior like that: + + - return a dict of template vars to fill out the template + - return something other than a dict and the view decorator will not + process the template, but return the handler result as is. + This includes returning a HTTPResponse(dict) to get, + for instance, JSON with autojson or other castfilters. + ''' + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, (dict, DictMixin)): + tplvars = defaults.copy() + tplvars.update(result) + return template(tpl_name, **tplvars) + return result + return wrapper + return decorator + +mako_view = functools.partial(view, template_adapter=MakoTemplate) +cheetah_view = functools.partial(view, template_adapter=CheetahTemplate) +jinja2_view = functools.partial(view, template_adapter=Jinja2Template) +simpletal_view = functools.partial(view, template_adapter=SimpleTALTemplate) + + + + + + +############################################################################### +# Constants and Globals ######################################################## +############################################################################### + + +TEMPLATE_PATH = ['./', './views/'] +TEMPLATES = {} +DEBUG = False +NORUN = False # If set, run() does nothing. Used by load_app() + +#: A dict to map HTTP status codes (e.g. 404) to phrases (e.g. 'Not Found') +HTTP_CODES = httplib.responses +HTTP_CODES[418] = "I'm a teapot" # RFC 2324 +HTTP_CODES[428] = "Precondition Required" +HTTP_CODES[429] = "Too Many Requests" +HTTP_CODES[431] = "Request Header Fields Too Large" +HTTP_CODES[511] = "Network Authentication Required" +_HTTP_STATUS_LINES = dict((k, '%d %s'%(k,v)) for (k,v) in HTTP_CODES.items()) + +#: The default template used for error pages. Override with @error() +ERROR_PAGE_TEMPLATE = """ +%%try: + %%from %s import DEBUG, HTTP_CODES, request, touni + <!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN"> + <html> + <head> + <title>Error: {{e.status}}</title> + <style type="text/css"> + html {background-color: #eee; font-family: sans;} + body {background-color: #fff; border: 1px solid #ddd; + padding: 15px; margin: 15px;} + pre {background-color: #eee; border: 1px solid #ddd; padding: 5px;} + </style> + </head> + <body> + <h1>Error: {{e.status}}</h1> + <p>Sorry, the requested URL <tt>{{repr(request.url)}}</tt> + caused an error:</p> + <pre>{{e.body}}</pre> + %%if DEBUG and e.exception: + <h2>Exception:</h2> + <pre>{{repr(e.exception)}}</pre> + %%end + %%if DEBUG and e.traceback: + <h2>Traceback:</h2> + <pre>{{e.traceback}}</pre> + %%end + </body> + </html> +%%except ImportError: + <b>ImportError:</b> Could not generate the error page. Please add bottle to + the import path. +%%end +""" % __name__ + +#: A thread-safe instance of :class:`LocalRequest`. If accessed from within a +#: request callback, this instance always refers to the *current* request +#: (even on a multithreaded server). +request = LocalRequest() + +#: A thread-safe instance of :class:`LocalResponse`. It is used to change the +#: HTTP response for the *current* request. +response = LocalResponse() + +#: A thread-safe namespace. Not used by Bottle. +local = threading.local() + +# Initialize app stack (create first empty Bottle app) +# BC: 0.6.4 and needed for run() +app = default_app = AppStack() +app.push() + +#: A virtual package that redirects import statements. +#: Example: ``import bottle.ext.sqlite`` actually imports `bottle_sqlite`. +ext = _ImportRedirect('bottle.ext' if __name__ == '__main__' else __name__+".ext", 'bottle_%s').module + +if __name__ == '__main__': + opt, args, parser = _cmd_options, _cmd_args, _cmd_parser + if opt.version: + _stdout('Bottle %s\n'%__version__) + sys.exit(0) + if not args: + parser.print_help() + _stderr('\nError: No application specified.\n') + sys.exit(1) + + sys.path.insert(0, '.') + sys.modules.setdefault('bottle', sys.modules['__main__']) + + host, port = (opt.bind or 'localhost'), 8080 + if ':' in host: + host, port = host.rsplit(':', 1) + + run(args[0], host=host, port=port, server=opt.server, + reloader=opt.reload, plugins=opt.plugin, debug=opt.debug) + + + + +# THE END diff --git a/pyload/lib/forwarder.py b/pyload/lib/forwarder.py new file mode 100644 index 000000000..eacb33c2b --- /dev/null +++ b/pyload/lib/forwarder.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +""" + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 3 of the License, + or (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, see <http://www.gnu.org/licenses/>. + + @author: RaNaN +""" + +from sys import argv +from sys import exit + +import socket +import thread + +from traceback import print_exc + +class Forwarder(): + + def __init__(self, extip,extport=9666): + print "Start portforwarding to %s:%s" % (extip, extport) + proxy(extip, extport, 9666) + + +def proxy(*settings): + while True: + server(*settings) + +def server(*settings): + try: + dock_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + dock_socket.bind(("127.0.0.1", settings[2])) + dock_socket.listen(5) + while True: + client_socket = dock_socket.accept()[0] + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.connect((settings[0], settings[1])) + thread.start_new_thread(forward, (client_socket, server_socket)) + thread.start_new_thread(forward, (server_socket, client_socket)) + except Exception: + print_exc() + + +def forward(source, destination): + string = ' ' + while string: + string = source.recv(1024) + if string: + destination.sendall(string) + else: + #source.shutdown(socket.SHUT_RD) + destination.shutdown(socket.SHUT_WR) + +if __name__ == "__main__": + args = argv[1:] + if not args: + print "Usage: forwarder.py <remote ip> <remote port>" + exit() + if len(args) == 1: + args.append(9666) + + f = Forwarder(args[0], int(args[1])) +
\ No newline at end of file diff --git a/pyload/lib/hg_tool.py b/pyload/lib/hg_tool.py new file mode 100644 index 000000000..cd97833df --- /dev/null +++ b/pyload/lib/hg_tool.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import re +from subprocess import Popen, PIPE +from time import time, gmtime, strftime + +aliases = {"zoidber": "zoidberg", "zoidberg10": "zoidberg", "webmaster": "dhmh", "mast3rranan": "ranan", + "ranan2": "ranan"} +exclude = ["locale/*", "module/lib/*"] +date_format = "%Y-%m-%d" +line_re = re.compile(r" (\d+) \**", re.I) + +def add_exclude_flags(args): + for dir in exclude: + args.extend(["-X", dir]) + +# remove small percentages +def wipe(data, perc=1): + s = (sum(data.values()) * perc) / 100 + for k, v in data.items(): + if v < s: del data[k] + + return data + +# remove aliases +def de_alias(data): + for k, v in aliases.iteritems(): + if k not in data: continue + alias = aliases[k] + + if alias in data: data[alias] += data[k] + else: data[alias] = data[k] + + del data[k] + + return data + + +def output(data): + s = float(sum(data.values())) + print "Total Lines: %d" % s + for k, v in data.iteritems(): + print "%15s: %.1f%% | %d" % (k, (v * 100) / s, v) + print + + +def file_list(): + args = ["hg", "status", "-A"] + add_exclude_flags(args) + p = Popen(args, stdout=PIPE) + out, err = p.communicate() + return [x.split()[1] for x in out.splitlines() if x.split()[0] in "CMA"] + + +def hg_annotate(path): + args = ["hg", "annotate", "-u", path] + p = Popen(args, stdout=PIPE) + out, err = p.communicate() + + data = {} + + for line in out.splitlines(): + author, non, line = line.partition(":") + + # probably binary file + if author == path: return {} + + author = author.strip().lower() + if not line.strip(): continue # don't count blank lines + + if author in data: data[author] += 1 + else: data[author] = 1 + + return de_alias(data) + + +def hg_churn(days=None): + args = ["hg", "churn"] + if days: + args.append("-d") + t = time() - 60 * 60 * 24 * days + args.append("%s to %s" % (strftime(date_format, gmtime(t)), strftime(date_format))) + + add_exclude_flags(args) + p = Popen(args, stdout=PIPE) + out, err = p.communicate() + + data = {} + + for line in out.splitlines(): + m = line_re.search(line) + author = line.split()[0] + lines = int(m.group(1)) + + if "@" in author: + author, n, email = author.partition("@") + + author = author.strip().lower() + + if author in data: data[author] += lines + else: data[author] = lines + + return de_alias(data) + + +def complete_annotate(): + files = file_list() + data = {} + for f in files: + tmp = hg_annotate(f) + for k, v in tmp.iteritems(): + if k in data: data[k] += v + else: data[k] = v + + return data + + +if __name__ == "__main__": + for d in (30, 90, 180): + c = wipe(hg_churn(d)) + print "Changes in %d days:" % d + output(c) + + c = wipe(hg_churn()) + print "Total changes:" + output(c) + + print "Current source code version:" + data = wipe(complete_annotate()) + output(data) + + diff --git a/pyload/lib/mod_pywebsocket/COPYING b/pyload/lib/mod_pywebsocket/COPYING new file mode 100644 index 000000000..989d02e4c --- /dev/null +++ b/pyload/lib/mod_pywebsocket/COPYING @@ -0,0 +1,28 @@ +Copyright 2012, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pyload/lib/mod_pywebsocket/__init__.py b/pyload/lib/mod_pywebsocket/__init__.py new file mode 100644 index 000000000..454ae0c45 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/__init__.py @@ -0,0 +1,197 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""WebSocket extension for Apache HTTP Server. + +mod_pywebsocket is a WebSocket extension for Apache HTTP Server +intended for testing or experimental purposes. mod_python is required. + + +Installation +============ + +0. Prepare an Apache HTTP Server for which mod_python is enabled. + +1. Specify the following Apache HTTP Server directives to suit your + configuration. + + If mod_pywebsocket is not in the Python path, specify the following. + <websock_lib> is the directory where mod_pywebsocket is installed. + + PythonPath "sys.path+['<websock_lib>']" + + Always specify the following. <websock_handlers> is the directory where + user-written WebSocket handlers are placed. + + PythonOption mod_pywebsocket.handler_root <websock_handlers> + PythonHeaderParserHandler mod_pywebsocket.headerparserhandler + + To limit the search for WebSocket handlers to a directory <scan_dir> + under <websock_handlers>, configure as follows: + + PythonOption mod_pywebsocket.handler_scan <scan_dir> + + <scan_dir> is useful in saving scan time when <websock_handlers> + contains many non-WebSocket handler files. + + If you want to allow handlers whose canonical path is not under the root + directory (i.e. symbolic link is in root directory but its target is not), + configure as follows: + + PythonOption mod_pywebsocket.allow_handlers_outside_root_dir On + + Example snippet of httpd.conf: + (mod_pywebsocket is in /websock_lib, WebSocket handlers are in + /websock_handlers, port is 80 for ws, 443 for wss.) + + <IfModule python_module> + PythonPath "sys.path+['/websock_lib']" + PythonOption mod_pywebsocket.handler_root /websock_handlers + PythonHeaderParserHandler mod_pywebsocket.headerparserhandler + </IfModule> + +2. Tune Apache parameters for serving WebSocket. We'd like to note that at + least TimeOut directive from core features and RequestReadTimeout + directive from mod_reqtimeout should be modified not to kill connections + in only a few seconds of idle time. + +3. Verify installation. You can use example/console.html to poke the server. + + +Writing WebSocket handlers +========================== + +When a WebSocket request comes in, the resource name +specified in the handshake is considered as if it is a file path under +<websock_handlers> and the handler defined in +<websock_handlers>/<resource_name>_wsh.py is invoked. + +For example, if the resource name is /example/chat, the handler defined in +<websock_handlers>/example/chat_wsh.py is invoked. + +A WebSocket handler is composed of the following three functions: + + web_socket_do_extra_handshake(request) + web_socket_transfer_data(request) + web_socket_passive_closing_handshake(request) + +where: + request: mod_python request. + +web_socket_do_extra_handshake is called during the handshake after the +headers are successfully parsed and WebSocket properties (ws_location, +ws_origin, and ws_resource) are added to request. A handler +can reject the request by raising an exception. + +A request object has the following properties that you can use during the +extra handshake (web_socket_do_extra_handshake): +- ws_resource +- ws_origin +- ws_version +- ws_location (HyBi 00 only) +- ws_extensions (HyBi 06 and later) +- ws_deflate (HyBi 06 and later) +- ws_protocol +- ws_requested_protocols (HyBi 06 and later) + +The last two are a bit tricky. See the next subsection. + + +Subprotocol Negotiation +----------------------- + +For HyBi 06 and later, ws_protocol is always set to None when +web_socket_do_extra_handshake is called. If ws_requested_protocols is not +None, you must choose one subprotocol from this list and set it to +ws_protocol. + +For HyBi 00, when web_socket_do_extra_handshake is called, +ws_protocol is set to the value given by the client in +Sec-WebSocket-Protocol header or None if +such header was not found in the opening handshake request. Finish extra +handshake with ws_protocol untouched to accept the request subprotocol. +Then, Sec-WebSocket-Protocol header will be sent to +the client in response with the same value as requested. Raise an exception +in web_socket_do_extra_handshake to reject the requested subprotocol. + + +Data Transfer +------------- + +web_socket_transfer_data is called after the handshake completed +successfully. A handler can receive/send messages from/to the client +using request. mod_pywebsocket.msgutil module provides utilities +for data transfer. + +You can receive a message by the following statement. + + message = request.ws_stream.receive_message() + +This call blocks until any complete text frame arrives, and the payload data +of the incoming frame will be stored into message. When you're using IETF +HyBi 00 or later protocol, receive_message() will return None on receiving +client-initiated closing handshake. When any error occurs, receive_message() +will raise some exception. + +You can send a message by the following statement. + + request.ws_stream.send_message(message) + + +Closing Connection +------------------ + +Executing the following statement or just return-ing from +web_socket_transfer_data cause connection close. + + request.ws_stream.close_connection() + +close_connection will wait +for closing handshake acknowledgement coming from the client. When it +couldn't receive a valid acknowledgement, raises an exception. + +web_socket_passive_closing_handshake is called after the server receives +incoming closing frame from the client peer immediately. You can specify +code and reason by return values. They are sent as a outgoing closing frame +from the server. A request object has the following properties that you can +use in web_socket_passive_closing_handshake. +- ws_close_code +- ws_close_reason + + +Threading +--------- + +A WebSocket handler must be thread-safe if the server (Apache or +standalone.py) is configured to use threads. +""" + + +# vi:sts=4 sw=4 et tw=72 diff --git a/pyload/lib/mod_pywebsocket/_stream_base.py b/pyload/lib/mod_pywebsocket/_stream_base.py new file mode 100644 index 000000000..60fb33d2c --- /dev/null +++ b/pyload/lib/mod_pywebsocket/_stream_base.py @@ -0,0 +1,165 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Base stream class. +""" + + +# Note: request.connection.write/read are used in this module, even though +# mod_python document says that they should be used only in connection +# handlers. Unfortunately, we have no other options. For example, +# request.write/read are not suitable because they don't allow direct raw bytes +# writing/reading. + + +from mod_pywebsocket import util + + +# Exceptions + + +class ConnectionTerminatedException(Exception): + """This exception will be raised when a connection is terminated + unexpectedly. + """ + + pass + + +class InvalidFrameException(ConnectionTerminatedException): + """This exception will be raised when we received an invalid frame we + cannot parse. + """ + + pass + + +class BadOperationException(Exception): + """This exception will be raised when send_message() is called on + server-terminated connection or receive_message() is called on + client-terminated connection. + """ + + pass + + +class UnsupportedFrameException(Exception): + """This exception will be raised when we receive a frame with flag, opcode + we cannot handle. Handlers can just catch and ignore this exception and + call receive_message() again to continue processing the next frame. + """ + + pass + + +class InvalidUTF8Exception(Exception): + """This exception will be raised when we receive a text frame which + contains invalid UTF-8 strings. + """ + + pass + + +class StreamBase(object): + """Base stream class.""" + + def __init__(self, request): + """Construct an instance. + + Args: + request: mod_python request. + """ + + self._logger = util.get_class_logger(self) + + self._request = request + + def _read(self, length): + """Reads length bytes from connection. In case we catch any exception, + prepends remote address to the exception message and raise again. + + Raises: + ConnectionTerminatedException: when read returns empty string. + """ + + bytes = self._request.connection.read(length) + if not bytes: + raise ConnectionTerminatedException( + 'Receiving %d byte failed. Peer (%r) closed connection' % + (length, (self._request.connection.remote_addr,))) + return bytes + + def _write(self, bytes): + """Writes given bytes to connection. In case we catch any exception, + prepends remote address to the exception message and raise again. + """ + + try: + self._request.connection.write(bytes) + except Exception, e: + util.prepend_message_to_exception( + 'Failed to send message to %r: ' % + (self._request.connection.remote_addr,), + e) + raise + + def receive_bytes(self, length): + """Receives multiple bytes. Retries read when we couldn't receive the + specified amount. + + Raises: + ConnectionTerminatedException: when read returns empty string. + """ + + bytes = [] + while length > 0: + new_bytes = self._read(length) + bytes.append(new_bytes) + length -= len(new_bytes) + return ''.join(bytes) + + def _read_until(self, delim_char): + """Reads bytes until we encounter delim_char. The result will not + contain delim_char. + + Raises: + ConnectionTerminatedException: when read returns empty string. + """ + + bytes = [] + while True: + ch = self._read(1) + if ch == delim_char: + break + bytes.append(ch) + return ''.join(bytes) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/_stream_hixie75.py b/pyload/lib/mod_pywebsocket/_stream_hixie75.py new file mode 100644 index 000000000..94cf5b31b --- /dev/null +++ b/pyload/lib/mod_pywebsocket/_stream_hixie75.py @@ -0,0 +1,229 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file provides a class for parsing/building frames of the WebSocket +protocol version HyBi 00 and Hixie 75. + +Specification: +- HyBi 00 http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +- Hixie 75 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-75 +""" + + +from mod_pywebsocket import common +from mod_pywebsocket._stream_base import BadOperationException +from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException +from mod_pywebsocket._stream_base import StreamBase +from mod_pywebsocket._stream_base import UnsupportedFrameException +from mod_pywebsocket import util + + +class StreamHixie75(StreamBase): + """A class for parsing/building frames of the WebSocket protocol version + HyBi 00 and Hixie 75. + """ + + def __init__(self, request, enable_closing_handshake=False): + """Construct an instance. + + Args: + request: mod_python request. + enable_closing_handshake: to let StreamHixie75 perform closing + handshake as specified in HyBi 00, set + this option to True. + """ + + StreamBase.__init__(self, request) + + self._logger = util.get_class_logger(self) + + self._enable_closing_handshake = enable_closing_handshake + + self._request.client_terminated = False + self._request.server_terminated = False + + def send_message(self, message, end=True, binary=False): + """Send message. + + Args: + message: unicode string to send. + binary: not used in hixie75. + + Raises: + BadOperationException: when called on a server-terminated + connection. + """ + + if not end: + raise BadOperationException( + 'StreamHixie75 doesn\'t support send_message with end=False') + + if binary: + raise BadOperationException( + 'StreamHixie75 doesn\'t support send_message with binary=True') + + if self._request.server_terminated: + raise BadOperationException( + 'Requested send_message after sending out a closing handshake') + + self._write(''.join(['\x00', message.encode('utf-8'), '\xff'])) + + def _read_payload_length_hixie75(self): + """Reads a length header in a Hixie75 version frame with length. + + Raises: + ConnectionTerminatedException: when read returns empty string. + """ + + length = 0 + while True: + b_str = self._read(1) + b = ord(b_str) + length = length * 128 + (b & 0x7f) + if (b & 0x80) == 0: + break + return length + + def receive_message(self): + """Receive a WebSocket frame and return its payload an unicode string. + + Returns: + payload unicode string in a WebSocket frame. + + Raises: + ConnectionTerminatedException: when read returns empty + string. + BadOperationException: when called on a client-terminated + connection. + """ + + if self._request.client_terminated: + raise BadOperationException( + 'Requested receive_message after receiving a closing ' + 'handshake') + + while True: + # Read 1 byte. + # mp_conn.read will block if no bytes are available. + # Timeout is controlled by TimeOut directive of Apache. + frame_type_str = self.receive_bytes(1) + frame_type = ord(frame_type_str) + if (frame_type & 0x80) == 0x80: + # The payload length is specified in the frame. + # Read and discard. + length = self._read_payload_length_hixie75() + if length > 0: + _ = self.receive_bytes(length) + # 5.3 3. 12. if /type/ is 0xFF and /length/ is 0, then set the + # /client terminated/ flag and abort these steps. + if not self._enable_closing_handshake: + continue + + if frame_type == 0xFF and length == 0: + self._request.client_terminated = True + + if self._request.server_terminated: + self._logger.debug( + 'Received ack for server-initiated closing ' + 'handshake') + return None + + self._logger.debug( + 'Received client-initiated closing handshake') + + self._send_closing_handshake() + self._logger.debug( + 'Sent ack for client-initiated closing handshake') + return None + else: + # The payload is delimited with \xff. + bytes = self._read_until('\xff') + # The WebSocket protocol section 4.4 specifies that invalid + # characters must be replaced with U+fffd REPLACEMENT + # CHARACTER. + message = bytes.decode('utf-8', 'replace') + if frame_type == 0x00: + return message + # Discard data of other types. + + def _send_closing_handshake(self): + if not self._enable_closing_handshake: + raise BadOperationException( + 'Closing handshake is not supported in Hixie 75 protocol') + + self._request.server_terminated = True + + # 5.3 the server may decide to terminate the WebSocket connection by + # running through the following steps: + # 1. send a 0xFF byte and a 0x00 byte to the client to indicate the + # start of the closing handshake. + self._write('\xff\x00') + + def close_connection(self, unused_code='', unused_reason=''): + """Closes a WebSocket connection. + + Raises: + ConnectionTerminatedException: when closing handshake was + not successfull. + """ + + if self._request.server_terminated: + self._logger.debug( + 'Requested close_connection but server is already terminated') + return + + if not self._enable_closing_handshake: + self._request.server_terminated = True + self._logger.debug('Connection closed') + return + + self._send_closing_handshake() + self._logger.debug('Sent server-initiated closing handshake') + + # TODO(ukai): 2. wait until the /client terminated/ flag has been set, + # or until a server-defined timeout expires. + # + # For now, we expect receiving closing handshake right after sending + # out closing handshake, and if we couldn't receive non-handshake + # frame, we take it as ConnectionTerminatedException. + message = self.receive_message() + if message is not None: + raise ConnectionTerminatedException( + 'Didn\'t receive valid ack for closing handshake') + # TODO: 3. close the WebSocket connection. + # note: mod_python Connection (mp_conn) doesn't have close method. + + def send_ping(self, body): + raise BadOperationException( + 'StreamHixie75 doesn\'t support send_ping') + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/_stream_hybi.py b/pyload/lib/mod_pywebsocket/_stream_hybi.py new file mode 100644 index 000000000..bd158fa6b --- /dev/null +++ b/pyload/lib/mod_pywebsocket/_stream_hybi.py @@ -0,0 +1,915 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file provides classes and helper functions for parsing/building frames +of the WebSocket protocol (RFC 6455). + +Specification: +http://tools.ietf.org/html/rfc6455 +""" + + +from collections import deque +import logging +import os +import struct +import time + +from mod_pywebsocket import common +from mod_pywebsocket import util +from mod_pywebsocket._stream_base import BadOperationException +from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException +from mod_pywebsocket._stream_base import InvalidUTF8Exception +from mod_pywebsocket._stream_base import StreamBase +from mod_pywebsocket._stream_base import UnsupportedFrameException + + +_NOOP_MASKER = util.NoopMasker() + + +class Frame(object): + + def __init__(self, fin=1, rsv1=0, rsv2=0, rsv3=0, + opcode=None, payload=''): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.payload = payload + + +# Helper functions made public to be used for writing unittests for WebSocket +# clients. + + +def create_length_header(length, mask): + """Creates a length header. + + Args: + length: Frame length. Must be less than 2^63. + mask: Mask bit. Must be boolean. + + Raises: + ValueError: when bad data is given. + """ + + if mask: + mask_bit = 1 << 7 + else: + mask_bit = 0 + + if length < 0: + raise ValueError('length must be non negative integer') + elif length <= 125: + return chr(mask_bit | length) + elif length < (1 << 16): + return chr(mask_bit | 126) + struct.pack('!H', length) + elif length < (1 << 63): + return chr(mask_bit | 127) + struct.pack('!Q', length) + else: + raise ValueError('Payload is too big for one frame') + + +def create_header(opcode, payload_length, fin, rsv1, rsv2, rsv3, mask): + """Creates a frame header. + + Raises: + Exception: when bad data is given. + """ + + if opcode < 0 or 0xf < opcode: + raise ValueError('Opcode out of range') + + if payload_length < 0 or (1 << 63) <= payload_length: + raise ValueError('payload_length out of range') + + if (fin | rsv1 | rsv2 | rsv3) & ~1: + raise ValueError('FIN bit and Reserved bit parameter must be 0 or 1') + + header = '' + + first_byte = ((fin << 7) + | (rsv1 << 6) | (rsv2 << 5) | (rsv3 << 4) + | opcode) + header += chr(first_byte) + header += create_length_header(payload_length, mask) + + return header + + +def _build_frame(header, body, mask): + if not mask: + return header + body + + masking_nonce = os.urandom(4) + masker = util.RepeatedXorMasker(masking_nonce) + + return header + masking_nonce + masker.mask(body) + + +def _filter_and_format_frame_object(frame, mask, frame_filters): + for frame_filter in frame_filters: + frame_filter.filter(frame) + + header = create_header( + frame.opcode, len(frame.payload), frame.fin, + frame.rsv1, frame.rsv2, frame.rsv3, mask) + return _build_frame(header, frame.payload, mask) + + +def create_binary_frame( + message, opcode=common.OPCODE_BINARY, fin=1, mask=False, frame_filters=[]): + """Creates a simple binary frame with no extension, reserved bit.""" + + frame = Frame(fin=fin, opcode=opcode, payload=message) + return _filter_and_format_frame_object(frame, mask, frame_filters) + + +def create_text_frame( + message, opcode=common.OPCODE_TEXT, fin=1, mask=False, frame_filters=[]): + """Creates a simple text frame with no extension, reserved bit.""" + + encoded_message = message.encode('utf-8') + return create_binary_frame(encoded_message, opcode, fin, mask, + frame_filters) + + +def parse_frame(receive_bytes, logger=None, + ws_version=common.VERSION_HYBI_LATEST, + unmask_receive=True): + """Parses a frame. Returns a tuple containing each header field and + payload. + + Args: + receive_bytes: a function that reads frame data from a stream or + something similar. The function takes length of the bytes to be + read. The function must raise ConnectionTerminatedException if + there is not enough data to be read. + logger: a logging object. + ws_version: the version of WebSocket protocol. + unmask_receive: unmask received frames. When received unmasked + frame, raises InvalidFrameException. + + Raises: + ConnectionTerminatedException: when receive_bytes raises it. + InvalidFrameException: when the frame contains invalid data. + """ + + if not logger: + logger = logging.getLogger() + + logger.log(common.LOGLEVEL_FINE, 'Receive the first 2 octets of a frame') + + received = receive_bytes(2) + + first_byte = ord(received[0]) + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + second_byte = ord(received[1]) + mask = (second_byte >> 7) & 1 + payload_length = second_byte & 0x7f + + logger.log(common.LOGLEVEL_FINE, + 'FIN=%s, RSV1=%s, RSV2=%s, RSV3=%s, opcode=%s, ' + 'Mask=%s, Payload_length=%s', + fin, rsv1, rsv2, rsv3, opcode, mask, payload_length) + + if (mask == 1) != unmask_receive: + raise InvalidFrameException( + 'Mask bit on the received frame did\'nt match masking ' + 'configuration for received frames') + + # The HyBi and later specs disallow putting a value in 0x0-0xFFFF + # into the 8-octet extended payload length field (or 0x0-0xFD in + # 2-octet field). + valid_length_encoding = True + length_encoding_bytes = 1 + if payload_length == 127: + logger.log(common.LOGLEVEL_FINE, + 'Receive 8-octet extended payload length') + + extended_payload_length = receive_bytes(8) + payload_length = struct.unpack( + '!Q', extended_payload_length)[0] + if payload_length > 0x7FFFFFFFFFFFFFFF: + raise InvalidFrameException( + 'Extended payload length >= 2^63') + if ws_version >= 13 and payload_length < 0x10000: + valid_length_encoding = False + length_encoding_bytes = 8 + + logger.log(common.LOGLEVEL_FINE, + 'Decoded_payload_length=%s', payload_length) + elif payload_length == 126: + logger.log(common.LOGLEVEL_FINE, + 'Receive 2-octet extended payload length') + + extended_payload_length = receive_bytes(2) + payload_length = struct.unpack( + '!H', extended_payload_length)[0] + if ws_version >= 13 and payload_length < 126: + valid_length_encoding = False + length_encoding_bytes = 2 + + logger.log(common.LOGLEVEL_FINE, + 'Decoded_payload_length=%s', payload_length) + + if not valid_length_encoding: + logger.warning( + 'Payload length is not encoded using the minimal number of ' + 'bytes (%d is encoded using %d bytes)', + payload_length, + length_encoding_bytes) + + if mask == 1: + logger.log(common.LOGLEVEL_FINE, 'Receive mask') + + masking_nonce = receive_bytes(4) + masker = util.RepeatedXorMasker(masking_nonce) + + logger.log(common.LOGLEVEL_FINE, 'Mask=%r', masking_nonce) + else: + masker = _NOOP_MASKER + + logger.log(common.LOGLEVEL_FINE, 'Receive payload data') + if logger.isEnabledFor(common.LOGLEVEL_FINE): + receive_start = time.time() + + raw_payload_bytes = receive_bytes(payload_length) + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + logger.log( + common.LOGLEVEL_FINE, + 'Done receiving payload data at %s MB/s', + payload_length / (time.time() - receive_start) / 1000 / 1000) + logger.log(common.LOGLEVEL_FINE, 'Unmask payload data') + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + unmask_start = time.time() + + bytes = masker.mask(raw_payload_bytes) + + if logger.isEnabledFor(common.LOGLEVEL_FINE): + logger.log( + common.LOGLEVEL_FINE, + 'Done unmasking payload data at %s MB/s', + payload_length / (time.time() - unmask_start) / 1000 / 1000) + + return opcode, bytes, fin, rsv1, rsv2, rsv3 + + +class FragmentedFrameBuilder(object): + """A stateful class to send a message as fragments.""" + + def __init__(self, mask, frame_filters=[], encode_utf8=True): + """Constructs an instance.""" + + self._mask = mask + self._frame_filters = frame_filters + # This is for skipping UTF-8 encoding when building text type frames + # from compressed data. + self._encode_utf8 = encode_utf8 + + self._started = False + + # Hold opcode of the first frame in messages to verify types of other + # frames in the message are all the same. + self._opcode = common.OPCODE_TEXT + + def build(self, payload_data, end, binary): + if binary: + frame_type = common.OPCODE_BINARY + else: + frame_type = common.OPCODE_TEXT + if self._started: + if self._opcode != frame_type: + raise ValueError('Message types are different in frames for ' + 'the same message') + opcode = common.OPCODE_CONTINUATION + else: + opcode = frame_type + self._opcode = frame_type + + if end: + self._started = False + fin = 1 + else: + self._started = True + fin = 0 + + if binary or not self._encode_utf8: + return create_binary_frame( + payload_data, opcode, fin, self._mask, self._frame_filters) + else: + return create_text_frame( + payload_data, opcode, fin, self._mask, self._frame_filters) + + +def _create_control_frame(opcode, body, mask, frame_filters): + frame = Frame(opcode=opcode, payload=body) + + for frame_filter in frame_filters: + frame_filter.filter(frame) + + if len(frame.payload) > 125: + raise BadOperationException( + 'Payload data size of control frames must be 125 bytes or less') + + header = create_header( + frame.opcode, len(frame.payload), frame.fin, + frame.rsv1, frame.rsv2, frame.rsv3, mask) + return _build_frame(header, frame.payload, mask) + + +def create_ping_frame(body, mask=False, frame_filters=[]): + return _create_control_frame(common.OPCODE_PING, body, mask, frame_filters) + + +def create_pong_frame(body, mask=False, frame_filters=[]): + return _create_control_frame(common.OPCODE_PONG, body, mask, frame_filters) + + +def create_close_frame(body, mask=False, frame_filters=[]): + return _create_control_frame( + common.OPCODE_CLOSE, body, mask, frame_filters) + + +def create_closing_handshake_body(code, reason): + body = '' + if code is not None: + if (code > common.STATUS_USER_PRIVATE_MAX or + code < common.STATUS_NORMAL_CLOSURE): + raise BadOperationException('Status code is out of range') + if (code == common.STATUS_NO_STATUS_RECEIVED or + code == common.STATUS_ABNORMAL_CLOSURE or + code == common.STATUS_TLS_HANDSHAKE): + raise BadOperationException('Status code is reserved pseudo ' + 'code') + encoded_reason = reason.encode('utf-8') + body = struct.pack('!H', code) + encoded_reason + return body + + +class StreamOptions(object): + """Holds option values to configure Stream objects.""" + + def __init__(self): + """Constructs StreamOptions.""" + + # Enables deflate-stream extension. + self.deflate_stream = False + + # Filters applied to frames. + self.outgoing_frame_filters = [] + self.incoming_frame_filters = [] + + # Filters applied to messages. Control frames are not affected by them. + self.outgoing_message_filters = [] + self.incoming_message_filters = [] + + self.encode_text_message_to_utf8 = True + self.mask_send = False + self.unmask_receive = True + # RFC6455 disallows fragmented control frames, but mux extension + # relaxes the restriction. + self.allow_fragmented_control_frame = False + + +class Stream(StreamBase): + """A class for parsing/building frames of the WebSocket protocol + (RFC 6455). + """ + + def __init__(self, request, options): + """Constructs an instance. + + Args: + request: mod_python request. + """ + + StreamBase.__init__(self, request) + + self._logger = util.get_class_logger(self) + + self._options = options + + if self._options.deflate_stream: + self._logger.debug('Setup filter for deflate-stream') + self._request = util.DeflateRequest(self._request) + + self._request.client_terminated = False + self._request.server_terminated = False + + # Holds body of received fragments. + self._received_fragments = [] + # Holds the opcode of the first fragment. + self._original_opcode = None + + self._writer = FragmentedFrameBuilder( + self._options.mask_send, self._options.outgoing_frame_filters, + self._options.encode_text_message_to_utf8) + + self._ping_queue = deque() + + def _receive_frame(self): + """Receives a frame and return data in the frame as a tuple containing + each header field and payload separately. + + Raises: + ConnectionTerminatedException: when read returns empty + string. + InvalidFrameException: when the frame contains invalid data. + """ + + def _receive_bytes(length): + return self.receive_bytes(length) + + return parse_frame(receive_bytes=_receive_bytes, + logger=self._logger, + ws_version=self._request.ws_version, + unmask_receive=self._options.unmask_receive) + + def _receive_frame_as_frame_object(self): + opcode, bytes, fin, rsv1, rsv2, rsv3 = self._receive_frame() + + return Frame(fin=fin, rsv1=rsv1, rsv2=rsv2, rsv3=rsv3, + opcode=opcode, payload=bytes) + + def receive_filtered_frame(self): + """Receives a frame and applies frame filters and message filters. + The frame to be received must satisfy following conditions: + - The frame is not fragmented. + - The opcode of the frame is TEXT or BINARY. + + DO NOT USE this method except for testing purpose. + """ + + frame = self._receive_frame_as_frame_object() + if not frame.fin: + raise InvalidFrameException( + 'Segmented frames must not be received via ' + 'receive_filtered_frame()') + if (frame.opcode != common.OPCODE_TEXT and + frame.opcode != common.OPCODE_BINARY): + raise InvalidFrameException( + 'Control frames must not be received via ' + 'receive_filtered_frame()') + + for frame_filter in self._options.incoming_frame_filters: + frame_filter.filter(frame) + for message_filter in self._options.incoming_message_filters: + frame.payload = message_filter.filter(frame.payload) + return frame + + def send_message(self, message, end=True, binary=False): + """Send message. + + Args: + message: text in unicode or binary in str to send. + binary: send message as binary frame. + + Raises: + BadOperationException: when called on a server-terminated + connection or called with inconsistent message type or + binary parameter. + """ + + if self._request.server_terminated: + raise BadOperationException( + 'Requested send_message after sending out a closing handshake') + + if binary and isinstance(message, unicode): + raise BadOperationException( + 'Message for binary frame must be instance of str') + + for message_filter in self._options.outgoing_message_filters: + message = message_filter.filter(message, end, binary) + + try: + # Set this to any positive integer to limit maximum size of data in + # payload data of each frame. + MAX_PAYLOAD_DATA_SIZE = -1 + + if MAX_PAYLOAD_DATA_SIZE <= 0: + self._write(self._writer.build(message, end, binary)) + return + + bytes_written = 0 + while True: + end_for_this_frame = end + bytes_to_write = len(message) - bytes_written + if (MAX_PAYLOAD_DATA_SIZE > 0 and + bytes_to_write > MAX_PAYLOAD_DATA_SIZE): + end_for_this_frame = False + bytes_to_write = MAX_PAYLOAD_DATA_SIZE + + frame = self._writer.build( + message[bytes_written:bytes_written + bytes_to_write], + end_for_this_frame, + binary) + self._write(frame) + + bytes_written += bytes_to_write + + # This if must be placed here (the end of while block) so that + # at least one frame is sent. + if len(message) <= bytes_written: + break + except ValueError, e: + raise BadOperationException(e) + + def _get_message_from_frame(self, frame): + """Gets a message from frame. If the message is composed of fragmented + frames and the frame is not the last fragmented frame, this method + returns None. The whole message will be returned when the last + fragmented frame is passed to this method. + + Raises: + InvalidFrameException: when the frame doesn't match defragmentation + context, or the frame contains invalid data. + """ + + if frame.opcode == common.OPCODE_CONTINUATION: + if not self._received_fragments: + if frame.fin: + raise InvalidFrameException( + 'Received a termination frame but fragmentation ' + 'not started') + else: + raise InvalidFrameException( + 'Received an intermediate frame but ' + 'fragmentation not started') + + if frame.fin: + # End of fragmentation frame + self._received_fragments.append(frame.payload) + message = ''.join(self._received_fragments) + self._received_fragments = [] + return message + else: + # Intermediate frame + self._received_fragments.append(frame.payload) + return None + else: + if self._received_fragments: + if frame.fin: + raise InvalidFrameException( + 'Received an unfragmented frame without ' + 'terminating existing fragmentation') + else: + raise InvalidFrameException( + 'New fragmentation started without terminating ' + 'existing fragmentation') + + if frame.fin: + # Unfragmented frame + + self._original_opcode = frame.opcode + return frame.payload + else: + # Start of fragmentation frame + + if (not self._options.allow_fragmented_control_frame and + common.is_control_opcode(frame.opcode)): + raise InvalidFrameException( + 'Control frames must not be fragmented') + + self._original_opcode = frame.opcode + self._received_fragments.append(frame.payload) + return None + + def _process_close_message(self, message): + """Processes close message. + + Args: + message: close message. + + Raises: + InvalidFrameException: when the message is invalid. + """ + + self._request.client_terminated = True + + # Status code is optional. We can have status reason only if we + # have status code. Status reason can be empty string. So, + # allowed cases are + # - no application data: no code no reason + # - 2 octet of application data: has code but no reason + # - 3 or more octet of application data: both code and reason + if len(message) == 0: + self._logger.debug('Received close frame (empty body)') + self._request.ws_close_code = ( + common.STATUS_NO_STATUS_RECEIVED) + elif len(message) == 1: + raise InvalidFrameException( + 'If a close frame has status code, the length of ' + 'status code must be 2 octet') + elif len(message) >= 2: + self._request.ws_close_code = struct.unpack( + '!H', message[0:2])[0] + self._request.ws_close_reason = message[2:].decode( + 'utf-8', 'replace') + self._logger.debug( + 'Received close frame (code=%d, reason=%r)', + self._request.ws_close_code, + self._request.ws_close_reason) + + # Drain junk data after the close frame if necessary. + self._drain_received_data() + + if self._request.server_terminated: + self._logger.debug( + 'Received ack for server-initiated closing handshake') + return + + self._logger.debug( + 'Received client-initiated closing handshake') + + code = common.STATUS_NORMAL_CLOSURE + reason = '' + if hasattr(self._request, '_dispatcher'): + dispatcher = self._request._dispatcher + code, reason = dispatcher.passive_closing_handshake( + self._request) + if code is None and reason is not None and len(reason) > 0: + self._logger.warning( + 'Handler specified reason despite code being None') + reason = '' + if reason is None: + reason = '' + self._send_closing_handshake(code, reason) + self._logger.debug( + 'Sent ack for client-initiated closing handshake ' + '(code=%r, reason=%r)', code, reason) + + def _process_ping_message(self, message): + """Processes ping message. + + Args: + message: ping message. + """ + + try: + handler = self._request.on_ping_handler + if handler: + handler(self._request, message) + return + except AttributeError, e: + pass + self._send_pong(message) + + def _process_pong_message(self, message): + """Processes pong message. + + Args: + message: pong message. + """ + + # TODO(tyoshino): Add ping timeout handling. + + inflight_pings = deque() + + while True: + try: + expected_body = self._ping_queue.popleft() + if expected_body == message: + # inflight_pings contains pings ignored by the + # other peer. Just forget them. + self._logger.debug( + 'Ping %r is acked (%d pings were ignored)', + expected_body, len(inflight_pings)) + break + else: + inflight_pings.append(expected_body) + except IndexError, e: + # The received pong was unsolicited pong. Keep the + # ping queue as is. + self._ping_queue = inflight_pings + self._logger.debug('Received a unsolicited pong') + break + + try: + handler = self._request.on_pong_handler + if handler: + handler(self._request, message) + except AttributeError, e: + pass + + def receive_message(self): + """Receive a WebSocket frame and return its payload as a text in + unicode or a binary in str. + + Returns: + payload data of the frame + - as unicode instance if received text frame + - as str instance if received binary frame + or None iff received closing handshake. + Raises: + BadOperationException: when called on a client-terminated + connection. + ConnectionTerminatedException: when read returns empty + string. + InvalidFrameException: when the frame contains invalid + data. + UnsupportedFrameException: when the received frame has + flags, opcode we cannot handle. You can ignore this + exception and continue receiving the next frame. + """ + + if self._request.client_terminated: + raise BadOperationException( + 'Requested receive_message after receiving a closing ' + 'handshake') + + while True: + # mp_conn.read will block if no bytes are available. + # Timeout is controlled by TimeOut directive of Apache. + + frame = self._receive_frame_as_frame_object() + + # Check the constraint on the payload size for control frames + # before extension processes the frame. + # See also http://tools.ietf.org/html/rfc6455#section-5.5 + if (common.is_control_opcode(frame.opcode) and + len(frame.payload) > 125): + raise InvalidFrameException( + 'Payload data size of control frames must be 125 bytes or ' + 'less') + + for frame_filter in self._options.incoming_frame_filters: + frame_filter.filter(frame) + + if frame.rsv1 or frame.rsv2 or frame.rsv3: + raise UnsupportedFrameException( + 'Unsupported flag is set (rsv = %d%d%d)' % + (frame.rsv1, frame.rsv2, frame.rsv3)) + + message = self._get_message_from_frame(frame) + if message is None: + continue + + for message_filter in self._options.incoming_message_filters: + message = message_filter.filter(message) + + if self._original_opcode == common.OPCODE_TEXT: + # The WebSocket protocol section 4.4 specifies that invalid + # characters must be replaced with U+fffd REPLACEMENT + # CHARACTER. + try: + return message.decode('utf-8') + except UnicodeDecodeError, e: + raise InvalidUTF8Exception(e) + elif self._original_opcode == common.OPCODE_BINARY: + return message + elif self._original_opcode == common.OPCODE_CLOSE: + self._process_close_message(message) + return None + elif self._original_opcode == common.OPCODE_PING: + self._process_ping_message(message) + elif self._original_opcode == common.OPCODE_PONG: + self._process_pong_message(message) + else: + raise UnsupportedFrameException( + 'Opcode %d is not supported' % self._original_opcode) + + def _send_closing_handshake(self, code, reason): + body = create_closing_handshake_body(code, reason) + frame = create_close_frame( + body, mask=self._options.mask_send, + frame_filters=self._options.outgoing_frame_filters) + + self._request.server_terminated = True + + self._write(frame) + + def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''): + """Closes a WebSocket connection. + + Args: + code: Status code for close frame. If code is None, a close + frame with empty body will be sent. + reason: string representing close reason. + Raises: + BadOperationException: when reason is specified with code None + or reason is not an instance of both str and unicode. + """ + + if self._request.server_terminated: + self._logger.debug( + 'Requested close_connection but server is already terminated') + return + + if code is None: + if reason is not None and len(reason) > 0: + raise BadOperationException( + 'close reason must not be specified if code is None') + reason = '' + else: + if not isinstance(reason, str) and not isinstance(reason, unicode): + raise BadOperationException( + 'close reason must be an instance of str or unicode') + + self._send_closing_handshake(code, reason) + self._logger.debug( + 'Sent server-initiated closing handshake (code=%r, reason=%r)', + code, reason) + + if (code == common.STATUS_GOING_AWAY or + code == common.STATUS_PROTOCOL_ERROR): + # It doesn't make sense to wait for a close frame if the reason is + # protocol error or that the server is going away. For some of + # other reasons, it might not make sense to wait for a close frame, + # but it's not clear, yet. + return + + # TODO(ukai): 2. wait until the /client terminated/ flag has been set, + # or until a server-defined timeout expires. + # + # For now, we expect receiving closing handshake right after sending + # out closing handshake. + message = self.receive_message() + if message is not None: + raise ConnectionTerminatedException( + 'Didn\'t receive valid ack for closing handshake') + # TODO: 3. close the WebSocket connection. + # note: mod_python Connection (mp_conn) doesn't have close method. + + def send_ping(self, body=''): + frame = create_ping_frame( + body, + self._options.mask_send, + self._options.outgoing_frame_filters) + self._write(frame) + + self._ping_queue.append(body) + + def _send_pong(self, body): + frame = create_pong_frame( + body, + self._options.mask_send, + self._options.outgoing_frame_filters) + self._write(frame) + + def get_last_received_opcode(self): + """Returns the opcode of the WebSocket message which the last received + frame belongs to. The return value is valid iff immediately after + receive_message call. + """ + + return self._original_opcode + + def _drain_received_data(self): + """Drains unread data in the receive buffer to avoid sending out TCP + RST packet. This is because when deflate-stream is enabled, some + DEFLATE block for flushing data may follow a close frame. If any data + remains in the receive buffer of a socket when the socket is closed, + it sends out TCP RST packet to the other peer. + + Since mod_python's mp_conn object doesn't support non-blocking read, + we perform this only when pywebsocket is running in standalone mode. + """ + + # If self._options.deflate_stream is true, self._request is + # DeflateRequest, so we can get wrapped request object by + # self._request._request. + # + # Only _StandaloneRequest has _drain_received_data method. + if (self._options.deflate_stream and + ('_drain_received_data' in dir(self._request._request))): + self._request._request._drain_received_data() + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/common.py b/pyload/lib/mod_pywebsocket/common.py new file mode 100644 index 000000000..2388379c0 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/common.py @@ -0,0 +1,307 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file must not depend on any module specific to the WebSocket protocol. +""" + + +from mod_pywebsocket import http_header_util + + +# Additional log level definitions. +LOGLEVEL_FINE = 9 + +# Constants indicating WebSocket protocol version. +VERSION_HIXIE75 = -1 +VERSION_HYBI00 = 0 +VERSION_HYBI01 = 1 +VERSION_HYBI02 = 2 +VERSION_HYBI03 = 2 +VERSION_HYBI04 = 4 +VERSION_HYBI05 = 5 +VERSION_HYBI06 = 6 +VERSION_HYBI07 = 7 +VERSION_HYBI08 = 8 +VERSION_HYBI09 = 8 +VERSION_HYBI10 = 8 +VERSION_HYBI11 = 8 +VERSION_HYBI12 = 8 +VERSION_HYBI13 = 13 +VERSION_HYBI14 = 13 +VERSION_HYBI15 = 13 +VERSION_HYBI16 = 13 +VERSION_HYBI17 = 13 + +# Constants indicating WebSocket protocol latest version. +VERSION_HYBI_LATEST = VERSION_HYBI13 + +# Port numbers +DEFAULT_WEB_SOCKET_PORT = 80 +DEFAULT_WEB_SOCKET_SECURE_PORT = 443 + +# Schemes +WEB_SOCKET_SCHEME = 'ws' +WEB_SOCKET_SECURE_SCHEME = 'wss' + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +OPCODE_TEXT = 0x1 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xa + +# UUIDs used by HyBi 04 and later opening handshake and frame masking. +WEBSOCKET_ACCEPT_UUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +# Opening handshake header names and expected values. +UPGRADE_HEADER = 'Upgrade' +WEBSOCKET_UPGRADE_TYPE = 'websocket' +WEBSOCKET_UPGRADE_TYPE_HIXIE75 = 'WebSocket' +CONNECTION_HEADER = 'Connection' +UPGRADE_CONNECTION_TYPE = 'Upgrade' +HOST_HEADER = 'Host' +ORIGIN_HEADER = 'Origin' +SEC_WEBSOCKET_ORIGIN_HEADER = 'Sec-WebSocket-Origin' +SEC_WEBSOCKET_KEY_HEADER = 'Sec-WebSocket-Key' +SEC_WEBSOCKET_ACCEPT_HEADER = 'Sec-WebSocket-Accept' +SEC_WEBSOCKET_VERSION_HEADER = 'Sec-WebSocket-Version' +SEC_WEBSOCKET_PROTOCOL_HEADER = 'Sec-WebSocket-Protocol' +SEC_WEBSOCKET_EXTENSIONS_HEADER = 'Sec-WebSocket-Extensions' +SEC_WEBSOCKET_DRAFT_HEADER = 'Sec-WebSocket-Draft' +SEC_WEBSOCKET_KEY1_HEADER = 'Sec-WebSocket-Key1' +SEC_WEBSOCKET_KEY2_HEADER = 'Sec-WebSocket-Key2' +SEC_WEBSOCKET_LOCATION_HEADER = 'Sec-WebSocket-Location' + +# Extensions +DEFLATE_STREAM_EXTENSION = 'deflate-stream' +DEFLATE_FRAME_EXTENSION = 'deflate-frame' +PERFRAME_COMPRESSION_EXTENSION = 'perframe-compress' +PERMESSAGE_COMPRESSION_EXTENSION = 'permessage-compress' +X_WEBKIT_DEFLATE_FRAME_EXTENSION = 'x-webkit-deflate-frame' +X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION = 'x-webkit-permessage-compress' +MUX_EXTENSION = 'mux_DO_NOT_USE' + +# Status codes +# Code STATUS_NO_STATUS_RECEIVED, STATUS_ABNORMAL_CLOSURE, and +# STATUS_TLS_HANDSHAKE are pseudo codes to indicate specific error cases. +# Could not be used for codes in actual closing frames. +# Application level errors must use codes in the range +# STATUS_USER_REGISTERED_BASE to STATUS_USER_PRIVATE_MAX. The codes in the +# range STATUS_USER_REGISTERED_BASE to STATUS_USER_REGISTERED_MAX are managed +# by IANA. Usually application must define user protocol level errors in the +# range STATUS_USER_PRIVATE_BASE to STATUS_USER_PRIVATE_MAX. +STATUS_NORMAL_CLOSURE = 1000 +STATUS_GOING_AWAY = 1001 +STATUS_PROTOCOL_ERROR = 1002 +STATUS_UNSUPPORTED_DATA = 1003 +STATUS_NO_STATUS_RECEIVED = 1005 +STATUS_ABNORMAL_CLOSURE = 1006 +STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007 +STATUS_POLICY_VIOLATION = 1008 +STATUS_MESSAGE_TOO_BIG = 1009 +STATUS_MANDATORY_EXTENSION = 1010 +STATUS_INTERNAL_ENDPOINT_ERROR = 1011 +STATUS_TLS_HANDSHAKE = 1015 +STATUS_USER_REGISTERED_BASE = 3000 +STATUS_USER_REGISTERED_MAX = 3999 +STATUS_USER_PRIVATE_BASE = 4000 +STATUS_USER_PRIVATE_MAX = 4999 +# Following definitions are aliases to keep compatibility. Applications must +# not use these obsoleted definitions anymore. +STATUS_NORMAL = STATUS_NORMAL_CLOSURE +STATUS_UNSUPPORTED = STATUS_UNSUPPORTED_DATA +STATUS_CODE_NOT_AVAILABLE = STATUS_NO_STATUS_RECEIVED +STATUS_ABNORMAL_CLOSE = STATUS_ABNORMAL_CLOSURE +STATUS_INVALID_FRAME_PAYLOAD = STATUS_INVALID_FRAME_PAYLOAD_DATA +STATUS_MANDATORY_EXT = STATUS_MANDATORY_EXTENSION + +# HTTP status codes +HTTP_STATUS_BAD_REQUEST = 400 +HTTP_STATUS_FORBIDDEN = 403 +HTTP_STATUS_NOT_FOUND = 404 + + +def is_control_opcode(opcode): + return (opcode >> 3) == 1 + + +class ExtensionParameter(object): + """Holds information about an extension which is exchanged on extension + negotiation in opening handshake. + """ + + def __init__(self, name): + self._name = name + # TODO(tyoshino): Change the data structure to more efficient one such + # as dict when the spec changes to say like + # - Parameter names must be unique + # - The order of parameters is not significant + self._parameters = [] + + def name(self): + return self._name + + def add_parameter(self, name, value): + self._parameters.append((name, value)) + + def get_parameters(self): + return self._parameters + + def get_parameter_names(self): + return [name for name, unused_value in self._parameters] + + def has_parameter(self, name): + for param_name, param_value in self._parameters: + if param_name == name: + return True + return False + + def get_parameter_value(self, name): + for param_name, param_value in self._parameters: + if param_name == name: + return param_value + + +class ExtensionParsingException(Exception): + def __init__(self, name): + super(ExtensionParsingException, self).__init__(name) + + +def _parse_extension_param(state, definition, allow_quoted_string): + param_name = http_header_util.consume_token(state) + + if param_name is None: + raise ExtensionParsingException('No valid parameter name found') + + http_header_util.consume_lwses(state) + + if not http_header_util.consume_string(state, '='): + definition.add_parameter(param_name, None) + return + + http_header_util.consume_lwses(state) + + if allow_quoted_string: + # TODO(toyoshim): Add code to validate that parsed param_value is token + param_value = http_header_util.consume_token_or_quoted_string(state) + else: + param_value = http_header_util.consume_token(state) + if param_value is None: + raise ExtensionParsingException( + 'No valid parameter value found on the right-hand side of ' + 'parameter %r' % param_name) + + definition.add_parameter(param_name, param_value) + + +def _parse_extension(state, allow_quoted_string): + extension_token = http_header_util.consume_token(state) + if extension_token is None: + return None + + extension = ExtensionParameter(extension_token) + + while True: + http_header_util.consume_lwses(state) + + if not http_header_util.consume_string(state, ';'): + break + + http_header_util.consume_lwses(state) + + try: + _parse_extension_param(state, extension, allow_quoted_string) + except ExtensionParsingException, e: + raise ExtensionParsingException( + 'Failed to parse parameter for %r (%r)' % + (extension_token, e)) + + return extension + + +def parse_extensions(data, allow_quoted_string=False): + """Parses Sec-WebSocket-Extensions header value returns a list of + ExtensionParameter objects. + + Leading LWSes must be trimmed. + """ + + state = http_header_util.ParsingState(data) + + extension_list = [] + while True: + extension = _parse_extension(state, allow_quoted_string) + if extension is not None: + extension_list.append(extension) + + http_header_util.consume_lwses(state) + + if http_header_util.peek(state) is None: + break + + if not http_header_util.consume_string(state, ','): + raise ExtensionParsingException( + 'Failed to parse Sec-WebSocket-Extensions header: ' + 'Expected a comma but found %r' % + http_header_util.peek(state)) + + http_header_util.consume_lwses(state) + + if len(extension_list) == 0: + raise ExtensionParsingException( + 'No valid extension entry found') + + return extension_list + + +def format_extension(extension): + """Formats an ExtensionParameter object.""" + + formatted_params = [extension.name()] + for param_name, param_value in extension.get_parameters(): + if param_value is None: + formatted_params.append(param_name) + else: + quoted_value = http_header_util.quote_if_necessary(param_value) + formatted_params.append('%s=%s' % (param_name, quoted_value)) + return '; '.join(formatted_params) + + +def format_extensions(extension_list): + """Formats a list of ExtensionParameter objects.""" + + formatted_extension_list = [] + for extension in extension_list: + formatted_extension_list.append(format_extension(extension)) + return ', '.join(formatted_extension_list) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/dispatch.py b/pyload/lib/mod_pywebsocket/dispatch.py new file mode 100644 index 000000000..25905f180 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/dispatch.py @@ -0,0 +1,387 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Dispatch WebSocket request. +""" + + +import logging +import os +import re + +from mod_pywebsocket import common +from mod_pywebsocket import handshake +from mod_pywebsocket import msgutil +from mod_pywebsocket import mux +from mod_pywebsocket import stream +from mod_pywebsocket import util + + +_SOURCE_PATH_PATTERN = re.compile(r'(?i)_wsh\.py$') +_SOURCE_SUFFIX = '_wsh.py' +_DO_EXTRA_HANDSHAKE_HANDLER_NAME = 'web_socket_do_extra_handshake' +_TRANSFER_DATA_HANDLER_NAME = 'web_socket_transfer_data' +_PASSIVE_CLOSING_HANDSHAKE_HANDLER_NAME = ( + 'web_socket_passive_closing_handshake') + + +class DispatchException(Exception): + """Exception in dispatching WebSocket request.""" + + def __init__(self, name, status=common.HTTP_STATUS_NOT_FOUND): + super(DispatchException, self).__init__(name) + self.status = status + + +def _default_passive_closing_handshake_handler(request): + """Default web_socket_passive_closing_handshake handler.""" + + return common.STATUS_NORMAL_CLOSURE, '' + + +def _normalize_path(path): + """Normalize path. + + Args: + path: the path to normalize. + + Path is converted to the absolute path. + The input path can use either '\\' or '/' as the separator. + The normalized path always uses '/' regardless of the platform. + """ + + path = path.replace('\\', os.path.sep) + path = os.path.realpath(path) + path = path.replace('\\', '/') + return path + + +def _create_path_to_resource_converter(base_dir): + """Returns a function that converts the path of a WebSocket handler source + file to a resource string by removing the path to the base directory from + its head, removing _SOURCE_SUFFIX from its tail, and replacing path + separators in it with '/'. + + Args: + base_dir: the path to the base directory. + """ + + base_dir = _normalize_path(base_dir) + + base_len = len(base_dir) + suffix_len = len(_SOURCE_SUFFIX) + + def converter(path): + if not path.endswith(_SOURCE_SUFFIX): + return None + # _normalize_path must not be used because resolving symlink breaks + # following path check. + path = path.replace('\\', '/') + if not path.startswith(base_dir): + return None + return path[base_len:-suffix_len] + + return converter + + +def _enumerate_handler_file_paths(directory): + """Returns a generator that enumerates WebSocket Handler source file names + in the given directory. + """ + + for root, unused_dirs, files in os.walk(directory): + for base in files: + path = os.path.join(root, base) + if _SOURCE_PATH_PATTERN.search(path): + yield path + + +class _HandlerSuite(object): + """A handler suite holder class.""" + + def __init__(self, do_extra_handshake, transfer_data, + passive_closing_handshake): + self.do_extra_handshake = do_extra_handshake + self.transfer_data = transfer_data + self.passive_closing_handshake = passive_closing_handshake + + +def _source_handler_file(handler_definition): + """Source a handler definition string. + + Args: + handler_definition: a string containing Python statements that define + handler functions. + """ + + global_dic = {} + try: + exec handler_definition in global_dic + except Exception: + raise DispatchException('Error in sourcing handler:' + + util.get_stack_trace()) + passive_closing_handshake_handler = None + try: + passive_closing_handshake_handler = _extract_handler( + global_dic, _PASSIVE_CLOSING_HANDSHAKE_HANDLER_NAME) + except Exception: + passive_closing_handshake_handler = ( + _default_passive_closing_handshake_handler) + return _HandlerSuite( + _extract_handler(global_dic, _DO_EXTRA_HANDSHAKE_HANDLER_NAME), + _extract_handler(global_dic, _TRANSFER_DATA_HANDLER_NAME), + passive_closing_handshake_handler) + + +def _extract_handler(dic, name): + """Extracts a callable with the specified name from the given dictionary + dic. + """ + + if name not in dic: + raise DispatchException('%s is not defined.' % name) + handler = dic[name] + if not callable(handler): + raise DispatchException('%s is not callable.' % name) + return handler + + +class Dispatcher(object): + """Dispatches WebSocket requests. + + This class maintains a map from resource name to handlers. + """ + + def __init__( + self, root_dir, scan_dir=None, + allow_handlers_outside_root_dir=True): + """Construct an instance. + + Args: + root_dir: The directory where handler definition files are + placed. + scan_dir: The directory where handler definition files are + searched. scan_dir must be a directory under root_dir, + including root_dir itself. If scan_dir is None, + root_dir is used as scan_dir. scan_dir can be useful + in saving scan time when root_dir contains many + subdirectories. + allow_handlers_outside_root_dir: Scans handler files even if their + canonical path is not under root_dir. + """ + + self._logger = util.get_class_logger(self) + + self._handler_suite_map = {} + self._source_warnings = [] + if scan_dir is None: + scan_dir = root_dir + if not os.path.realpath(scan_dir).startswith( + os.path.realpath(root_dir)): + raise DispatchException('scan_dir:%s must be a directory under ' + 'root_dir:%s.' % (scan_dir, root_dir)) + self._source_handler_files_in_dir( + root_dir, scan_dir, allow_handlers_outside_root_dir) + + def add_resource_path_alias(self, + alias_resource_path, existing_resource_path): + """Add resource path alias. + + Once added, request to alias_resource_path would be handled by + handler registered for existing_resource_path. + + Args: + alias_resource_path: alias resource path + existing_resource_path: existing resource path + """ + try: + handler_suite = self._handler_suite_map[existing_resource_path] + self._handler_suite_map[alias_resource_path] = handler_suite + except KeyError: + raise DispatchException('No handler for: %r' % + existing_resource_path) + + def source_warnings(self): + """Return warnings in sourcing handlers.""" + + return self._source_warnings + + def do_extra_handshake(self, request): + """Do extra checking in WebSocket handshake. + + Select a handler based on request.uri and call its + web_socket_do_extra_handshake function. + + Args: + request: mod_python request. + + Raises: + DispatchException: when handler was not found + AbortedByUserException: when user handler abort connection + HandshakeException: when opening handshake failed + """ + + handler_suite = self.get_handler_suite(request.ws_resource) + if handler_suite is None: + raise DispatchException('No handler for: %r' % request.ws_resource) + do_extra_handshake_ = handler_suite.do_extra_handshake + try: + do_extra_handshake_(request) + except handshake.AbortedByUserException, e: + raise + except Exception, e: + util.prepend_message_to_exception( + '%s raised exception for %s: ' % ( + _DO_EXTRA_HANDSHAKE_HANDLER_NAME, + request.ws_resource), + e) + raise handshake.HandshakeException(e, common.HTTP_STATUS_FORBIDDEN) + + def transfer_data(self, request): + """Let a handler transfer_data with a WebSocket client. + + Select a handler based on request.ws_resource and call its + web_socket_transfer_data function. + + Args: + request: mod_python request. + + Raises: + DispatchException: when handler was not found + AbortedByUserException: when user handler abort connection + """ + + # TODO(tyoshino): Terminate underlying TCP connection if possible. + try: + if mux.use_mux(request): + mux.start(request, self) + else: + handler_suite = self.get_handler_suite(request.ws_resource) + if handler_suite is None: + raise DispatchException('No handler for: %r' % + request.ws_resource) + transfer_data_ = handler_suite.transfer_data + transfer_data_(request) + + if not request.server_terminated: + request.ws_stream.close_connection() + # Catch non-critical exceptions the handler didn't handle. + except handshake.AbortedByUserException, e: + self._logger.debug('%s', e) + raise + except msgutil.BadOperationException, e: + self._logger.debug('%s', e) + request.ws_stream.close_connection(common.STATUS_ABNORMAL_CLOSURE) + except msgutil.InvalidFrameException, e: + # InvalidFrameException must be caught before + # ConnectionTerminatedException that catches InvalidFrameException. + self._logger.debug('%s', e) + request.ws_stream.close_connection(common.STATUS_PROTOCOL_ERROR) + except msgutil.UnsupportedFrameException, e: + self._logger.debug('%s', e) + request.ws_stream.close_connection(common.STATUS_UNSUPPORTED_DATA) + except stream.InvalidUTF8Exception, e: + self._logger.debug('%s', e) + request.ws_stream.close_connection( + common.STATUS_INVALID_FRAME_PAYLOAD_DATA) + except msgutil.ConnectionTerminatedException, e: + self._logger.debug('%s', e) + except Exception, e: + util.prepend_message_to_exception( + '%s raised exception for %s: ' % ( + _TRANSFER_DATA_HANDLER_NAME, request.ws_resource), + e) + raise + + def passive_closing_handshake(self, request): + """Prepare code and reason for responding client initiated closing + handshake. + """ + + handler_suite = self.get_handler_suite(request.ws_resource) + if handler_suite is None: + return _default_passive_closing_handshake_handler(request) + return handler_suite.passive_closing_handshake(request) + + def get_handler_suite(self, resource): + """Retrieves two handlers (one for extra handshake processing, and one + for data transfer) for the given request as a HandlerSuite object. + """ + + fragment = None + if '#' in resource: + resource, fragment = resource.split('#', 1) + if '?' in resource: + resource = resource.split('?', 1)[0] + handler_suite = self._handler_suite_map.get(resource) + if handler_suite and fragment: + raise DispatchException('Fragment identifiers MUST NOT be used on ' + 'WebSocket URIs', + common.HTTP_STATUS_BAD_REQUEST) + return handler_suite + + def _source_handler_files_in_dir( + self, root_dir, scan_dir, allow_handlers_outside_root_dir): + """Source all the handler source files in the scan_dir directory. + + The resource path is determined relative to root_dir. + """ + + # We build a map from resource to handler code assuming that there's + # only one path from root_dir to scan_dir and it can be obtained by + # comparing realpath of them. + + # Here we cannot use abspath. See + # https://bugs.webkit.org/show_bug.cgi?id=31603 + + convert = _create_path_to_resource_converter(root_dir) + scan_realpath = os.path.realpath(scan_dir) + root_realpath = os.path.realpath(root_dir) + for path in _enumerate_handler_file_paths(scan_realpath): + if (not allow_handlers_outside_root_dir and + (not os.path.realpath(path).startswith(root_realpath))): + self._logger.debug( + 'Canonical path of %s is not under root directory' % + path) + continue + try: + handler_suite = _source_handler_file(open(path).read()) + except DispatchException, e: + self._source_warnings.append('%s: %s' % (path, e)) + continue + resource = convert(path) + if resource is None: + self._logger.debug( + 'Path to resource conversion on %s failed' % path) + else: + self._handler_suite_map[convert(path)] = handler_suite + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/extensions.py b/pyload/lib/mod_pywebsocket/extensions.py new file mode 100644 index 000000000..03dbf9ee1 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/extensions.py @@ -0,0 +1,727 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +from mod_pywebsocket import common +from mod_pywebsocket import util +from mod_pywebsocket.http_header_util import quote_if_necessary + + +_available_processors = {} + + +class ExtensionProcessorInterface(object): + + def name(self): + return None + + def get_extension_response(self): + return None + + def setup_stream_options(self, stream_options): + pass + + +class DeflateStreamExtensionProcessor(ExtensionProcessorInterface): + """WebSocket DEFLATE stream extension processor. + + Specification: + Section 9.2.1 in + http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-10 + """ + + def __init__(self, request): + self._logger = util.get_class_logger(self) + + self._request = request + + def name(self): + return common.DEFLATE_STREAM_EXTENSION + + def get_extension_response(self): + if len(self._request.get_parameter_names()) != 0: + return None + + self._logger.debug( + 'Enable %s extension', common.DEFLATE_STREAM_EXTENSION) + + return common.ExtensionParameter(common.DEFLATE_STREAM_EXTENSION) + + def setup_stream_options(self, stream_options): + stream_options.deflate_stream = True + + +_available_processors[common.DEFLATE_STREAM_EXTENSION] = ( + DeflateStreamExtensionProcessor) + + +def _log_compression_ratio(logger, original_bytes, total_original_bytes, + filtered_bytes, total_filtered_bytes): + # Print inf when ratio is not available. + ratio = float('inf') + average_ratio = float('inf') + if original_bytes != 0: + ratio = float(filtered_bytes) / original_bytes + if total_original_bytes != 0: + average_ratio = ( + float(total_filtered_bytes) / total_original_bytes) + logger.debug('Outgoing compress ratio: %f (average: %f)' % + (ratio, average_ratio)) + + +def _log_decompression_ratio(logger, received_bytes, total_received_bytes, + filtered_bytes, total_filtered_bytes): + # Print inf when ratio is not available. + ratio = float('inf') + average_ratio = float('inf') + if received_bytes != 0: + ratio = float(received_bytes) / filtered_bytes + if total_filtered_bytes != 0: + average_ratio = ( + float(total_received_bytes) / total_filtered_bytes) + logger.debug('Incoming compress ratio: %f (average: %f)' % + (ratio, average_ratio)) + + +class DeflateFrameExtensionProcessor(ExtensionProcessorInterface): + """WebSocket Per-frame DEFLATE extension processor. + + Specification: + http://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate + """ + + _WINDOW_BITS_PARAM = 'max_window_bits' + _NO_CONTEXT_TAKEOVER_PARAM = 'no_context_takeover' + + def __init__(self, request): + self._logger = util.get_class_logger(self) + + self._request = request + + self._response_window_bits = None + self._response_no_context_takeover = False + self._bfinal = False + + # Counters for statistics. + + # Total number of outgoing bytes supplied to this filter. + self._total_outgoing_payload_bytes = 0 + # Total number of bytes sent to the network after applying this filter. + self._total_filtered_outgoing_payload_bytes = 0 + + # Total number of bytes received from the network. + self._total_incoming_payload_bytes = 0 + # Total number of incoming bytes obtained after applying this filter. + self._total_filtered_incoming_payload_bytes = 0 + + def name(self): + return common.DEFLATE_FRAME_EXTENSION + + def get_extension_response(self): + # Any unknown parameter will be just ignored. + + window_bits = self._request.get_parameter_value( + self._WINDOW_BITS_PARAM) + no_context_takeover = self._request.has_parameter( + self._NO_CONTEXT_TAKEOVER_PARAM) + if (no_context_takeover and + self._request.get_parameter_value( + self._NO_CONTEXT_TAKEOVER_PARAM) is not None): + return None + + if window_bits is not None: + try: + window_bits = int(window_bits) + except ValueError, e: + return None + if window_bits < 8 or window_bits > 15: + return None + + self._deflater = util._RFC1979Deflater( + window_bits, no_context_takeover) + + self._inflater = util._RFC1979Inflater() + + self._compress_outgoing = True + + response = common.ExtensionParameter(self._request.name()) + + if self._response_window_bits is not None: + response.add_parameter( + self._WINDOW_BITS_PARAM, str(self._response_window_bits)) + if self._response_no_context_takeover: + response.add_parameter( + self._NO_CONTEXT_TAKEOVER_PARAM, None) + + self._logger.debug( + 'Enable %s extension (' + 'request: window_bits=%s; no_context_takeover=%r, ' + 'response: window_wbits=%s; no_context_takeover=%r)' % + (self._request.name(), + window_bits, + no_context_takeover, + self._response_window_bits, + self._response_no_context_takeover)) + + return response + + def setup_stream_options(self, stream_options): + + class _OutgoingFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, frame): + self._parent._outgoing_filter(frame) + + class _IncomingFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, frame): + self._parent._incoming_filter(frame) + + stream_options.outgoing_frame_filters.append( + _OutgoingFilter(self)) + stream_options.incoming_frame_filters.insert( + 0, _IncomingFilter(self)) + + def set_response_window_bits(self, value): + self._response_window_bits = value + + def set_response_no_context_takeover(self, value): + self._response_no_context_takeover = value + + def set_bfinal(self, value): + self._bfinal = value + + def enable_outgoing_compression(self): + self._compress_outgoing = True + + def disable_outgoing_compression(self): + self._compress_outgoing = False + + def _outgoing_filter(self, frame): + """Transform outgoing frames. This method is called only by + an _OutgoingFilter instance. + """ + + original_payload_size = len(frame.payload) + self._total_outgoing_payload_bytes += original_payload_size + + if (not self._compress_outgoing or + common.is_control_opcode(frame.opcode)): + self._total_filtered_outgoing_payload_bytes += ( + original_payload_size) + return + + frame.payload = self._deflater.filter( + frame.payload, bfinal=self._bfinal) + frame.rsv1 = 1 + + filtered_payload_size = len(frame.payload) + self._total_filtered_outgoing_payload_bytes += filtered_payload_size + + _log_compression_ratio(self._logger, original_payload_size, + self._total_outgoing_payload_bytes, + filtered_payload_size, + self._total_filtered_outgoing_payload_bytes) + + def _incoming_filter(self, frame): + """Transform incoming frames. This method is called only by + an _IncomingFilter instance. + """ + + received_payload_size = len(frame.payload) + self._total_incoming_payload_bytes += received_payload_size + + if frame.rsv1 != 1 or common.is_control_opcode(frame.opcode): + self._total_filtered_incoming_payload_bytes += ( + received_payload_size) + return + + frame.payload = self._inflater.filter(frame.payload) + frame.rsv1 = 0 + + filtered_payload_size = len(frame.payload) + self._total_filtered_incoming_payload_bytes += filtered_payload_size + + _log_decompression_ratio(self._logger, received_payload_size, + self._total_incoming_payload_bytes, + filtered_payload_size, + self._total_filtered_incoming_payload_bytes) + + +_available_processors[common.DEFLATE_FRAME_EXTENSION] = ( + DeflateFrameExtensionProcessor) + + +# Adding vendor-prefixed deflate-frame extension. +# TODO(bashi): Remove this after WebKit stops using vendor prefix. +_available_processors[common.X_WEBKIT_DEFLATE_FRAME_EXTENSION] = ( + DeflateFrameExtensionProcessor) + + +def _parse_compression_method(data): + """Parses the value of "method" extension parameter.""" + + return common.parse_extensions(data, allow_quoted_string=True) + + +def _create_accepted_method_desc(method_name, method_params): + """Creates accepted-method-desc from given method name and parameters""" + + extension = common.ExtensionParameter(method_name) + for name, value in method_params: + extension.add_parameter(name, value) + return common.format_extension(extension) + + +class CompressionExtensionProcessorBase(ExtensionProcessorInterface): + """Base class for Per-frame and Per-message compression extension.""" + + _METHOD_PARAM = 'method' + + def __init__(self, request): + self._logger = util.get_class_logger(self) + self._request = request + self._compression_method_name = None + self._compression_processor = None + self._compression_processor_hook = None + + def name(self): + return '' + + def _lookup_compression_processor(self, method_desc): + return None + + def _get_compression_processor_response(self): + """Looks up the compression processor based on the self._request and + returns the compression processor's response. + """ + + method_list = self._request.get_parameter_value(self._METHOD_PARAM) + if method_list is None: + return None + methods = _parse_compression_method(method_list) + if methods is None: + return None + comression_processor = None + # The current implementation tries only the first method that matches + # supported algorithm. Following methods aren't tried even if the + # first one is rejected. + # TODO(bashi): Need to clarify this behavior. + for method_desc in methods: + compression_processor = self._lookup_compression_processor( + method_desc) + if compression_processor is not None: + self._compression_method_name = method_desc.name() + break + if compression_processor is None: + return None + + if self._compression_processor_hook: + self._compression_processor_hook(compression_processor) + + processor_response = compression_processor.get_extension_response() + if processor_response is None: + return None + self._compression_processor = compression_processor + return processor_response + + def get_extension_response(self): + processor_response = self._get_compression_processor_response() + if processor_response is None: + return None + + response = common.ExtensionParameter(self._request.name()) + accepted_method_desc = _create_accepted_method_desc( + self._compression_method_name, + processor_response.get_parameters()) + response.add_parameter(self._METHOD_PARAM, accepted_method_desc) + self._logger.debug( + 'Enable %s extension (method: %s)' % + (self._request.name(), self._compression_method_name)) + return response + + def setup_stream_options(self, stream_options): + if self._compression_processor is None: + return + self._compression_processor.setup_stream_options(stream_options) + + def set_compression_processor_hook(self, hook): + self._compression_processor_hook = hook + + def get_compression_processor(self): + return self._compression_processor + + +class PerFrameCompressionExtensionProcessor(CompressionExtensionProcessorBase): + """WebSocket Per-frame compression extension processor. + + Specification: + http://tools.ietf.org/html/draft-ietf-hybi-websocket-perframe-compression + """ + + _DEFLATE_METHOD = 'deflate' + + def __init__(self, request): + CompressionExtensionProcessorBase.__init__(self, request) + + def name(self): + return common.PERFRAME_COMPRESSION_EXTENSION + + def _lookup_compression_processor(self, method_desc): + if method_desc.name() == self._DEFLATE_METHOD: + return DeflateFrameExtensionProcessor(method_desc) + return None + + +_available_processors[common.PERFRAME_COMPRESSION_EXTENSION] = ( + PerFrameCompressionExtensionProcessor) + + +class DeflateMessageProcessor(ExtensionProcessorInterface): + """Per-message deflate processor.""" + + _S2C_MAX_WINDOW_BITS_PARAM = 's2c_max_window_bits' + _S2C_NO_CONTEXT_TAKEOVER_PARAM = 's2c_no_context_takeover' + _C2S_MAX_WINDOW_BITS_PARAM = 'c2s_max_window_bits' + _C2S_NO_CONTEXT_TAKEOVER_PARAM = 'c2s_no_context_takeover' + + def __init__(self, request): + self._request = request + self._logger = util.get_class_logger(self) + + self._c2s_max_window_bits = None + self._c2s_no_context_takeover = False + self._bfinal = False + + self._compress_outgoing_enabled = False + + # True if a message is fragmented and compression is ongoing. + self._compress_ongoing = False + + # Counters for statistics. + + # Total number of outgoing bytes supplied to this filter. + self._total_outgoing_payload_bytes = 0 + # Total number of bytes sent to the network after applying this filter. + self._total_filtered_outgoing_payload_bytes = 0 + + # Total number of bytes received from the network. + self._total_incoming_payload_bytes = 0 + # Total number of incoming bytes obtained after applying this filter. + self._total_filtered_incoming_payload_bytes = 0 + + def name(self): + return 'deflate' + + def get_extension_response(self): + # Any unknown parameter will be just ignored. + + s2c_max_window_bits = self._request.get_parameter_value( + self._S2C_MAX_WINDOW_BITS_PARAM) + if s2c_max_window_bits is not None: + try: + s2c_max_window_bits = int(s2c_max_window_bits) + except ValueError, e: + return None + if s2c_max_window_bits < 8 or s2c_max_window_bits > 15: + return None + + s2c_no_context_takeover = self._request.has_parameter( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM) + if (s2c_no_context_takeover and + self._request.get_parameter_value( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM) is not None): + return None + + self._deflater = util._RFC1979Deflater( + s2c_max_window_bits, s2c_no_context_takeover) + + self._inflater = util._RFC1979Inflater() + + self._compress_outgoing_enabled = True + + response = common.ExtensionParameter(self._request.name()) + + if s2c_max_window_bits is not None: + response.add_parameter( + self._S2C_MAX_WINDOW_BITS_PARAM, str(s2c_max_window_bits)) + + if s2c_no_context_takeover: + response.add_parameter( + self._S2C_NO_CONTEXT_TAKEOVER_PARAM, None) + + if self._c2s_max_window_bits is not None: + response.add_parameter( + self._C2S_MAX_WINDOW_BITS_PARAM, + str(self._c2s_max_window_bits)) + if self._c2s_no_context_takeover: + response.add_parameter( + self._C2S_NO_CONTEXT_TAKEOVER_PARAM, None) + + self._logger.debug( + 'Enable %s extension (' + 'request: s2c_max_window_bits=%s; s2c_no_context_takeover=%r, ' + 'response: c2s_max_window_bits=%s; c2s_no_context_takeover=%r)' % + (self._request.name(), + s2c_max_window_bits, + s2c_no_context_takeover, + self._c2s_max_window_bits, + self._c2s_no_context_takeover)) + + return response + + def setup_stream_options(self, stream_options): + class _OutgoingMessageFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, message, end=True, binary=False): + return self._parent._process_outgoing_message( + message, end, binary) + + class _IncomingMessageFilter(object): + + def __init__(self, parent): + self._parent = parent + self._decompress_next_message = False + + def decompress_next_message(self): + self._decompress_next_message = True + + def filter(self, message): + message = self._parent._process_incoming_message( + message, self._decompress_next_message) + self._decompress_next_message = False + return message + + self._outgoing_message_filter = _OutgoingMessageFilter(self) + self._incoming_message_filter = _IncomingMessageFilter(self) + stream_options.outgoing_message_filters.append( + self._outgoing_message_filter) + stream_options.incoming_message_filters.append( + self._incoming_message_filter) + + class _OutgoingFrameFilter(object): + + def __init__(self, parent): + self._parent = parent + self._set_compression_bit = False + + def set_compression_bit(self): + self._set_compression_bit = True + + def filter(self, frame): + self._parent._process_outgoing_frame( + frame, self._set_compression_bit) + self._set_compression_bit = False + + class _IncomingFrameFilter(object): + + def __init__(self, parent): + self._parent = parent + + def filter(self, frame): + self._parent._process_incoming_frame(frame) + + self._outgoing_frame_filter = _OutgoingFrameFilter(self) + self._incoming_frame_filter = _IncomingFrameFilter(self) + stream_options.outgoing_frame_filters.append( + self._outgoing_frame_filter) + stream_options.incoming_frame_filters.append( + self._incoming_frame_filter) + + stream_options.encode_text_message_to_utf8 = False + + def set_c2s_max_window_bits(self, value): + self._c2s_max_window_bits = value + + def set_c2s_no_context_takeover(self, value): + self._c2s_no_context_takeover = value + + def set_bfinal(self, value): + self._bfinal = value + + def enable_outgoing_compression(self): + self._compress_outgoing_enabled = True + + def disable_outgoing_compression(self): + self._compress_outgoing_enabled = False + + def _process_incoming_message(self, message, decompress): + if not decompress: + return message + + received_payload_size = len(message) + self._total_incoming_payload_bytes += received_payload_size + + message = self._inflater.filter(message) + + filtered_payload_size = len(message) + self._total_filtered_incoming_payload_bytes += filtered_payload_size + + _log_decompression_ratio(self._logger, received_payload_size, + self._total_incoming_payload_bytes, + filtered_payload_size, + self._total_filtered_incoming_payload_bytes) + + return message + + def _process_outgoing_message(self, message, end, binary): + if not binary: + message = message.encode('utf-8') + + if not self._compress_outgoing_enabled: + return message + + original_payload_size = len(message) + self._total_outgoing_payload_bytes += original_payload_size + + message = self._deflater.filter( + message, flush=end, bfinal=self._bfinal) + + filtered_payload_size = len(message) + self._total_filtered_outgoing_payload_bytes += filtered_payload_size + + _log_compression_ratio(self._logger, original_payload_size, + self._total_outgoing_payload_bytes, + filtered_payload_size, + self._total_filtered_outgoing_payload_bytes) + + if not self._compress_ongoing: + self._outgoing_frame_filter.set_compression_bit() + self._compress_ongoing = not end + return message + + def _process_incoming_frame(self, frame): + if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode): + self._incoming_message_filter.decompress_next_message() + frame.rsv1 = 0 + + def _process_outgoing_frame(self, frame, compression_bit): + if (not compression_bit or + common.is_control_opcode(frame.opcode)): + return + + frame.rsv1 = 1 + + +class PerMessageCompressionExtensionProcessor( + CompressionExtensionProcessorBase): + """WebSocket Per-message compression extension processor. + + Specification: + http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression + """ + + _DEFLATE_METHOD = 'deflate' + + def __init__(self, request): + CompressionExtensionProcessorBase.__init__(self, request) + + def name(self): + return common.PERMESSAGE_COMPRESSION_EXTENSION + + def _lookup_compression_processor(self, method_desc): + if method_desc.name() == self._DEFLATE_METHOD: + return DeflateMessageProcessor(method_desc) + return None + + +_available_processors[common.PERMESSAGE_COMPRESSION_EXTENSION] = ( + PerMessageCompressionExtensionProcessor) + + +# Adding vendor-prefixed permessage-compress extension. +# TODO(bashi): Remove this after WebKit stops using vendor prefix. +_available_processors[common.X_WEBKIT_PERMESSAGE_COMPRESSION_EXTENSION] = ( + PerMessageCompressionExtensionProcessor) + + +class MuxExtensionProcessor(ExtensionProcessorInterface): + """WebSocket multiplexing extension processor.""" + + _QUOTA_PARAM = 'quota' + + def __init__(self, request): + self._request = request + + def name(self): + return common.MUX_EXTENSION + + def get_extension_response(self, ws_request, + logical_channel_extensions): + # Mux extension cannot be used after extensions that depend on + # frame boundary, extension data field, or any reserved bits + # which are attributed to each frame. + for extension in logical_channel_extensions: + name = extension.name() + if (name == common.PERFRAME_COMPRESSION_EXTENSION or + name == common.DEFLATE_FRAME_EXTENSION or + name == common.X_WEBKIT_DEFLATE_FRAME_EXTENSION): + return None + + quota = self._request.get_parameter_value(self._QUOTA_PARAM) + if quota is None: + ws_request.mux_quota = 0 + else: + try: + quota = int(quota) + except ValueError, e: + return None + if quota < 0 or quota >= 2 ** 32: + return None + ws_request.mux_quota = quota + + ws_request.mux = True + ws_request.mux_extensions = logical_channel_extensions + return common.ExtensionParameter(common.MUX_EXTENSION) + + def setup_stream_options(self, stream_options): + pass + + +_available_processors[common.MUX_EXTENSION] = MuxExtensionProcessor + + +def get_extension_processor(extension_request): + global _available_processors + processor_class = _available_processors.get(extension_request.name()) + if processor_class is None: + return None + return processor_class(extension_request) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/handshake/__init__.py b/pyload/lib/mod_pywebsocket/handshake/__init__.py new file mode 100644 index 000000000..194f6b395 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/handshake/__init__.py @@ -0,0 +1,110 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""WebSocket opening handshake processor. This class try to apply available +opening handshake processors for each protocol version until a connection is +successfully established. +""" + + +import logging + +from mod_pywebsocket import common +from mod_pywebsocket.handshake import hybi00 +from mod_pywebsocket.handshake import hybi +# Export AbortedByUserException, HandshakeException, and VersionException +# symbol from this module. +from mod_pywebsocket.handshake._base import AbortedByUserException +from mod_pywebsocket.handshake._base import HandshakeException +from mod_pywebsocket.handshake._base import VersionException + + +_LOGGER = logging.getLogger(__name__) + + +def do_handshake(request, dispatcher, allowDraft75=False, strict=False): + """Performs WebSocket handshake. + + Args: + request: mod_python request. + dispatcher: Dispatcher (dispatch.Dispatcher). + allowDraft75: obsolete argument. ignored. + strict: obsolete argument. ignored. + + Handshaker will add attributes such as ws_resource in performing + handshake. + """ + + _LOGGER.debug('Client\'s opening handshake resource: %r', request.uri) + # To print mimetools.Message as escaped one-line string, we converts + # headers_in to dict object. Without conversion, if we use %r, it just + # prints the type and address, and if we use %s, it prints the original + # header string as multiple lines. + # + # Both mimetools.Message and MpTable_Type of mod_python can be + # converted to dict. + # + # mimetools.Message.__str__ returns the original header string. + # dict(mimetools.Message object) returns the map from header names to + # header values. While MpTable_Type doesn't have such __str__ but just + # __repr__ which formats itself as well as dictionary object. + _LOGGER.debug( + 'Client\'s opening handshake headers: %r', dict(request.headers_in)) + + handshakers = [] + handshakers.append( + ('RFC 6455', hybi.Handshaker(request, dispatcher))) + handshakers.append( + ('HyBi 00', hybi00.Handshaker(request, dispatcher))) + + for name, handshaker in handshakers: + _LOGGER.debug('Trying protocol version %s', name) + try: + handshaker.do_handshake() + _LOGGER.info('Established (%s protocol)', name) + return + except HandshakeException, e: + _LOGGER.debug( + 'Failed to complete opening handshake as %s protocol: %r', + name, e) + if e.status: + raise e + except AbortedByUserException, e: + raise + except VersionException, e: + raise + + # TODO(toyoshim): Add a test to cover the case all handshakers fail. + raise HandshakeException( + 'Failed to complete opening handshake for all available protocols', + status=common.HTTP_STATUS_BAD_REQUEST) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/handshake/_base.py b/pyload/lib/mod_pywebsocket/handshake/_base.py new file mode 100644 index 000000000..e5c94ca90 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/handshake/_base.py @@ -0,0 +1,226 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Common functions and exceptions used by WebSocket opening handshake +processors. +""" + + +from mod_pywebsocket import common +from mod_pywebsocket import http_header_util + + +class AbortedByUserException(Exception): + """Exception for aborting a connection intentionally. + + If this exception is raised in do_extra_handshake handler, the connection + will be abandoned. No other WebSocket or HTTP(S) handler will be invoked. + + If this exception is raised in transfer_data_handler, the connection will + be closed without closing handshake. No other WebSocket or HTTP(S) handler + will be invoked. + """ + + pass + + +class HandshakeException(Exception): + """This exception will be raised when an error occurred while processing + WebSocket initial handshake. + """ + + def __init__(self, name, status=None): + super(HandshakeException, self).__init__(name) + self.status = status + + +class VersionException(Exception): + """This exception will be raised when a version of client request does not + match with version the server supports. + """ + + def __init__(self, name, supported_versions=''): + """Construct an instance. + + Args: + supported_version: a str object to show supported hybi versions. + (e.g. '8, 13') + """ + super(VersionException, self).__init__(name) + self.supported_versions = supported_versions + + +def get_default_port(is_secure): + if is_secure: + return common.DEFAULT_WEB_SOCKET_SECURE_PORT + else: + return common.DEFAULT_WEB_SOCKET_PORT + + +def validate_subprotocol(subprotocol, hixie): + """Validate a value in the Sec-WebSocket-Protocol field. + + See + - RFC 6455: Section 4.1., 4.2.2., and 4.3. + - HyBi 00: Section 4.1. Opening handshake + + Args: + hixie: if True, checks if characters in subprotocol are in range + between U+0020 and U+007E. It's required by HyBi 00 but not by + RFC 6455. + """ + + if not subprotocol: + raise HandshakeException('Invalid subprotocol name: empty') + if hixie: + # Parameter should be in the range U+0020 to U+007E. + for c in subprotocol: + if not 0x20 <= ord(c) <= 0x7e: + raise HandshakeException( + 'Illegal character in subprotocol name: %r' % c) + else: + # Parameter should be encoded HTTP token. + state = http_header_util.ParsingState(subprotocol) + token = http_header_util.consume_token(state) + rest = http_header_util.peek(state) + # If |rest| is not None, |subprotocol| is not one token or invalid. If + # |rest| is None, |token| must not be None because |subprotocol| is + # concatenation of |token| and |rest| and is not None. + if rest is not None: + raise HandshakeException('Invalid non-token string in subprotocol ' + 'name: %r' % rest) + + +def parse_host_header(request): + fields = request.headers_in['Host'].split(':', 1) + if len(fields) == 1: + return fields[0], get_default_port(request.is_https()) + try: + return fields[0], int(fields[1]) + except ValueError, e: + raise HandshakeException('Invalid port number format: %r' % e) + + +def format_header(name, value): + return '%s: %s\r\n' % (name, value) + + +def build_location(request): + """Build WebSocket location for request.""" + location_parts = [] + if request.is_https(): + location_parts.append(common.WEB_SOCKET_SECURE_SCHEME) + else: + location_parts.append(common.WEB_SOCKET_SCHEME) + location_parts.append('://') + host, port = parse_host_header(request) + connection_port = request.connection.local_addr[1] + if port != connection_port: + raise HandshakeException('Header/connection port mismatch: %d/%d' % + (port, connection_port)) + location_parts.append(host) + if (port != get_default_port(request.is_https())): + location_parts.append(':') + location_parts.append(str(port)) + location_parts.append(request.uri) + return ''.join(location_parts) + + +def get_mandatory_header(request, key): + value = request.headers_in.get(key) + if value is None: + raise HandshakeException('Header %s is not defined' % key) + return value + + +def validate_mandatory_header(request, key, expected_value, fail_status=None): + value = get_mandatory_header(request, key) + + if value.lower() != expected_value.lower(): + raise HandshakeException( + 'Expected %r for header %s but found %r (case-insensitive)' % + (expected_value, key, value), status=fail_status) + + +def check_request_line(request): + # 5.1 1. The three character UTF-8 string "GET". + # 5.1 2. A UTF-8-encoded U+0020 SPACE character (0x20 byte). + if request.method != 'GET': + raise HandshakeException('Method is not GET: %r' % request.method) + + if request.protocol != 'HTTP/1.1': + raise HandshakeException('Version is not HTTP/1.1: %r' % + request.protocol) + + +def check_header_lines(request, mandatory_headers): + check_request_line(request) + + # The expected field names, and the meaning of their corresponding + # values, are as follows. + # |Upgrade| and |Connection| + for key, expected_value in mandatory_headers: + validate_mandatory_header(request, key, expected_value) + + +def parse_token_list(data): + """Parses a header value which follows 1#token and returns parsed elements + as a list of strings. + + Leading LWSes must be trimmed. + """ + + state = http_header_util.ParsingState(data) + + token_list = [] + + while True: + token = http_header_util.consume_token(state) + if token is not None: + token_list.append(token) + + http_header_util.consume_lwses(state) + + if http_header_util.peek(state) is None: + break + + if not http_header_util.consume_string(state, ','): + raise HandshakeException( + 'Expected a comma but found %r' % http_header_util.peek(state)) + + http_header_util.consume_lwses(state) + + if len(token_list) == 0: + raise HandshakeException('No valid token found') + + return token_list + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi.py b/pyload/lib/mod_pywebsocket/handshake/hybi.py new file mode 100644 index 000000000..fc0e2a096 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/handshake/hybi.py @@ -0,0 +1,404 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file provides the opening handshake processor for the WebSocket +protocol (RFC 6455). + +Specification: +http://tools.ietf.org/html/rfc6455 +""" + + +# Note: request.connection.write is used in this module, even though mod_python +# document says that it should be used only in connection handlers. +# Unfortunately, we have no other options. For example, request.write is not +# suitable because it doesn't allow direct raw bytes writing. + + +import base64 +import logging +import os +import re + +from mod_pywebsocket import common +from mod_pywebsocket.extensions import get_extension_processor +from mod_pywebsocket.handshake._base import check_request_line +from mod_pywebsocket.handshake._base import format_header +from mod_pywebsocket.handshake._base import get_mandatory_header +from mod_pywebsocket.handshake._base import HandshakeException +from mod_pywebsocket.handshake._base import parse_token_list +from mod_pywebsocket.handshake._base import validate_mandatory_header +from mod_pywebsocket.handshake._base import validate_subprotocol +from mod_pywebsocket.handshake._base import VersionException +from mod_pywebsocket.stream import Stream +from mod_pywebsocket.stream import StreamOptions +from mod_pywebsocket import util + + +# Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648 +# disallows non-zero padding, so the character right before == must be any of +# A, Q, g and w. +_SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$') + +# Defining aliases for values used frequently. +_VERSION_HYBI08 = common.VERSION_HYBI08 +_VERSION_HYBI08_STRING = str(_VERSION_HYBI08) +_VERSION_LATEST = common.VERSION_HYBI_LATEST +_VERSION_LATEST_STRING = str(_VERSION_LATEST) +_SUPPORTED_VERSIONS = [ + _VERSION_LATEST, + _VERSION_HYBI08, +] + + +def compute_accept(key): + """Computes value for the Sec-WebSocket-Accept header from value of the + Sec-WebSocket-Key header. + """ + + accept_binary = util.sha1_hash( + key + common.WEBSOCKET_ACCEPT_UUID).digest() + accept = base64.b64encode(accept_binary) + + return (accept, accept_binary) + + +class Handshaker(object): + """Opening handshake processor for the WebSocket protocol (RFC 6455).""" + + def __init__(self, request, dispatcher): + """Construct an instance. + + Args: + request: mod_python request. + dispatcher: Dispatcher (dispatch.Dispatcher). + + Handshaker will add attributes such as ws_resource during handshake. + """ + + self._logger = util.get_class_logger(self) + + self._request = request + self._dispatcher = dispatcher + + def _validate_connection_header(self): + connection = get_mandatory_header( + self._request, common.CONNECTION_HEADER) + + try: + connection_tokens = parse_token_list(connection) + except HandshakeException, e: + raise HandshakeException( + 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e)) + + connection_is_valid = False + for token in connection_tokens: + if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower(): + connection_is_valid = True + break + if not connection_is_valid: + raise HandshakeException( + '%s header doesn\'t contain "%s"' % + (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) + + def do_handshake(self): + self._request.ws_close_code = None + self._request.ws_close_reason = None + + # Parsing. + + check_request_line(self._request) + + validate_mandatory_header( + self._request, + common.UPGRADE_HEADER, + common.WEBSOCKET_UPGRADE_TYPE) + + self._validate_connection_header() + + self._request.ws_resource = self._request.uri + + unused_host = get_mandatory_header(self._request, common.HOST_HEADER) + + self._request.ws_version = self._check_version() + + # This handshake must be based on latest hybi. We are responsible to + # fallback to HTTP on handshake failure as latest hybi handshake + # specifies. + try: + self._get_origin() + self._set_protocol() + self._parse_extensions() + + # Key validation, response generation. + + key = self._get_key() + (accept, accept_binary) = compute_accept(key) + self._logger.debug( + '%s: %r (%s)', + common.SEC_WEBSOCKET_ACCEPT_HEADER, + accept, + util.hexify(accept_binary)) + + self._logger.debug('Protocol version is RFC 6455') + + # Setup extension processors. + + processors = [] + if self._request.ws_requested_extensions is not None: + for extension_request in self._request.ws_requested_extensions: + processor = get_extension_processor(extension_request) + # Unknown extension requests are just ignored. + if processor is not None: + processors.append(processor) + self._request.ws_extension_processors = processors + + # Extra handshake handler may modify/remove processors. + self._dispatcher.do_extra_handshake(self._request) + processors = filter(lambda processor: processor is not None, + self._request.ws_extension_processors) + + accepted_extensions = [] + + # We need to take care of mux extension here. Extensions that + # are placed before mux should be applied to logical channels. + mux_index = -1 + for i, processor in enumerate(processors): + if processor.name() == common.MUX_EXTENSION: + mux_index = i + break + if mux_index >= 0: + mux_processor = processors[mux_index] + logical_channel_processors = processors[:mux_index] + processors = processors[mux_index+1:] + + for processor in logical_channel_processors: + extension_response = processor.get_extension_response() + if extension_response is None: + # Rejected. + continue + accepted_extensions.append(extension_response) + # Pass a shallow copy of accepted_extensions as extensions for + # logical channels. + mux_response = mux_processor.get_extension_response( + self._request, accepted_extensions[:]) + if mux_response is not None: + accepted_extensions.append(mux_response) + + stream_options = StreamOptions() + + # When there is mux extension, here, |processors| contain only + # prosessors for extensions placed after mux. + for processor in processors: + + extension_response = processor.get_extension_response() + if extension_response is None: + # Rejected. + continue + + accepted_extensions.append(extension_response) + + processor.setup_stream_options(stream_options) + + if len(accepted_extensions) > 0: + self._request.ws_extensions = accepted_extensions + self._logger.debug( + 'Extensions accepted: %r', + map(common.ExtensionParameter.name, accepted_extensions)) + else: + self._request.ws_extensions = None + + self._request.ws_stream = self._create_stream(stream_options) + + if self._request.ws_requested_protocols is not None: + if self._request.ws_protocol is None: + raise HandshakeException( + 'do_extra_handshake must choose one subprotocol from ' + 'ws_requested_protocols and set it to ws_protocol') + validate_subprotocol(self._request.ws_protocol, hixie=False) + + self._logger.debug( + 'Subprotocol accepted: %r', + self._request.ws_protocol) + else: + if self._request.ws_protocol is not None: + raise HandshakeException( + 'ws_protocol must be None when the client didn\'t ' + 'request any subprotocol') + + self._send_handshake(accept) + except HandshakeException, e: + if not e.status: + # Fallback to 400 bad request by default. + e.status = common.HTTP_STATUS_BAD_REQUEST + raise e + + def _get_origin(self): + if self._request.ws_version is _VERSION_HYBI08: + origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER + else: + origin_header = common.ORIGIN_HEADER + origin = self._request.headers_in.get(origin_header) + if origin is None: + self._logger.debug('Client request does not have origin header') + self._request.ws_origin = origin + + def _check_version(self): + version = get_mandatory_header(self._request, + common.SEC_WEBSOCKET_VERSION_HEADER) + if version == _VERSION_HYBI08_STRING: + return _VERSION_HYBI08 + if version == _VERSION_LATEST_STRING: + return _VERSION_LATEST + + if version.find(',') >= 0: + raise HandshakeException( + 'Multiple versions (%r) are not allowed for header %s' % + (version, common.SEC_WEBSOCKET_VERSION_HEADER), + status=common.HTTP_STATUS_BAD_REQUEST) + raise VersionException( + 'Unsupported version %r for header %s' % + (version, common.SEC_WEBSOCKET_VERSION_HEADER), + supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS))) + + def _set_protocol(self): + self._request.ws_protocol = None + + protocol_header = self._request.headers_in.get( + common.SEC_WEBSOCKET_PROTOCOL_HEADER) + + if protocol_header is None: + self._request.ws_requested_protocols = None + return + + self._request.ws_requested_protocols = parse_token_list( + protocol_header) + self._logger.debug('Subprotocols requested: %r', + self._request.ws_requested_protocols) + + def _parse_extensions(self): + extensions_header = self._request.headers_in.get( + common.SEC_WEBSOCKET_EXTENSIONS_HEADER) + if not extensions_header: + self._request.ws_requested_extensions = None + return + + if self._request.ws_version is common.VERSION_HYBI08: + allow_quoted_string=False + else: + allow_quoted_string=True + try: + self._request.ws_requested_extensions = common.parse_extensions( + extensions_header, allow_quoted_string=allow_quoted_string) + except common.ExtensionParsingException, e: + raise HandshakeException( + 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) + + self._logger.debug( + 'Extensions requested: %r', + map(common.ExtensionParameter.name, + self._request.ws_requested_extensions)) + + def _validate_key(self, key): + if key.find(',') >= 0: + raise HandshakeException('Request has multiple %s header lines or ' + 'contains illegal character \',\': %r' % + (common.SEC_WEBSOCKET_KEY_HEADER, key)) + + # Validate + key_is_valid = False + try: + # Validate key by quick regex match before parsing by base64 + # module. Because base64 module skips invalid characters, we have + # to do this in advance to make this server strictly reject illegal + # keys. + if _SEC_WEBSOCKET_KEY_REGEX.match(key): + decoded_key = base64.b64decode(key) + if len(decoded_key) == 16: + key_is_valid = True + except TypeError, e: + pass + + if not key_is_valid: + raise HandshakeException( + 'Illegal value for header %s: %r' % + (common.SEC_WEBSOCKET_KEY_HEADER, key)) + + return decoded_key + + def _get_key(self): + key = get_mandatory_header( + self._request, common.SEC_WEBSOCKET_KEY_HEADER) + + decoded_key = self._validate_key(key) + + self._logger.debug( + '%s: %r (%s)', + common.SEC_WEBSOCKET_KEY_HEADER, + key, + util.hexify(decoded_key)) + + return key + + def _create_stream(self, stream_options): + return Stream(self._request, stream_options) + + def _create_handshake_response(self, accept): + response = [] + + response.append('HTTP/1.1 101 Switching Protocols\r\n') + + response.append(format_header( + common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE)) + response.append(format_header( + common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) + response.append(format_header( + common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) + if self._request.ws_protocol is not None: + response.append(format_header( + common.SEC_WEBSOCKET_PROTOCOL_HEADER, + self._request.ws_protocol)) + if (self._request.ws_extensions is not None and + len(self._request.ws_extensions) != 0): + response.append(format_header( + common.SEC_WEBSOCKET_EXTENSIONS_HEADER, + common.format_extensions(self._request.ws_extensions))) + response.append('\r\n') + + return ''.join(response) + + def _send_handshake(self, accept): + raw_response = self._create_handshake_response(accept) + self._request.connection.write(raw_response) + self._logger.debug('Sent server\'s opening handshake: %r', + raw_response) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/handshake/hybi00.py b/pyload/lib/mod_pywebsocket/handshake/hybi00.py new file mode 100644 index 000000000..cc6f8dc43 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/handshake/hybi00.py @@ -0,0 +1,242 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file provides the opening handshake processor for the WebSocket +protocol version HyBi 00. + +Specification: +http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00 +""" + + +# Note: request.connection.write/read are used in this module, even though +# mod_python document says that they should be used only in connection +# handlers. Unfortunately, we have no other options. For example, +# request.write/read are not suitable because they don't allow direct raw bytes +# writing/reading. + + +import logging +import re +import struct + +from mod_pywebsocket import common +from mod_pywebsocket.stream import StreamHixie75 +from mod_pywebsocket import util +from mod_pywebsocket.handshake._base import HandshakeException +from mod_pywebsocket.handshake._base import build_location +from mod_pywebsocket.handshake._base import check_header_lines +from mod_pywebsocket.handshake._base import format_header +from mod_pywebsocket.handshake._base import get_mandatory_header +from mod_pywebsocket.handshake._base import validate_subprotocol + + +_MANDATORY_HEADERS = [ + # key, expected value or None + [common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE_HIXIE75], + [common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE], +] + + +class Handshaker(object): + """Opening handshake processor for the WebSocket protocol version HyBi 00. + """ + + def __init__(self, request, dispatcher): + """Construct an instance. + + Args: + request: mod_python request. + dispatcher: Dispatcher (dispatch.Dispatcher). + + Handshaker will add attributes such as ws_resource in performing + handshake. + """ + + self._logger = util.get_class_logger(self) + + self._request = request + self._dispatcher = dispatcher + + def do_handshake(self): + """Perform WebSocket Handshake. + + On _request, we set + ws_resource, ws_protocol, ws_location, ws_origin, ws_challenge, + ws_challenge_md5: WebSocket handshake information. + ws_stream: Frame generation/parsing class. + ws_version: Protocol version. + + Raises: + HandshakeException: when any error happened in parsing the opening + handshake request. + """ + + # 5.1 Reading the client's opening handshake. + # dispatcher sets it in self._request. + check_header_lines(self._request, _MANDATORY_HEADERS) + self._set_resource() + self._set_subprotocol() + self._set_location() + self._set_origin() + self._set_challenge_response() + self._set_protocol_version() + + self._dispatcher.do_extra_handshake(self._request) + + self._send_handshake() + + def _set_resource(self): + self._request.ws_resource = self._request.uri + + def _set_subprotocol(self): + # |Sec-WebSocket-Protocol| + subprotocol = self._request.headers_in.get( + common.SEC_WEBSOCKET_PROTOCOL_HEADER) + if subprotocol is not None: + validate_subprotocol(subprotocol, hixie=True) + self._request.ws_protocol = subprotocol + + def _set_location(self): + # |Host| + host = self._request.headers_in.get(common.HOST_HEADER) + if host is not None: + self._request.ws_location = build_location(self._request) + # TODO(ukai): check host is this host. + + def _set_origin(self): + # |Origin| + origin = self._request.headers_in.get(common.ORIGIN_HEADER) + if origin is not None: + self._request.ws_origin = origin + + def _set_protocol_version(self): + # |Sec-WebSocket-Draft| + draft = self._request.headers_in.get(common.SEC_WEBSOCKET_DRAFT_HEADER) + if draft is not None and draft != '0': + raise HandshakeException('Illegal value for %s: %s' % + (common.SEC_WEBSOCKET_DRAFT_HEADER, + draft)) + + self._logger.debug('Protocol version is HyBi 00') + self._request.ws_version = common.VERSION_HYBI00 + self._request.ws_stream = StreamHixie75(self._request, True) + + def _set_challenge_response(self): + # 5.2 4-8. + self._request.ws_challenge = self._get_challenge() + # 5.2 9. let /response/ be the MD5 finterprint of /challenge/ + self._request.ws_challenge_md5 = util.md5_hash( + self._request.ws_challenge).digest() + self._logger.debug( + 'Challenge: %r (%s)', + self._request.ws_challenge, + util.hexify(self._request.ws_challenge)) + self._logger.debug( + 'Challenge response: %r (%s)', + self._request.ws_challenge_md5, + util.hexify(self._request.ws_challenge_md5)) + + def _get_key_value(self, key_field): + key_value = get_mandatory_header(self._request, key_field) + + self._logger.debug('%s: %r', key_field, key_value) + + # 5.2 4. let /key-number_n/ be the digits (characters in the range + # U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9)) in /key_n/, + # interpreted as a base ten integer, ignoring all other characters + # in /key_n/. + try: + key_number = int(re.sub("\\D", "", key_value)) + except: + raise HandshakeException('%s field contains no digit' % key_field) + # 5.2 5. let /spaces_n/ be the number of U+0020 SPACE characters + # in /key_n/. + spaces = re.subn(" ", "", key_value)[1] + if spaces == 0: + raise HandshakeException('%s field contains no space' % key_field) + + self._logger.debug( + '%s: Key-number is %d and number of spaces is %d', + key_field, key_number, spaces) + + # 5.2 6. if /key-number_n/ is not an integral multiple of /spaces_n/ + # then abort the WebSocket connection. + if key_number % spaces != 0: + raise HandshakeException( + '%s: Key-number (%d) is not an integral multiple of spaces ' + '(%d)' % (key_field, key_number, spaces)) + # 5.2 7. let /part_n/ be /key-number_n/ divided by /spaces_n/. + part = key_number / spaces + self._logger.debug('%s: Part is %d', key_field, part) + return part + + def _get_challenge(self): + # 5.2 4-7. + key1 = self._get_key_value(common.SEC_WEBSOCKET_KEY1_HEADER) + key2 = self._get_key_value(common.SEC_WEBSOCKET_KEY2_HEADER) + # 5.2 8. let /challenge/ be the concatenation of /part_1/, + challenge = '' + challenge += struct.pack('!I', key1) # network byteorder int + challenge += struct.pack('!I', key2) # network byteorder int + challenge += self._request.connection.read(8) + return challenge + + def _send_handshake(self): + response = [] + + # 5.2 10. send the following line. + response.append('HTTP/1.1 101 WebSocket Protocol Handshake\r\n') + + # 5.2 11. send the following fields to the client. + response.append(format_header( + common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE_HIXIE75)) + response.append(format_header( + common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) + response.append(format_header( + common.SEC_WEBSOCKET_LOCATION_HEADER, self._request.ws_location)) + response.append(format_header( + common.SEC_WEBSOCKET_ORIGIN_HEADER, self._request.ws_origin)) + if self._request.ws_protocol: + response.append(format_header( + common.SEC_WEBSOCKET_PROTOCOL_HEADER, + self._request.ws_protocol)) + # 5.2 12. send two bytes 0x0D 0x0A. + response.append('\r\n') + # 5.2 13. send /response/ + response.append(self._request.ws_challenge_md5) + + raw_response = ''.join(response) + self._request.connection.write(raw_response) + self._logger.debug('Sent server\'s opening handshake: %r', + raw_response) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/headerparserhandler.py b/pyload/lib/mod_pywebsocket/headerparserhandler.py new file mode 100644 index 000000000..2cc62de04 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/headerparserhandler.py @@ -0,0 +1,244 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""PythonHeaderParserHandler for mod_pywebsocket. + +Apache HTTP Server and mod_python must be configured such that this +function is called to handle WebSocket request. +""" + + +import logging + +from mod_python import apache + +from mod_pywebsocket import common +from mod_pywebsocket import dispatch +from mod_pywebsocket import handshake +from mod_pywebsocket import util + + +# PythonOption to specify the handler root directory. +_PYOPT_HANDLER_ROOT = 'mod_pywebsocket.handler_root' + +# PythonOption to specify the handler scan directory. +# This must be a directory under the root directory. +# The default is the root directory. +_PYOPT_HANDLER_SCAN = 'mod_pywebsocket.handler_scan' + +# PythonOption to allow handlers whose canonical path is +# not under the root directory. It's disallowed by default. +# Set this option with value of 'yes' to allow. +_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT = ( + 'mod_pywebsocket.allow_handlers_outside_root_dir') +# Map from values to their meanings. 'Yes' and 'No' are allowed just for +# compatibility. +_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION = { + 'off': False, 'no': False, 'on': True, 'yes': True} + +# (Obsolete option. Ignored.) +# PythonOption to specify to allow handshake defined in Hixie 75 version +# protocol. The default is None (Off) +_PYOPT_ALLOW_DRAFT75 = 'mod_pywebsocket.allow_draft75' +# Map from values to their meanings. +_PYOPT_ALLOW_DRAFT75_DEFINITION = {'off': False, 'on': True} + + +class ApacheLogHandler(logging.Handler): + """Wrapper logging.Handler to emit log message to apache's error.log.""" + + _LEVELS = { + logging.DEBUG: apache.APLOG_DEBUG, + logging.INFO: apache.APLOG_INFO, + logging.WARNING: apache.APLOG_WARNING, + logging.ERROR: apache.APLOG_ERR, + logging.CRITICAL: apache.APLOG_CRIT, + } + + def __init__(self, request=None): + logging.Handler.__init__(self) + self._log_error = apache.log_error + if request is not None: + self._log_error = request.log_error + + # Time and level will be printed by Apache. + self._formatter = logging.Formatter('%(name)s: %(message)s') + + def emit(self, record): + apache_level = apache.APLOG_DEBUG + if record.levelno in ApacheLogHandler._LEVELS: + apache_level = ApacheLogHandler._LEVELS[record.levelno] + + msg = self._formatter.format(record) + + # "server" parameter must be passed to have "level" parameter work. + # If only "level" parameter is passed, nothing shows up on Apache's + # log. However, at this point, we cannot get the server object of the + # virtual host which will process WebSocket requests. The only server + # object we can get here is apache.main_server. But Wherever (server + # configuration context or virtual host context) we put + # PythonHeaderParserHandler directive, apache.main_server just points + # the main server instance (not any of virtual server instance). Then, + # Apache follows LogLevel directive in the server configuration context + # to filter logs. So, we need to specify LogLevel in the server + # configuration context. Even if we specify "LogLevel debug" in the + # virtual host context which actually handles WebSocket connections, + # DEBUG level logs never show up unless "LogLevel debug" is specified + # in the server configuration context. + # + # TODO(tyoshino): Provide logging methods on request object. When + # request is mp_request object (when used together with Apache), the + # methods call request.log_error indirectly. When request is + # _StandaloneRequest, the methods call Python's logging facility which + # we create in standalone.py. + self._log_error(msg, apache_level, apache.main_server) + + +def _configure_logging(): + logger = logging.getLogger() + # Logs are filtered by Apache based on LogLevel directive in Apache + # configuration file. We must just pass logs for all levels to + # ApacheLogHandler. + logger.setLevel(logging.DEBUG) + logger.addHandler(ApacheLogHandler()) + + +_configure_logging() + +_LOGGER = logging.getLogger(__name__) + + +def _parse_option(name, value, definition): + if value is None: + return False + + meaning = definition.get(value.lower()) + if meaning is None: + raise Exception('Invalid value for PythonOption %s: %r' % + (name, value)) + return meaning + + +def _create_dispatcher(): + _LOGGER.info('Initializing Dispatcher') + + options = apache.main_server.get_options() + + handler_root = options.get(_PYOPT_HANDLER_ROOT, None) + if not handler_root: + raise Exception('PythonOption %s is not defined' % _PYOPT_HANDLER_ROOT, + apache.APLOG_ERR) + + handler_scan = options.get(_PYOPT_HANDLER_SCAN, handler_root) + + allow_handlers_outside_root = _parse_option( + _PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT, + options.get(_PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT), + _PYOPT_ALLOW_HANDLERS_OUTSIDE_ROOT_DEFINITION) + + dispatcher = dispatch.Dispatcher( + handler_root, handler_scan, allow_handlers_outside_root) + + for warning in dispatcher.source_warnings(): + apache.log_error('mod_pywebsocket: %s' % warning, apache.APLOG_WARNING) + + return dispatcher + + +# Initialize +_dispatcher = _create_dispatcher() + + +def headerparserhandler(request): + """Handle request. + + Args: + request: mod_python request. + + This function is named headerparserhandler because it is the default + name for a PythonHeaderParserHandler. + """ + + handshake_is_done = False + try: + # Fallback to default http handler for request paths for which + # we don't have request handlers. + if not _dispatcher.get_handler_suite(request.uri): + request.log_error('No handler for resource: %r' % request.uri, + apache.APLOG_INFO) + request.log_error('Fallback to Apache', apache.APLOG_INFO) + return apache.DECLINED + except dispatch.DispatchException, e: + request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + if not handshake_is_done: + return e.status + + try: + allow_draft75 = _parse_option( + _PYOPT_ALLOW_DRAFT75, + apache.main_server.get_options().get(_PYOPT_ALLOW_DRAFT75), + _PYOPT_ALLOW_DRAFT75_DEFINITION) + + try: + handshake.do_handshake( + request, _dispatcher, allowDraft75=allow_draft75) + except handshake.VersionException, e: + request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + request.err_headers_out.add(common.SEC_WEBSOCKET_VERSION_HEADER, + e.supported_versions) + return apache.HTTP_BAD_REQUEST + except handshake.HandshakeException, e: + # Handshake for ws/wss failed. + # Send http response with error status. + request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + return e.status + + handshake_is_done = True + request._dispatcher = _dispatcher + _dispatcher.transfer_data(request) + except handshake.AbortedByUserException, e: + request.log_error('mod_pywebsocket: %s' % e, apache.APLOG_INFO) + except Exception, e: + # DispatchException can also be thrown if something is wrong in + # pywebsocket code. It's caught here, then. + + request.log_error('mod_pywebsocket: %s\n%s' % + (e, util.get_stack_trace()), + apache.APLOG_ERR) + # Unknown exceptions before handshake mean Apache must handle its + # request with another handler. + if not handshake_is_done: + return apache.DECLINED + # Set assbackwards to suppress response header generation by Apache. + request.assbackwards = 1 + return apache.DONE # Return DONE such that no other handlers are invoked. + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/http_header_util.py b/pyload/lib/mod_pywebsocket/http_header_util.py new file mode 100644 index 000000000..b77465393 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/http_header_util.py @@ -0,0 +1,263 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Utilities for parsing and formatting headers that follow the grammar defined +in HTTP RFC http://www.ietf.org/rfc/rfc2616.txt. +""" + + +import urlparse + + +_SEPARATORS = '()<>@,;:\\"/[]?={} \t' + + +def _is_char(c): + """Returns true iff c is in CHAR as specified in HTTP RFC.""" + + return ord(c) <= 127 + + +def _is_ctl(c): + """Returns true iff c is in CTL as specified in HTTP RFC.""" + + return ord(c) <= 31 or ord(c) == 127 + + +class ParsingState(object): + + def __init__(self, data): + self.data = data + self.head = 0 + + +def peek(state, pos=0): + """Peeks the character at pos from the head of data.""" + + if state.head + pos >= len(state.data): + return None + + return state.data[state.head + pos] + + +def consume(state, amount=1): + """Consumes specified amount of bytes from the head and returns the + consumed bytes. If there's not enough bytes to consume, returns None. + """ + + if state.head + amount > len(state.data): + return None + + result = state.data[state.head:state.head + amount] + state.head = state.head + amount + return result + + +def consume_string(state, expected): + """Given a parsing state and a expected string, consumes the string from + the head. Returns True if consumed successfully. Otherwise, returns + False. + """ + + pos = 0 + + for c in expected: + if c != peek(state, pos): + return False + pos += 1 + + consume(state, pos) + return True + + +def consume_lws(state): + """Consumes a LWS from the head. Returns True if any LWS is consumed. + Otherwise, returns False. + + LWS = [CRLF] 1*( SP | HT ) + """ + + original_head = state.head + + consume_string(state, '\r\n') + + pos = 0 + + while True: + c = peek(state, pos) + if c == ' ' or c == '\t': + pos += 1 + else: + if pos == 0: + state.head = original_head + return False + else: + consume(state, pos) + return True + + +def consume_lwses(state): + """Consumes *LWS from the head.""" + + while consume_lws(state): + pass + + +def consume_token(state): + """Consumes a token from the head. Returns the token or None if no token + was found. + """ + + pos = 0 + + while True: + c = peek(state, pos) + if c is None or c in _SEPARATORS or _is_ctl(c) or not _is_char(c): + if pos == 0: + return None + + return consume(state, pos) + else: + pos += 1 + + +def consume_token_or_quoted_string(state): + """Consumes a token or a quoted-string, and returns the token or unquoted + string. If no token or quoted-string was found, returns None. + """ + + original_head = state.head + + if not consume_string(state, '"'): + return consume_token(state) + + result = [] + + expect_quoted_pair = False + + while True: + if not expect_quoted_pair and consume_lws(state): + result.append(' ') + continue + + c = consume(state) + if c is None: + # quoted-string is not enclosed with double quotation + state.head = original_head + return None + elif expect_quoted_pair: + expect_quoted_pair = False + if _is_char(c): + result.append(c) + else: + # Non CHAR character found in quoted-pair + state.head = original_head + return None + elif c == '\\': + expect_quoted_pair = True + elif c == '"': + return ''.join(result) + elif _is_ctl(c): + # Invalid character %r found in qdtext + state.head = original_head + return None + else: + result.append(c) + + +def quote_if_necessary(s): + """Quotes arbitrary string into quoted-string.""" + + quote = False + if s == '': + return '""' + + result = [] + for c in s: + if c == '"' or c in _SEPARATORS or _is_ctl(c) or not _is_char(c): + quote = True + + if c == '"' or _is_ctl(c): + result.append('\\' + c) + else: + result.append(c) + + if quote: + return '"' + ''.join(result) + '"' + else: + return ''.join(result) + + +def parse_uri(uri): + """Parse absolute URI then return host, port and resource.""" + + parsed = urlparse.urlsplit(uri) + if parsed.scheme != 'wss' and parsed.scheme != 'ws': + # |uri| must be a relative URI. + # TODO(toyoshim): Should validate |uri|. + return None, None, uri + + if parsed.hostname is None: + return None, None, None + + port = None + try: + port = parsed.port + except ValueError, e: + # port property cause ValueError on invalid null port description like + # 'ws://host:/path'. + return None, None, None + + if port is None: + if parsed.scheme == 'ws': + port = 80 + else: + port = 443 + + path = parsed.path + if not path: + path += '/' + if parsed.query: + path += '?' + parsed.query + if parsed.fragment: + path += '#' + parsed.fragment + + return parsed.hostname, port, path + + +try: + urlparse.uses_netloc.index('ws') +except ValueError, e: + # urlparse in Python2.5.1 doesn't have 'ws' and 'wss' entries. + urlparse.uses_netloc.append('ws') + urlparse.uses_netloc.append('wss') + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/memorizingfile.py b/pyload/lib/mod_pywebsocket/memorizingfile.py new file mode 100644 index 000000000..4d4cd9585 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/memorizingfile.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Memorizing file. + +A memorizing file wraps a file and memorizes lines read by readline. +""" + + +import sys + + +class MemorizingFile(object): + """MemorizingFile wraps a file and memorizes lines read by readline. + + Note that data read by other methods are not memorized. This behavior + is good enough for memorizing lines SimpleHTTPServer reads before + the control reaches WebSocketRequestHandler. + """ + + def __init__(self, file_, max_memorized_lines=sys.maxint): + """Construct an instance. + + Args: + file_: the file object to wrap. + max_memorized_lines: the maximum number of lines to memorize. + Only the first max_memorized_lines are memorized. + Default: sys.maxint. + """ + + self._file = file_ + self._memorized_lines = [] + self._max_memorized_lines = max_memorized_lines + self._buffered = False + self._buffered_line = None + + def __getattribute__(self, name): + if name in ('_file', '_memorized_lines', '_max_memorized_lines', + '_buffered', '_buffered_line', 'readline', + 'get_memorized_lines'): + return object.__getattribute__(self, name) + return self._file.__getattribute__(name) + + def readline(self, size=-1): + """Override file.readline and memorize the line read. + + Note that even if size is specified and smaller than actual size, + the whole line will be read out from underlying file object by + subsequent readline calls. + """ + + if self._buffered: + line = self._buffered_line + self._buffered = False + else: + line = self._file.readline() + if line and len(self._memorized_lines) < self._max_memorized_lines: + self._memorized_lines.append(line) + if size >= 0 and size < len(line): + self._buffered = True + self._buffered_line = line[size:] + return line[:size] + return line + + def get_memorized_lines(self): + """Get lines memorized so far.""" + return self._memorized_lines + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/msgutil.py b/pyload/lib/mod_pywebsocket/msgutil.py new file mode 100644 index 000000000..4c1a0114b --- /dev/null +++ b/pyload/lib/mod_pywebsocket/msgutil.py @@ -0,0 +1,219 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Message related utilities. + +Note: request.connection.write/read are used in this module, even though +mod_python document says that they should be used only in connection +handlers. Unfortunately, we have no other options. For example, +request.write/read are not suitable because they don't allow direct raw +bytes writing/reading. +""" + + +import Queue +import threading + + +# Export Exception symbols from msgutil for backward compatibility +from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException +from mod_pywebsocket._stream_base import BadOperationException +from mod_pywebsocket._stream_base import UnsupportedFrameException + + +# An API for handler to send/receive WebSocket messages. +def close_connection(request): + """Close connection. + + Args: + request: mod_python request. + """ + request.ws_stream.close_connection() + + +def send_message(request, payload_data, end=True, binary=False): + """Send a message (or part of a message). + + Args: + request: mod_python request. + payload_data: unicode text or str binary to send. + end: True to terminate a message. + False to send payload_data as part of a message that is to be + terminated by next or later send_message call with end=True. + binary: send payload_data as binary frame(s). + Raises: + BadOperationException: when server already terminated. + """ + request.ws_stream.send_message(payload_data, end, binary) + + +def receive_message(request): + """Receive a WebSocket frame and return its payload as a text in + unicode or a binary in str. + + Args: + request: mod_python request. + Raises: + InvalidFrameException: when client send invalid frame. + UnsupportedFrameException: when client send unsupported frame e.g. some + of reserved bit is set but no extension can + recognize it. + InvalidUTF8Exception: when client send a text frame containing any + invalid UTF-8 string. + ConnectionTerminatedException: when the connection is closed + unexpectedly. + BadOperationException: when client already terminated. + """ + return request.ws_stream.receive_message() + + +def send_ping(request, body=''): + request.ws_stream.send_ping(body) + + +class MessageReceiver(threading.Thread): + """This class receives messages from the client. + + This class provides three ways to receive messages: blocking, + non-blocking, and via callback. Callback has the highest precedence. + + Note: This class should not be used with the standalone server for wss + because pyOpenSSL used by the server raises a fatal error if the socket + is accessed from multiple threads. + """ + + def __init__(self, request, onmessage=None): + """Construct an instance. + + Args: + request: mod_python request. + onmessage: a function to be called when a message is received. + May be None. If not None, the function is called on + another thread. In that case, MessageReceiver.receive + and MessageReceiver.receive_nowait are useless + because they will never return any messages. + """ + + threading.Thread.__init__(self) + self._request = request + self._queue = Queue.Queue() + self._onmessage = onmessage + self._stop_requested = False + self.setDaemon(True) + self.start() + + def run(self): + try: + while not self._stop_requested: + message = receive_message(self._request) + if self._onmessage: + self._onmessage(message) + else: + self._queue.put(message) + finally: + close_connection(self._request) + + def receive(self): + """ Receive a message from the channel, blocking. + + Returns: + message as a unicode string. + """ + return self._queue.get() + + def receive_nowait(self): + """ Receive a message from the channel, non-blocking. + + Returns: + message as a unicode string if available. None otherwise. + """ + try: + message = self._queue.get_nowait() + except Queue.Empty: + message = None + return message + + def stop(self): + """Request to stop this instance. + + The instance will be stopped after receiving the next message. + This method may not be very useful, but there is no clean way + in Python to forcefully stop a running thread. + """ + self._stop_requested = True + + +class MessageSender(threading.Thread): + """This class sends messages to the client. + + This class provides both synchronous and asynchronous ways to send + messages. + + Note: This class should not be used with the standalone server for wss + because pyOpenSSL used by the server raises a fatal error if the socket + is accessed from multiple threads. + """ + + def __init__(self, request): + """Construct an instance. + + Args: + request: mod_python request. + """ + threading.Thread.__init__(self) + self._request = request + self._queue = Queue.Queue() + self.setDaemon(True) + self.start() + + def run(self): + while True: + message, condition = self._queue.get() + condition.acquire() + send_message(self._request, message) + condition.notify() + condition.release() + + def send(self, message): + """Send a message, blocking.""" + + condition = threading.Condition() + condition.acquire() + self._queue.put((message, condition)) + condition.wait() + + def send_nowait(self, message): + """Send a message, non-blocking.""" + + self._queue.put((message, threading.Condition())) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/mux.py b/pyload/lib/mod_pywebsocket/mux.py new file mode 100644 index 000000000..f0bdd2461 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/mux.py @@ -0,0 +1,1636 @@ +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file provides classes and helper functions for multiplexing extension. + +Specification: +http://tools.ietf.org/html/draft-ietf-hybi-websocket-multiplexing-06 +""" + + +import collections +import copy +import email +import email.parser +import logging +import math +import struct +import threading +import traceback + +from mod_pywebsocket import common +from mod_pywebsocket import handshake +from mod_pywebsocket import util +from mod_pywebsocket._stream_base import BadOperationException +from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_hybi import Frame +from mod_pywebsocket._stream_hybi import Stream +from mod_pywebsocket._stream_hybi import StreamOptions +from mod_pywebsocket._stream_hybi import create_binary_frame +from mod_pywebsocket._stream_hybi import create_closing_handshake_body +from mod_pywebsocket._stream_hybi import create_header +from mod_pywebsocket._stream_hybi import create_length_header +from mod_pywebsocket._stream_hybi import parse_frame +from mod_pywebsocket.handshake import hybi + + +_CONTROL_CHANNEL_ID = 0 +_DEFAULT_CHANNEL_ID = 1 + +_MUX_OPCODE_ADD_CHANNEL_REQUEST = 0 +_MUX_OPCODE_ADD_CHANNEL_RESPONSE = 1 +_MUX_OPCODE_FLOW_CONTROL = 2 +_MUX_OPCODE_DROP_CHANNEL = 3 +_MUX_OPCODE_NEW_CHANNEL_SLOT = 4 + +_MAX_CHANNEL_ID = 2 ** 29 - 1 + +_INITIAL_NUMBER_OF_CHANNEL_SLOTS = 64 +_INITIAL_QUOTA_FOR_CLIENT = 8 * 1024 + +_HANDSHAKE_ENCODING_IDENTITY = 0 +_HANDSHAKE_ENCODING_DELTA = 1 + +# We need only these status code for now. +_HTTP_BAD_RESPONSE_MESSAGES = { + common.HTTP_STATUS_BAD_REQUEST: 'Bad Request', +} + +# DropChannel reason code +# TODO(bashi): Define all reason code defined in -05 draft. +_DROP_CODE_NORMAL_CLOSURE = 1000 + +_DROP_CODE_INVALID_ENCAPSULATING_MESSAGE = 2001 +_DROP_CODE_CHANNEL_ID_TRUNCATED = 2002 +_DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED = 2003 +_DROP_CODE_UNKNOWN_MUX_OPCODE = 2004 +_DROP_CODE_INVALID_MUX_CONTROL_BLOCK = 2005 +_DROP_CODE_CHANNEL_ALREADY_EXISTS = 2006 +_DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION = 2007 + +_DROP_CODE_UNKNOWN_REQUEST_ENCODING = 3002 +_DROP_CODE_SEND_QUOTA_VIOLATION = 3005 +_DROP_CODE_ACKNOWLEDGED = 3008 + + +class MuxUnexpectedException(Exception): + """Exception in handling multiplexing extension.""" + pass + + +# Temporary +class MuxNotImplementedException(Exception): + """Raised when a flow enters unimplemented code path.""" + pass + + +class LogicalConnectionClosedException(Exception): + """Raised when logical connection is gracefully closed.""" + pass + + +class PhysicalConnectionError(Exception): + """Raised when there is a physical connection error.""" + def __init__(self, drop_code, message=''): + super(PhysicalConnectionError, self).__init__( + 'code=%d, message=%r' % (drop_code, message)) + self.drop_code = drop_code + self.message = message + + +class LogicalChannelError(Exception): + """Raised when there is a logical channel error.""" + def __init__(self, channel_id, drop_code, message=''): + super(LogicalChannelError, self).__init__( + 'channel_id=%d, code=%d, message=%r' % ( + channel_id, drop_code, message)) + self.channel_id = channel_id + self.drop_code = drop_code + self.message = message + + +def _encode_channel_id(channel_id): + if channel_id < 0: + raise ValueError('Channel id %d must not be negative' % channel_id) + + if channel_id < 2 ** 7: + return chr(channel_id) + if channel_id < 2 ** 14: + return struct.pack('!H', 0x8000 + channel_id) + if channel_id < 2 ** 21: + first = chr(0xc0 + (channel_id >> 16)) + return first + struct.pack('!H', channel_id & 0xffff) + if channel_id < 2 ** 29: + return struct.pack('!L', 0xe0000000 + channel_id) + + raise ValueError('Channel id %d is too large' % channel_id) + + +def _encode_number(number): + return create_length_header(number, False) + + +def _create_add_channel_response(channel_id, encoded_handshake, + encoding=0, rejected=False, + outer_frame_mask=False): + if encoding != 0 and encoding != 1: + raise ValueError('Invalid encoding %d' % encoding) + + first_byte = ((_MUX_OPCODE_ADD_CHANNEL_RESPONSE << 5) | + (rejected << 4) | encoding) + block = (chr(first_byte) + + _encode_channel_id(channel_id) + + _encode_number(len(encoded_handshake)) + + encoded_handshake) + payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block + return create_binary_frame(payload, mask=outer_frame_mask) + + +def _create_drop_channel(channel_id, code=None, message='', + outer_frame_mask=False): + if len(message) > 0 and code is None: + raise ValueError('Code must be specified if message is specified') + + first_byte = _MUX_OPCODE_DROP_CHANNEL << 5 + block = chr(first_byte) + _encode_channel_id(channel_id) + if code is None: + block += _encode_number(0) # Reason size + else: + reason = struct.pack('!H', code) + message + reason_size = _encode_number(len(reason)) + block += reason_size + reason + + payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block + return create_binary_frame(payload, mask=outer_frame_mask) + + +def _create_flow_control(channel_id, replenished_quota, + outer_frame_mask=False): + first_byte = _MUX_OPCODE_FLOW_CONTROL << 5 + block = (chr(first_byte) + + _encode_channel_id(channel_id) + + _encode_number(replenished_quota)) + payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block + return create_binary_frame(payload, mask=outer_frame_mask) + + +def _create_new_channel_slot(slots, send_quota, outer_frame_mask=False): + if slots < 0 or send_quota < 0: + raise ValueError('slots and send_quota must be non-negative.') + first_byte = _MUX_OPCODE_NEW_CHANNEL_SLOT << 5 + block = (chr(first_byte) + + _encode_number(slots) + + _encode_number(send_quota)) + payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block + return create_binary_frame(payload, mask=outer_frame_mask) + + +def _create_fallback_new_channel_slot(outer_frame_mask=False): + first_byte = (_MUX_OPCODE_NEW_CHANNEL_SLOT << 5) | 1 # Set the F flag + block = (chr(first_byte) + _encode_number(0) + _encode_number(0)) + payload = _encode_channel_id(_CONTROL_CHANNEL_ID) + block + return create_binary_frame(payload, mask=outer_frame_mask) + + +def _parse_request_text(request_text): + request_line, header_lines = request_text.split('\r\n', 1) + + words = request_line.split(' ') + if len(words) != 3: + raise ValueError('Bad Request-Line syntax %r' % request_line) + [command, path, version] = words + if version != 'HTTP/1.1': + raise ValueError('Bad request version %r' % version) + + # email.parser.Parser() parses RFC 2822 (RFC 822) style headers. + # RFC 6455 refers RFC 2616 for handshake parsing, and RFC 2616 refers + # RFC 822. + headers = email.parser.Parser().parsestr(header_lines) + return command, path, version, headers + + +class _ControlBlock(object): + """A structure that holds parsing result of multiplexing control block. + Control block specific attributes will be added by _MuxFramePayloadParser. + (e.g. encoded_handshake will be added for AddChannelRequest and + AddChannelResponse) + """ + + def __init__(self, opcode): + self.opcode = opcode + + +class _MuxFramePayloadParser(object): + """A class that parses multiplexed frame payload.""" + + def __init__(self, payload): + self._data = payload + self._read_position = 0 + self._logger = util.get_class_logger(self) + + def read_channel_id(self): + """Reads channel id. + + Raises: + ValueError: when the payload doesn't contain + valid channel id. + """ + + remaining_length = len(self._data) - self._read_position + pos = self._read_position + if remaining_length == 0: + raise ValueError('Invalid channel id format') + + channel_id = ord(self._data[pos]) + channel_id_length = 1 + if channel_id & 0xe0 == 0xe0: + if remaining_length < 4: + raise ValueError('Invalid channel id format') + channel_id = struct.unpack('!L', + self._data[pos:pos+4])[0] & 0x1fffffff + channel_id_length = 4 + elif channel_id & 0xc0 == 0xc0: + if remaining_length < 3: + raise ValueError('Invalid channel id format') + channel_id = (((channel_id & 0x1f) << 16) + + struct.unpack('!H', self._data[pos+1:pos+3])[0]) + channel_id_length = 3 + elif channel_id & 0x80 == 0x80: + if remaining_length < 2: + raise ValueError('Invalid channel id format') + channel_id = struct.unpack('!H', + self._data[pos:pos+2])[0] & 0x3fff + channel_id_length = 2 + self._read_position += channel_id_length + + return channel_id + + def read_inner_frame(self): + """Reads an inner frame. + + Raises: + PhysicalConnectionError: when the inner frame is invalid. + """ + + if len(self._data) == self._read_position: + raise PhysicalConnectionError( + _DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED) + + bits = ord(self._data[self._read_position]) + self._read_position += 1 + fin = (bits & 0x80) == 0x80 + rsv1 = (bits & 0x40) == 0x40 + rsv2 = (bits & 0x20) == 0x20 + rsv3 = (bits & 0x10) == 0x10 + opcode = bits & 0xf + payload = self.remaining_data() + # Consume rest of the message which is payload data of the original + # frame. + self._read_position = len(self._data) + return fin, rsv1, rsv2, rsv3, opcode, payload + + def _read_number(self): + if self._read_position + 1 > len(self._data): + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Cannot read the first byte of number field') + + number = ord(self._data[self._read_position]) + if number & 0x80 == 0x80: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'The most significant bit of the first byte of number should ' + 'be unset') + self._read_position += 1 + pos = self._read_position + if number == 127: + if pos + 8 > len(self._data): + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Invalid number field') + self._read_position += 8 + number = struct.unpack('!Q', self._data[pos:pos+8])[0] + if number > 0x7FFFFFFFFFFFFFFF: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Encoded number >= 2^63') + if number <= 0xFFFF: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + '%d should not be encoded by 9 bytes encoding' % number) + return number + if number == 126: + if pos + 2 > len(self._data): + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Invalid number field') + self._read_position += 2 + number = struct.unpack('!H', self._data[pos:pos+2])[0] + if number <= 125: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + '%d should not be encoded by 3 bytes encoding' % number) + return number + + def _read_size_and_contents(self): + """Reads data that consists of followings: + - the size of the contents encoded the same way as payload length + of the WebSocket Protocol with 1 bit padding at the head. + - the contents. + """ + + size = self._read_number() + pos = self._read_position + if pos + size > len(self._data): + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Cannot read %d bytes data' % size) + + self._read_position += size + return self._data[pos:pos+size] + + def _read_add_channel_request(self, first_byte, control_block): + reserved = (first_byte >> 2) & 0x7 + if reserved != 0: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Reserved bits must be unset') + + # Invalid encoding will be handled by MuxHandler. + encoding = first_byte & 0x3 + try: + control_block.channel_id = self.read_channel_id() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) + control_block.encoding = encoding + encoded_handshake = self._read_size_and_contents() + control_block.encoded_handshake = encoded_handshake + return control_block + + def _read_add_channel_response(self, first_byte, control_block): + reserved = (first_byte >> 2) & 0x3 + if reserved != 0: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Reserved bits must be unset') + + control_block.accepted = (first_byte >> 4) & 1 + control_block.encoding = first_byte & 0x3 + try: + control_block.channel_id = self.read_channel_id() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) + control_block.encoded_handshake = self._read_size_and_contents() + return control_block + + def _read_flow_control(self, first_byte, control_block): + reserved = first_byte & 0x1f + if reserved != 0: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Reserved bits must be unset') + + try: + control_block.channel_id = self.read_channel_id() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) + control_block.send_quota = self._read_number() + return control_block + + def _read_drop_channel(self, first_byte, control_block): + reserved = first_byte & 0x1f + if reserved != 0: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Reserved bits must be unset') + + try: + control_block.channel_id = self.read_channel_id() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_INVALID_MUX_CONTROL_BLOCK) + reason = self._read_size_and_contents() + if len(reason) == 0: + control_block.drop_code = None + control_block.drop_message = '' + elif len(reason) >= 2: + control_block.drop_code = struct.unpack('!H', reason[:2])[0] + control_block.drop_message = reason[2:] + else: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Received DropChannel that conains only 1-byte reason') + return control_block + + def _read_new_channel_slot(self, first_byte, control_block): + reserved = first_byte & 0x1e + if reserved != 0: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Reserved bits must be unset') + control_block.fallback = first_byte & 1 + control_block.slots = self._read_number() + control_block.send_quota = self._read_number() + return control_block + + def read_control_blocks(self): + """Reads control block(s). + + Raises: + PhysicalConnectionError: when the payload contains invalid control + block(s). + StopIteration: when no control blocks left. + """ + + while self._read_position < len(self._data): + first_byte = ord(self._data[self._read_position]) + self._read_position += 1 + opcode = (first_byte >> 5) & 0x7 + control_block = _ControlBlock(opcode=opcode) + if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST: + yield self._read_add_channel_request(first_byte, control_block) + elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE: + yield self._read_add_channel_response( + first_byte, control_block) + elif opcode == _MUX_OPCODE_FLOW_CONTROL: + yield self._read_flow_control(first_byte, control_block) + elif opcode == _MUX_OPCODE_DROP_CHANNEL: + yield self._read_drop_channel(first_byte, control_block) + elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT: + yield self._read_new_channel_slot(first_byte, control_block) + else: + raise PhysicalConnectionError( + _DROP_CODE_UNKNOWN_MUX_OPCODE, + 'Invalid opcode %d' % opcode) + + assert self._read_position == len(self._data) + raise StopIteration + + def remaining_data(self): + """Returns remaining data.""" + + return self._data[self._read_position:] + + +class _LogicalRequest(object): + """Mimics mod_python request.""" + + def __init__(self, channel_id, command, path, protocol, headers, + connection): + """Constructs an instance. + + Args: + channel_id: the channel id of the logical channel. + command: HTTP request command. + path: HTTP request path. + headers: HTTP headers. + connection: _LogicalConnection instance. + """ + + self.channel_id = channel_id + self.method = command + self.uri = path + self.protocol = protocol + self.headers_in = headers + self.connection = connection + self.server_terminated = False + self.client_terminated = False + + def is_https(self): + """Mimics request.is_https(). Returns False because this method is + used only by old protocols (hixie and hybi00). + """ + + return False + + +class _LogicalConnection(object): + """Mimics mod_python mp_conn.""" + + # For details, see the comment of set_read_state(). + STATE_ACTIVE = 1 + STATE_GRACEFULLY_CLOSED = 2 + STATE_TERMINATED = 3 + + def __init__(self, mux_handler, channel_id): + """Constructs an instance. + + Args: + mux_handler: _MuxHandler instance. + channel_id: channel id of this connection. + """ + + self._mux_handler = mux_handler + self._channel_id = channel_id + self._incoming_data = '' + self._write_condition = threading.Condition() + self._waiting_write_completion = False + self._read_condition = threading.Condition() + self._read_state = self.STATE_ACTIVE + + def get_local_addr(self): + """Getter to mimic mp_conn.local_addr.""" + + return self._mux_handler.physical_connection.get_local_addr() + local_addr = property(get_local_addr) + + def get_remote_addr(self): + """Getter to mimic mp_conn.remote_addr.""" + + return self._mux_handler.physical_connection.get_remote_addr() + remote_addr = property(get_remote_addr) + + def get_memorized_lines(self): + """Gets memorized lines. Not supported.""" + + raise MuxUnexpectedException('_LogicalConnection does not support ' + 'get_memorized_lines') + + def write(self, data): + """Writes data. mux_handler sends data asynchronously. The caller will + be suspended until write done. + + Args: + data: data to be written. + + Raises: + MuxUnexpectedException: when called before finishing the previous + write. + """ + + try: + self._write_condition.acquire() + if self._waiting_write_completion: + raise MuxUnexpectedException( + 'Logical connection %d is already waiting the completion ' + 'of write' % self._channel_id) + + self._waiting_write_completion = True + self._mux_handler.send_data(self._channel_id, data) + self._write_condition.wait() + finally: + self._write_condition.release() + + def write_control_data(self, data): + """Writes data via the control channel. Don't wait finishing write + because this method can be called by mux dispatcher. + + Args: + data: data to be written. + """ + + self._mux_handler.send_control_data(data) + + def notify_write_done(self): + """Called when sending data is completed.""" + + try: + self._write_condition.acquire() + if not self._waiting_write_completion: + raise MuxUnexpectedException( + 'Invalid call of notify_write_done for logical connection' + ' %d' % self._channel_id) + self._waiting_write_completion = False + self._write_condition.notify() + finally: + self._write_condition.release() + + def append_frame_data(self, frame_data): + """Appends incoming frame data. Called when mux_handler dispatches + frame data to the corresponding application. + + Args: + frame_data: incoming frame data. + """ + + self._read_condition.acquire() + self._incoming_data += frame_data + self._read_condition.notify() + self._read_condition.release() + + def read(self, length): + """Reads data. Blocks until enough data has arrived via physical + connection. + + Args: + length: length of data to be read. + Raises: + LogicalConnectionClosedException: when closing handshake for this + logical channel has been received. + ConnectionTerminatedException: when the physical connection has + closed, or an error is caused on the reader thread. + """ + + self._read_condition.acquire() + while (self._read_state == self.STATE_ACTIVE and + len(self._incoming_data) < length): + self._read_condition.wait() + + try: + if self._read_state == self.STATE_GRACEFULLY_CLOSED: + raise LogicalConnectionClosedException( + 'Logical channel %d has closed.' % self._channel_id) + elif self._read_state == self.STATE_TERMINATED: + raise ConnectionTerminatedException( + 'Receiving %d byte failed. Logical channel (%d) closed' % + (length, self._channel_id)) + + value = self._incoming_data[:length] + self._incoming_data = self._incoming_data[length:] + finally: + self._read_condition.release() + + return value + + def set_read_state(self, new_state): + """Sets the state of this connection. Called when an event for this + connection has occurred. + + Args: + new_state: state to be set. new_state must be one of followings: + - STATE_GRACEFULLY_CLOSED: when closing handshake for this + connection has been received. + - STATE_TERMINATED: when the physical connection has closed or + DropChannel of this connection has received. + """ + + self._read_condition.acquire() + self._read_state = new_state + self._read_condition.notify() + self._read_condition.release() + + +class _LogicalStream(Stream): + """Mimics the Stream class. This class interprets multiplexed WebSocket + frames. + """ + + def __init__(self, request, send_quota, receive_quota): + """Constructs an instance. + + Args: + request: _LogicalRequest instance. + send_quota: Initial send quota. + receive_quota: Initial receive quota. + """ + + # TODO(bashi): Support frame filters. + stream_options = StreamOptions() + # Physical stream is responsible for masking. + stream_options.unmask_receive = False + # Control frames can be fragmented on logical channel. + stream_options.allow_fragmented_control_frame = True + Stream.__init__(self, request, stream_options) + self._send_quota = send_quota + self._send_quota_condition = threading.Condition() + self._receive_quota = receive_quota + self._write_inner_frame_semaphore = threading.Semaphore() + + def _create_inner_frame(self, opcode, payload, end=True): + # TODO(bashi): Support extensions that use reserved bits. + first_byte = (end << 7) | opcode + return (_encode_channel_id(self._request.channel_id) + + chr(first_byte) + payload) + + def _write_inner_frame(self, opcode, payload, end=True): + payload_length = len(payload) + write_position = 0 + + try: + # An inner frame will be fragmented if there is no enough send + # quota. This semaphore ensures that fragmented inner frames are + # sent in order on the logical channel. + # Note that frames that come from other logical channels or + # multiplexing control blocks can be inserted between fragmented + # inner frames on the physical channel. + self._write_inner_frame_semaphore.acquire() + while write_position < payload_length: + try: + self._send_quota_condition.acquire() + while self._send_quota == 0: + self._logger.debug( + 'No quota. Waiting FlowControl message for %d.' % + self._request.channel_id) + self._send_quota_condition.wait() + + remaining = payload_length - write_position + write_length = min(self._send_quota, remaining) + inner_frame_end = ( + end and + (write_position + write_length == payload_length)) + + inner_frame = self._create_inner_frame( + opcode, + payload[write_position:write_position+write_length], + inner_frame_end) + frame_data = self._writer.build( + inner_frame, end=True, binary=True) + self._send_quota -= write_length + self._logger.debug('Consumed quota=%d, remaining=%d' % + (write_length, self._send_quota)) + finally: + self._send_quota_condition.release() + + # Writing data will block the worker so we need to release + # _send_quota_condition before writing. + self._logger.debug('Sending inner frame: %r' % frame_data) + self._request.connection.write(frame_data) + write_position += write_length + + opcode = common.OPCODE_CONTINUATION + + except ValueError, e: + raise BadOperationException(e) + finally: + self._write_inner_frame_semaphore.release() + + def replenish_send_quota(self, send_quota): + """Replenish send quota.""" + + self._send_quota_condition.acquire() + self._send_quota += send_quota + self._logger.debug('Replenished send quota for channel id %d: %d' % + (self._request.channel_id, self._send_quota)) + self._send_quota_condition.notify() + self._send_quota_condition.release() + + def consume_receive_quota(self, amount): + """Consumes receive quota. Returns False on failure.""" + + if self._receive_quota < amount: + self._logger.debug('Violate quota on channel id %d: %d < %d' % + (self._request.channel_id, + self._receive_quota, amount)) + return False + self._receive_quota -= amount + return True + + def send_message(self, message, end=True, binary=False): + """Override Stream.send_message.""" + + if self._request.server_terminated: + raise BadOperationException( + 'Requested send_message after sending out a closing handshake') + + if binary and isinstance(message, unicode): + raise BadOperationException( + 'Message for binary frame must be instance of str') + + if binary: + opcode = common.OPCODE_BINARY + else: + opcode = common.OPCODE_TEXT + message = message.encode('utf-8') + + self._write_inner_frame(opcode, message, end) + + def _receive_frame(self): + """Overrides Stream._receive_frame. + + In addition to call Stream._receive_frame, this method adds the amount + of payload to receiving quota and sends FlowControl to the client. + We need to do it here because Stream.receive_message() handles + control frames internally. + """ + + opcode, payload, fin, rsv1, rsv2, rsv3 = Stream._receive_frame(self) + amount = len(payload) + self._receive_quota += amount + frame_data = _create_flow_control(self._request.channel_id, + amount) + self._logger.debug('Sending flow control for %d, replenished=%d' % + (self._request.channel_id, amount)) + self._request.connection.write_control_data(frame_data) + return opcode, payload, fin, rsv1, rsv2, rsv3 + + def receive_message(self): + """Overrides Stream.receive_message.""" + + # Just call Stream.receive_message(), but catch + # LogicalConnectionClosedException, which is raised when the logical + # connection has closed gracefully. + try: + return Stream.receive_message(self) + except LogicalConnectionClosedException, e: + self._logger.debug('%s', e) + return None + + def _send_closing_handshake(self, code, reason): + """Overrides Stream._send_closing_handshake.""" + + body = create_closing_handshake_body(code, reason) + self._logger.debug('Sending closing handshake for %d: (%r, %r)' % + (self._request.channel_id, code, reason)) + self._write_inner_frame(common.OPCODE_CLOSE, body, end=True) + + self._request.server_terminated = True + + def send_ping(self, body=''): + """Overrides Stream.send_ping""" + + self._logger.debug('Sending ping on logical channel %d: %r' % + (self._request.channel_id, body)) + self._write_inner_frame(common.OPCODE_PING, body, end=True) + + self._ping_queue.append(body) + + def _send_pong(self, body): + """Overrides Stream._send_pong""" + + self._logger.debug('Sending pong on logical channel %d: %r' % + (self._request.channel_id, body)) + self._write_inner_frame(common.OPCODE_PONG, body, end=True) + + def close_connection(self, code=common.STATUS_NORMAL_CLOSURE, reason=''): + """Overrides Stream.close_connection.""" + + # TODO(bashi): Implement + self._logger.debug('Closing logical connection %d' % + self._request.channel_id) + self._request.server_terminated = True + + def _drain_received_data(self): + """Overrides Stream._drain_received_data. Nothing need to be done for + logical channel. + """ + + pass + + +class _OutgoingData(object): + """A structure that holds data to be sent via physical connection and + origin of the data. + """ + + def __init__(self, channel_id, data): + self.channel_id = channel_id + self.data = data + + +class _PhysicalConnectionWriter(threading.Thread): + """A thread that is responsible for writing data to physical connection. + + TODO(bashi): Make sure there is no thread-safety problem when the reader + thread reads data from the same socket at a time. + """ + + def __init__(self, mux_handler): + """Constructs an instance. + + Args: + mux_handler: _MuxHandler instance. + """ + + threading.Thread.__init__(self) + self._logger = util.get_class_logger(self) + self._mux_handler = mux_handler + self.setDaemon(True) + self._stop_requested = False + self._deque = collections.deque() + self._deque_condition = threading.Condition() + + def put_outgoing_data(self, data): + """Puts outgoing data. + + Args: + data: _OutgoingData instance. + + Raises: + BadOperationException: when the thread has been requested to + terminate. + """ + + try: + self._deque_condition.acquire() + if self._stop_requested: + raise BadOperationException('Cannot write data anymore') + + self._deque.append(data) + self._deque_condition.notify() + finally: + self._deque_condition.release() + + def _write_data(self, outgoing_data): + try: + self._mux_handler.physical_connection.write(outgoing_data.data) + except Exception, e: + util.prepend_message_to_exception( + 'Failed to send message to %r: ' % + (self._mux_handler.physical_connection.remote_addr,), e) + raise + + # TODO(bashi): It would be better to block the thread that sends + # control data as well. + if outgoing_data.channel_id != _CONTROL_CHANNEL_ID: + self._mux_handler.notify_write_done(outgoing_data.channel_id) + + def run(self): + self._deque_condition.acquire() + while not self._stop_requested: + if len(self._deque) == 0: + self._deque_condition.wait() + continue + + outgoing_data = self._deque.popleft() + self._deque_condition.release() + self._write_data(outgoing_data) + self._deque_condition.acquire() + + # Flush deque + try: + while len(self._deque) > 0: + outgoing_data = self._deque.popleft() + self._write_data(outgoing_data) + finally: + self._deque_condition.release() + + def stop(self): + """Stops the writer thread.""" + + self._deque_condition.acquire() + self._stop_requested = True + self._deque_condition.notify() + self._deque_condition.release() + + +class _PhysicalConnectionReader(threading.Thread): + """A thread that is responsible for reading data from physical connection. + """ + + def __init__(self, mux_handler): + """Constructs an instance. + + Args: + mux_handler: _MuxHandler instance. + """ + + threading.Thread.__init__(self) + self._logger = util.get_class_logger(self) + self._mux_handler = mux_handler + self.setDaemon(True) + + def run(self): + while True: + try: + physical_stream = self._mux_handler.physical_stream + message = physical_stream.receive_message() + if message is None: + break + # Below happens only when a data message is received. + opcode = physical_stream.get_last_received_opcode() + if opcode != common.OPCODE_BINARY: + self._mux_handler.fail_physical_connection( + _DROP_CODE_INVALID_ENCAPSULATING_MESSAGE, + 'Received a text message on physical connection') + break + + except ConnectionTerminatedException, e: + self._logger.debug('%s', e) + break + + try: + self._mux_handler.dispatch_message(message) + except PhysicalConnectionError, e: + self._mux_handler.fail_physical_connection( + e.drop_code, e.message) + break + except LogicalChannelError, e: + self._mux_handler.fail_logical_channel( + e.channel_id, e.drop_code, e.message) + except Exception, e: + self._logger.debug(traceback.format_exc()) + break + + self._mux_handler.notify_reader_done() + + +class _Worker(threading.Thread): + """A thread that is responsible for running the corresponding application + handler. + """ + + def __init__(self, mux_handler, request): + """Constructs an instance. + + Args: + mux_handler: _MuxHandler instance. + request: _LogicalRequest instance. + """ + + threading.Thread.__init__(self) + self._logger = util.get_class_logger(self) + self._mux_handler = mux_handler + self._request = request + self.setDaemon(True) + + def run(self): + self._logger.debug('Logical channel worker started. (id=%d)' % + self._request.channel_id) + try: + # Non-critical exceptions will be handled by dispatcher. + self._mux_handler.dispatcher.transfer_data(self._request) + finally: + self._mux_handler.notify_worker_done(self._request.channel_id) + + +class _MuxHandshaker(hybi.Handshaker): + """Opening handshake processor for multiplexing.""" + + _DUMMY_WEBSOCKET_KEY = 'dGhlIHNhbXBsZSBub25jZQ==' + + def __init__(self, request, dispatcher, send_quota, receive_quota): + """Constructs an instance. + Args: + request: _LogicalRequest instance. + dispatcher: Dispatcher instance (dispatch.Dispatcher). + send_quota: Initial send quota. + receive_quota: Initial receive quota. + """ + + hybi.Handshaker.__init__(self, request, dispatcher) + self._send_quota = send_quota + self._receive_quota = receive_quota + + # Append headers which should not be included in handshake field of + # AddChannelRequest. + # TODO(bashi): Make sure whether we should raise exception when + # these headers are included already. + request.headers_in[common.UPGRADE_HEADER] = ( + common.WEBSOCKET_UPGRADE_TYPE) + request.headers_in[common.CONNECTION_HEADER] = ( + common.UPGRADE_CONNECTION_TYPE) + request.headers_in[common.SEC_WEBSOCKET_VERSION_HEADER] = ( + str(common.VERSION_HYBI_LATEST)) + request.headers_in[common.SEC_WEBSOCKET_KEY_HEADER] = ( + self._DUMMY_WEBSOCKET_KEY) + + def _create_stream(self, stream_options): + """Override hybi.Handshaker._create_stream.""" + + self._logger.debug('Creating logical stream for %d' % + self._request.channel_id) + return _LogicalStream(self._request, self._send_quota, + self._receive_quota) + + def _create_handshake_response(self, accept): + """Override hybi._create_handshake_response.""" + + response = [] + + response.append('HTTP/1.1 101 Switching Protocols\r\n') + + # Upgrade, Connection and Sec-WebSocket-Accept should be excluded. + if self._request.ws_protocol is not None: + response.append('%s: %s\r\n' % ( + common.SEC_WEBSOCKET_PROTOCOL_HEADER, + self._request.ws_protocol)) + if (self._request.ws_extensions is not None and + len(self._request.ws_extensions) != 0): + response.append('%s: %s\r\n' % ( + common.SEC_WEBSOCKET_EXTENSIONS_HEADER, + common.format_extensions(self._request.ws_extensions))) + response.append('\r\n') + + return ''.join(response) + + def _send_handshake(self, accept): + """Override hybi.Handshaker._send_handshake.""" + + # Don't send handshake response for the default channel + if self._request.channel_id == _DEFAULT_CHANNEL_ID: + return + + handshake_response = self._create_handshake_response(accept) + frame_data = _create_add_channel_response( + self._request.channel_id, + handshake_response) + self._logger.debug('Sending handshake response for %d: %r' % + (self._request.channel_id, frame_data)) + self._request.connection.write_control_data(frame_data) + + +class _LogicalChannelData(object): + """A structure that holds information about logical channel. + """ + + def __init__(self, request, worker): + self.request = request + self.worker = worker + self.drop_code = _DROP_CODE_NORMAL_CLOSURE + self.drop_message = '' + + +class _HandshakeDeltaBase(object): + """A class that holds information for delta-encoded handshake.""" + + def __init__(self, headers): + self._headers = headers + + def create_headers(self, delta=None): + """Creates request headers for an AddChannelRequest that has + delta-encoded handshake. + + Args: + delta: headers should be overridden. + """ + + headers = copy.copy(self._headers) + if delta: + for key, value in delta.items(): + # The spec requires that a header with an empty value is + # removed from the delta base. + if len(value) == 0 and headers.has_key(key): + del headers[key] + else: + headers[key] = value + # TODO(bashi): Support extensions + headers['Sec-WebSocket-Extensions'] = '' + return headers + + +class _MuxHandler(object): + """Multiplexing handler. When a handler starts, it launches three + threads; the reader thread, the writer thread, and a worker thread. + + The reader thread reads data from the physical stream, i.e., the + ws_stream object of the underlying websocket connection. The reader + thread interprets multiplexed frames and dispatches them to logical + channels. Methods of this class are mostly called by the reader thread. + + The writer thread sends multiplexed frames which are created by + logical channels via the physical connection. + + The worker thread launched at the starting point handles the + "Implicitly Opened Connection". If multiplexing handler receives + an AddChannelRequest and accepts it, the handler will launch a new worker + thread and dispatch the request to it. + """ + + def __init__(self, request, dispatcher): + """Constructs an instance. + + Args: + request: mod_python request of the physical connection. + dispatcher: Dispatcher instance (dispatch.Dispatcher). + """ + + self.original_request = request + self.dispatcher = dispatcher + self.physical_connection = request.connection + self.physical_stream = request.ws_stream + self._logger = util.get_class_logger(self) + self._logical_channels = {} + self._logical_channels_condition = threading.Condition() + # Holds client's initial quota + self._channel_slots = collections.deque() + self._handshake_base = None + self._worker_done_notify_received = False + self._reader = None + self._writer = None + + def start(self): + """Starts the handler. + + Raises: + MuxUnexpectedException: when the handler already started, or when + opening handshake of the default channel fails. + """ + + if self._reader or self._writer: + raise MuxUnexpectedException('MuxHandler already started') + + self._reader = _PhysicalConnectionReader(self) + self._writer = _PhysicalConnectionWriter(self) + self._reader.start() + self._writer.start() + + # Create "Implicitly Opened Connection". + logical_connection = _LogicalConnection(self, _DEFAULT_CHANNEL_ID) + self._handshake_base = _HandshakeDeltaBase( + self.original_request.headers_in) + logical_request = _LogicalRequest( + _DEFAULT_CHANNEL_ID, + self.original_request.method, + self.original_request.uri, + self.original_request.protocol, + self._handshake_base.create_headers(), + logical_connection) + # Client's send quota for the implicitly opened connection is zero, + # but we will send FlowControl later so set the initial quota to + # _INITIAL_QUOTA_FOR_CLIENT. + self._channel_slots.append(_INITIAL_QUOTA_FOR_CLIENT) + if not self._do_handshake_for_logical_request( + logical_request, send_quota=self.original_request.mux_quota): + raise MuxUnexpectedException( + 'Failed handshake on the default channel id') + self._add_logical_channel(logical_request) + + # Send FlowControl for the implicitly opened connection. + frame_data = _create_flow_control(_DEFAULT_CHANNEL_ID, + _INITIAL_QUOTA_FOR_CLIENT) + logical_request.connection.write_control_data(frame_data) + + def add_channel_slots(self, slots, send_quota): + """Adds channel slots. + + Args: + slots: number of slots to be added. + send_quota: initial send quota for slots. + """ + + self._channel_slots.extend([send_quota] * slots) + # Send NewChannelSlot to client. + frame_data = _create_new_channel_slot(slots, send_quota) + self.send_control_data(frame_data) + + def wait_until_done(self, timeout=None): + """Waits until all workers are done. Returns False when timeout has + occurred. Returns True on success. + + Args: + timeout: timeout in sec. + """ + + self._logical_channels_condition.acquire() + try: + while len(self._logical_channels) > 0: + self._logger.debug('Waiting workers(%d)...' % + len(self._logical_channels)) + self._worker_done_notify_received = False + self._logical_channels_condition.wait(timeout) + if not self._worker_done_notify_received: + self._logger.debug('Waiting worker(s) timed out') + return False + + finally: + self._logical_channels_condition.release() + + # Flush pending outgoing data + self._writer.stop() + self._writer.join() + + return True + + def notify_write_done(self, channel_id): + """Called by the writer thread when a write operation has done. + + Args: + channel_id: objective channel id. + """ + + try: + self._logical_channels_condition.acquire() + if channel_id in self._logical_channels: + channel_data = self._logical_channels[channel_id] + channel_data.request.connection.notify_write_done() + else: + self._logger.debug('Seems that logical channel for %d has gone' + % channel_id) + finally: + self._logical_channels_condition.release() + + def send_control_data(self, data): + """Sends data via the control channel. + + Args: + data: data to be sent. + """ + + self._writer.put_outgoing_data(_OutgoingData( + channel_id=_CONTROL_CHANNEL_ID, data=data)) + + def send_data(self, channel_id, data): + """Sends data via given logical channel. This method is called by + worker threads. + + Args: + channel_id: objective channel id. + data: data to be sent. + """ + + self._writer.put_outgoing_data(_OutgoingData( + channel_id=channel_id, data=data)) + + def _send_drop_channel(self, channel_id, code=None, message=''): + frame_data = _create_drop_channel(channel_id, code, message) + self._logger.debug( + 'Sending drop channel for channel id %d' % channel_id) + self.send_control_data(frame_data) + + def _send_error_add_channel_response(self, channel_id, status=None): + if status is None: + status = common.HTTP_STATUS_BAD_REQUEST + + if status in _HTTP_BAD_RESPONSE_MESSAGES: + message = _HTTP_BAD_RESPONSE_MESSAGES[status] + else: + self._logger.debug('Response message for %d is not found' % status) + message = '???' + + response = 'HTTP/1.1 %d %s\r\n\r\n' % (status, message) + frame_data = _create_add_channel_response(channel_id, + encoded_handshake=response, + encoding=0, rejected=True) + self.send_control_data(frame_data) + + def _create_logical_request(self, block): + if block.channel_id == _CONTROL_CHANNEL_ID: + # TODO(bashi): Raise PhysicalConnectionError with code 2006 + # instead of MuxUnexpectedException. + raise MuxUnexpectedException( + 'Received the control channel id (0) as objective channel ' + 'id for AddChannel') + + if block.encoding > _HANDSHAKE_ENCODING_DELTA: + raise PhysicalConnectionError( + _DROP_CODE_UNKNOWN_REQUEST_ENCODING) + + method, path, version, headers = _parse_request_text( + block.encoded_handshake) + if block.encoding == _HANDSHAKE_ENCODING_DELTA: + headers = self._handshake_base.create_headers(headers) + + connection = _LogicalConnection(self, block.channel_id) + request = _LogicalRequest(block.channel_id, method, path, version, + headers, connection) + return request + + def _do_handshake_for_logical_request(self, request, send_quota=0): + try: + receive_quota = self._channel_slots.popleft() + except IndexError: + raise LogicalChannelError( + request.channel_id, _DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION) + + handshaker = _MuxHandshaker(request, self.dispatcher, + send_quota, receive_quota) + try: + handshaker.do_handshake() + except handshake.VersionException, e: + self._logger.info('%s', e) + self._send_error_add_channel_response( + request.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) + return False + except handshake.HandshakeException, e: + # TODO(bashi): Should we _Fail the Logical Channel_ with 3001 + # instead? + self._logger.info('%s', e) + self._send_error_add_channel_response(request.channel_id, + status=e.status) + return False + except handshake.AbortedByUserException, e: + self._logger.info('%s', e) + self._send_error_add_channel_response(request.channel_id) + return False + + return True + + def _add_logical_channel(self, logical_request): + try: + self._logical_channels_condition.acquire() + if logical_request.channel_id in self._logical_channels: + self._logger.debug('Channel id %d already exists' % + logical_request.channel_id) + raise PhysicalConnectionError( + _DROP_CODE_CHANNEL_ALREADY_EXISTS, + 'Channel id %d already exists' % + logical_request.channel_id) + worker = _Worker(self, logical_request) + channel_data = _LogicalChannelData(logical_request, worker) + self._logical_channels[logical_request.channel_id] = channel_data + worker.start() + finally: + self._logical_channels_condition.release() + + def _process_add_channel_request(self, block): + try: + logical_request = self._create_logical_request(block) + except ValueError, e: + self._logger.debug('Failed to create logical request: %r' % e) + self._send_error_add_channel_response( + block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) + return + if self._do_handshake_for_logical_request(logical_request): + if block.encoding == _HANDSHAKE_ENCODING_IDENTITY: + # Update handshake base. + # TODO(bashi): Make sure this is the right place to update + # handshake base. + self._handshake_base = _HandshakeDeltaBase( + logical_request.headers_in) + self._add_logical_channel(logical_request) + else: + self._send_error_add_channel_response( + block.channel_id, status=common.HTTP_STATUS_BAD_REQUEST) + + def _process_flow_control(self, block): + try: + self._logical_channels_condition.acquire() + if not block.channel_id in self._logical_channels: + return + channel_data = self._logical_channels[block.channel_id] + channel_data.request.ws_stream.replenish_send_quota( + block.send_quota) + finally: + self._logical_channels_condition.release() + + def _process_drop_channel(self, block): + self._logger.debug( + 'DropChannel received for %d: code=%r, reason=%r' % + (block.channel_id, block.drop_code, block.drop_message)) + try: + self._logical_channels_condition.acquire() + if not block.channel_id in self._logical_channels: + return + channel_data = self._logical_channels[block.channel_id] + channel_data.drop_code = _DROP_CODE_ACKNOWLEDGED + # Close the logical channel + channel_data.request.connection.set_read_state( + _LogicalConnection.STATE_TERMINATED) + finally: + self._logical_channels_condition.release() + + def _process_control_blocks(self, parser): + for control_block in parser.read_control_blocks(): + opcode = control_block.opcode + self._logger.debug('control block received, opcode: %d' % opcode) + if opcode == _MUX_OPCODE_ADD_CHANNEL_REQUEST: + self._process_add_channel_request(control_block) + elif opcode == _MUX_OPCODE_ADD_CHANNEL_RESPONSE: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Received AddChannelResponse') + elif opcode == _MUX_OPCODE_FLOW_CONTROL: + self._process_flow_control(control_block) + elif opcode == _MUX_OPCODE_DROP_CHANNEL: + self._process_drop_channel(control_block) + elif opcode == _MUX_OPCODE_NEW_CHANNEL_SLOT: + raise PhysicalConnectionError( + _DROP_CODE_INVALID_MUX_CONTROL_BLOCK, + 'Received NewChannelSlot') + else: + raise MuxUnexpectedException( + 'Unexpected opcode %r' % opcode) + + def _process_logical_frame(self, channel_id, parser): + self._logger.debug('Received a frame. channel id=%d' % channel_id) + try: + self._logical_channels_condition.acquire() + if not channel_id in self._logical_channels: + # We must ignore the message for an inactive channel. + return + channel_data = self._logical_channels[channel_id] + fin, rsv1, rsv2, rsv3, opcode, payload = parser.read_inner_frame() + if not channel_data.request.ws_stream.consume_receive_quota( + len(payload)): + # The client violates quota. Close logical channel. + raise LogicalChannelError( + channel_id, _DROP_CODE_SEND_QUOTA_VIOLATION) + header = create_header(opcode, len(payload), fin, rsv1, rsv2, rsv3, + mask=False) + frame_data = header + payload + channel_data.request.connection.append_frame_data(frame_data) + finally: + self._logical_channels_condition.release() + + def dispatch_message(self, message): + """Dispatches message. The reader thread calls this method. + + Args: + message: a message that contains encapsulated frame. + Raises: + PhysicalConnectionError: if the message contains physical + connection level errors. + LogicalChannelError: if the message contains logical channel + level errors. + """ + + parser = _MuxFramePayloadParser(message) + try: + channel_id = parser.read_channel_id() + except ValueError, e: + raise PhysicalConnectionError(_DROP_CODE_CHANNEL_ID_TRUNCATED) + if channel_id == _CONTROL_CHANNEL_ID: + self._process_control_blocks(parser) + else: + self._process_logical_frame(channel_id, parser) + + def notify_worker_done(self, channel_id): + """Called when a worker has finished. + + Args: + channel_id: channel id corresponded with the worker. + """ + + self._logger.debug('Worker for channel id %d terminated' % channel_id) + try: + self._logical_channels_condition.acquire() + if not channel_id in self._logical_channels: + raise MuxUnexpectedException( + 'Channel id %d not found' % channel_id) + channel_data = self._logical_channels.pop(channel_id) + finally: + self._worker_done_notify_received = True + self._logical_channels_condition.notify() + self._logical_channels_condition.release() + + if not channel_data.request.server_terminated: + self._send_drop_channel( + channel_id, code=channel_data.drop_code, + message=channel_data.drop_message) + + def notify_reader_done(self): + """This method is called by the reader thread when the reader has + finished. + """ + + # Terminate all logical connections + self._logger.debug('termiating all logical connections...') + self._logical_channels_condition.acquire() + for channel_data in self._logical_channels.values(): + try: + channel_data.request.connection.set_read_state( + _LogicalConnection.STATE_TERMINATED) + except Exception: + pass + self._logical_channels_condition.release() + + def fail_physical_connection(self, code, message): + """Fail the physical connection. + + Args: + code: drop reason code. + message: drop message. + """ + + self._logger.debug('Failing the physical connection...') + self._send_drop_channel(_CONTROL_CHANNEL_ID, code, message) + self.physical_stream.close_connection( + common.STATUS_INTERNAL_ENDPOINT_ERROR) + + def fail_logical_channel(self, channel_id, code, message): + """Fail a logical channel. + + Args: + channel_id: channel id. + code: drop reason code. + message: drop message. + """ + + self._logger.debug('Failing logical channel %d...' % channel_id) + try: + self._logical_channels_condition.acquire() + if channel_id in self._logical_channels: + channel_data = self._logical_channels[channel_id] + # Close the logical channel. notify_worker_done() will be + # called later and it will send DropChannel. + channel_data.drop_code = code + channel_data.drop_message = message + channel_data.request.connection.set_read_state( + _LogicalConnection.STATE_TERMINATED) + else: + self._send_drop_channel(channel_id, code, message) + finally: + self._logical_channels_condition.release() + + +def use_mux(request): + return hasattr(request, 'mux') and request.mux + + +def start(request, dispatcher): + mux_handler = _MuxHandler(request, dispatcher) + mux_handler.start() + + mux_handler.add_channel_slots(_INITIAL_NUMBER_OF_CHANNEL_SLOTS, + _INITIAL_QUOTA_FOR_CLIENT) + + mux_handler.wait_until_done() + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/standalone.py b/pyload/lib/mod_pywebsocket/standalone.py new file mode 100755 index 000000000..07a33d9c9 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/standalone.py @@ -0,0 +1,998 @@ +#!/usr/bin/env python +# +# Copyright 2012, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""Standalone WebSocket server. + +Use this file to launch pywebsocket without Apache HTTP Server. + + +BASIC USAGE + +Go to the src directory and run + + $ python mod_pywebsocket/standalone.py [-p <ws_port>] + [-w <websock_handlers>] + [-d <document_root>] + +<ws_port> is the port number to use for ws:// connection. + +<document_root> is the path to the root directory of HTML files. + +<websock_handlers> is the path to the root directory of WebSocket handlers. +If not specified, <document_root> will be used. See __init__.py (or +run $ pydoc mod_pywebsocket) for how to write WebSocket handlers. + +For more detail and other options, run + + $ python mod_pywebsocket/standalone.py --help + +or see _build_option_parser method below. + +For trouble shooting, adding "--log_level debug" might help you. + + +TRY DEMO + +Go to the src directory and run + + $ python standalone.py -d example + +to launch pywebsocket with the sample handler and html on port 80. Open +http://localhost/console.html, click the connect button, type something into +the text box next to the send button and click the send button. If everything +is working, you'll see the message you typed echoed by the server. + + +SUPPORTING TLS + +To support TLS, run standalone.py with -t, -k, and -c options. + + +SUPPORTING CLIENT AUTHENTICATION + +To support client authentication with TLS, run standalone.py with -t, -k, -c, +and --tls-client-auth, and --tls-client-ca options. + +E.g., $./standalone.py -d ../example -p 10443 -t -c ../test/cert/cert.pem -k +../test/cert/key.pem --tls-client-auth --tls-client-ca=../test/cert/cacert.pem + + +CONFIGURATION FILE + +You can also write a configuration file and use it by specifying the path to +the configuration file by --config option. Please write a configuration file +following the documentation of the Python ConfigParser library. Name of each +entry must be the long version argument name. E.g. to set log level to debug, +add the following line: + +log_level=debug + +For options which doesn't take value, please add some fake value. E.g. for +--tls option, add the following line: + +tls=True + +Note that tls will be enabled even if you write tls=False as the value part is +fake. + +When both a command line argument and a configuration file entry are set for +the same configuration item, the command line value will override one in the +configuration file. + + +THREADING + +This server is derived from SocketServer.ThreadingMixIn. Hence a thread is +used for each request. + + +SECURITY WARNING + +This uses CGIHTTPServer and CGIHTTPServer is not secure. +It may execute arbitrary Python code or external programs. It should not be +used outside a firewall. +""" + +import BaseHTTPServer +import CGIHTTPServer +import SimpleHTTPServer +import SocketServer +import ConfigParser +import base64 +import httplib +import logging +import logging.handlers +import optparse +import os +import re +import select +import socket +import sys +import threading +import time + +_HAS_SSL = False +_HAS_OPEN_SSL = False +try: + import ssl + _HAS_SSL = True +except ImportError: + try: + import OpenSSL.SSL + _HAS_OPEN_SSL = True + except ImportError: + pass + +from mod_pywebsocket import common +from mod_pywebsocket import dispatch +from mod_pywebsocket import handshake +from mod_pywebsocket import http_header_util +from mod_pywebsocket import memorizingfile +from mod_pywebsocket import util + + +_DEFAULT_LOG_MAX_BYTES = 1024 * 256 +_DEFAULT_LOG_BACKUP_COUNT = 5 + +_DEFAULT_REQUEST_QUEUE_SIZE = 128 + +# 1024 is practically large enough to contain WebSocket handshake lines. +_MAX_MEMORIZED_LINES = 1024 + + +class _StandaloneConnection(object): + """Mimic mod_python mp_conn.""" + + def __init__(self, request_handler): + """Construct an instance. + + Args: + request_handler: A WebSocketRequestHandler instance. + """ + + self._request_handler = request_handler + + def get_local_addr(self): + """Getter to mimic mp_conn.local_addr.""" + + return (self._request_handler.server.server_name, + self._request_handler.server.server_port) + local_addr = property(get_local_addr) + + def get_remote_addr(self): + """Getter to mimic mp_conn.remote_addr. + + Setting the property in __init__ won't work because the request + handler is not initialized yet there.""" + + return self._request_handler.client_address + remote_addr = property(get_remote_addr) + + def write(self, data): + """Mimic mp_conn.write().""" + + return self._request_handler.wfile.write(data) + + def read(self, length): + """Mimic mp_conn.read().""" + + return self._request_handler.rfile.read(length) + + def get_memorized_lines(self): + """Get memorized lines.""" + + return self._request_handler.rfile.get_memorized_lines() + + +class _StandaloneRequest(object): + """Mimic mod_python request.""" + + def __init__(self, request_handler, use_tls): + """Construct an instance. + + Args: + request_handler: A WebSocketRequestHandler instance. + """ + + self._logger = util.get_class_logger(self) + + self._request_handler = request_handler + self.connection = _StandaloneConnection(request_handler) + self._use_tls = use_tls + self.headers_in = request_handler.headers + + def get_uri(self): + """Getter to mimic request.uri.""" + + return self._request_handler.path + uri = property(get_uri) + + def get_method(self): + """Getter to mimic request.method.""" + + return self._request_handler.command + method = property(get_method) + + def get_protocol(self): + """Getter to mimic request.protocol.""" + + return self._request_handler.request_version + protocol = property(get_protocol) + + def is_https(self): + """Mimic request.is_https().""" + + return self._use_tls + + def _drain_received_data(self): + """Don't use this method from WebSocket handler. Drains unread data + in the receive buffer. + """ + + raw_socket = self._request_handler.connection + drained_data = util.drain_received_data(raw_socket) + + if drained_data: + self._logger.debug( + 'Drained data following close frame: %r', drained_data) + + +class _StandaloneSSLConnection(object): + """A wrapper class for OpenSSL.SSL.Connection to provide makefile method + which is not supported by the class. + """ + + def __init__(self, connection): + self._connection = connection + + def __getattribute__(self, name): + if name in ('_connection', 'makefile'): + return object.__getattribute__(self, name) + return self._connection.__getattribute__(name) + + def __setattr__(self, name, value): + if name in ('_connection', 'makefile'): + return object.__setattr__(self, name, value) + return self._connection.__setattr__(name, value) + + def makefile(self, mode='r', bufsize=-1): + return socket._fileobject(self._connection, mode, bufsize) + + +def _alias_handlers(dispatcher, websock_handlers_map_file): + """Set aliases specified in websock_handler_map_file in dispatcher. + + Args: + dispatcher: dispatch.Dispatcher instance + websock_handler_map_file: alias map file + """ + + fp = open(websock_handlers_map_file) + try: + for line in fp: + if line[0] == '#' or line.isspace(): + continue + m = re.match('(\S+)\s+(\S+)', line) + if not m: + logging.warning('Wrong format in map file:' + line) + continue + try: + dispatcher.add_resource_path_alias( + m.group(1), m.group(2)) + except dispatch.DispatchException, e: + logging.error(str(e)) + finally: + fp.close() + + +class WebSocketServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): + """HTTPServer specialized for WebSocket.""" + + # Overrides SocketServer.ThreadingMixIn.daemon_threads + daemon_threads = True + # Overrides BaseHTTPServer.HTTPServer.allow_reuse_address + allow_reuse_address = True + + def __init__(self, options): + """Override SocketServer.TCPServer.__init__ to set SSL enabled + socket object to self.socket before server_bind and server_activate, + if necessary. + """ + + # Share a Dispatcher among request handlers to save time for + # instantiation. Dispatcher can be shared because it is thread-safe. + options.dispatcher = dispatch.Dispatcher( + options.websock_handlers, + options.scan_dir, + options.allow_handlers_outside_root_dir) + if options.websock_handlers_map_file: + _alias_handlers(options.dispatcher, + options.websock_handlers_map_file) + warnings = options.dispatcher.source_warnings() + if warnings: + for warning in warnings: + logging.warning('mod_pywebsocket: %s' % warning) + + self._logger = util.get_class_logger(self) + + self.request_queue_size = options.request_queue_size + self.__ws_is_shut_down = threading.Event() + self.__ws_serving = False + + SocketServer.BaseServer.__init__( + self, (options.server_host, options.port), WebSocketRequestHandler) + + # Expose the options object to allow handler objects access it. We name + # it with websocket_ prefix to avoid conflict. + self.websocket_server_options = options + + self._create_sockets() + self.server_bind() + self.server_activate() + + def _create_sockets(self): + self.server_name, self.server_port = self.server_address + self._sockets = [] + if not self.server_name: + # On platforms that doesn't support IPv6, the first bind fails. + # On platforms that supports IPv6 + # - If it binds both IPv4 and IPv6 on call with AF_INET6, the + # first bind succeeds and the second fails (we'll see 'Address + # already in use' error). + # - If it binds only IPv6 on call with AF_INET6, both call are + # expected to succeed to listen both protocol. + addrinfo_array = [ + (socket.AF_INET6, socket.SOCK_STREAM, '', '', ''), + (socket.AF_INET, socket.SOCK_STREAM, '', '', '')] + else: + addrinfo_array = socket.getaddrinfo(self.server_name, + self.server_port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP) + for addrinfo in addrinfo_array: + self._logger.info('Create socket on: %r', addrinfo) + family, socktype, proto, canonname, sockaddr = addrinfo + try: + socket_ = socket.socket(family, socktype) + except Exception, e: + self._logger.info('Skip by failure: %r', e) + continue + if self.websocket_server_options.use_tls: + if _HAS_SSL: + if self.websocket_server_options.tls_client_auth: + client_cert_ = ssl.CERT_REQUIRED + else: + client_cert_ = ssl.CERT_NONE + socket_ = ssl.wrap_socket(socket_, + keyfile=self.websocket_server_options.private_key, + certfile=self.websocket_server_options.certificate, + ssl_version=ssl.PROTOCOL_SSLv23, + ca_certs=self.websocket_server_options.tls_client_ca, + cert_reqs=client_cert_) + if _HAS_OPEN_SSL: + ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD) + ctx.use_privatekey_file( + self.websocket_server_options.private_key) + ctx.use_certificate_file( + self.websocket_server_options.certificate) + socket_ = OpenSSL.SSL.Connection(ctx, socket_) + self._sockets.append((socket_, addrinfo)) + + def server_bind(self): + """Override SocketServer.TCPServer.server_bind to enable multiple + sockets bind. + """ + + failed_sockets = [] + + for socketinfo in self._sockets: + socket_, addrinfo = socketinfo + self._logger.info('Bind on: %r', addrinfo) + if self.allow_reuse_address: + socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + socket_.bind(self.server_address) + except Exception, e: + self._logger.info('Skip by failure: %r', e) + socket_.close() + failed_sockets.append(socketinfo) + if self.server_address[1] == 0: + # The operating system assigns the actual port number for port + # number 0. This case, the second and later sockets should use + # the same port number. Also self.server_port is rewritten + # because it is exported, and will be used by external code. + self.server_address = ( + self.server_name, socket_.getsockname()[1]) + self.server_port = self.server_address[1] + self._logger.info('Port %r is assigned', self.server_port) + + for socketinfo in failed_sockets: + self._sockets.remove(socketinfo) + + def server_activate(self): + """Override SocketServer.TCPServer.server_activate to enable multiple + sockets listen. + """ + + failed_sockets = [] + + for socketinfo in self._sockets: + socket_, addrinfo = socketinfo + self._logger.info('Listen on: %r', addrinfo) + try: + socket_.listen(self.request_queue_size) + except Exception, e: + self._logger.info('Skip by failure: %r', e) + socket_.close() + failed_sockets.append(socketinfo) + + for socketinfo in failed_sockets: + self._sockets.remove(socketinfo) + + if len(self._sockets) == 0: + self._logger.critical( + 'No sockets activated. Use info log level to see the reason.') + + def server_close(self): + """Override SocketServer.TCPServer.server_close to enable multiple + sockets close. + """ + + for socketinfo in self._sockets: + socket_, addrinfo = socketinfo + self._logger.info('Close on: %r', addrinfo) + socket_.close() + + def fileno(self): + """Override SocketServer.TCPServer.fileno.""" + + self._logger.critical('Not supported: fileno') + return self._sockets[0][0].fileno() + + def handle_error(self, rquest, client_address): + """Override SocketServer.handle_error.""" + + self._logger.error( + 'Exception in processing request from: %r\n%s', + client_address, + util.get_stack_trace()) + # Note: client_address is a tuple. + + def get_request(self): + """Override TCPServer.get_request to wrap OpenSSL.SSL.Connection + object with _StandaloneSSLConnection to provide makefile method. We + cannot substitute OpenSSL.SSL.Connection.makefile since it's readonly + attribute. + """ + + accepted_socket, client_address = self.socket.accept() + if self.websocket_server_options.use_tls and _HAS_OPEN_SSL: + accepted_socket = _StandaloneSSLConnection(accepted_socket) + return accepted_socket, client_address + + def serve_forever(self, poll_interval=0.5): + """Override SocketServer.BaseServer.serve_forever.""" + + self.__ws_serving = True + self.__ws_is_shut_down.clear() + handle_request = self.handle_request + if hasattr(self, '_handle_request_noblock'): + handle_request = self._handle_request_noblock + else: + self._logger.warning('Fallback to blocking request handler') + try: + while self.__ws_serving: + r, w, e = select.select( + [socket_[0] for socket_ in self._sockets], + [], [], poll_interval) + for socket_ in r: + self.socket = socket_ + handle_request() + self.socket = None + finally: + self.__ws_is_shut_down.set() + + def shutdown(self): + """Override SocketServer.BaseServer.shutdown.""" + + self.__ws_serving = False + self.__ws_is_shut_down.wait() + + +class WebSocketRequestHandler(CGIHTTPServer.CGIHTTPRequestHandler): + """CGIHTTPRequestHandler specialized for WebSocket.""" + + # Use httplib.HTTPMessage instead of mimetools.Message. + MessageClass = httplib.HTTPMessage + + def setup(self): + """Override SocketServer.StreamRequestHandler.setup to wrap rfile + with MemorizingFile. + + This method will be called by BaseRequestHandler's constructor + before calling BaseHTTPRequestHandler.handle. + BaseHTTPRequestHandler.handle will call + BaseHTTPRequestHandler.handle_one_request and it will call + WebSocketRequestHandler.parse_request. + """ + + # Call superclass's setup to prepare rfile, wfile, etc. See setup + # definition on the root class SocketServer.StreamRequestHandler to + # understand what this does. + CGIHTTPServer.CGIHTTPRequestHandler.setup(self) + + self.rfile = memorizingfile.MemorizingFile( + self.rfile, + max_memorized_lines=_MAX_MEMORIZED_LINES) + + def __init__(self, request, client_address, server): + self._logger = util.get_class_logger(self) + + self._options = server.websocket_server_options + + # Overrides CGIHTTPServerRequestHandler.cgi_directories. + self.cgi_directories = self._options.cgi_directories + # Replace CGIHTTPRequestHandler.is_executable method. + if self._options.is_executable_method is not None: + self.is_executable = self._options.is_executable_method + + # This actually calls BaseRequestHandler.__init__. + CGIHTTPServer.CGIHTTPRequestHandler.__init__( + self, request, client_address, server) + + def parse_request(self): + """Override BaseHTTPServer.BaseHTTPRequestHandler.parse_request. + + Return True to continue processing for HTTP(S), False otherwise. + + See BaseHTTPRequestHandler.handle_one_request method which calls + this method to understand how the return value will be handled. + """ + + # We hook parse_request method, but also call the original + # CGIHTTPRequestHandler.parse_request since when we return False, + # CGIHTTPRequestHandler.handle_one_request continues processing and + # it needs variables set by CGIHTTPRequestHandler.parse_request. + # + # Variables set by this method will be also used by WebSocket request + # handling (self.path, self.command, self.requestline, etc. See also + # how _StandaloneRequest's members are implemented using these + # attributes). + if not CGIHTTPServer.CGIHTTPRequestHandler.parse_request(self): + return False + + if self._options.use_basic_auth: + auth = self.headers.getheader('Authorization') + if auth != self._options.basic_auth_credential: + self.send_response(401) + self.send_header('WWW-Authenticate', + 'Basic realm="Pywebsocket"') + self.end_headers() + self._logger.info('Request basic authentication') + return True + + host, port, resource = http_header_util.parse_uri(self.path) + if resource is None: + self._logger.info('Invalid URI: %r', self.path) + self._logger.info('Fallback to CGIHTTPRequestHandler') + return True + server_options = self.server.websocket_server_options + if host is not None: + validation_host = server_options.validation_host + if validation_host is not None and host != validation_host: + self._logger.info('Invalid host: %r (expected: %r)', + host, + validation_host) + self._logger.info('Fallback to CGIHTTPRequestHandler') + return True + if port is not None: + validation_port = server_options.validation_port + if validation_port is not None and port != validation_port: + self._logger.info('Invalid port: %r (expected: %r)', + port, + validation_port) + self._logger.info('Fallback to CGIHTTPRequestHandler') + return True + self.path = resource + + request = _StandaloneRequest(self, self._options.use_tls) + + try: + # Fallback to default http handler for request paths for which + # we don't have request handlers. + if not self._options.dispatcher.get_handler_suite(self.path): + self._logger.info('No handler for resource: %r', + self.path) + self._logger.info('Fallback to CGIHTTPRequestHandler') + return True + except dispatch.DispatchException, e: + self._logger.info('%s', e) + self.send_error(e.status) + return False + + # If any Exceptions without except clause setup (including + # DispatchException) is raised below this point, it will be caught + # and logged by WebSocketServer. + + try: + try: + handshake.do_handshake( + request, + self._options.dispatcher, + allowDraft75=self._options.allow_draft75, + strict=self._options.strict) + except handshake.VersionException, e: + self._logger.info('%s', e) + self.send_response(common.HTTP_STATUS_BAD_REQUEST) + self.send_header(common.SEC_WEBSOCKET_VERSION_HEADER, + e.supported_versions) + self.end_headers() + return False + except handshake.HandshakeException, e: + # Handshake for ws(s) failed. + self._logger.info('%s', e) + self.send_error(e.status) + return False + + request._dispatcher = self._options.dispatcher + self._options.dispatcher.transfer_data(request) + except handshake.AbortedByUserException, e: + self._logger.info('%s', e) + return False + + def log_request(self, code='-', size='-'): + """Override BaseHTTPServer.log_request.""" + + self._logger.info('"%s" %s %s', + self.requestline, str(code), str(size)) + + def log_error(self, *args): + """Override BaseHTTPServer.log_error.""" + + # Despite the name, this method is for warnings than for errors. + # For example, HTTP status code is logged by this method. + self._logger.warning('%s - %s', + self.address_string(), + args[0] % args[1:]) + + def is_cgi(self): + """Test whether self.path corresponds to a CGI script. + + Add extra check that self.path doesn't contains .. + Also check if the file is a executable file or not. + If the file is not executable, it is handled as static file or dir + rather than a CGI script. + """ + + if CGIHTTPServer.CGIHTTPRequestHandler.is_cgi(self): + if '..' in self.path: + return False + # strip query parameter from request path + resource_name = self.path.split('?', 2)[0] + # convert resource_name into real path name in filesystem. + scriptfile = self.translate_path(resource_name) + if not os.path.isfile(scriptfile): + return False + if not self.is_executable(scriptfile): + return False + return True + return False + + +def _get_logger_from_class(c): + return logging.getLogger('%s.%s' % (c.__module__, c.__name__)) + + +def _configure_logging(options): + logging.addLevelName(common.LOGLEVEL_FINE, 'FINE') + + logger = logging.getLogger() + logger.setLevel(logging.getLevelName(options.log_level.upper())) + if options.log_file: + handler = logging.handlers.RotatingFileHandler( + options.log_file, 'a', options.log_max, options.log_count) + else: + handler = logging.StreamHandler() + formatter = logging.Formatter( + '[%(asctime)s] [%(levelname)s] %(name)s: %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + + deflate_log_level_name = logging.getLevelName( + options.deflate_log_level.upper()) + _get_logger_from_class(util._Deflater).setLevel( + deflate_log_level_name) + _get_logger_from_class(util._Inflater).setLevel( + deflate_log_level_name) + + +def _build_option_parser(): + parser = optparse.OptionParser() + + parser.add_option('--config', dest='config_file', type='string', + default=None, + help=('Path to configuration file. See the file comment ' + 'at the top of this file for the configuration ' + 'file format')) + parser.add_option('-H', '--server-host', '--server_host', + dest='server_host', + default='', + help='server hostname to listen to') + parser.add_option('-V', '--validation-host', '--validation_host', + dest='validation_host', + default=None, + help='server hostname to validate in absolute path.') + parser.add_option('-p', '--port', dest='port', type='int', + default=common.DEFAULT_WEB_SOCKET_PORT, + help='port to listen to') + parser.add_option('-P', '--validation-port', '--validation_port', + dest='validation_port', type='int', + default=None, + help='server port to validate in absolute path.') + parser.add_option('-w', '--websock-handlers', '--websock_handlers', + dest='websock_handlers', + default='.', + help=('The root directory of WebSocket handler files. ' + 'If the path is relative, --document-root is used ' + 'as the base.')) + parser.add_option('-m', '--websock-handlers-map-file', + '--websock_handlers_map_file', + dest='websock_handlers_map_file', + default=None, + help=('WebSocket handlers map file. ' + 'Each line consists of alias_resource_path and ' + 'existing_resource_path, separated by spaces.')) + parser.add_option('-s', '--scan-dir', '--scan_dir', dest='scan_dir', + default=None, + help=('Must be a directory under --websock-handlers. ' + 'Only handlers under this directory are scanned ' + 'and registered to the server. ' + 'Useful for saving scan time when the handler ' + 'root directory contains lots of files that are ' + 'not handler file or are handler files but you ' + 'don\'t want them to be registered. ')) + parser.add_option('--allow-handlers-outside-root-dir', + '--allow_handlers_outside_root_dir', + dest='allow_handlers_outside_root_dir', + action='store_true', + default=False, + help=('Scans WebSocket handlers even if their canonical ' + 'path is not under --websock-handlers.')) + parser.add_option('-d', '--document-root', '--document_root', + dest='document_root', default='.', + help='Document root directory.') + parser.add_option('-x', '--cgi-paths', '--cgi_paths', dest='cgi_paths', + default=None, + help=('CGI paths relative to document_root.' + 'Comma-separated. (e.g -x /cgi,/htbin) ' + 'Files under document_root/cgi_path are handled ' + 'as CGI programs. Must be executable.')) + parser.add_option('-t', '--tls', dest='use_tls', action='store_true', + default=False, help='use TLS (wss://)') + parser.add_option('-k', '--private-key', '--private_key', + dest='private_key', + default='', help='TLS private key file.') + parser.add_option('-c', '--certificate', dest='certificate', + default='', help='TLS certificate file.') + parser.add_option('--tls-client-auth', dest='tls_client_auth', + action='store_true', default=False, + help='Requires TLS client auth on every connection.') + parser.add_option('--tls-client-ca', dest='tls_client_ca', default='', + help=('Specifies a pem file which contains a set of ' + 'concatenated CA certificates which are used to ' + 'validate certificates passed from clients')) + parser.add_option('--basic-auth', dest='use_basic_auth', + action='store_true', default=False, + help='Requires Basic authentication.') + parser.add_option('--basic-auth-credential', + dest='basic_auth_credential', default='test:test', + help='Specifies the credential of basic authentication ' + 'by username:password pair (e.g. test:test).') + parser.add_option('-l', '--log-file', '--log_file', dest='log_file', + default='', help='Log file.') + # Custom log level: + # - FINE: Prints status of each frame processing step + parser.add_option('--log-level', '--log_level', type='choice', + dest='log_level', default='warn', + choices=['fine', + 'debug', 'info', 'warning', 'warn', 'error', + 'critical'], + help='Log level.') + parser.add_option('--deflate-log-level', '--deflate_log_level', + type='choice', + dest='deflate_log_level', default='warn', + choices=['debug', 'info', 'warning', 'warn', 'error', + 'critical'], + help='Log level for _Deflater and _Inflater.') + parser.add_option('--thread-monitor-interval-in-sec', + '--thread_monitor_interval_in_sec', + dest='thread_monitor_interval_in_sec', + type='int', default=-1, + help=('If positive integer is specified, run a thread ' + 'monitor to show the status of server threads ' + 'periodically in the specified inteval in ' + 'second. If non-positive integer is specified, ' + 'disable the thread monitor.')) + parser.add_option('--log-max', '--log_max', dest='log_max', type='int', + default=_DEFAULT_LOG_MAX_BYTES, + help='Log maximum bytes') + parser.add_option('--log-count', '--log_count', dest='log_count', + type='int', default=_DEFAULT_LOG_BACKUP_COUNT, + help='Log backup count') + parser.add_option('--allow-draft75', dest='allow_draft75', + action='store_true', default=False, + help='Obsolete option. Ignored.') + parser.add_option('--strict', dest='strict', action='store_true', + default=False, help='Obsolete option. Ignored.') + parser.add_option('-q', '--queue', dest='request_queue_size', type='int', + default=_DEFAULT_REQUEST_QUEUE_SIZE, + help='request queue size') + + return parser + + +class ThreadMonitor(threading.Thread): + daemon = True + + def __init__(self, interval_in_sec): + threading.Thread.__init__(self, name='ThreadMonitor') + + self._logger = util.get_class_logger(self) + + self._interval_in_sec = interval_in_sec + + def run(self): + while True: + thread_name_list = [] + for thread in threading.enumerate(): + thread_name_list.append(thread.name) + self._logger.info( + "%d active threads: %s", + threading.active_count(), + ', '.join(thread_name_list)) + time.sleep(self._interval_in_sec) + + +def _parse_args_and_config(args): + parser = _build_option_parser() + + # First, parse options without configuration file. + temporary_options, temporary_args = parser.parse_args(args=args) + if temporary_args: + logging.critical( + 'Unrecognized positional arguments: %r', temporary_args) + sys.exit(1) + + if temporary_options.config_file: + try: + config_fp = open(temporary_options.config_file, 'r') + except IOError, e: + logging.critical( + 'Failed to open configuration file %r: %r', + temporary_options.config_file, + e) + sys.exit(1) + + config_parser = ConfigParser.SafeConfigParser() + config_parser.readfp(config_fp) + config_fp.close() + + args_from_config = [] + for name, value in config_parser.items('pywebsocket'): + args_from_config.append('--' + name) + args_from_config.append(value) + if args is None: + args = args_from_config + else: + args = args_from_config + args + return parser.parse_args(args=args) + else: + return temporary_options, temporary_args + + +def _main(args=None): + """You can call this function from your own program, but please note that + this function has some side-effects that might affect your program. For + example, util.wrap_popen3_for_win use in this method replaces implementation + of os.popen3. + """ + + options, args = _parse_args_and_config(args=args) + + os.chdir(options.document_root) + + _configure_logging(options) + + # TODO(tyoshino): Clean up initialization of CGI related values. Move some + # of code here to WebSocketRequestHandler class if it's better. + options.cgi_directories = [] + options.is_executable_method = None + if options.cgi_paths: + options.cgi_directories = options.cgi_paths.split(',') + if sys.platform in ('cygwin', 'win32'): + cygwin_path = None + # For Win32 Python, it is expected that CYGWIN_PATH + # is set to a directory of cygwin binaries. + # For example, websocket_server.py in Chromium sets CYGWIN_PATH to + # full path of third_party/cygwin/bin. + if 'CYGWIN_PATH' in os.environ: + cygwin_path = os.environ['CYGWIN_PATH'] + util.wrap_popen3_for_win(cygwin_path) + + def __check_script(scriptpath): + return util.get_script_interp(scriptpath, cygwin_path) + + options.is_executable_method = __check_script + + if options.use_tls: + if not (_HAS_SSL or _HAS_OPEN_SSL): + logging.critical('TLS support requires ssl or pyOpenSSL module.') + sys.exit(1) + if not options.private_key or not options.certificate: + logging.critical( + 'To use TLS, specify private_key and certificate.') + sys.exit(1) + + if options.tls_client_auth: + if not options.use_tls: + logging.critical('TLS must be enabled for client authentication.') + sys.exit(1) + if not _HAS_SSL: + logging.critical('Client authentication requires ssl module.') + + if not options.scan_dir: + options.scan_dir = options.websock_handlers + + if options.use_basic_auth: + options.basic_auth_credential = 'Basic ' + base64.b64encode( + options.basic_auth_credential) + + try: + if options.thread_monitor_interval_in_sec > 0: + # Run a thread monitor to show the status of server threads for + # debugging. + ThreadMonitor(options.thread_monitor_interval_in_sec).start() + + server = WebSocketServer(options) + server.serve_forever() + except Exception, e: + logging.critical('mod_pywebsocket: %s' % e) + logging.critical('mod_pywebsocket: %s' % util.get_stack_trace()) + sys.exit(1) + + +if __name__ == '__main__': + _main(sys.argv[1:]) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/stream.py b/pyload/lib/mod_pywebsocket/stream.py new file mode 100644 index 000000000..edc533279 --- /dev/null +++ b/pyload/lib/mod_pywebsocket/stream.py @@ -0,0 +1,57 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""This file exports public symbols. +""" + + +from mod_pywebsocket._stream_base import BadOperationException +from mod_pywebsocket._stream_base import ConnectionTerminatedException +from mod_pywebsocket._stream_base import InvalidFrameException +from mod_pywebsocket._stream_base import InvalidUTF8Exception +from mod_pywebsocket._stream_base import UnsupportedFrameException +from mod_pywebsocket._stream_hixie75 import StreamHixie75 +from mod_pywebsocket._stream_hybi import Frame +from mod_pywebsocket._stream_hybi import Stream +from mod_pywebsocket._stream_hybi import StreamOptions + +# These methods are intended to be used by WebSocket client developers to have +# their implementations receive broken data in tests. +from mod_pywebsocket._stream_hybi import create_close_frame +from mod_pywebsocket._stream_hybi import create_header +from mod_pywebsocket._stream_hybi import create_length_header +from mod_pywebsocket._stream_hybi import create_ping_frame +from mod_pywebsocket._stream_hybi import create_pong_frame +from mod_pywebsocket._stream_hybi import create_binary_frame +from mod_pywebsocket._stream_hybi import create_text_frame +from mod_pywebsocket._stream_hybi import create_closing_handshake_body + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/mod_pywebsocket/util.py b/pyload/lib/mod_pywebsocket/util.py new file mode 100644 index 000000000..7bb0b5d9e --- /dev/null +++ b/pyload/lib/mod_pywebsocket/util.py @@ -0,0 +1,515 @@ +# Copyright 2011, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +"""WebSocket utilities. +""" + + +import array +import errno + +# Import hash classes from a module available and recommended for each Python +# version and re-export those symbol. Use sha and md5 module in Python 2.4, and +# hashlib module in Python 2.6. +try: + import hashlib + md5_hash = hashlib.md5 + sha1_hash = hashlib.sha1 +except ImportError: + import md5 + import sha + md5_hash = md5.md5 + sha1_hash = sha.sha + +import StringIO +import logging +import os +import re +import socket +import traceback +import zlib + + +def get_stack_trace(): + """Get the current stack trace as string. + + This is needed to support Python 2.3. + TODO: Remove this when we only support Python 2.4 and above. + Use traceback.format_exc instead. + """ + + out = StringIO.StringIO() + traceback.print_exc(file=out) + return out.getvalue() + + +def prepend_message_to_exception(message, exc): + """Prepend message to the exception.""" + + exc.args = (message + str(exc),) + return + + +def __translate_interp(interp, cygwin_path): + """Translate interp program path for Win32 python to run cygwin program + (e.g. perl). Note that it doesn't support path that contains space, + which is typically true for Unix, where #!-script is written. + For Win32 python, cygwin_path is a directory of cygwin binaries. + + Args: + interp: interp command line + cygwin_path: directory name of cygwin binary, or None + Returns: + translated interp command line. + """ + if not cygwin_path: + return interp + m = re.match('^[^ ]*/([^ ]+)( .*)?', interp) + if m: + cmd = os.path.join(cygwin_path, m.group(1)) + return cmd + m.group(2) + return interp + + +def get_script_interp(script_path, cygwin_path=None): + """Gets #!-interpreter command line from the script. + + It also fixes command path. When Cygwin Python is used, e.g. in WebKit, + it could run "/usr/bin/perl -wT hello.pl". + When Win32 Python is used, e.g. in Chromium, it couldn't. So, fix + "/usr/bin/perl" to "<cygwin_path>\perl.exe". + + Args: + script_path: pathname of the script + cygwin_path: directory name of cygwin binary, or None + Returns: + #!-interpreter command line, or None if it is not #!-script. + """ + fp = open(script_path) + line = fp.readline() + fp.close() + m = re.match('^#!(.*)', line) + if m: + return __translate_interp(m.group(1), cygwin_path) + return None + + +def wrap_popen3_for_win(cygwin_path): + """Wrap popen3 to support #!-script on Windows. + + Args: + cygwin_path: path for cygwin binary if command path is needed to be + translated. None if no translation required. + """ + + __orig_popen3 = os.popen3 + + def __wrap_popen3(cmd, mode='t', bufsize=-1): + cmdline = cmd.split(' ') + interp = get_script_interp(cmdline[0], cygwin_path) + if interp: + cmd = interp + ' ' + cmd + return __orig_popen3(cmd, mode, bufsize) + + os.popen3 = __wrap_popen3 + + +def hexify(s): + return ' '.join(map(lambda x: '%02x' % ord(x), s)) + + +def get_class_logger(o): + return logging.getLogger( + '%s.%s' % (o.__class__.__module__, o.__class__.__name__)) + + +class NoopMasker(object): + """A masking object that has the same interface as RepeatedXorMasker but + just returns the string passed in without making any change. + """ + + def __init__(self): + pass + + def mask(self, s): + return s + + +class RepeatedXorMasker(object): + """A masking object that applies XOR on the string given to mask method + with the masking bytes given to the constructor repeatedly. This object + remembers the position in the masking bytes the last mask method call + ended and resumes from that point on the next mask method call. + """ + + def __init__(self, mask): + self._mask = map(ord, mask) + self._mask_size = len(self._mask) + self._count = 0 + + def mask(self, s): + result = array.array('B') + result.fromstring(s) + # Use temporary local variables to eliminate the cost to access + # attributes + count = self._count + mask = self._mask + mask_size = self._mask_size + for i in xrange(len(result)): + result[i] ^= mask[count] + count = (count + 1) % mask_size + self._count = count + + return result.tostring() + + +class DeflateRequest(object): + """A wrapper class for request object to intercept send and recv to perform + deflate compression and decompression transparently. + """ + + def __init__(self, request): + self._request = request + self.connection = DeflateConnection(request.connection) + + def __getattribute__(self, name): + if name in ('_request', 'connection'): + return object.__getattribute__(self, name) + return self._request.__getattribute__(name) + + def __setattr__(self, name, value): + if name in ('_request', 'connection'): + return object.__setattr__(self, name, value) + return self._request.__setattr__(name, value) + + +# By making wbits option negative, we can suppress CMF/FLG (2 octet) and +# ADLER32 (4 octet) fields of zlib so that we can use zlib module just as +# deflate library. DICTID won't be added as far as we don't set dictionary. +# LZ77 window of 32K will be used for both compression and decompression. +# For decompression, we can just use 32K to cover any windows size. For +# compression, we use 32K so receivers must use 32K. +# +# Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level +# to decode. +# +# See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of +# Python. See also RFC1950 (ZLIB 3.3). + + +class _Deflater(object): + + def __init__(self, window_bits): + self._logger = get_class_logger(self) + + self._compress = zlib.compressobj( + zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits) + + def compress(self, bytes): + compressed_bytes = self._compress.compress(bytes) + self._logger.debug('Compress input %r', bytes) + self._logger.debug('Compress result %r', compressed_bytes) + return compressed_bytes + + def compress_and_flush(self, bytes): + compressed_bytes = self._compress.compress(bytes) + compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH) + self._logger.debug('Compress input %r', bytes) + self._logger.debug('Compress result %r', compressed_bytes) + return compressed_bytes + + def compress_and_finish(self, bytes): + compressed_bytes = self._compress.compress(bytes) + compressed_bytes += self._compress.flush(zlib.Z_FINISH) + self._logger.debug('Compress input %r', bytes) + self._logger.debug('Compress result %r', compressed_bytes) + return compressed_bytes + +class _Inflater(object): + + def __init__(self): + self._logger = get_class_logger(self) + + self._unconsumed = '' + + self.reset() + + def decompress(self, size): + if not (size == -1 or size > 0): + raise Exception('size must be -1 or positive') + + data = '' + + while True: + if size == -1: + data += self._decompress.decompress(self._unconsumed) + # See Python bug http://bugs.python.org/issue12050 to + # understand why the same code cannot be used for updating + # self._unconsumed for here and else block. + self._unconsumed = '' + else: + data += self._decompress.decompress( + self._unconsumed, size - len(data)) + self._unconsumed = self._decompress.unconsumed_tail + if self._decompress.unused_data: + # Encountered a last block (i.e. a block with BFINAL = 1) and + # found a new stream (unused_data). We cannot use the same + # zlib.Decompress object for the new stream. Create a new + # Decompress object to decompress the new one. + # + # It's fine to ignore unconsumed_tail if unused_data is not + # empty. + self._unconsumed = self._decompress.unused_data + self.reset() + if size >= 0 and len(data) == size: + # data is filled. Don't call decompress again. + break + else: + # Re-invoke Decompress.decompress to try to decompress all + # available bytes before invoking read which blocks until + # any new byte is available. + continue + else: + # Here, since unused_data is empty, even if unconsumed_tail is + # not empty, bytes of requested length are already in data. We + # don't have to "continue" here. + break + + if data: + self._logger.debug('Decompressed %r', data) + return data + + def append(self, data): + self._logger.debug('Appended %r', data) + self._unconsumed += data + + def reset(self): + self._logger.debug('Reset') + self._decompress = zlib.decompressobj(-zlib.MAX_WBITS) + + +# Compresses/decompresses given octets using the method introduced in RFC1979. + + +class _RFC1979Deflater(object): + """A compressor class that applies DEFLATE to given byte sequence and + flushes using the algorithm described in the RFC1979 section 2.1. + """ + + def __init__(self, window_bits, no_context_takeover): + self._deflater = None + if window_bits is None: + window_bits = zlib.MAX_WBITS + self._window_bits = window_bits + self._no_context_takeover = no_context_takeover + + def filter(self, bytes, flush=True, bfinal=False): + if self._deflater is None or (self._no_context_takeover and flush): + self._deflater = _Deflater(self._window_bits) + + if bfinal: + result = self._deflater.compress_and_finish(bytes) + # Add a padding block with BFINAL = 0 and BTYPE = 0. + result = result + chr(0) + self._deflater = None + return result + if flush: + # Strip last 4 octets which is LEN and NLEN field of a + # non-compressed block added for Z_SYNC_FLUSH. + return self._deflater.compress_and_flush(bytes)[:-4] + return self._deflater.compress(bytes) + +class _RFC1979Inflater(object): + """A decompressor class for byte sequence compressed and flushed following + the algorithm described in the RFC1979 section 2.1. + """ + + def __init__(self): + self._inflater = _Inflater() + + def filter(self, bytes): + # Restore stripped LEN and NLEN field of a non-compressed block added + # for Z_SYNC_FLUSH. + self._inflater.append(bytes + '\x00\x00\xff\xff') + return self._inflater.decompress(-1) + + +class DeflateSocket(object): + """A wrapper class for socket object to intercept send and recv to perform + deflate compression and decompression transparently. + """ + + # Size of the buffer passed to recv to receive compressed data. + _RECV_SIZE = 4096 + + def __init__(self, socket): + self._socket = socket + + self._logger = get_class_logger(self) + + self._deflater = _Deflater(zlib.MAX_WBITS) + self._inflater = _Inflater() + + def recv(self, size): + """Receives data from the socket specified on the construction up + to the specified size. Once any data is available, returns it even + if it's smaller than the specified size. + """ + + # TODO(tyoshino): Allow call with size=0. It should block until any + # decompressed data is available. + if size <= 0: + raise Exception('Non-positive size passed') + while True: + data = self._inflater.decompress(size) + if len(data) != 0: + return data + + read_data = self._socket.recv(DeflateSocket._RECV_SIZE) + if not read_data: + return '' + self._inflater.append(read_data) + + def sendall(self, bytes): + self.send(bytes) + + def send(self, bytes): + self._socket.sendall(self._deflater.compress_and_flush(bytes)) + return len(bytes) + + +class DeflateConnection(object): + """A wrapper class for request object to intercept write and read to + perform deflate compression and decompression transparently. + """ + + def __init__(self, connection): + self._connection = connection + + self._logger = get_class_logger(self) + + self._deflater = _Deflater(zlib.MAX_WBITS) + self._inflater = _Inflater() + + def get_remote_addr(self): + return self._connection.remote_addr + remote_addr = property(get_remote_addr) + + def put_bytes(self, bytes): + self.write(bytes) + + def read(self, size=-1): + """Reads at most size bytes. Blocks until there's at least one byte + available. + """ + + # TODO(tyoshino): Allow call with size=0. + if not (size == -1 or size > 0): + raise Exception('size must be -1 or positive') + + data = '' + while True: + if size == -1: + data += self._inflater.decompress(-1) + else: + data += self._inflater.decompress(size - len(data)) + + if size >= 0 and len(data) != 0: + break + + # TODO(tyoshino): Make this read efficient by some workaround. + # + # In 3.0.3 and prior of mod_python, read blocks until length bytes + # was read. We don't know the exact size to read while using + # deflate, so read byte-by-byte. + # + # _StandaloneRequest.read that ultimately performs + # socket._fileobject.read also blocks until length bytes was read + read_data = self._connection.read(1) + if not read_data: + break + self._inflater.append(read_data) + return data + + def write(self, bytes): + self._connection.write(self._deflater.compress_and_flush(bytes)) + + +def _is_ewouldblock_errno(error_number): + """Returns True iff error_number indicates that receive operation would + block. To make this portable, we check availability of errno and then + compare them. + """ + + for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']: + if (error_name in dir(errno) and + error_number == getattr(errno, error_name)): + return True + return False + + +def drain_received_data(raw_socket): + # Set the socket non-blocking. + original_timeout = raw_socket.gettimeout() + raw_socket.settimeout(0.0) + + drained_data = [] + + # Drain until the socket is closed or no data is immediately + # available for read. + while True: + try: + data = raw_socket.recv(1) + if not data: + break + drained_data.append(data) + except socket.error, e: + # e can be either a pair (errno, string) or just a string (or + # something else) telling what went wrong. We suppress only + # the errors that indicates that the socket blocks. Those + # exceptions can be parsed as a pair (errno, string). + try: + error_number, message = e + except: + # Failed to parse socket.error. + raise e + + if _is_ewouldblock_errno(error_number): + break + else: + raise e + + # Rollback timeout value. + raw_socket.settimeout(original_timeout) + + return ''.join(drained_data) + + +# vi:sts=4 sw=4 et diff --git a/pyload/lib/new_collections.py b/pyload/lib/new_collections.py new file mode 100644 index 000000000..12d05b4b9 --- /dev/null +++ b/pyload/lib/new_collections.py @@ -0,0 +1,375 @@ +## {{{ http://code.activestate.com/recipes/576693/ (r9) +# Backport of OrderedDict() class that runs on Python 2.4, 2.5, 2.6, 2.7 and pypy. +# Passes Python2.7's test suite and incorporates all the latest updates. + +try: + from thread import get_ident as _get_ident +except ImportError: + from dummy_thread import get_ident as _get_ident + +try: + from _abcoll import KeysView, ValuesView, ItemsView +except ImportError: + pass + + +class OrderedDict(dict): + 'Dictionary that remembers insertion order' + # An inherited dict maps keys to values. + # The inherited dict provides __getitem__, __len__, __contains__, and get. + # The remaining methods are order-aware. + # Big-O running times for all methods are the same as for regular dictionaries. + + # The internal self.__map dictionary maps keys to links in a doubly linked list. + # The circular doubly linked list starts and ends with a sentinel element. + # The sentinel element never gets deleted (this simplifies the algorithm). + # Each link is stored as a list of length three: [PREV, NEXT, KEY]. + + def __init__(self, *args, **kwds): + '''Initialize an ordered dictionary. Signature is the same as for + regular dictionaries, but keyword arguments are not recommended + because their insertion order is arbitrary. + + ''' + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__root + except AttributeError: + self.__root = root = [] # sentinel node + root[:] = [root, root, None] + self.__map = {} + self.__update(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=dict.__setitem__): + 'od.__setitem__(i, y) <==> od[i]=y' + # Setting a new item creates a new link which goes at the end of the linked + # list, and the inherited dictionary is updated with the new key/value pair. + if key not in self: + root = self.__root + last = root[0] + last[1] = root[0] = self.__map[key] = [last, root, key] + dict_setitem(self, key, value) + + def __delitem__(self, key, dict_delitem=dict.__delitem__): + 'od.__delitem__(y) <==> del od[y]' + # Deleting an existing item uses self.__map to find the link which is + # then removed by updating the links in the predecessor and successor nodes. + dict_delitem(self, key) + link_prev, link_next, key = self.__map.pop(key) + link_prev[1] = link_next + link_next[0] = link_prev + + def __iter__(self): + 'od.__iter__() <==> iter(od)' + root = self.__root + curr = root[1] + while curr is not root: + yield curr[2] + curr = curr[1] + + def __reversed__(self): + 'od.__reversed__() <==> reversed(od)' + root = self.__root + curr = root[0] + while curr is not root: + yield curr[2] + curr = curr[0] + + def clear(self): + 'od.clear() -> None. Remove all items from od.' + try: + for node in self.__map.itervalues(): + del node[:] + root = self.__root + root[:] = [root, root, None] + self.__map.clear() + except AttributeError: + pass + dict.clear(self) + + def popitem(self, last=True): + '''od.popitem() -> (k, v), return and remove a (key, value) pair. + Pairs are returned in LIFO order if last is true or FIFO order if false. + + ''' + if not self: + raise KeyError('dictionary is empty') + root = self.__root + if last: + link = root[0] + link_prev = link[0] + link_prev[1] = root + root[0] = link_prev + else: + link = root[1] + link_next = link[1] + root[1] = link_next + link_next[0] = root + key = link[2] + del self.__map[key] + value = dict.pop(self, key) + return key, value + + # -- the following methods do not depend on the internal structure -- + + def keys(self): + 'od.keys() -> list of keys in od' + return list(self) + + def values(self): + 'od.values() -> list of values in od' + return [self[key] for key in self] + + def items(self): + 'od.items() -> list of (key, value) pairs in od' + return [(key, self[key]) for key in self] + + def iterkeys(self): + 'od.iterkeys() -> an iterator over the keys in od' + return iter(self) + + def itervalues(self): + 'od.itervalues -> an iterator over the values in od' + for k in self: + yield self[k] + + def iteritems(self): + 'od.iteritems -> an iterator over the (key, value) items in od' + for k in self: + yield (k, self[k]) + + def update(*args, **kwds): + '''od.update(E, **F) -> None. Update od from dict/iterable E and F. + + If E is a dict instance, does: for k in E: od[k] = E[k] + If E has a .keys() method, does: for k in E.keys(): od[k] = E[k] + Or if E is an iterable of items, does: for k, v in E: od[k] = v + In either case, this is followed by: for k, v in F.items(): od[k] = v + + ''' + if len(args) > 2: + raise TypeError('update() takes at most 2 positional ' + 'arguments (%d given)' % (len(args),)) + elif not args: + raise TypeError('update() takes at least 1 argument (0 given)') + self = args[0] + # Make progressively weaker assumptions about "other" + other = () + if len(args) == 2: + other = args[1] + if isinstance(other, dict): + for key in other: + self[key] = other[key] + elif hasattr(other, 'keys'): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + __update = update # let subclasses override update without breaking __init__ + + __marker = object() + + def pop(self, key, default=__marker): + '''od.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised. + + ''' + if key in self: + result = self[key] + del self[key] + return result + if default is self.__marker: + raise KeyError(key) + return default + + def setdefault(self, key, default=None): + 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' + if key in self: + return self[key] + self[key] = default + return default + + def __repr__(self, _repr_running={}): + 'od.__repr__() <==> repr(od)' + call_key = id(self), _get_ident() + if call_key in _repr_running: + return '...' + _repr_running[call_key] = 1 + try: + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + finally: + del _repr_running[call_key] + + def __reduce__(self): + 'Return state information for pickling' + items = [[k, self[k]] for k in self] + inst_dict = vars(self).copy() + for k in vars(OrderedDict()): + inst_dict.pop(k, None) + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def copy(self): + 'od.copy() -> a shallow copy of od' + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S + and values equal to v (which defaults to None). + + ''' + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive + while comparison to a regular mapping is order-insensitive. + + ''' + if isinstance(other, OrderedDict): + return len(self)==len(other) and self.items() == other.items() + return dict.__eq__(self, other) + + def __ne__(self, other): + return not self == other + + # -- the following methods are only used in Python 2.7 -- + + def viewkeys(self): + "od.viewkeys() -> a set-like object providing a view on od's keys" + return KeysView(self) + + def viewvalues(self): + "od.viewvalues() -> an object providing a view on od's values" + return ValuesView(self) + + def viewitems(self): + "od.viewitems() -> a set-like object providing a view on od's items" + return ItemsView(self) +## end of http://code.activestate.com/recipes/576693/ }}} + +## {{{ http://code.activestate.com/recipes/500261/ (r15) +from operator import itemgetter as _itemgetter +from keyword import iskeyword as _iskeyword +import sys as _sys + +def namedtuple(typename, field_names, verbose=False, rename=False): + """Returns a new subclass of tuple with named fields. + + >>> Point = namedtuple('Point', 'x y') + >>> Point.__doc__ # docstring for the new class + 'Point(x, y)' + >>> p = Point(11, y=22) # instantiate with positional args or keywords + >>> p[0] + p[1] # indexable like a plain tuple + 33 + >>> x, y = p # unpack like a regular tuple + >>> x, y + (11, 22) + >>> p.x + p.y # fields also accessable by name + 33 + >>> d = p._asdict() # convert to a dictionary + >>> d['x'] + 11 + >>> Point(**d) # convert from a dictionary + Point(x=11, y=22) + >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields + Point(x=100, y=22) + + """ + + # Parse and validate the field names. Validation serves two purposes, + # generating informative error messages and preventing template injection attacks. + if isinstance(field_names, basestring): + field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas + field_names = tuple(map(str, field_names)) + if rename: + names = list(field_names) + seen = set() + for i, name in enumerate(names): + if (not min(c.isalnum() or c=='_' for c in name) or _iskeyword(name) + or not name or name[0].isdigit() or name.startswith('_') + or name in seen): + names[i] = '_%d' % i + seen.add(name) + field_names = tuple(names) + for name in (typename,) + field_names: + if not min(c.isalnum() or c=='_' for c in name): + raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name) + if _iskeyword(name): + raise ValueError('Type names and field names cannot be a keyword: %r' % name) + if name[0].isdigit(): + raise ValueError('Type names and field names cannot start with a number: %r' % name) + seen_names = set() + for name in field_names: + if name.startswith('_') and not rename: + raise ValueError('Field names cannot start with an underscore: %r' % name) + if name in seen_names: + raise ValueError('Encountered duplicate field name: %r' % name) + seen_names.add(name) + + # Create and fill-in the class template + numfields = len(field_names) + argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes + reprtxt = ', '.join('%s=%%r' % name for name in field_names) + template = '''class %(typename)s(tuple): + '%(typename)s(%(argtxt)s)' \n + __slots__ = () \n + _fields = %(field_names)r \n + def __new__(_cls, %(argtxt)s): + return _tuple.__new__(_cls, (%(argtxt)s)) \n + @classmethod + def _make(cls, iterable, new=tuple.__new__, len=len): + 'Make a new %(typename)s object from a sequence or iterable' + result = new(cls, iterable) + if len(result) != %(numfields)d: + raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result)) + return result \n + def __repr__(self): + return '%(typename)s(%(reprtxt)s)' %% self \n + def _asdict(self): + 'Return a new dict which maps field names to their values' + return dict(zip(self._fields, self)) \n + def _replace(_self, **kwds): + 'Return a new %(typename)s object replacing specified fields with new values' + result = _self._make(map(kwds.pop, %(field_names)r, _self)) + if kwds: + raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) + return result \n + def __getnewargs__(self): + return tuple(self) \n\n''' % locals() + for i, name in enumerate(field_names): + template += ' %s = _property(_itemgetter(%d))\n' % (name, i) + if verbose: + print template + + # Execute the template string in a temporary namespace + namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename, + _property=property, _tuple=tuple) + try: + exec template in namespace + except SyntaxError, e: + raise SyntaxError(e.message + ':\n' + template) + result = namespace[typename] + + # For pickling to work, the __module__ variable needs to be set to the frame + # where the named tuple is created. Bypass this step in enviroments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython). + try: + result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + + return result +## end of http://code.activestate.com/recipes/500261/ }}} diff --git a/pyload/lib/rename_process.py b/pyload/lib/rename_process.py new file mode 100644 index 000000000..2527cef39 --- /dev/null +++ b/pyload/lib/rename_process.py @@ -0,0 +1,14 @@ +import sys + +def renameProcess(new_name): + """ Renames the process calling the function to the given name. """ + if sys.platform != 'linux2': + return False + try: + from ctypes import CDLL + libc = CDLL('libc.so.6') + libc.prctl(15, new_name, 0, 0, 0) + return True + except Exception, e: + #print "Rename process failed", e + return False diff --git a/pyload/lib/simplejson/__init__.py b/pyload/lib/simplejson/__init__.py new file mode 100644 index 000000000..ef5c0db48 --- /dev/null +++ b/pyload/lib/simplejson/__init__.py @@ -0,0 +1,466 @@ +r"""JSON (JavaScript Object Notation) <http://json.org> is a subset of +JavaScript syntax (ECMA-262 3rd edition) used as a lightweight data +interchange format. + +:mod:`simplejson` exposes an API familiar to users of the standard library +:mod:`marshal` and :mod:`pickle` modules. It is the externally maintained +version of the :mod:`json` library contained in Python 2.6, but maintains +compatibility with Python 2.4 and Python 2.5 and (currently) has +significant performance advantages, even without using the optional C +extension for speedups. + +Encoding basic Python object hierarchies:: + + >>> import simplejson as json + >>> json.dumps(['foo', {'bar': ('baz', None, 1.0, 2)}]) + '["foo", {"bar": ["baz", null, 1.0, 2]}]' + >>> print json.dumps("\"foo\bar") + "\"foo\bar" + >>> print json.dumps(u'\u1234') + "\u1234" + >>> print json.dumps('\\') + "\\" + >>> print json.dumps({"c": 0, "b": 0, "a": 0}, sort_keys=True) + {"a": 0, "b": 0, "c": 0} + >>> from StringIO import StringIO + >>> io = StringIO() + >>> json.dump(['streaming API'], io) + >>> io.getvalue() + '["streaming API"]' + +Compact encoding:: + + >>> import simplejson as json + >>> json.dumps([1,2,3,{'4': 5, '6': 7}], separators=(',',':')) + '[1,2,3,{"4":5,"6":7}]' + +Pretty printing:: + + >>> import simplejson as json + >>> s = json.dumps({'4': 5, '6': 7}, sort_keys=True, indent=' ') + >>> print '\n'.join([l.rstrip() for l in s.splitlines()]) + { + "4": 5, + "6": 7 + } + +Decoding JSON:: + + >>> import simplejson as json + >>> obj = [u'foo', {u'bar': [u'baz', None, 1.0, 2]}] + >>> json.loads('["foo", {"bar":["baz", null, 1.0, 2]}]') == obj + True + >>> json.loads('"\\"foo\\bar"') == u'"foo\x08ar' + True + >>> from StringIO import StringIO + >>> io = StringIO('["streaming API"]') + >>> json.load(io)[0] == 'streaming API' + True + +Specializing JSON object decoding:: + + >>> import simplejson as json + >>> def as_complex(dct): + ... if '__complex__' in dct: + ... return complex(dct['real'], dct['imag']) + ... return dct + ... + >>> json.loads('{"__complex__": true, "real": 1, "imag": 2}', + ... object_hook=as_complex) + (1+2j) + >>> from decimal import Decimal + >>> json.loads('1.1', parse_float=Decimal) == Decimal('1.1') + True + +Specializing JSON object encoding:: + + >>> import simplejson as json + >>> def encode_complex(obj): + ... if isinstance(obj, complex): + ... return [obj.real, obj.imag] + ... raise TypeError(repr(o) + " is not JSON serializable") + ... + >>> json.dumps(2 + 1j, default=encode_complex) + '[2.0, 1.0]' + >>> json.JSONEncoder(default=encode_complex).encode(2 + 1j) + '[2.0, 1.0]' + >>> ''.join(json.JSONEncoder(default=encode_complex).iterencode(2 + 1j)) + '[2.0, 1.0]' + + +Using simplejson.tool from the shell to validate and pretty-print:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 2 (char 2) +""" +__version__ = '2.2.1' +__all__ = [ + 'dump', 'dumps', 'load', 'loads', + 'JSONDecoder', 'JSONDecodeError', 'JSONEncoder', + 'OrderedDict', +] + +__author__ = 'Bob Ippolito <bob@redivi.com>' + +from decimal import Decimal + +from decoder import JSONDecoder, JSONDecodeError +from encoder import JSONEncoder +def _import_OrderedDict(): + import collections + try: + return collections.OrderedDict + except AttributeError: + import ordered_dict + return ordered_dict.OrderedDict +OrderedDict = _import_OrderedDict() + +def _import_c_make_encoder(): + try: + from simplejson._speedups import make_encoder + return make_encoder + except ImportError: + return None + +_default_encoder = JSONEncoder( + skipkeys=False, + ensure_ascii=True, + check_circular=True, + allow_nan=True, + indent=None, + separators=None, + encoding='utf-8', + default=None, + use_decimal=True, + namedtuple_as_object=True, + tuple_as_array=True, +) + +def dump(obj, fp, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=True, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, tuple_as_array=True, + **kw): + """Serialize ``obj`` as a JSON formatted stream to ``fp`` (a + ``.write()``-supporting file-like object). + + If ``skipkeys`` is true then ``dict`` keys that are not basic types + (``str``, ``unicode``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If ``ensure_ascii`` is false, then the some chunks written to ``fp`` + may be ``unicode`` instances, subject to normal Python ``str`` to + ``unicode`` coercion rules. Unless ``fp.write()`` explicitly + understands ``unicode`` (as in ``codecs.getwriter()``) this is likely + to cause an error. + + If ``check_circular`` is false, then the circular reference check + for container types will be skipped and a circular reference will + result in an ``OverflowError`` (or worse). + + If ``allow_nan`` is false, then it will be a ``ValueError`` to + serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) + in strict compliance of the JSON specification, instead of using the + JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). + + If *indent* is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If ``separators`` is an ``(item_separator, dict_separator)`` tuple + then it will be used instead of the default ``(', ', ': ')`` separators. + ``(',', ':')`` is the most compact JSON representation. + + ``encoding`` is the character encoding for str instances, default is UTF-8. + + ``default(obj)`` is a function that should return a serializable version + of obj or raise TypeError. The default simply raises TypeError. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not kw): + iterable = _default_encoder.iterencode(obj) + else: + if cls is None: + cls = JSONEncoder + iterable = cls(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, + default=default, use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + **kw).iterencode(obj) + # could accelerate with writelines in some versions of Python, at + # a debuggability cost + for chunk in iterable: + fp.write(chunk) + + +def dumps(obj, skipkeys=False, ensure_ascii=True, check_circular=True, + allow_nan=True, cls=None, indent=None, separators=None, + encoding='utf-8', default=None, use_decimal=True, + namedtuple_as_object=True, + tuple_as_array=True, + **kw): + """Serialize ``obj`` to a JSON formatted ``str``. + + If ``skipkeys`` is false then ``dict`` keys that are not basic types + (``str``, ``unicode``, ``int``, ``long``, ``float``, ``bool``, ``None``) + will be skipped instead of raising a ``TypeError``. + + If ``ensure_ascii`` is false, then the return value will be a + ``unicode`` instance subject to normal Python ``str`` to ``unicode`` + coercion rules instead of being escaped to an ASCII ``str``. + + If ``check_circular`` is false, then the circular reference check + for container types will be skipped and a circular reference will + result in an ``OverflowError`` (or worse). + + If ``allow_nan`` is false, then it will be a ``ValueError`` to + serialize out of range ``float`` values (``nan``, ``inf``, ``-inf``) in + strict compliance of the JSON specification, instead of using the + JavaScript equivalents (``NaN``, ``Infinity``, ``-Infinity``). + + If ``indent`` is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If ``separators`` is an ``(item_separator, dict_separator)`` tuple + then it will be used instead of the default ``(', ', ': ')`` separators. + ``(',', ':')`` is the most compact JSON representation. + + ``encoding`` is the character encoding for str instances, default is UTF-8. + + ``default(obj)`` is a function that should return a serializable version + of obj or raise TypeError. The default simply raises TypeError. + + If *use_decimal* is true (default: ``True``) then decimal.Decimal + will be natively serialized to JSON with full precision. + + If *namedtuple_as_object* is true (default: ``True``), + :class:`tuple` subclasses with ``_asdict()`` methods will be encoded + as JSON objects. + + If *tuple_as_array* is true (default: ``True``), + :class:`tuple` (and subclasses) will be encoded as JSON arrays. + + To use a custom ``JSONEncoder`` subclass (e.g. one that overrides the + ``.default()`` method to serialize additional types), specify it with + the ``cls`` kwarg. + + """ + # cached encoder + if (not skipkeys and ensure_ascii and + check_circular and allow_nan and + cls is None and indent is None and separators is None and + encoding == 'utf-8' and default is None and use_decimal + and namedtuple_as_object and tuple_as_array and not kw): + return _default_encoder.encode(obj) + if cls is None: + cls = JSONEncoder + return cls( + skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, indent=indent, + separators=separators, encoding=encoding, default=default, + use_decimal=use_decimal, + namedtuple_as_object=namedtuple_as_object, + tuple_as_array=tuple_as_array, + **kw).encode(obj) + + +_default_decoder = JSONDecoder(encoding=None, object_hook=None, + object_pairs_hook=None) + + +def load(fp, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, namedtuple_as_object=True, tuple_as_array=True, + **kw): + """Deserialize ``fp`` (a ``.read()``-supporting file-like object containing + a JSON document) to a Python object. + + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *parse_constant*, if specified, will be called with one of the + following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This + can be used to raise an exception if invalid JSON numbers are + encountered. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. + + """ + return loads(fp.read(), + encoding=encoding, cls=cls, object_hook=object_hook, + parse_float=parse_float, parse_int=parse_int, + parse_constant=parse_constant, object_pairs_hook=object_pairs_hook, + use_decimal=use_decimal, **kw) + + +def loads(s, encoding=None, cls=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, object_pairs_hook=None, + use_decimal=False, **kw): + """Deserialize ``s`` (a ``str`` or ``unicode`` instance containing a JSON + document) to a Python object. + + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *parse_constant*, if specified, will be called with one of the + following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This + can be used to raise an exception if invalid JSON numbers are + encountered. + + If *use_decimal* is true (default: ``False``) then it implies + parse_float=decimal.Decimal for parity with ``dump``. + + To use a custom ``JSONDecoder`` subclass, specify it with the ``cls`` + kwarg. + + """ + if (cls is None and encoding is None and object_hook is None and + parse_int is None and parse_float is None and + parse_constant is None and object_pairs_hook is None + and not use_decimal and not kw): + return _default_decoder.decode(s) + if cls is None: + cls = JSONDecoder + if object_hook is not None: + kw['object_hook'] = object_hook + if object_pairs_hook is not None: + kw['object_pairs_hook'] = object_pairs_hook + if parse_float is not None: + kw['parse_float'] = parse_float + if parse_int is not None: + kw['parse_int'] = parse_int + if parse_constant is not None: + kw['parse_constant'] = parse_constant + if use_decimal: + if parse_float is not None: + raise TypeError("use_decimal=True implies parse_float=Decimal") + kw['parse_float'] = Decimal + return cls(encoding=encoding, **kw).decode(s) + + +def _toggle_speedups(enabled): + import simplejson.decoder as dec + import simplejson.encoder as enc + import simplejson.scanner as scan + c_make_encoder = _import_c_make_encoder() + if enabled: + dec.scanstring = dec.c_scanstring or dec.py_scanstring + enc.c_make_encoder = c_make_encoder + enc.encode_basestring_ascii = (enc.c_encode_basestring_ascii or + enc.py_encode_basestring_ascii) + scan.make_scanner = scan.c_make_scanner or scan.py_make_scanner + else: + dec.scanstring = dec.py_scanstring + enc.c_make_encoder = None + enc.encode_basestring_ascii = enc.py_encode_basestring_ascii + scan.make_scanner = scan.py_make_scanner + dec.make_scanner = scan.make_scanner + global _default_decoder + _default_decoder = JSONDecoder( + encoding=None, + object_hook=None, + object_pairs_hook=None, + ) + global _default_encoder + _default_encoder = JSONEncoder( + skipkeys=False, + ensure_ascii=True, + check_circular=True, + allow_nan=True, + indent=None, + separators=None, + encoding='utf-8', + default=None, + ) diff --git a/pyload/lib/simplejson/decoder.py b/pyload/lib/simplejson/decoder.py new file mode 100644 index 000000000..e5496d6e7 --- /dev/null +++ b/pyload/lib/simplejson/decoder.py @@ -0,0 +1,421 @@ +"""Implementation of JSONDecoder +""" +import re +import sys +import struct + +from simplejson.scanner import make_scanner +def _import_c_scanstring(): + try: + from simplejson._speedups import scanstring + return scanstring + except ImportError: + return None +c_scanstring = _import_c_scanstring() + +__all__ = ['JSONDecoder'] + +FLAGS = re.VERBOSE | re.MULTILINE | re.DOTALL + +def _floatconstants(): + _BYTES = '7FF80000000000007FF0000000000000'.decode('hex') + # The struct module in Python 2.4 would get frexp() out of range here + # when an endian is specified in the format string. Fixed in Python 2.5+ + if sys.byteorder != 'big': + _BYTES = _BYTES[:8][::-1] + _BYTES[8:][::-1] + nan, inf = struct.unpack('dd', _BYTES) + return nan, inf, -inf + +NaN, PosInf, NegInf = _floatconstants() + + +class JSONDecodeError(ValueError): + """Subclass of ValueError with the following additional properties: + + msg: The unformatted error message + doc: The JSON document being parsed + pos: The start index of doc where parsing failed + end: The end index of doc where parsing failed (may be None) + lineno: The line corresponding to pos + colno: The column corresponding to pos + endlineno: The line corresponding to end (may be None) + endcolno: The column corresponding to end (may be None) + + """ + def __init__(self, msg, doc, pos, end=None): + ValueError.__init__(self, errmsg(msg, doc, pos, end=end)) + self.msg = msg + self.doc = doc + self.pos = pos + self.end = end + self.lineno, self.colno = linecol(doc, pos) + if end is not None: + self.endlineno, self.endcolno = linecol(doc, end) + else: + self.endlineno, self.endcolno = None, None + + +def linecol(doc, pos): + lineno = doc.count('\n', 0, pos) + 1 + if lineno == 1: + colno = pos + else: + colno = pos - doc.rindex('\n', 0, pos) + return lineno, colno + + +def errmsg(msg, doc, pos, end=None): + # Note that this function is called from _speedups + lineno, colno = linecol(doc, pos) + if end is None: + #fmt = '{0}: line {1} column {2} (char {3})' + #return fmt.format(msg, lineno, colno, pos) + fmt = '%s: line %d column %d (char %d)' + return fmt % (msg, lineno, colno, pos) + endlineno, endcolno = linecol(doc, end) + #fmt = '{0}: line {1} column {2} - line {3} column {4} (char {5} - {6})' + #return fmt.format(msg, lineno, colno, endlineno, endcolno, pos, end) + fmt = '%s: line %d column %d - line %d column %d (char %d - %d)' + return fmt % (msg, lineno, colno, endlineno, endcolno, pos, end) + + +_CONSTANTS = { + '-Infinity': NegInf, + 'Infinity': PosInf, + 'NaN': NaN, +} + +STRINGCHUNK = re.compile(r'(.*?)(["\\\x00-\x1f])', FLAGS) +BACKSLASH = { + '"': u'"', '\\': u'\\', '/': u'/', + 'b': u'\b', 'f': u'\f', 'n': u'\n', 'r': u'\r', 't': u'\t', +} + +DEFAULT_ENCODING = "utf-8" + +def py_scanstring(s, end, encoding=None, strict=True, + _b=BACKSLASH, _m=STRINGCHUNK.match): + """Scan the string s for a JSON string. End is the index of the + character in s after the quote that started the JSON string. + Unescapes all valid JSON string escape sequences and raises ValueError + on attempt to decode an invalid string. If strict is False then literal + control characters are allowed in the string. + + Returns a tuple of the decoded string and the index of the character in s + after the end quote.""" + if encoding is None: + encoding = DEFAULT_ENCODING + chunks = [] + _append = chunks.append + begin = end - 1 + while 1: + chunk = _m(s, end) + if chunk is None: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + end = chunk.end() + content, terminator = chunk.groups() + # Content is contains zero or more unescaped string characters + if content: + if not isinstance(content, unicode): + content = unicode(content, encoding) + _append(content) + # Terminator is the end of string, a literal control character, + # or a backslash denoting that an escape sequence follows + if terminator == '"': + break + elif terminator != '\\': + if strict: + msg = "Invalid control character %r at" % (terminator,) + #msg = "Invalid control character {0!r} at".format(terminator) + raise JSONDecodeError(msg, s, end) + else: + _append(terminator) + continue + try: + esc = s[end] + except IndexError: + raise JSONDecodeError( + "Unterminated string starting at", s, begin) + # If not a unicode escape sequence, must be in the lookup table + if esc != 'u': + try: + char = _b[esc] + except KeyError: + msg = "Invalid \\escape: " + repr(esc) + raise JSONDecodeError(msg, s, end) + end += 1 + else: + # Unicode escape sequence + esc = s[end + 1:end + 5] + next_end = end + 5 + if len(esc) != 4: + msg = "Invalid \\uXXXX escape" + raise JSONDecodeError(msg, s, end) + uni = int(esc, 16) + # Check for surrogate pair on UCS-4 systems + if 0xd800 <= uni <= 0xdbff and sys.maxunicode > 65535: + msg = "Invalid \\uXXXX\\uXXXX surrogate pair" + if not s[end + 5:end + 7] == '\\u': + raise JSONDecodeError(msg, s, end) + esc2 = s[end + 7:end + 11] + if len(esc2) != 4: + raise JSONDecodeError(msg, s, end) + uni2 = int(esc2, 16) + uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00)) + next_end += 6 + char = unichr(uni) + end = next_end + # Append the unescaped character + _append(char) + return u''.join(chunks), end + + +# Use speedup if available +scanstring = c_scanstring or py_scanstring + +WHITESPACE = re.compile(r'[ \t\n\r]*', FLAGS) +WHITESPACE_STR = ' \t\n\r' + +def JSONObject((s, end), encoding, strict, scan_once, object_hook, + object_pairs_hook, memo=None, + _w=WHITESPACE.match, _ws=WHITESPACE_STR): + # Backwards compatibility + if memo is None: + memo = {} + memo_get = memo.setdefault + pairs = [] + # Use a slice to prevent IndexError from being raised, the following + # check will raise a more specific ValueError if the string is empty + nextchar = s[end:end + 1] + # Normally we expect nextchar == '"' + if nextchar != '"': + if nextchar in _ws: + end = _w(s, end).end() + nextchar = s[end:end + 1] + # Trivial empty object + if nextchar == '}': + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + 1 + pairs = {} + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + 1 + elif nextchar != '"': + raise JSONDecodeError("Expecting property name", s, end) + end += 1 + while True: + key, end = scanstring(s, end, encoding, strict) + key = memo_get(key, key) + + # To skip some function call overhead we optimize the fast paths where + # the JSON key separator is ": " or just ":". + if s[end:end + 1] != ':': + end = _w(s, end).end() + if s[end:end + 1] != ':': + raise JSONDecodeError("Expecting : delimiter", s, end) + + end += 1 + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + try: + value, end = scan_once(s, end) + except StopIteration: + raise JSONDecodeError("Expecting object", s, end) + pairs.append((key, value)) + + try: + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + end += 1 + + if nextchar == '}': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting , delimiter", s, end - 1) + + try: + nextchar = s[end] + if nextchar in _ws: + end += 1 + nextchar = s[end] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end] + except IndexError: + nextchar = '' + + end += 1 + if nextchar != '"': + raise JSONDecodeError("Expecting property name", s, end - 1) + + if object_pairs_hook is not None: + result = object_pairs_hook(pairs) + return result, end + pairs = dict(pairs) + if object_hook is not None: + pairs = object_hook(pairs) + return pairs, end + +def JSONArray((s, end), scan_once, _w=WHITESPACE.match, _ws=WHITESPACE_STR): + values = [] + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + # Look-ahead for trivial empty array + if nextchar == ']': + return values, end + 1 + _append = values.append + while True: + try: + value, end = scan_once(s, end) + except StopIteration: + raise JSONDecodeError("Expecting object", s, end) + _append(value) + nextchar = s[end:end + 1] + if nextchar in _ws: + end = _w(s, end + 1).end() + nextchar = s[end:end + 1] + end += 1 + if nextchar == ']': + break + elif nextchar != ',': + raise JSONDecodeError("Expecting , delimiter", s, end) + + try: + if s[end] in _ws: + end += 1 + if s[end] in _ws: + end = _w(s, end + 1).end() + except IndexError: + pass + + return values, end + +class JSONDecoder(object): + """Simple JSON <http://json.org> decoder + + Performs the following translations in decoding by default: + + +---------------+-------------------+ + | JSON | Python | + +===============+===================+ + | object | dict | + +---------------+-------------------+ + | array | list | + +---------------+-------------------+ + | string | unicode | + +---------------+-------------------+ + | number (int) | int, long | + +---------------+-------------------+ + | number (real) | float | + +---------------+-------------------+ + | true | True | + +---------------+-------------------+ + | false | False | + +---------------+-------------------+ + | null | None | + +---------------+-------------------+ + + It also understands ``NaN``, ``Infinity``, and ``-Infinity`` as + their corresponding ``float`` values, which is outside the JSON spec. + + """ + + def __init__(self, encoding=None, object_hook=None, parse_float=None, + parse_int=None, parse_constant=None, strict=True, + object_pairs_hook=None): + """ + *encoding* determines the encoding used to interpret any + :class:`str` objects decoded by this instance (``'utf-8'`` by + default). It has no effect when decoding :class:`unicode` objects. + + Note that currently only encodings that are a superset of ASCII work, + strings of other encodings should be passed in as :class:`unicode`. + + *object_hook*, if specified, will be called with the result of every + JSON object decoded and its return value will be used in place of the + given :class:`dict`. This can be used to provide custom + deserializations (e.g. to support JSON-RPC class hinting). + + *object_pairs_hook* is an optional function that will be called with + the result of any object literal decode with an ordered list of pairs. + The return value of *object_pairs_hook* will be used instead of the + :class:`dict`. This feature can be used to implement custom decoders + that rely on the order that the key and value pairs are decoded (for + example, :func:`collections.OrderedDict` will remember the order of + insertion). If *object_hook* is also defined, the *object_pairs_hook* + takes priority. + + *parse_float*, if specified, will be called with the string of every + JSON float to be decoded. By default, this is equivalent to + ``float(num_str)``. This can be used to use another datatype or parser + for JSON floats (e.g. :class:`decimal.Decimal`). + + *parse_int*, if specified, will be called with the string of every + JSON int to be decoded. By default, this is equivalent to + ``int(num_str)``. This can be used to use another datatype or parser + for JSON integers (e.g. :class:`float`). + + *parse_constant*, if specified, will be called with one of the + following strings: ``'-Infinity'``, ``'Infinity'``, ``'NaN'``. This + can be used to raise an exception if invalid JSON numbers are + encountered. + + *strict* controls the parser's behavior when it encounters an + invalid control character in a string. The default setting of + ``True`` means that unescaped control characters are parse errors, if + ``False`` then control characters will be allowed in strings. + + """ + self.encoding = encoding + self.object_hook = object_hook + self.object_pairs_hook = object_pairs_hook + self.parse_float = parse_float or float + self.parse_int = parse_int or int + self.parse_constant = parse_constant or _CONSTANTS.__getitem__ + self.strict = strict + self.parse_object = JSONObject + self.parse_array = JSONArray + self.parse_string = scanstring + self.memo = {} + self.scan_once = make_scanner(self) + + def decode(self, s, _w=WHITESPACE.match): + """Return the Python representation of ``s`` (a ``str`` or ``unicode`` + instance containing a JSON document) + + """ + obj, end = self.raw_decode(s, idx=_w(s, 0).end()) + end = _w(s, end).end() + if end != len(s): + raise JSONDecodeError("Extra data", s, end, len(s)) + return obj + + def raw_decode(self, s, idx=0): + """Decode a JSON document from ``s`` (a ``str`` or ``unicode`` + beginning with a JSON document) and return a 2-tuple of the Python + representation and the index in ``s`` where the document ended. + + This can be used to decode a JSON document from a string that may + have extraneous data at the end. + + """ + try: + obj, end = self.scan_once(s, idx) + except StopIteration: + raise JSONDecodeError("No JSON object could be decoded", s, idx) + return obj, end diff --git a/pyload/lib/simplejson/encoder.py b/pyload/lib/simplejson/encoder.py new file mode 100644 index 000000000..5ec7440f1 --- /dev/null +++ b/pyload/lib/simplejson/encoder.py @@ -0,0 +1,534 @@ +"""Implementation of JSONEncoder +""" +import re +from decimal import Decimal + +def _import_speedups(): + try: + from simplejson import _speedups + return _speedups.encode_basestring_ascii, _speedups.make_encoder + except ImportError: + return None, None +c_encode_basestring_ascii, c_make_encoder = _import_speedups() + +from simplejson.decoder import PosInf + +ESCAPE = re.compile(ur'[\x00-\x1f\\"\b\f\n\r\t\u2028\u2029]') +ESCAPE_ASCII = re.compile(r'([\\"]|[^\ -~])') +HAS_UTF8 = re.compile(r'[\x80-\xff]') +ESCAPE_DCT = { + '\\': '\\\\', + '"': '\\"', + '\b': '\\b', + '\f': '\\f', + '\n': '\\n', + '\r': '\\r', + '\t': '\\t', + u'\u2028': '\\u2028', + u'\u2029': '\\u2029', +} +for i in range(0x20): + #ESCAPE_DCT.setdefault(chr(i), '\\u{0:04x}'.format(i)) + ESCAPE_DCT.setdefault(chr(i), '\\u%04x' % (i,)) + +FLOAT_REPR = repr + +def encode_basestring(s): + """Return a JSON representation of a Python string + + """ + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = s.decode('utf-8') + def replace(match): + return ESCAPE_DCT[match.group(0)] + return u'"' + ESCAPE.sub(replace, s) + u'"' + + +def py_encode_basestring_ascii(s): + """Return an ASCII-only JSON representation of a Python string + + """ + if isinstance(s, str) and HAS_UTF8.search(s) is not None: + s = s.decode('utf-8') + def replace(match): + s = match.group(0) + try: + return ESCAPE_DCT[s] + except KeyError: + n = ord(s) + if n < 0x10000: + #return '\\u{0:04x}'.format(n) + return '\\u%04x' % (n,) + else: + # surrogate pair + n -= 0x10000 + s1 = 0xd800 | ((n >> 10) & 0x3ff) + s2 = 0xdc00 | (n & 0x3ff) + #return '\\u{0:04x}\\u{1:04x}'.format(s1, s2) + return '\\u%04x\\u%04x' % (s1, s2) + return '"' + str(ESCAPE_ASCII.sub(replace, s)) + '"' + + +encode_basestring_ascii = ( + c_encode_basestring_ascii or py_encode_basestring_ascii) + +class JSONEncoder(object): + """Extensible JSON <http://json.org> encoder for Python data structures. + + Supports the following objects and types by default: + + +-------------------+---------------+ + | Python | JSON | + +===================+===============+ + | dict, namedtuple | object | + +-------------------+---------------+ + | list, tuple | array | + +-------------------+---------------+ + | str, unicode | string | + +-------------------+---------------+ + | int, long, float | number | + +-------------------+---------------+ + | True | true | + +-------------------+---------------+ + | False | false | + +-------------------+---------------+ + | None | null | + +-------------------+---------------+ + + To extend this to recognize other objects, subclass and implement a + ``.default()`` method with another method that returns a serializable + object for ``o`` if possible, otherwise it should call the superclass + implementation (to raise ``TypeError``). + + """ + item_separator = ', ' + key_separator = ': ' + def __init__(self, skipkeys=False, ensure_ascii=True, + check_circular=True, allow_nan=True, sort_keys=False, + indent=None, separators=None, encoding='utf-8', default=None, + use_decimal=True, namedtuple_as_object=True, + tuple_as_array=True): + """Constructor for JSONEncoder, with sensible defaults. + + If skipkeys is false, then it is a TypeError to attempt + encoding of keys that are not str, int, long, float or None. If + skipkeys is True, such items are simply skipped. + + If ensure_ascii is true, the output is guaranteed to be str + objects with all incoming unicode characters escaped. If + ensure_ascii is false, the output will be unicode object. + + If check_circular is true, then lists, dicts, and custom encoded + objects will be checked for circular references during encoding to + prevent an infinite recursion (which would cause an OverflowError). + Otherwise, no such check takes place. + + If allow_nan is true, then NaN, Infinity, and -Infinity will be + encoded as such. This behavior is not JSON specification compliant, + but is consistent with most JavaScript based encoders and decoders. + Otherwise, it will be a ValueError to encode such floats. + + If sort_keys is true, then the output of dictionaries will be + sorted by key; this is useful for regression tests to ensure + that JSON serializations can be compared on a day-to-day basis. + + If indent is a string, then JSON array elements and object members + will be pretty-printed with a newline followed by that string repeated + for each level of nesting. ``None`` (the default) selects the most compact + representation without any newlines. For backwards compatibility with + versions of simplejson earlier than 2.1.0, an integer is also accepted + and is converted to a string with that many spaces. + + If specified, separators should be a (item_separator, key_separator) + tuple. The default is (', ', ': '). To get the most compact JSON + representation you should specify (',', ':') to eliminate whitespace. + + If specified, default is a function that gets called for objects + that can't otherwise be serialized. It should return a JSON encodable + version of the object or raise a ``TypeError``. + + If encoding is not None, then all input strings will be + transformed into unicode using that encoding prior to JSON-encoding. + The default is UTF-8. + + If use_decimal is true (not the default), ``decimal.Decimal`` will + be supported directly by the encoder. For the inverse, decode JSON + with ``parse_float=decimal.Decimal``. + + If namedtuple_as_object is true (the default), tuple subclasses with + ``_asdict()`` methods will be encoded as JSON objects. + + If tuple_as_array is true (the default), tuple (and subclasses) will + be encoded as JSON arrays. + """ + + self.skipkeys = skipkeys + self.ensure_ascii = ensure_ascii + self.check_circular = check_circular + self.allow_nan = allow_nan + self.sort_keys = sort_keys + self.use_decimal = use_decimal + self.namedtuple_as_object = namedtuple_as_object + self.tuple_as_array = tuple_as_array + if isinstance(indent, (int, long)): + indent = ' ' * indent + self.indent = indent + if separators is not None: + self.item_separator, self.key_separator = separators + elif indent is not None: + self.item_separator = ',' + if default is not None: + self.default = default + self.encoding = encoding + + def default(self, o): + """Implement this method in a subclass such that it returns + a serializable object for ``o``, or calls the base implementation + (to raise a ``TypeError``). + + For example, to support arbitrary iterators, you could + implement default like this:: + + def default(self, o): + try: + iterable = iter(o) + except TypeError: + pass + else: + return list(iterable) + return JSONEncoder.default(self, o) + + """ + raise TypeError(repr(o) + " is not JSON serializable") + + def encode(self, o): + """Return a JSON string representation of a Python data structure. + + >>> from simplejson import JSONEncoder + >>> JSONEncoder().encode({"foo": ["bar", "baz"]}) + '{"foo": ["bar", "baz"]}' + + """ + # This is for extremely simple cases and benchmarks. + if isinstance(o, basestring): + if isinstance(o, str): + _encoding = self.encoding + if (_encoding is not None + and not (_encoding == 'utf-8')): + o = o.decode(_encoding) + if self.ensure_ascii: + return encode_basestring_ascii(o) + else: + return encode_basestring(o) + # This doesn't pass the iterator directly to ''.join() because the + # exceptions aren't as detailed. The list call should be roughly + # equivalent to the PySequence_Fast that ''.join() would do. + chunks = self.iterencode(o, _one_shot=True) + if not isinstance(chunks, (list, tuple)): + chunks = list(chunks) + if self.ensure_ascii: + return ''.join(chunks) + else: + return u''.join(chunks) + + def iterencode(self, o, _one_shot=False): + """Encode the given object and yield each string + representation as available. + + For example:: + + for chunk in JSONEncoder().iterencode(bigobject): + mysocket.write(chunk) + + """ + if self.check_circular: + markers = {} + else: + markers = None + if self.ensure_ascii: + _encoder = encode_basestring_ascii + else: + _encoder = encode_basestring + if self.encoding != 'utf-8': + def _encoder(o, _orig_encoder=_encoder, _encoding=self.encoding): + if isinstance(o, str): + o = o.decode(_encoding) + return _orig_encoder(o) + + def floatstr(o, allow_nan=self.allow_nan, + _repr=FLOAT_REPR, _inf=PosInf, _neginf=-PosInf): + # Check for specials. Note that this type of test is processor + # and/or platform-specific, so do tests which don't depend on + # the internals. + + if o != o: + text = 'NaN' + elif o == _inf: + text = 'Infinity' + elif o == _neginf: + text = '-Infinity' + else: + return _repr(o) + + if not allow_nan: + raise ValueError( + "Out of range float values are not JSON compliant: " + + repr(o)) + + return text + + + key_memo = {} + if (_one_shot and c_make_encoder is not None + and self.indent is None): + _iterencode = c_make_encoder( + markers, self.default, _encoder, self.indent, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, self.allow_nan, key_memo, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array) + else: + _iterencode = _make_iterencode( + markers, self.default, _encoder, self.indent, floatstr, + self.key_separator, self.item_separator, self.sort_keys, + self.skipkeys, _one_shot, self.use_decimal, + self.namedtuple_as_object, self.tuple_as_array) + try: + return _iterencode(o, 0) + finally: + key_memo.clear() + + +class JSONEncoderForHTML(JSONEncoder): + """An encoder that produces JSON safe to embed in HTML. + + To embed JSON content in, say, a script tag on a web page, the + characters &, < and > should be escaped. They cannot be escaped + with the usual entities (e.g. &) because they are not expanded + within <script> tags. + """ + + def encode(self, o): + # Override JSONEncoder.encode because it has hacks for + # performance that make things more complicated. + chunks = self.iterencode(o, True) + if self.ensure_ascii: + return ''.join(chunks) + else: + return u''.join(chunks) + + def iterencode(self, o, _one_shot=False): + chunks = super(JSONEncoderForHTML, self).iterencode(o, _one_shot) + for chunk in chunks: + chunk = chunk.replace('&', '\\u0026') + chunk = chunk.replace('<', '\\u003c') + chunk = chunk.replace('>', '\\u003e') + yield chunk + + +def _make_iterencode(markers, _default, _encoder, _indent, _floatstr, + _key_separator, _item_separator, _sort_keys, _skipkeys, _one_shot, + _use_decimal, _namedtuple_as_object, _tuple_as_array, + ## HACK: hand-optimized bytecode; turn globals into locals + False=False, + True=True, + ValueError=ValueError, + basestring=basestring, + Decimal=Decimal, + dict=dict, + float=float, + id=id, + int=int, + isinstance=isinstance, + list=list, + long=long, + str=str, + tuple=tuple, + ): + + def _iterencode_list(lst, _current_indent_level): + if not lst: + yield '[]' + return + if markers is not None: + markerid = id(lst) + if markerid in markers: + raise ValueError("Circular reference detected") + markers[markerid] = lst + buf = '[' + if _indent is not None: + _current_indent_level += 1 + newline_indent = '\n' + (_indent * _current_indent_level) + separator = _item_separator + newline_indent + buf += newline_indent + else: + newline_indent = None + separator = _item_separator + first = True + for value in lst: + if first: + first = False + else: + buf = separator + if isinstance(value, basestring): + yield buf + _encoder(value) + elif value is None: + yield buf + 'null' + elif value is True: + yield buf + 'true' + elif value is False: + yield buf + 'false' + elif isinstance(value, (int, long)): + yield buf + str(value) + elif isinstance(value, float): + yield buf + _floatstr(value) + elif _use_decimal and isinstance(value, Decimal): + yield buf + str(value) + else: + yield buf + if isinstance(value, list): + chunks = _iterencode_list(value, _current_indent_level) + elif (_namedtuple_as_object and isinstance(value, tuple) and + hasattr(value, '_asdict')): + chunks = _iterencode_dict(value._asdict(), + _current_indent_level) + elif _tuple_as_array and isinstance(value, tuple): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) + for chunk in chunks: + yield chunk + if newline_indent is not None: + _current_indent_level -= 1 + yield '\n' + (_indent * _current_indent_level) + yield ']' + if markers is not None: + del markers[markerid] + + def _iterencode_dict(dct, _current_indent_level): + if not dct: + yield '{}' + return + if markers is not None: + markerid = id(dct) + if markerid in markers: + raise ValueError("Circular reference detected") + markers[markerid] = dct + yield '{' + if _indent is not None: + _current_indent_level += 1 + newline_indent = '\n' + (_indent * _current_indent_level) + item_separator = _item_separator + newline_indent + yield newline_indent + else: + newline_indent = None + item_separator = _item_separator + first = True + if _sort_keys: + items = dct.items() + items.sort(key=lambda kv: kv[0]) + else: + items = dct.iteritems() + for key, value in items: + if isinstance(key, basestring): + pass + # JavaScript is weakly typed for these, so it makes sense to + # also allow them. Many encoders seem to do something like this. + elif isinstance(key, float): + key = _floatstr(key) + elif key is True: + key = 'true' + elif key is False: + key = 'false' + elif key is None: + key = 'null' + elif isinstance(key, (int, long)): + key = str(key) + elif _skipkeys: + continue + else: + raise TypeError("key " + repr(key) + " is not a string") + if first: + first = False + else: + yield item_separator + yield _encoder(key) + yield _key_separator + if isinstance(value, basestring): + yield _encoder(value) + elif value is None: + yield 'null' + elif value is True: + yield 'true' + elif value is False: + yield 'false' + elif isinstance(value, (int, long)): + yield str(value) + elif isinstance(value, float): + yield _floatstr(value) + elif _use_decimal and isinstance(value, Decimal): + yield str(value) + else: + if isinstance(value, list): + chunks = _iterencode_list(value, _current_indent_level) + elif (_namedtuple_as_object and isinstance(value, tuple) and + hasattr(value, '_asdict')): + chunks = _iterencode_dict(value._asdict(), + _current_indent_level) + elif _tuple_as_array and isinstance(value, tuple): + chunks = _iterencode_list(value, _current_indent_level) + elif isinstance(value, dict): + chunks = _iterencode_dict(value, _current_indent_level) + else: + chunks = _iterencode(value, _current_indent_level) + for chunk in chunks: + yield chunk + if newline_indent is not None: + _current_indent_level -= 1 + yield '\n' + (_indent * _current_indent_level) + yield '}' + if markers is not None: + del markers[markerid] + + def _iterencode(o, _current_indent_level): + if isinstance(o, basestring): + yield _encoder(o) + elif o is None: + yield 'null' + elif o is True: + yield 'true' + elif o is False: + yield 'false' + elif isinstance(o, (int, long)): + yield str(o) + elif isinstance(o, float): + yield _floatstr(o) + elif isinstance(o, list): + for chunk in _iterencode_list(o, _current_indent_level): + yield chunk + elif (_namedtuple_as_object and isinstance(o, tuple) and + hasattr(o, '_asdict')): + for chunk in _iterencode_dict(o._asdict(), _current_indent_level): + yield chunk + elif (_tuple_as_array and isinstance(o, tuple)): + for chunk in _iterencode_list(o, _current_indent_level): + yield chunk + elif isinstance(o, dict): + for chunk in _iterencode_dict(o, _current_indent_level): + yield chunk + elif _use_decimal and isinstance(o, Decimal): + yield str(o) + else: + if markers is not None: + markerid = id(o) + if markerid in markers: + raise ValueError("Circular reference detected") + markers[markerid] = o + o = _default(o) + for chunk in _iterencode(o, _current_indent_level): + yield chunk + if markers is not None: + del markers[markerid] + + return _iterencode diff --git a/pyload/lib/simplejson/ordered_dict.py b/pyload/lib/simplejson/ordered_dict.py new file mode 100644 index 000000000..87ad88824 --- /dev/null +++ b/pyload/lib/simplejson/ordered_dict.py @@ -0,0 +1,119 @@ +"""Drop-in replacement for collections.OrderedDict by Raymond Hettinger + +http://code.activestate.com/recipes/576693/ + +""" +from UserDict import DictMixin + +# Modified from original to support Python 2.4, see +# http://code.google.com/p/simplejson/issues/detail?id=53 +try: + all +except NameError: + def all(seq): + for elem in seq: + if not elem: + return False + return True + +class OrderedDict(dict, DictMixin): + + def __init__(self, *args, **kwds): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + try: + self.__end + except AttributeError: + self.clear() + self.update(*args, **kwds) + + def clear(self): + self.__end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.__map = {} # key --> [key, prev, next] + dict.clear(self) + + def __setitem__(self, key, value): + if key not in self: + end = self.__end + curr = end[1] + curr[2] = end[1] = self.__map[key] = [key, curr, end] + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + dict.__delitem__(self, key) + key, prev, next = self.__map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.__end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.__end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def popitem(self, last=True): + if not self: + raise KeyError('dictionary is empty') + # Modified from original to support Python 2.4, see + # http://code.google.com/p/simplejson/issues/detail?id=53 + if last: + key = reversed(self).next() + else: + key = iter(self).next() + value = self.pop(key) + return key, value + + def __reduce__(self): + items = [[k, self[k]] for k in self] + tmp = self.__map, self.__end + del self.__map, self.__end + inst_dict = vars(self).copy() + self.__map, self.__end = tmp + if inst_dict: + return (self.__class__, (items,), inst_dict) + return self.__class__, (items,) + + def keys(self): + return list(self) + + setdefault = DictMixin.setdefault + update = DictMixin.update + pop = DictMixin.pop + values = DictMixin.values + items = DictMixin.items + iterkeys = DictMixin.iterkeys + itervalues = DictMixin.itervalues + iteritems = DictMixin.iteritems + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, self.items()) + + def copy(self): + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + if isinstance(other, OrderedDict): + return len(self)==len(other) and \ + all(p==q for p, q in zip(self.items(), other.items())) + return dict.__eq__(self, other) + + def __ne__(self, other): + return not self == other diff --git a/pyload/lib/simplejson/scanner.py b/pyload/lib/simplejson/scanner.py new file mode 100644 index 000000000..54593a371 --- /dev/null +++ b/pyload/lib/simplejson/scanner.py @@ -0,0 +1,77 @@ +"""JSON token scanner +""" +import re +def _import_c_make_scanner(): + try: + from simplejson._speedups import make_scanner + return make_scanner + except ImportError: + return None +c_make_scanner = _import_c_make_scanner() + +__all__ = ['make_scanner'] + +NUMBER_RE = re.compile( + r'(-?(?:0|[1-9]\d*))(\.\d+)?([eE][-+]?\d+)?', + (re.VERBOSE | re.MULTILINE | re.DOTALL)) + +def py_make_scanner(context): + parse_object = context.parse_object + parse_array = context.parse_array + parse_string = context.parse_string + match_number = NUMBER_RE.match + encoding = context.encoding + strict = context.strict + parse_float = context.parse_float + parse_int = context.parse_int + parse_constant = context.parse_constant + object_hook = context.object_hook + object_pairs_hook = context.object_pairs_hook + memo = context.memo + + def _scan_once(string, idx): + try: + nextchar = string[idx] + except IndexError: + raise StopIteration + + if nextchar == '"': + return parse_string(string, idx + 1, encoding, strict) + elif nextchar == '{': + return parse_object((string, idx + 1), encoding, strict, + _scan_once, object_hook, object_pairs_hook, memo) + elif nextchar == '[': + return parse_array((string, idx + 1), _scan_once) + elif nextchar == 'n' and string[idx:idx + 4] == 'null': + return None, idx + 4 + elif nextchar == 't' and string[idx:idx + 4] == 'true': + return True, idx + 4 + elif nextchar == 'f' and string[idx:idx + 5] == 'false': + return False, idx + 5 + + m = match_number(string, idx) + if m is not None: + integer, frac, exp = m.groups() + if frac or exp: + res = parse_float(integer + (frac or '') + (exp or '')) + else: + res = parse_int(integer) + return res, m.end() + elif nextchar == 'N' and string[idx:idx + 3] == 'NaN': + return parse_constant('NaN'), idx + 3 + elif nextchar == 'I' and string[idx:idx + 8] == 'Infinity': + return parse_constant('Infinity'), idx + 8 + elif nextchar == '-' and string[idx:idx + 9] == '-Infinity': + return parse_constant('-Infinity'), idx + 9 + else: + raise StopIteration + + def scan_once(string, idx): + try: + return _scan_once(string, idx) + finally: + memo.clear() + + return scan_once + +make_scanner = c_make_scanner or py_make_scanner diff --git a/pyload/lib/simplejson/tool.py b/pyload/lib/simplejson/tool.py new file mode 100644 index 000000000..73370db55 --- /dev/null +++ b/pyload/lib/simplejson/tool.py @@ -0,0 +1,39 @@ +r"""Command-line tool to validate and pretty-print JSON + +Usage:: + + $ echo '{"json":"obj"}' | python -m simplejson.tool + { + "json": "obj" + } + $ echo '{ 1.2:3.4}' | python -m simplejson.tool + Expecting property name: line 1 column 2 (char 2) + +""" +import sys +import simplejson as json + +def main(): + if len(sys.argv) == 1: + infile = sys.stdin + outfile = sys.stdout + elif len(sys.argv) == 2: + infile = open(sys.argv[1], 'rb') + outfile = sys.stdout + elif len(sys.argv) == 3: + infile = open(sys.argv[1], 'rb') + outfile = open(sys.argv[2], 'wb') + else: + raise SystemExit(sys.argv[0] + " [infile [outfile]]") + try: + obj = json.load(infile, + object_pairs_hook=json.OrderedDict, + use_decimal=True) + except ValueError, e: + raise SystemExit(e) + json.dump(obj, outfile, sort_keys=True, indent=' ', use_decimal=True) + outfile.write('\n') + + +if __name__ == '__main__': + main() diff --git a/pyload/lib/wsgiserver/LICENSE.txt b/pyload/lib/wsgiserver/LICENSE.txt new file mode 100644 index 000000000..a15165ee2 --- /dev/null +++ b/pyload/lib/wsgiserver/LICENSE.txt @@ -0,0 +1,25 @@ +Copyright (c) 2004-2007, CherryPy Team (team@cherrypy.org) +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + * Neither the name of the CherryPy Team nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pyload/lib/wsgiserver/__init__.py b/pyload/lib/wsgiserver/__init__.py new file mode 100644 index 000000000..c380e18b0 --- /dev/null +++ b/pyload/lib/wsgiserver/__init__.py @@ -0,0 +1,1794 @@ +"""A high-speed, production ready, thread pooled, generic WSGI server. + +Simplest example on how to use this module directly +(without using CherryPy's application machinery): + + from cherrypy import wsgiserver + + def my_crazy_app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type','text/plain')] + start_response(status, response_headers) + return ['Hello world!\n'] + + server = wsgiserver.CherryPyWSGIServer( + ('0.0.0.0', 8070), my_crazy_app, + server_name='www.cherrypy.example') + +The CherryPy WSGI server can serve as many WSGI applications +as you want in one instance by using a WSGIPathInfoDispatcher: + + d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) + server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) + +Want SSL support? Just set these attributes: + + server.ssl_certificate = <filename> + server.ssl_private_key = <filename> + + if __name__ == '__main__': + try: + server.start() + except KeyboardInterrupt: + server.stop() + +This won't call the CherryPy engine (application side) at all, only the +WSGI server, which is independant from the rest of CherryPy. Don't +let the name "CherryPyWSGIServer" throw you; the name merely reflects +its origin, not its coupling. + +For those of you wanting to understand internals of this module, here's the +basic call flow. The server's listening thread runs a very tight loop, +sticking incoming connections onto a Queue: + + server = CherryPyWSGIServer(...) + server.start() + while True: + tick() + # This blocks until a request comes in: + child = socket.accept() + conn = HTTPConnection(child, ...) + server.requests.put(conn) + +Worker threads are kept in a pool and poll the Queue, popping off and then +handling each connection in turn. Each connection can consist of an arbitrary +number of requests and their responses, so we run a nested loop: + + while True: + conn = server.requests.get() + conn.communicate() + -> while True: + req = HTTPRequest(...) + req.parse_request() + -> # Read the Request-Line, e.g. "GET /page HTTP/1.1" + req.rfile.readline() + req.read_headers() + req.respond() + -> response = wsgi_app(...) + try: + for chunk in response: + if chunk: + req.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + if req.close_connection: + return +""" + + +import base64 +import os +import Queue +import re +quoted_slash = re.compile("(?i)%2F") +import rfc822 +import socket +try: + import cStringIO as StringIO +except ImportError: + import StringIO + +_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) + +import sys +import threading +import time +import traceback +from urllib import unquote +from urlparse import urlparse +import warnings + +try: + from OpenSSL import SSL + from OpenSSL import crypto +except ImportError: + SSL = None + +import errno + +def plat_specific_errors(*errnames): + """Return error numbers for all errors in errnames on this platform. + + The 'errno' module contains different global constants depending on + the specific platform (OS). This function will return the list of + numeric values for a given list of potential names. + """ + errno_names = dir(errno) + nums = [getattr(errno, k) for k in errnames if k in errno_names] + # de-dupe the list + return dict.fromkeys(nums).keys() + +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") + +socket_errors_to_ignore = plat_specific_errors( + "EPIPE", + "EBADF", "WSAEBADF", + "ENOTSOCK", "WSAENOTSOCK", + "ETIMEDOUT", "WSAETIMEDOUT", + "ECONNREFUSED", "WSAECONNREFUSED", + "ECONNRESET", "WSAECONNRESET", + "ECONNABORTED", "WSAECONNABORTED", + "ENETRESET", "WSAENETRESET", + "EHOSTDOWN", "EHOSTUNREACH", + ) +socket_errors_to_ignore.append("timed out") + +socket_errors_nonblocking = plat_specific_errors( + 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') + +comma_separated_headers = ['ACCEPT', 'ACCEPT-CHARSET', 'ACCEPT-ENCODING', + 'ACCEPT-LANGUAGE', 'ACCEPT-RANGES', 'ALLOW', 'CACHE-CONTROL', + 'CONNECTION', 'CONTENT-ENCODING', 'CONTENT-LANGUAGE', 'EXPECT', + 'IF-MATCH', 'IF-NONE-MATCH', 'PRAGMA', 'PROXY-AUTHENTICATE', 'TE', + 'TRAILER', 'TRANSFER-ENCODING', 'UPGRADE', 'VARY', 'VIA', 'WARNING', + 'WWW-AUTHENTICATE'] + + +class WSGIPathInfoDispatcher(object): + """A WSGI dispatcher for dispatch based on the PATH_INFO. + + apps: a dict or list of (path_prefix, app) pairs. + """ + + def __init__(self, apps): + try: + apps = apps.items() + except AttributeError: + pass + + # Sort the apps by len(path), descending + apps.sort() + apps.reverse() + + # The path_prefix strings must start, but not end, with a slash. + # Use "" instead of "/". + self.apps = [(p.rstrip("/"), a) for p, a in apps] + + def __call__(self, environ, start_response): + path = environ["PATH_INFO"] or "/" + for p, app in self.apps: + # The apps list should be sorted by length, descending. + if path.startswith(p + "/") or path == p: + environ = environ.copy() + environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p + environ["PATH_INFO"] = path[len(p):] + return app(environ, start_response) + + start_response('404 Not Found', [('Content-Type', 'text/plain'), + ('Content-Length', '0')]) + return [''] + + +class MaxSizeExceeded(Exception): + pass + +class SizeCheckWrapper(object): + """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + + def __init__(self, rfile, maxlen): + self.rfile = rfile + self.maxlen = maxlen + self.bytes_read = 0 + + def _check_length(self): + if self.maxlen and self.bytes_read > self.maxlen: + raise MaxSizeExceeded() + + def read(self, size=None): + data = self.rfile.read(size) + self.bytes_read += len(data) + self._check_length() + return data + + def readline(self, size=None): + if size is not None: + data = self.rfile.readline(size) + self.bytes_read += len(data) + self._check_length() + return data + + # User didn't specify a size ... + # We read the line in chunks to make sure it's not a 100MB line ! + res = [] + while True: + data = self.rfile.readline(256) + self.bytes_read += len(data) + self._check_length() + res.append(data) + # See http://www.cherrypy.org/ticket/421 + if len(data) < 256 or data[-1:] == "\n": + return ''.join(res) + + def readlines(self, sizehint=0): + # Shamelessly stolen from StringIO + total = 0 + lines = [] + line = self.readline() + while line: + lines.append(line) + total += len(line) + if 0 < sizehint <= total: + break + line = self.readline() + return lines + + def close(self): + self.rfile.close() + + def __iter__(self): + return self + + def next(self): + data = self.rfile.next() + self.bytes_read += len(data) + self._check_length() + return data + + +class HTTPRequest(object): + """An HTTP Request (and response). + + A single HTTP connection may consist of multiple request/response pairs. + + send: the 'send' method from the connection's socket object. + wsgi_app: the WSGI application to call. + environ: a partial WSGI environ (server and connection entries). + The caller MUST set the following entries: + * All wsgi.* entries, including .input + * SERVER_NAME and SERVER_PORT + * Any SSL_* entries + * Any custom entries like REMOTE_ADDR and REMOTE_PORT + * SERVER_SOFTWARE: the value to write in the "Server" response header. + * ACTUAL_SERVER_PROTOCOL: the value to write in the Status-Line of + the response. From RFC 2145: "An HTTP server SHOULD send a + response version equal to the highest version for which the + server is at least conditionally compliant, and whose major + version is less than or equal to the one received in the + request. An HTTP server MUST NOT send a version for which + it is not at least conditionally compliant." + + outheaders: a list of header tuples to write in the response. + ready: when True, the request has been parsed and is ready to begin + generating the response. When False, signals the calling Connection + that the response should not be generated and the connection should + close. + close_connection: signals the calling Connection that the request + should close. This does not imply an error! The client and/or + server may each request that the connection be closed. + chunked_write: if True, output will be encoded with the "chunked" + transfer-coding. This value is set automatically inside + send_headers. + """ + + max_request_header_size = 0 + max_request_body_size = 0 + + def __init__(self, wfile, environ, wsgi_app): + self.rfile = environ['wsgi.input'] + self.wfile = wfile + self.environ = environ.copy() + self.wsgi_app = wsgi_app + + self.ready = False + self.started_response = False + self.status = "" + self.outheaders = [] + self.sent_headers = False + self.close_connection = False + self.chunked_write = False + + def parse_request(self): + """Parse the next HTTP request start-line and message-headers.""" + self.rfile.maxlen = self.max_request_header_size + self.rfile.bytes_read = 0 + + try: + self._parse_request() + except MaxSizeExceeded: + self.simple_response("413 Request Entity Too Large") + return + + def _parse_request(self): + # HTTP/1.1 connections are persistent by default. If a client + # requests a page, then idles (leaves the connection open), + # then rfile.readline() will raise socket.error("timed out"). + # Note that it does this based on the value given to settimeout(), + # and doesn't need the client to request or acknowledge the close + # (although your TCP stack might suffer for it: cf Apache's history + # with FIN_WAIT_2). + request_line = self.rfile.readline() + if not request_line: + # Force self.ready = False so the connection will close. + self.ready = False + return + + if request_line == "\r\n": + # RFC 2616 sec 4.1: "...if the server is reading the protocol + # stream at the beginning of a message and receives a CRLF + # first, it should ignore the CRLF." + # But only ignore one leading line! else we enable a DoS. + request_line = self.rfile.readline() + if not request_line: + self.ready = False + return + + environ = self.environ + + try: + method, path, req_protocol = request_line.strip().split(" ", 2) + except ValueError: + self.simple_response(400, "Malformed Request-Line") + return + + environ["REQUEST_METHOD"] = method + + # path may be an abs_path (including "http://host.domain.tld"); + scheme, location, path, params, qs, frag = urlparse(path) + + if frag: + self.simple_response("400 Bad Request", + "Illegal #fragment in Request-URI.") + return + + if scheme: + environ["wsgi.url_scheme"] = scheme + if params: + path = path + ";" + params + + environ["SCRIPT_NAME"] = "" + + # Unquote the path+params (e.g. "/this%20path" -> "this path"). + # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 + # + # But note that "...a URI must be separated into its components + # before the escaped characters within those components can be + # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 + atoms = [unquote(x) for x in quoted_slash.split(path)] + path = "%2F".join(atoms) + environ["PATH_INFO"] = path + + # Note that, like wsgiref and most other WSGI servers, + # we unquote the path but not the query string. + environ["QUERY_STRING"] = qs + + # Compare request and server HTTP protocol versions, in case our + # server does not support the requested protocol. Limit our output + # to min(req, server). We want the following output: + # request server actual written supported response + # protocol protocol response protocol feature set + # a 1.0 1.0 1.0 1.0 + # b 1.0 1.1 1.1 1.0 + # c 1.1 1.0 1.0 1.0 + # d 1.1 1.1 1.1 1.1 + # Notice that, in (b), the response will be "HTTP/1.1" even though + # the client only understands 1.0. RFC 2616 10.5.6 says we should + # only return 505 if the _major_ version is different. + rp = int(req_protocol[5]), int(req_protocol[7]) + server_protocol = environ["ACTUAL_SERVER_PROTOCOL"] + sp = int(server_protocol[5]), int(server_protocol[7]) + if sp[0] != rp[0]: + self.simple_response("505 HTTP Version Not Supported") + return + # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. + environ["SERVER_PROTOCOL"] = req_protocol + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) + + # If the Request-URI was an absoluteURI, use its location atom. + if location: + environ["SERVER_NAME"] = location + + # then all the http headers + try: + self.read_headers() + except ValueError, ex: + self.simple_response("400 Bad Request", repr(ex.args)) + return + + mrbs = self.max_request_body_size + if mrbs and int(environ.get("CONTENT_LENGTH", 0)) > mrbs: + self.simple_response("413 Request Entity Too Large") + return + + # Persistent connection support + if self.response_protocol == "HTTP/1.1": + # Both server and client are HTTP/1.1 + if environ.get("HTTP_CONNECTION", "") == "close": + self.close_connection = True + else: + # Either the server or client (or both) are HTTP/1.0 + if environ.get("HTTP_CONNECTION", "") != "Keep-Alive": + self.close_connection = True + + # Transfer-Encoding support + te = None + if self.response_protocol == "HTTP/1.1": + te = environ.get("HTTP_TRANSFER_ENCODING") + if te: + te = [x.strip().lower() for x in te.split(",") if x.strip()] + + self.chunked_read = False + + if te: + for enc in te: + if enc == "chunked": + self.chunked_read = True + else: + # Note that, even if we see "chunked", we must reject + # if there is an extension we don't recognize. + self.simple_response("501 Unimplemented") + self.close_connection = True + return + + # From PEP 333: + # "Servers and gateways that implement HTTP 1.1 must provide + # transparent support for HTTP 1.1's "expect/continue" mechanism. + # This may be done in any of several ways: + # 1. Respond to requests containing an Expect: 100-continue request + # with an immediate "100 Continue" response, and proceed normally. + # 2. Proceed with the request normally, but provide the application + # with a wsgi.input stream that will send the "100 Continue" + # response if/when the application first attempts to read from + # the input stream. The read request must then remain blocked + # until the client responds. + # 3. Wait until the client decides that the server does not support + # expect/continue, and sends the request body on its own. + # (This is suboptimal, and is not recommended.) + # + # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, + # but it seems like it would be a big slowdown for such a rare case. + if environ.get("HTTP_EXPECT", "") == "100-continue": + self.simple_response(100) + + self.ready = True + + def read_headers(self): + """Read header lines from the incoming stream.""" + environ = self.environ + + while True: + line = self.rfile.readline() + if not line: + # No more data--illegal end of headers + raise ValueError("Illegal end of headers.") + + if line == '\r\n': + # Normal end of headers + break + + if line[0] in ' \t': + # It's a continuation line. + v = line.strip() + else: + k, v = line.split(":", 1) + k, v = k.strip().upper(), v.strip() + envname = "HTTP_" + k.replace("-", "_") + + if k in comma_separated_headers: + existing = environ.get(envname) + if existing: + v = ", ".join((existing, v)) + environ[envname] = v + + ct = environ.pop("HTTP_CONTENT_TYPE", None) + if ct is not None: + environ["CONTENT_TYPE"] = ct + cl = environ.pop("HTTP_CONTENT_LENGTH", None) + if cl is not None: + environ["CONTENT_LENGTH"] = cl + + def decode_chunked(self): + """Decode the 'chunked' transfer coding.""" + cl = 0 + data = StringIO.StringIO() + while True: + line = self.rfile.readline().strip().split(";", 1) + chunk_size = int(line.pop(0), 16) + if chunk_size <= 0: + break +## if line: chunk_extension = line[0] + cl += chunk_size + data.write(self.rfile.read(chunk_size)) + crlf = self.rfile.read(2) + if crlf != "\r\n": + self.simple_response("400 Bad Request", + "Bad chunked transfer coding " + "(expected '\\r\\n', got %r)" % crlf) + return + + # Grab any trailer headers + self.read_headers() + + data.seek(0) + self.environ["wsgi.input"] = data + self.environ["CONTENT_LENGTH"] = str(cl) or "" + return True + + def respond(self): + """Call the appropriate WSGI app and write its iterable output.""" + # Set rfile.maxlen to ensure we don't read past Content-Length. + # This will also be used to read the entire request body if errors + # are raised before the app can read the body. + if self.chunked_read: + # If chunked, Content-Length will be 0. + self.rfile.maxlen = self.max_request_body_size + else: + cl = int(self.environ.get("CONTENT_LENGTH", 0)) + if self.max_request_body_size: + self.rfile.maxlen = min(cl, self.max_request_body_size) + else: + self.rfile.maxlen = cl + self.rfile.bytes_read = 0 + + try: + self._respond() + except MaxSizeExceeded: + if not self.sent_headers: + self.simple_response("413 Request Entity Too Large") + return + + def _respond(self): + if self.chunked_read: + if not self.decode_chunked(): + self.close_connection = True + return + + response = self.wsgi_app(self.environ, self.start_response) + try: + for chunk in response: + # "The start_response callable must not actually transmit + # the response headers. Instead, it must store them for the + # server or gateway to transmit only after the first + # iteration of the application return value that yields + # a NON-EMPTY string, or upon the application's first + # invocation of the write() callable." (PEP 333) + if chunk: + self.write(chunk) + finally: + if hasattr(response, "close"): + response.close() + + if (self.ready and not self.sent_headers): + self.sent_headers = True + self.send_headers() + if self.chunked_write: + self.wfile.sendall("0\r\n\r\n") + + def simple_response(self, status, msg=""): + """Write a simple response back to the client.""" + status = str(status) + buf = ["%s %s\r\n" % (self.environ['ACTUAL_SERVER_PROTOCOL'], status), + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n"] + + if status[:3] == "413" and self.response_protocol == 'HTTP/1.1': + # Request Entity Too Large + self.close_connection = True + buf.append("Connection: close\r\n") + + buf.append("\r\n") + if msg: + buf.append(msg) + + try: + self.wfile.sendall("".join(buf)) + except socket.error, x: + if x.args[0] not in socket_errors_to_ignore: + raise + + def start_response(self, status, headers, exc_info = None): + """WSGI callable to begin the HTTP response.""" + # "The application may call start_response more than once, + # if and only if the exc_info argument is provided." + if self.started_response and not exc_info: + raise AssertionError("WSGI start_response called a second " + "time with no exc_info.") + + # "if exc_info is provided, and the HTTP headers have already been + # sent, start_response must raise an error, and should raise the + # exc_info tuple." + if self.sent_headers: + try: + raise exc_info[0], exc_info[1], exc_info[2] + finally: + exc_info = None + + self.started_response = True + self.status = status + self.outheaders.extend(headers) + return self.write + + def write(self, chunk): + """WSGI callable to write unbuffered data to the client. + + This method is also used internally by start_response (to write + data from the iterable returned by the WSGI application). + """ + if not self.started_response: + raise AssertionError("WSGI write called before start_response.") + + if not self.sent_headers: + self.sent_headers = True + self.send_headers() + + if self.chunked_write and chunk: + buf = [hex(len(chunk))[2:], "\r\n", chunk, "\r\n"] + self.wfile.sendall("".join(buf)) + else: + self.wfile.sendall(chunk) + + def send_headers(self): + """Assert, process, and send the HTTP response message-headers.""" + hkeys = [key.lower() for key, value in self.outheaders] + status = int(self.status[:3]) + + if status == 413: + # Request Entity Too Large. Close conn to avoid garbage. + self.close_connection = True + elif "content-length" not in hkeys: + # "All 1xx (informational), 204 (no content), + # and 304 (not modified) responses MUST NOT + # include a message-body." So no point chunking. + if status < 200 or status in (204, 205, 304): + pass + else: + if (self.response_protocol == 'HTTP/1.1' + and self.environ["REQUEST_METHOD"] != 'HEAD'): + # Use the chunked transfer-coding + self.chunked_write = True + self.outheaders.append(("Transfer-Encoding", "chunked")) + else: + # Closing the conn is the only way to determine len. + self.close_connection = True + + if "connection" not in hkeys: + if self.response_protocol == 'HTTP/1.1': + # Both server and client are HTTP/1.1 or better + if self.close_connection: + self.outheaders.append(("Connection", "close")) + else: + # Server and/or client are HTTP/1.0 + if not self.close_connection: + self.outheaders.append(("Connection", "Keep-Alive")) + + if (not self.close_connection) and (not self.chunked_read): + # Read any remaining request body data on the socket. + # "If an origin server receives a request that does not include an + # Expect request-header field with the "100-continue" expectation, + # the request includes a request body, and the server responds + # with a final status code before reading the entire request body + # from the transport connection, then the server SHOULD NOT close + # the transport connection until it has read the entire request, + # or until the client closes the connection. Otherwise, the client + # might not reliably receive the response message. However, this + # requirement is not be construed as preventing a server from + # defending itself against denial-of-service attacks, or from + # badly broken client implementations." + size = self.rfile.maxlen - self.rfile.bytes_read + if size > 0: + self.rfile.read(size) + + if "date" not in hkeys: + self.outheaders.append(("Date", rfc822.formatdate())) + + if "server" not in hkeys: + self.outheaders.append(("Server", self.environ['SERVER_SOFTWARE'])) + + buf = [self.environ['ACTUAL_SERVER_PROTOCOL'], " ", self.status, "\r\n"] + try: + buf += [k + ": " + v + "\r\n" for k, v in self.outheaders] + except TypeError: + if not isinstance(k, str): + raise TypeError("WSGI response header key %r is not a string.") + if not isinstance(v, str): + raise TypeError("WSGI response header value %r is not a string.") + else: + raise + buf.append("\r\n") + self.wfile.sendall("".join(buf)) + + +class NoSSLError(Exception): + """Exception raised when a client speaks HTTP to an HTTPS socket.""" + pass + + +class FatalSSLAlert(Exception): + """Exception raised when the SSL implementation signals a fatal alert.""" + pass + + +if not _fileobject_uses_str_type: + class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + return self._sock.send(data) + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + return self._sock.recv(size) + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + def read(self, size=-1): + # Use max, disallow tiny reads in a loop as they are very inefficient. + # We never leave read() with any leftover data from a new recv() call + # in our internal buffer. + rbufsize = max(self._rbufsize, self.default_bufsize) + # Our use of StringIO rather than lists of string objects returned by + # recv() minimizes memory usage and fragmentation that occurs when + # rbufsize is large compared to the typical return value of recv(). + buf = self._rbuf + buf.seek(0, 2) # seek end + if size < 0: + # Read until EOF + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(rbufsize) + if not data: + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or EOF seen, whichever comes first + buf_len = buf.tell() + if buf_len >= size: + # Already have size bytes in our buffer? Extract and return. + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + left = size - buf_len + # recv() will malloc the amount of memory given as its + # parameter even though it often returns much less data + # than that. The returned data string is short lived + # as we copy it into a StringIO and free it. This avoids + # fragmentation issues on many platforms. + data = self.recv(left) + if not data: + break + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid buffer data copies when: + # - We have no data in our buffer. + # AND + # - Our call to recv returned exactly the + # number of bytes we were asked to read. + return data + if n == left: + buf.write(data) + del data # explicit free + break + assert n <= left, "recv(%d) returned %d bytes" % (left, n) + buf.write(data) + buf_len += n + del data # explicit free + #assert buf_len == buf.tell() + return buf.getvalue() + + def readline(self, size=-1): + buf = self._rbuf + buf.seek(0, 2) # seek end + if buf.tell() > 0: + # check if we already have it in our buffer + buf.seek(0) + bline = buf.readline(size) + if bline.endswith('\n') or len(bline) == size: + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return bline + del bline + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + buf.seek(0) + buffers = [buf.read()] + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + data = None + recv = self.recv + while data != "\n": + data = recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + + buf.seek(0, 2) # seek end + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + nl = data.find('\n') + if nl >= 0: + nl += 1 + buf.write(data[:nl]) + self._rbuf.write(data[nl:]) + del data + break + buf.write(data) + return buf.getvalue() + else: + # Read until size bytes or \n or EOF seen, whichever comes first + buf.seek(0, 2) # seek end + buf_len = buf.tell() + if buf_len >= size: + buf.seek(0) + rv = buf.read(size) + self._rbuf = StringIO.StringIO() + self._rbuf.write(buf.read()) + return rv + self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. + while True: + data = self.recv(self._rbufsize) + if not data: + break + left = size - buf_len + # did we just receive a newline? + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + # save the excess data to _rbuf + self._rbuf.write(data[nl:]) + if buf_len: + buf.write(data[:nl]) + break + else: + # Shortcut. Avoid data copy through buf when returning + # a substring of our first recv(). + return data[:nl] + n = len(data) + if n == size and not buf_len: + # Shortcut. Avoid data copy through buf when + # returning exactly all of our first recv(). + return data + if n >= left: + buf.write(data[:left]) + self._rbuf.write(data[left:]) + break + buf.write(data) + buf_len += n + #assert buf_len == buf.tell() + return buf.getvalue() + +else: + class CP_fileobject(socket._fileobject): + """Faux file object attached to a socket object.""" + + def sendall(self, data): + """Sendall for non-blocking sockets.""" + while data: + try: + bytes_sent = self.send(data) + data = data[bytes_sent:] + except socket.error, e: + if e.args[0] not in socket_errors_nonblocking: + raise + + def send(self, data): + return self._sock.send(data) + + def flush(self): + if self._wbuf: + buffer = "".join(self._wbuf) + self._wbuf = [] + self.sendall(buffer) + + def recv(self, size): + while True: + try: + return self._sock.recv(size) + except socket.error, e: + if (e.args[0] not in socket_errors_nonblocking + and e.args[0] not in socket_error_eintr): + raise + + def read(self, size=-1): + if size < 0: + # Read until EOF + buffers = [self._rbuf] + self._rbuf = "" + if self._rbufsize <= 1: + recv_size = self.default_bufsize + else: + recv_size = self._rbufsize + + while True: + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + return "".join(buffers) + else: + # Read until size bytes or EOF seen, whichever comes first + data = self._rbuf + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + left = size - buf_len + recv_size = max(self._rbufsize, left) + data = self.recv(recv_size) + if not data: + break + buffers.append(data) + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + def readline(self, size=-1): + data = self._rbuf + if size < 0: + # Read until \n or EOF, whichever comes first + if self._rbufsize <= 1: + # Speed up unbuffered case + assert data == "" + buffers = [] + while data != "\n": + data = self.recv(1) + if not data: + break + buffers.append(data) + return "".join(buffers) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + nl = data.find('\n') + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + return "".join(buffers) + else: + # Read until size bytes or \n or EOF seen, whichever comes first + nl = data.find('\n', 0, size) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + return data[:nl] + buf_len = len(data) + if buf_len >= size: + self._rbuf = data[size:] + return data[:size] + buffers = [] + if data: + buffers.append(data) + self._rbuf = "" + while True: + data = self.recv(self._rbufsize) + if not data: + break + buffers.append(data) + left = size - buf_len + nl = data.find('\n', 0, left) + if nl >= 0: + nl += 1 + self._rbuf = data[nl:] + buffers[-1] = data[:nl] + break + n = len(data) + if n >= left: + self._rbuf = data[left:] + buffers[-1] = data[:left] + break + buf_len += n + return "".join(buffers) + + +class SSL_fileobject(CP_fileobject): + """SSL file object attached to a socket object.""" + + ssl_timeout = 3 + ssl_retry = .01 + + def _safe_call(self, is_reader, call, *args, **kwargs): + """Wrap the given call with SSL error-trapping. + + is_reader: if False EOF errors will be raised. If True, EOF errors + will return "" (to emulate normal sockets). + """ + start = time.time() + while True: + try: + return call(*args, **kwargs) + except SSL.WantReadError: + # Sleep and try again. This is dangerous, because it means + # the rest of the stack has no way of differentiating + # between a "new handshake" error and "client dropped". + # Note this isn't an endless loop: there's a timeout below. + time.sleep(self.ssl_retry) + except SSL.WantWriteError: + time.sleep(self.ssl_retry) + except SSL.SysCallError, e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return "" + + errnum = e.args[0] + if is_reader and errnum in socket_errors_to_ignore: + return "" + raise socket.error(errnum) + except SSL.Error, e: + if is_reader and e.args == (-1, 'Unexpected EOF'): + return "" + + thirdarg = None + try: + thirdarg = e.args[0][0][2] + except IndexError: + pass + + if thirdarg == 'http request': + # The client is talking HTTP to an HTTPS server. + raise NoSSLError() + raise FatalSSLAlert(*e.args) + except: + raise + + if time.time() - start > self.ssl_timeout: + raise socket.timeout("timed out") + + def recv(self, *args, **kwargs): + buf = [] + r = super(SSL_fileobject, self).recv + while True: + data = self._safe_call(True, r, *args, **kwargs) + buf.append(data) + p = self._sock.pending() + if not p: + return "".join(buf) + + def sendall(self, *args, **kwargs): + return self._safe_call(False, super(SSL_fileobject, self).sendall, *args, **kwargs) + + def send(self, *args, **kwargs): + return self._safe_call(False, super(SSL_fileobject, self).send, *args, **kwargs) + + +class HTTPConnection(object): + """An HTTP connection (active socket). + + socket: the raw socket object (usually TCP) for this connection. + wsgi_app: the WSGI application for this server/connection. + environ: a WSGI environ template. This will be copied for each request. + + rfile: a fileobject for reading from the socket. + send: a function for writing (+ flush) to the socket. + """ + + rbufsize = -1 + RequestHandlerClass = HTTPRequest + environ = {"wsgi.version": (1, 0), + "wsgi.url_scheme": "http", + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "wsgi.errors": sys.stderr, + } + + def __init__(self, sock, wsgi_app, environ): + self.socket = sock + self.wsgi_app = wsgi_app + + # Copy the class environ into self. + self.environ = self.environ.copy() + self.environ.update(environ) + + if SSL and isinstance(sock, SSL.ConnectionType): + timeout = sock.gettimeout() + self.rfile = SSL_fileobject(sock, "rb", self.rbufsize) + self.rfile.ssl_timeout = timeout + self.wfile = SSL_fileobject(sock, "wb", -1) + self.wfile.ssl_timeout = timeout + else: + self.rfile = CP_fileobject(sock, "rb", self.rbufsize) + self.wfile = CP_fileobject(sock, "wb", -1) + + # Wrap wsgi.input but not HTTPConnection.rfile itself. + # We're also not setting maxlen yet; we'll do that separately + # for headers and body for each iteration of self.communicate + # (if maxlen is 0 the wrapper doesn't check length). + self.environ["wsgi.input"] = SizeCheckWrapper(self.rfile, 0) + + def communicate(self): + """Read each request and respond appropriately.""" + try: + while True: + # (re)set req to None so that if something goes wrong in + # the RequestHandlerClass constructor, the error doesn't + # get written to the previous request. + req = None + req = self.RequestHandlerClass(self.wfile, self.environ, + self.wsgi_app) + + # This order of operations should guarantee correct pipelining. + req.parse_request() + if not req.ready: + return + + req.respond() + if req.close_connection: + return + + except socket.error, e: + errnum = e.args[0] + if errnum == 'timed out': + if req and not req.sent_headers: + req.simple_response("408 Request Timeout") + elif errnum not in socket_errors_to_ignore: + if req and not req.sent_headers: + req.simple_response("500 Internal Server Error", + format_exc()) + return + except (KeyboardInterrupt, SystemExit): + raise + except FatalSSLAlert, e: + # Close the connection. + return + except NoSSLError: + if req and not req.sent_headers: + # Unwrap our wfile + req.wfile = CP_fileobject(self.socket._sock, "wb", -1) + req.simple_response("400 Bad Request", + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port.") + self.linger = True + except Exception, e: + if req and not req.sent_headers: + req.simple_response("500 Internal Server Error", format_exc()) + + linger = False + + def close(self): + """Close the socket underlying this connection.""" + self.rfile.close() + + if not self.linger: + # Python's socket module does NOT call close on the kernel socket + # when you call socket.close(). We do so manually here because we + # want this server to send a FIN TCP segment immediately. Note this + # must be called *before* calling socket.close(), because the latter + # drops its reference to the kernel socket. + self.socket._sock.close() + self.socket.close() + else: + # On the other hand, sometimes we want to hang around for a bit + # to make sure the client has a chance to read our entire + # response. Skipping the close() calls here delays the FIN + # packet until the socket object is garbage-collected later. + # Someday, perhaps, we'll do the full lingering_close that + # Apache does, but not today. + pass + + +def format_exc(limit=None): + """Like print_exc() but return a string. Backport for Python 2.3.""" + try: + etype, value, tb = sys.exc_info() + return ''.join(traceback.format_exception(etype, value, tb, limit)) + finally: + etype = value = tb = None + + +_SHUTDOWNREQUEST = None + +class WorkerThread(threading.Thread): + """Thread which continuously polls a Queue for Connection objects. + + server: the HTTP Server which spawned this thread, and which owns the + Queue and is placing active connections into it. + ready: a simple flag for the calling server to know when this thread + has begun polling the Queue. + + Due to the timing issues of polling a Queue, a WorkerThread does not + check its own 'ready' flag after it has started. To stop the thread, + it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue + (one for each running WorkerThread). + """ + + conn = None + + def __init__(self, server): + self.ready = False + self.server = server + threading.Thread.__init__(self) + + def run(self): + try: + self.ready = True + while True: + conn = self.server.requests.get() + if conn is _SHUTDOWNREQUEST: + return + + self.conn = conn + try: + conn.communicate() + finally: + conn.close() + self.conn = None + except (KeyboardInterrupt, SystemExit), exc: + self.server.interrupt = exc + + +class ThreadPool(object): + """A Request Queue for the CherryPyWSGIServer which pools threads. + + ThreadPool objects must provide min, get(), put(obj), start() + and stop(timeout) attributes. + """ + + def __init__(self, server, min=10, max=-1): + self.server = server + self.min = min + self.max = max + self._threads = [] + self._queue = Queue.Queue() + self.get = self._queue.get + + def start(self): + """Start the pool of threads.""" + for i in xrange(self.min): + self._threads.append(WorkerThread(self.server)) + for worker in self._threads: + worker.setName("CP WSGIServer " + worker.getName()) + worker.start() + for worker in self._threads: + while not worker.ready: + time.sleep(.1) + + def _get_idle(self): + """Number of worker threads which are idle. Read-only.""" + return len([t for t in self._threads if t.conn is None]) + idle = property(_get_idle, doc=_get_idle.__doc__) + + def put(self, obj): + self._queue.put(obj) + if obj is _SHUTDOWNREQUEST: + return + + def grow(self, amount): + """Spawn new worker threads (not above self.max).""" + for i in xrange(amount): + if self.max > 0 and len(self._threads) >= self.max: + break + worker = WorkerThread(self.server) + worker.setName("CP WSGIServer " + worker.getName()) + self._threads.append(worker) + worker.start() + + def shrink(self, amount): + """Kill off worker threads (not below self.min).""" + # Grow/shrink the pool if necessary. + # Remove any dead threads from our list + for t in self._threads: + if not t.isAlive(): + self._threads.remove(t) + amount -= 1 + + if amount > 0: + for i in xrange(min(amount, len(self._threads) - self.min)): + # Put a number of shutdown requests on the queue equal + # to 'amount'. Once each of those is processed by a worker, + # that worker will terminate and be culled from our list + # in self.put. + self._queue.put(_SHUTDOWNREQUEST) + + def stop(self, timeout=5): + # Must shut down threads here so the code that calls + # this method can know when all threads are stopped. + for worker in self._threads: + self._queue.put(_SHUTDOWNREQUEST) + + # Don't join currentThread (when stop is called inside a request). + current = threading.currentThread() + while self._threads: + worker = self._threads.pop() + if worker is not current and worker.isAlive(): + try: + if timeout is None or timeout < 0: + worker.join() + else: + worker.join(timeout) + if worker.isAlive(): + # We exhausted the timeout. + # Forcibly shut down the socket. + c = worker.conn + if c and not c.rfile.closed: + if SSL and isinstance(c.socket, SSL.ConnectionType): + # pyOpenSSL.socket.shutdown takes no args + c.socket.shutdown() + else: + c.socket.shutdown(socket.SHUT_RD) + worker.join() + except (AssertionError, + # Ignore repeated Ctrl-C. + # See http://www.cherrypy.org/ticket/691. + KeyboardInterrupt), exc1: + pass + + + +class SSLConnection: + """A thread-safe wrapper for an SSL.Connection. + + *args: the arguments to create the wrapped SSL.Connection(*args). + """ + + def __init__(self, *args): + self._ssl_conn = SSL.Connection(*args) + self._lock = threading.RLock() + + for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read', + 'renegotiate', 'bind', 'listen', 'connect', 'accept', + 'setblocking', 'fileno', 'shutdown', 'close', 'get_cipher_list', + 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', + 'makefile', 'get_app_data', 'set_app_data', 'state_string', + 'sock_shutdown', 'get_peer_certificate', 'want_read', + 'want_write', 'set_connect_state', 'set_accept_state', + 'connect_ex', 'sendall', 'settimeout'): + exec """def %s(self, *args): + self._lock.acquire() + try: + return self._ssl_conn.%s(*args) + finally: + self._lock.release() +""" % (f, f) + + +try: + import fcntl +except ImportError: + try: + from ctypes import windll, WinError + except ImportError: + def prevent_socket_inheritance(sock): + """Dummy function, since neither fcntl nor ctypes are available.""" + pass + else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (Windows).""" + if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): + raise WinError() +else: + def prevent_socket_inheritance(sock): + """Mark the given socket fd as non-inheritable (POSIX).""" + fd = sock.fileno() + old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) + + +class CherryPyWSGIServer(object): + """An HTTP server for WSGI. + + bind_addr: The interface on which to listen for connections. + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string. + wsgi_app: the WSGI 'application callable'; multiple WSGI applications + may be passed as (path_prefix, app) pairs. + numthreads: the number of worker threads to create (default 10). + server_name: the string to set for WSGI's SERVER_NAME environ entry. + Defaults to socket.gethostname(). + max: the maximum number of queued requests (defaults to -1 = no limit). + request_queue_size: the 'backlog' argument to socket.listen(); + specifies the maximum number of queued connections (default 5). + timeout: the timeout in seconds for accepted connections (default 10). + + nodelay: if True (the default since 3.1), sets the TCP_NODELAY socket + option. + + protocol: the version string to write in the Status-Line of all + HTTP responses. For example, "HTTP/1.1" (the default). This + also limits the supported features used in the response. + + + SSL/HTTPS + --------- + The OpenSSL module must be importable for SSL functionality. + You can obtain it from http://pyopenssl.sourceforge.net/ + + ssl_certificate: the filename of the server SSL certificate. + ssl_privatekey: the filename of the server's private key file. + + If either of these is None (both are None by default), this server + will not use SSL. If both are given and are valid, they will be read + on server start and used in the SSL context for the listening socket. + """ + + protocol = "HTTP/1.1" + _bind_addr = "127.0.0.1" + version = "CherryPy/3.1.2" + ready = False + _interrupt = None + + nodelay = True + + ConnectionClass = HTTPConnection + environ = {} + + # Paths to certificate and private key files + ssl_certificate = None + ssl_private_key = None + + def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, + max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): + self.requests = ThreadPool(self, min=numthreads or 1, max=max) + + if callable(wsgi_app): + # We've been handed a single wsgi_app, in CP-2.1 style. + # Assume it's mounted at "". + self.wsgi_app = wsgi_app + else: + # We've been handed a list of (path_prefix, wsgi_app) tuples, + # so that the server can call different wsgi_apps, and also + # correctly set SCRIPT_NAME. + warnings.warn("The ability to pass multiple apps is deprecated " + "and will be removed in 3.2. You should explicitly " + "include a WSGIPathInfoDispatcher instead.", + DeprecationWarning) + self.wsgi_app = WSGIPathInfoDispatcher(wsgi_app) + + self.bind_addr = bind_addr + if not server_name: + server_name = socket.gethostname() + self.server_name = server_name + self.request_queue_size = request_queue_size + + self.timeout = timeout + self.shutdown_timeout = shutdown_timeout + + def _get_numthreads(self): + return self.requests.min + def _set_numthreads(self, value): + self.requests.min = value + numthreads = property(_get_numthreads, _set_numthreads) + + def __str__(self): + return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, + self.bind_addr) + + def _get_bind_addr(self): + return self._bind_addr + def _set_bind_addr(self, value): + if isinstance(value, tuple) and value[0] in ('', None): + # Despite the socket module docs, using '' does not + # allow AI_PASSIVE to work. Passing None instead + # returns '0.0.0.0' like we want. In other words: + # host AI_PASSIVE result + # '' Y 192.168.x.y + # '' N 192.168.x.y + # None Y 0.0.0.0 + # None N 127.0.0.1 + # But since you can get the same effect with an explicit + # '0.0.0.0', we deny both the empty string and None as values. + raise ValueError("Host values of '' or None are not allowed. " + "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " + "to listen on all active interfaces.") + self._bind_addr = value + bind_addr = property(_get_bind_addr, _set_bind_addr, + doc="""The interface on which to listen for connections. + + For TCP sockets, a (host, port) tuple. Host values may be any IPv4 + or IPv6 address, or any valid hostname. The string 'localhost' is a + synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). + The string '0.0.0.0' is a special IPv4 entry meaning "any active + interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for + IPv6. The empty string or None are not allowed. + + For UNIX sockets, supply the filename as a string.""") + + def start(self): + """Run the server forever.""" + # We don't have to trap KeyboardInterrupt or SystemExit here, + # because cherrpy.server already does so, calling self.stop() for us. + # If you're using this server with another framework, you should + # trap those exceptions in whatever code block calls start(). + self._interrupt = None + + # Select the appropriate socket + if isinstance(self.bind_addr, basestring): + # AF_UNIX socket + + # So we can reuse the socket... + try: os.unlink(self.bind_addr) + except: pass + + # So everyone can access the socket... + try: os.chmod(self.bind_addr, 0777) + except: pass + + info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] + else: + # AF_INET or AF_INET6 socket + # Get the correct address family for our host (allows IPv6 addresses) + host, port = self.bind_addr + try: + info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM, 0, socket.AI_PASSIVE) + except socket.gaierror: + # Probably a DNS issue. Assume IPv4. + info = [(socket.AF_INET, socket.SOCK_STREAM, 0, "", self.bind_addr)] + + self.socket = None + msg = "No socket could be created" + for res in info: + af, socktype, proto, canonname, sa = res + try: + self.bind(af, socktype, proto) + except socket.error, msg: + if self.socket: + self.socket.close() + self.socket = None + continue + break + if not self.socket: + raise socket.error, msg + + # Timeout so KeyboardInterrupt can be caught on Win32 + self.socket.settimeout(1) + self.socket.listen(self.request_queue_size) + + # Create worker threads + self.requests.start() + + self.ready = True + while self.ready: + self.tick() + if self.interrupt: + while self.interrupt is True: + # Wait for self.stop() to complete. See _set_interrupt. + time.sleep(0.1) + if self.interrupt: + raise self.interrupt + + def bind(self, family, type, proto=0): + """Create (or recreate) the actual socket object.""" + self.socket = socket.socket(family, type, proto) + prevent_socket_inheritance(self.socket) + self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if self.nodelay: + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + if self.ssl_certificate and self.ssl_private_key: + if SSL is None: + raise ImportError("You must install pyOpenSSL to use HTTPS.") + + # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473 + ctx = SSL.Context(SSL.SSLv23_METHOD) + ctx.use_privatekey_file(self.ssl_private_key) + ctx.use_certificate_file(self.ssl_certificate) + self.socket = SSLConnection(ctx, self.socket) + self.populate_ssl_environ() + + # If listening on the IPV6 any address ('::' = IN6ADDR_ANY), + # activate dual-stack. See http://www.cherrypy.org/ticket/871. + if (not isinstance(self.bind_addr, basestring) + and self.bind_addr[0] == '::' and family == socket.AF_INET6): + try: + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + # Apparently, the socket option is not available in + # this machine's TCP stack + pass + + self.socket.bind(self.bind_addr) + + def tick(self): + """Accept a new connection and put it on the Queue.""" + try: + s, addr = self.socket.accept() + prevent_socket_inheritance(s) + if not self.ready: + return + if hasattr(s, 'settimeout'): + s.settimeout(self.timeout) + + environ = self.environ.copy() + # SERVER_SOFTWARE is common for IIS. It's also helpful for + # us to pass a default value for the "Server" response header. + if environ.get("SERVER_SOFTWARE") is None: + environ["SERVER_SOFTWARE"] = "%s WSGI Server" % self.version + # set a non-standard environ entry so the WSGI app can know what + # the *real* server protocol is (and what features to support). + # See http://www.faqs.org/rfcs/rfc2145.html. + environ["ACTUAL_SERVER_PROTOCOL"] = self.protocol + environ["SERVER_NAME"] = self.server_name + + if isinstance(self.bind_addr, basestring): + # AF_UNIX. This isn't really allowed by WSGI, which doesn't + # address unix domain sockets. But it's better than nothing. + environ["SERVER_PORT"] = "" + else: + environ["SERVER_PORT"] = str(self.bind_addr[1]) + # optional values + # Until we do DNS lookups, omit REMOTE_HOST + environ["REMOTE_ADDR"] = addr[0] + environ["REMOTE_PORT"] = str(addr[1]) + + conn = self.ConnectionClass(s, self.wsgi_app, environ) + self.requests.put(conn) + except socket.timeout: + # The only reason for the timeout in start() is so we can + # notice keyboard interrupts on Win32, which don't interrupt + # accept() by default + return + except socket.error, x: + if x.args[0] in socket_error_eintr: + # I *think* this is right. EINTR should occur when a signal + # is received during the accept() call; all docs say retry + # the call, and I *think* I'm reading it right that Python + # will then go ahead and poll for and handle the signal + # elsewhere. See http://www.cherrypy.org/ticket/707. + return + if x.args[0] in socket_errors_nonblocking: + # Just try again. See http://www.cherrypy.org/ticket/479. + return + if x.args[0] in socket_errors_to_ignore: + # Our socket was closed. + # See http://www.cherrypy.org/ticket/686. + return + raise + + def _get_interrupt(self): + return self._interrupt + def _set_interrupt(self, interrupt): + self._interrupt = True + self.stop() + self._interrupt = interrupt + interrupt = property(_get_interrupt, _set_interrupt, + doc="Set this to an Exception instance to " + "interrupt the server.") + + def stop(self): + """Gracefully shutdown a server that is serving forever.""" + self.ready = False + + sock = getattr(self, "socket", None) + if sock: + if not isinstance(self.bind_addr, basestring): + # Touch our own socket to make accept() return immediately. + try: + host, port = sock.getsockname()[:2] + except socket.error, x: + if x.args[0] not in socket_errors_to_ignore: + raise + else: + # Note that we're explicitly NOT using AI_PASSIVE, + # here, because we want an actual IP to touch. + # localhost won't work if we've bound to a public IP, + # but it will if we bound to '0.0.0.0' (INADDR_ANY). + for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, + socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + s = None + try: + s = socket.socket(af, socktype, proto) + # See http://groups.google.com/group/cherrypy-users/ + # browse_frm/thread/bbfe5eb39c904fe0 + s.settimeout(1.0) + s.connect((host, port)) + s.close() + except socket.error: + if s: + s.close() + if hasattr(sock, "close"): + sock.close() + self.socket = None + + self.requests.stop(self.shutdown_timeout) + + def populate_ssl_environ(self): + """Create WSGI environ entries to be merged into each request.""" + cert = open(self.ssl_certificate, 'rb').read() + cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + ssl_environ = { + "wsgi.url_scheme": "https", + "HTTPS": "on", + # pyOpenSSL doesn't provide access to any of these AFAICT +## 'SSL_PROTOCOL': 'SSLv2', +## SSL_CIPHER string The cipher specification name +## SSL_VERSION_INTERFACE string The mod_ssl program version +## SSL_VERSION_LIBRARY string The OpenSSL program version + } + + # Server certificate attributes + ssl_environ.update({ + 'SSL_SERVER_M_VERSION': cert.get_version(), + 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), +## 'SSL_SERVER_V_START': Validity of server's certificate (start time), +## 'SSL_SERVER_V_END': Validity of server's certificate (end time), + }) + + for prefix, dn in [("I", cert.get_issuer()), + ("S", cert.get_subject())]: + # X509Name objects don't seem to have a way to get the + # complete DN string. Use str() and slice it instead, + # because str(dn) == "<X509Name object '/C=US/ST=...'>" + dnstr = str(dn)[18:-2] + + wsgikey = 'SSL_SERVER_%s_DN' % prefix + ssl_environ[wsgikey] = dnstr + + # The DN should be of the form: /k1=v1/k2=v2, but we must allow + # for any value to contain slashes itself (in a URL). + while dnstr: + pos = dnstr.rfind("=") + dnstr, value = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind("/") + dnstr, key = dnstr[:pos], dnstr[pos + 1:] + if key and value: + wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + ssl_environ[wsgikey] = value + + self.environ.update(ssl_environ) + |