diff options
author | RaNaN <Mast3rRaNaN@hotmail.de> | 2011-06-27 23:45:45 +0200 |
---|---|---|
committer | RaNaN <Mast3rRaNaN@hotmail.de> | 2011-06-27 23:45:45 +0200 |
commit | 0c1a92dcfa6d9775d5d0da8ef5fc8d9cc40d77b9 (patch) | |
tree | 0fb4ed212d1e9e8d48c30d69644a774eb2424b23 /module/lib | |
parent | little cli improvement (diff) | |
download | pyload-0c1a92dcfa6d9775d5d0da8ef5fc8d9cc40d77b9.tar.xz |
thrift 0.7.0 from trunk, patched for low mem usage
Diffstat (limited to 'module/lib')
-rw-r--r-- | module/lib/thrift/Thrift.py | 19 | ||||
-rw-r--r-- | module/lib/thrift/protocol/TBase.py | 298 | ||||
-rw-r--r-- | module/lib/thrift/protocol/TCompactProtocol.py | 21 | ||||
-rw-r--r-- | module/lib/thrift/protocol/TProtocol.py | 199 | ||||
-rw-r--r-- | module/lib/thrift/protocol/__init__.py | 2 | ||||
-rw-r--r-- | module/lib/thrift/server/TProcessPoolServer.py | 125 | ||||
-rw-r--r-- | module/lib/thrift/transport/TSocket.py | 8 | ||||
-rw-r--r-- | module/lib/thrift/transport/TZlibTransport.py | 261 | ||||
-rw-r--r-- | module/lib/thrift/transport/__init__.py | 2 |
9 files changed, 921 insertions, 14 deletions
diff --git a/module/lib/thrift/Thrift.py b/module/lib/thrift/Thrift.py index 91728a776..a96351276 100644 --- a/module/lib/thrift/Thrift.py +++ b/module/lib/thrift/Thrift.py @@ -38,6 +38,25 @@ class TType: UTF8 = 16 UTF16 = 17 + _VALUES_TO_NAMES = ( 'STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16' ) + class TMessageType: CALL = 1 REPLY = 2 diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py new file mode 100644 index 000000000..dfe0d79ce --- /dev/null +++ b/module/lib/thrift/protocol/TBase.py @@ -0,0 +1,298 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: + from thrift.protocol import fastbinary +except: + fastbinary = None + +class TBase(object): + __slots__ = [] + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) + for key in self.__slots__ ] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: + fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + return + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: + oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + return + oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): + # old style class so python2.4 can raise exceptions derived from this + # This can't inherit from TBase because of that limitation. + __slots__ = [] + + __repr__ = TBase.__repr__.im_func + __eq__ = TBase.__eq__.im_func + __ne__ = TBase.__ne__.im_func + read = TBase.read.im_func + write = TBase.write.im_func + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: + from thrift.protocol import fastbinary +except: + fastbinary = None + +def read(iprot, types, ftype, spec): + try: + return types[ftype][0]() + except KeyError: + if ftype == TType.LIST: + ltype, lsize = iprot.readListBegin() + + value = [read(iprot, types, spec[0], spec[1]) for i in range(lsize)] + + iprot.readListEnd() + return value + + elif ftype == TType.SET: + ltype, lsize = iprot.readSetBegin() + + value = set([read(iprot, types, spec[0], spec[1]) for i in range(lsize)]) + + iprot.readSetEnd() + return value + + elif ftype == TType.MAP: + key_type, key_spec = spec[0], spec[1] + val_type, val_spec = spec[2], spec[3] + + ktype, vtype, mlen = iprot.readMapBegin() + res = dict() + + for i in xrange(mlen): + key = read(iprot, types, key_type, key_spec) + res[key] = read(iprot, types, val_type, val_spec) + + iprot.readMapEnd() + return res + + elif ftype == TType.STRUCT: + return spec[0]().read(iprot) + + + + +def write(oprot, types, ftype, spec, value): + try: + types[ftype][1](value) + except KeyError: + if ftype == TType.LIST: + oprot.writeListBegin(spec[0], len(value)) + + for elem in value: + write(oprot, types, spec[0], spec[1], elem) + + oprot.writeListEnd() + elif ftype == TType.SET: + oprot.writeSetBegin(spec[0], len(value)) + + for elem in value: + write(oprot, types, spec[0], spec[1], elem) + + oprot.writeSetEnd() + elif ftype == TType.MAP: + key_type, key_spec = spec[0], spec[1] + val_type, val_spec = spec[2], spec[3] + + oprot.writeMapBegin(key_type, val_type, len(value)) + for key, val in value.iteritems(): + write(oprot, types, key_type, key_spec, key) + write(oprot, types, val_type, val_spec, val) + + oprot.writeMapEnd() + elif ftype == TType.STRUCT: + value.write(oprot) + + +class TBase2(object): + __slots__ = ("thrift_spec") + + #subclasses provides this information + thrift_spec = () + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) + for key in self.__slots__ ] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: + fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + return + + #local copies for faster access + thrift_spec = self.thrift_spec + setter = self.__setattr__ + + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + + try: + specs = thrift_spec[fid] + if not specs or specs[1] != ftype: + iprot.skip(ftype) + + else: + pos, etype, ename, espec, unk = specs + value = read(iprot, iprot.primTypes, etype, espec) + setter(ename, value) + + except IndexError: + iprot.skip() + + iprot.readFieldEnd() + + iprot.readStructEnd() + + def write(self, oprot): + if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: + oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + return + + #local copies for faster access + oprot.writeStructBegin(self.__class__.__name__) + getter = self.__getattribute__ + + for spec in self.thrift_spec: + if spec is None: continue + # element attributes + pos, etype, ename, espec, unk = spec + value = getter(ename) + if value is None: continue + + oprot.writeFieldBegin(ename, etype, pos) + write(oprot, oprot.primTypes, etype, espec, value) + oprot.writeFieldEnd() + + oprot.writeFieldStop() + oprot.writeStructEnd() + +class TBase(object): + __slots__ = ('thrift_spec',) + + #provides by subclasses + thrift_spec = () + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) + for key in self.__slots__ ] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: + fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + return + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: + oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + return + oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): + # old style class so python2.4 can raise exceptions derived from this + # This can't inherit from TBase because of that limitation. + __slots__ = [] + + __repr__ = TBase.__repr__.im_func + __eq__ = TBase.__eq__.im_func + __ne__ = TBase.__ne__.im_func + read = TBase.read.im_func + write = TBase.write.im_func + diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py index fbc156a8f..280b54f0f 100644 --- a/module/lib/thrift/protocol/TCompactProtocol.py +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -52,8 +52,9 @@ def readVarint(trans): shift += 7 class CompactType: - TRUE = 1 - FALSE = 2 + STOP = 0x00 + TRUE = 0x01 + FALSE = 0x02 BYTE = 0x03 I16 = 0x04 I32 = 0x05 @@ -65,7 +66,8 @@ class CompactType: MAP = 0x0B STRUCT = 0x0C -CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection +CTYPES = {TType.STOP: CompactType.STOP, + TType.BOOL: CompactType.TRUE, # used for collection TType.BYTE: CompactType.BYTE, TType.I16: CompactType.I16, TType.I32: CompactType.I32, @@ -75,7 +77,7 @@ CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection TType.STRUCT: CompactType.STRUCT, TType.LIST: CompactType.LIST, TType.SET: CompactType.SET, - TType.MAP: CompactType.MAP, + TType.MAP: CompactType.MAP } TTYPES = {} @@ -196,11 +198,15 @@ class TCompactProtocol(TProtocolBase): def writeBool(self, bool): if self.state == BOOL_WRITE: - self.__writeFieldHeader(types[bool], self.__bool_fid) + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) elif self.state == CONTAINER_WRITE: self.__writeByte(int(bool)) else: - raise AssertetionError, "Invalid state in compact protocol" + raise AssertionError, "Invalid state in compact protocol" writeByte = writer(__writeByte) writeI16 = writer(__writeI16) @@ -285,9 +291,8 @@ class TCompactProtocol(TProtocolBase): return (name, type, seqid) def readMessageEnd(self): - assert self.state == VALUE_READ + assert self.state == CLEAR assert len(self.__structs) == 0 - self.state = CLEAR def readStructBegin(self): assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py index be3cb1403..beb6bea16 100644 --- a/module/lib/thrift/protocol/TProtocol.py +++ b/module/lib/thrift/protocol/TProtocol.py @@ -200,6 +200,205 @@ class TProtocolBase: self.skip(etype) self.readListEnd() + # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 == TType,STOP + (None, None, False), # 1 == TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 == TType.BOOL + ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE + (None, None, False), # 5, undefined + ('readI16', 'writeI16', False), # 6 == TType.I16 + (None, None, False), # 7, undefined + ('readI32', 'writeI32', False), # 8 == TType.I32 + (None, None, False), # 9, undefined + ('readI64', 'writeI64', False), # 10 == TType.I64 + ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET + ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST + (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? + (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? + ) + + def readFieldByTType(self, ttype, spec): + try: + (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] + except IndexError: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + if r_handler is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + reader = getattr(self, r_handler) + if not is_container: + return reader() + return reader(spec) + + def readContainerList(self, spec): + results = [] + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (list_type, list_len) = self.readListBegin() + if tspec is None: + # list values are simple types + for idx in xrange(list_len): + results.append(reader()) + else: + (elem_class, elem_spec) = tspec + for idx in xrange(list_len): + val = elem_class() + val.read(self) + results.append(val) + self.readListEnd() + return results + + def readContainerSet(self, spec): + results = set() + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (list_type, set_len) = self.readSetBegin() + if tspec is None: + # list values are simple types + for idx in xrange(set_len): + results.add(reader()) + else: + (elem_class, elem_spec) = tspec + for idx in xrange(set_len): + val = elem_class() + val.read(self) + results.add(val) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + results = dict() + key_ttype, key_spec = spec[0], spec[1] + val_ttype, val_spec = spec[2], spec[3] + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree + key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) + val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) + # list values are simple types + for idx in xrange(map_len): + if key_spec is None: + k_val = key_reader() + else: + k_val = self.readFieldByTType(key_ttype, key_spec) + if val_spec is None: + v_val = val_reader() + else: + v_val = self.readFieldByTType(val_ttype, val_spec) + # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails + results[k_val] = v_val + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec): + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + self.writeListBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec) + self.writeListEnd() + + def writeContainerSet(self, val, spec): + self.writeSetBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec) + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + k_type = spec[0] + v_type = spec[2] + ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] + ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] + k_writer = getattr(self, ktype_name) + v_writer = getattr(self, vtype_name) + self.writeMapBegin(k_type, v_type, len(val)) + for m_key, m_val in val.iteritems(): + if not k_is_container: + k_writer(m_key) + else: + k_writer(m_key, spec[1]) + if not v_is_container: + v_writer(m_val) + else: + v_writer(m_val, spec[3]) + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + # get the writer method for this value + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def writeFieldByTType(self, ttype, val, spec): + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] + writer = getattr(self, w_handler) + if is_container: + writer(val, spec) + else: + writer(val) + class TProtocolFactory: def getProtocol(self, trans): pass + diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py index 01bfe18e5..d53359b28 100644 --- a/module/lib/thrift/protocol/__init__.py +++ b/module/lib/thrift/protocol/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] diff --git a/module/lib/thrift/server/TProcessPoolServer.py b/module/lib/thrift/server/TProcessPoolServer.py new file mode 100644 index 000000000..7ed814a88 --- /dev/null +++ b/module/lib/thrift/server/TProcessPoolServer.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import logging +from multiprocessing import Process, Value, Condition, reduction + +from TServer import TServer +from thrift.transport.TTransport import TTransportException + +class TProcessPoolServer(TServer): + + """ + Server with a fixed size pool of worker subprocesses which service requests. + Note that if you need shared state between the handlers - it's up to you! + Written by Dvir Volk, doat.com + """ + + def __init__(self, * args): + TServer.__init__(self, *args) + self.numWorkers = 10 + self.workers = [] + self.isRunning = Value('b', False) + self.stopCondition = Condition() + self.postForkCallback = None + + def setPostForkCallback(self, callback): + if not callable(callback): + raise TypeError("This is not a callback!") + self.postForkCallback = callback + + def setNumWorkers(self, num): + """Set the number of worker threads that should be created""" + self.numWorkers = num + + def workerProcess(self): + """Loop around getting clients from the shared queue and process them.""" + + if self.postForkCallback: + self.postForkCallback() + + while self.isRunning.value == True: + try: + client = self.serverTransport.accept() + self.serveClient(client) + except (KeyboardInterrupt, SystemExit): + return 0 + except Exception, x: + logging.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + + #this is a shared state that can tell the workers to exit when set as false + self.isRunning.value = True + + #first bind and listen to the port + self.serverTransport.listen() + + #fork the children + for i in range(self.numWorkers): + try: + w = Process(target=self.workerProcess) + w.daemon = True + w.start() + self.workers.append(w) + except Exception, x: + logging.exception(x) + + #wait until the condition is set by stop() + + while True: + + self.stopCondition.acquire() + try: + self.stopCondition.wait() + break + except (SystemExit, KeyboardInterrupt): + break + except Exception, x: + logging.exception(x) + + self.isRunning.value = False + + def stop(self): + self.isRunning.value = False + self.stopCondition.acquire() + self.stopCondition.notify() + self.stopCondition.release() + diff --git a/module/lib/thrift/transport/TSocket.py b/module/lib/thrift/transport/TSocket.py index d77e358a2..be6167802 100644 --- a/module/lib/thrift/transport/TSocket.py +++ b/module/lib/thrift/transport/TSocket.py @@ -57,7 +57,7 @@ class TSocket(TSocketBase): self.handle = h def isOpen(self): - return self.handle != None + return self.handle is not None def setTimeout(self, ms): if ms is None: @@ -65,7 +65,7 @@ class TSocket(TSocketBase): else: self._timeout = ms/1000.0 - if (self.handle != None): + if self.handle is not None: self.handle.settimeout(self._timeout) def open(self): @@ -126,8 +126,8 @@ class TSocket(TSocketBase): class TServerSocket(TSocketBase, TServerTransportBase): """Socket implementation of TServerTransport base.""" - def __init__(self, port=9090, unix_socket=None): - self.host = None + def __init__(self, host=None, port=9090, unix_socket=None): + self.host = host self.port = port self._unix_socket = unix_socket self.handle = None diff --git a/module/lib/thrift/transport/TZlibTransport.py b/module/lib/thrift/transport/TZlibTransport.py new file mode 100644 index 000000000..784d4e1e0 --- /dev/null +++ b/module/lib/thrift/transport/TZlibTransport.py @@ -0,0 +1,261 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +''' +TZlibTransport provides a compressed transport and transport factory +class, using the python standard library zlib module to implement +data compression. +''' + +from __future__ import division +import zlib +from cStringIO import StringIO +from TTransport import TTransportBase, CReadableTransport + +class TZlibTransportFactory(object): + ''' + Factory transport that builds zlib compressed transports. + + This factory caches the last single client/transport that it was passed + and returns the same TZlibTransport object that was created. + + This caching means the TServer class will get the _same_ transport + object for both input and output transports from this factory. + (For non-threaded scenarios only, since the cache only holds one object) + + The purpose of this caching is to allocate only one TZlibTransport where + only one is really needed (since it must have separate read/write buffers), + and makes the statistics from getCompSavings() and getCompRatio() + easier to understand. + ''' + + # class scoped cache of last transport given and zlibtransport returned + _last_trans = None + _last_z = None + + def getTransport(self, trans, compresslevel=9): + '''Wrap a transport , trans, with the TZlibTransport + compressed transport class, returning a new + transport to the caller. + + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Defaults to 9. + @type compresslevel: int + + This method returns a TZlibTransport which wraps the + passed C{trans} TTransport derived instance. + ''' + if trans == self._last_trans: + return self._last_z + ztrans = TZlibTransport(trans, compresslevel) + self._last_trans = trans + self._last_z = ztrans + return ztrans + + +class TZlibTransport(TTransportBase, CReadableTransport): + ''' + Class that wraps a transport with zlib, compressing writes + and decompresses reads, using the python standard + library zlib module. + ''' + + # Read buffer size for the python fastbinary C extension, + # the TBinaryProtocolAccelerated class. + DEFAULT_BUFFSIZE = 4096 + + def __init__(self, trans, compresslevel=9): + ''' + Create a new TZlibTransport, wrapping C{trans}, another + TTransport derived object. + + @param trans: A thrift transport object, i.e. a TSocket() object. + @type trans: TTransport + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Default is 9. + @type compresslevel: int + ''' + self.__trans = trans + self.compresslevel = compresslevel + self.__rbuf = StringIO() + self.__wbuf = StringIO() + self._init_zlib() + self._init_stats() + + def _reinit_buffers(self): + ''' + Internal method to initialize/reset the internal StringIO objects + for read and write buffers. + ''' + self.__rbuf = StringIO() + self.__wbuf = StringIO() + + def _init_stats(self): + ''' + Internal method to reset the internal statistics counters + for compression ratios and bandwidth savings. + ''' + self.bytes_in = 0 + self.bytes_out = 0 + self.bytes_in_comp = 0 + self.bytes_out_comp = 0 + + def _init_zlib(self): + ''' + Internal method for setting up the zlib compression and + decompression objects. + ''' + self._zcomp_read = zlib.decompressobj() + self._zcomp_write = zlib.compressobj(self.compresslevel) + + def getCompRatio(self): + ''' + Get the current measured compression ratios (in,out) from + this transport. + + Returns a tuple of: + (inbound_compression_ratio, outbound_compression_ratio) + + The compression ratios are computed as: + compressed / uncompressed + + E.g., data that compresses by 10x will have a ratio of: 0.10 + and data that compresses to half of ts original size will + have a ratio of 0.5 + + None is returned if no bytes have yet been processed in + a particular direction. + ''' + r_percent, w_percent = (None, None) + if self.bytes_in > 0: + r_percent = self.bytes_in_comp / self.bytes_in + if self.bytes_out > 0: + w_percent = self.bytes_out_comp / self.bytes_out + return (r_percent, w_percent) + + def getCompSavings(self): + ''' + Get the current count of saved bytes due to data + compression. + + Returns a tuple of: + (inbound_saved_bytes, outbound_saved_bytes) + + Note: if compression is actually expanding your + data (only likely with very tiny thrift objects), then + the values returned will be negative. + ''' + r_saved = self.bytes_in - self.bytes_in_comp + w_saved = self.bytes_out - self.bytes_out_comp + return (r_saved, w_saved) + + def isOpen(self): + '''Return the underlying transport's open status''' + return self.__trans.isOpen() + + def open(self): + """Open the underlying transport""" + self._init_stats() + return self.__trans.open() + + def listen(self): + '''Invoke the underlying transport's listen() method''' + self.__trans.listen() + + def accept(self): + '''Accept connections on the underlying transport''' + return self.__trans.accept() + + def close(self): + '''Close the underlying transport,''' + self._reinit_buffers() + self._init_zlib() + return self.__trans.close() + + def read(self, sz): + ''' + Read up to sz bytes from the decompressed bytes buffer, and + read from the underlying transport if the decompression + buffer is empty. + ''' + ret = self.__rbuf.read(sz) + if len(ret) > 0: + return ret + # keep reading from transport until something comes back + while True: + if self.readComp(sz): + break + ret = self.__rbuf.read(sz) + return ret + + def readComp(self, sz): + ''' + Read compressed data from the underlying transport, then + decompress it and append it to the internal StringIO read buffer + ''' + zbuf = self.__trans.read(sz) + zbuf = self._zcomp_read.unconsumed_tail + zbuf + buf = self._zcomp_read.decompress(zbuf) + self.bytes_in += len(zbuf) + self.bytes_in_comp += len(buf) + old = self.__rbuf.read() + self.__rbuf = StringIO(old + buf) + if len(old) + len(buf) == 0: + return False + return True + + def write(self, buf): + ''' + Write some bytes, putting them into the internal write + buffer for eventual compression. + ''' + self.__wbuf.write(buf) + + def flush(self): + ''' + Flush any queued up data in the write buffer and ensure the + compression buffer is flushed out to the underlying transport + ''' + wout = self.__wbuf.getvalue() + if len(wout) > 0: + zbuf = self._zcomp_write.compress(wout) + self.bytes_out += len(wout) + self.bytes_out_comp += len(zbuf) + else: + zbuf = '' + ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) + self.bytes_out_comp += len(ztail) + if (len(zbuf) + len(ztail)) > 0: + self.__wbuf = StringIO() + self.__trans.write(zbuf + ztail) + self.__trans.flush() + + @property + def cstringio_buf(self): + '''Implement the CReadableTransport interface''' + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + '''Implement the CReadableTransport interface for refill''' + retstring = partialread + if reqlen < self.DEFAULT_BUFFSIZE: + retstring += self.read(self.DEFAULT_BUFFSIZE) + while len(retstring) < reqlen: + retstring += self.read(reqlen - len(retstring)) + self.__rbuf = StringIO(retstring) + return self.__rbuf diff --git a/module/lib/thrift/transport/__init__.py b/module/lib/thrift/transport/__init__.py index 02c6048a9..46e54fe6b 100644 --- a/module/lib/thrift/transport/__init__.py +++ b/module/lib/thrift/transport/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TTransport', 'TSocket', 'THttpClient'] +__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] |