diff options
| author | 2011-06-27 23:45:45 +0200 | |
|---|---|---|
| committer | 2011-06-27 23:45:45 +0200 | |
| commit | 0c1a92dcfa6d9775d5d0da8ef5fc8d9cc40d77b9 (patch) | |
| tree | 0fb4ed212d1e9e8d48c30d69644a774eb2424b23 /module/lib/thrift/protocol | |
| parent | little cli improvement (diff) | |
| download | pyload-0c1a92dcfa6d9775d5d0da8ef5fc8d9cc40d77b9.tar.xz | |
thrift 0.7.0 from trunk, patched for low mem usage
Diffstat (limited to 'module/lib/thrift/protocol')
| -rw-r--r-- | module/lib/thrift/protocol/TBase.py | 298 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TCompactProtocol.py | 21 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/TProtocol.py | 199 | ||||
| -rw-r--r-- | module/lib/thrift/protocol/__init__.py | 2 | 
4 files changed, 511 insertions, 9 deletions
| diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py new file mode 100644 index 000000000..dfe0d79ce --- /dev/null +++ b/module/lib/thrift/protocol/TBase.py @@ -0,0 +1,298 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: +  from thrift.protocol import fastbinary +except: +  fastbinary = None + +class TBase(object): +  __slots__ = [] + +  def __repr__(self): +    L = ['%s=%r' % (key, getattr(self, key)) +              for key in self.__slots__ ] +    return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + +  def __eq__(self, other): +    if not isinstance(other, self.__class__): +      return False +    for attr in self.__slots__: +      my_val = getattr(self, attr) +      other_val = getattr(other, attr) +      if my_val != other_val: +        return False +    return True +     +  def __ne__(self, other): +    return not (self == other) +   +  def read(self, iprot): +    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: +      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +      return +    iprot.readStruct(self, self.thrift_spec) + +  def write(self, oprot): +    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: +      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +      return +    oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): +  # old style class so python2.4 can raise exceptions derived from this +  #  This can't inherit from TBase because of that limitation. +  __slots__ = [] +   +  __repr__ = TBase.__repr__.im_func +  __eq__ = TBase.__eq__.im_func +  __ne__ = TBase.__ne__.im_func +  read = TBase.read.im_func +  write = TBase.write.im_func + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +#   http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: +  from thrift.protocol import fastbinary +except: +  fastbinary = None + +def read(iprot, types, ftype, spec): +    try: +        return types[ftype][0]() +    except KeyError: +        if ftype == TType.LIST: +            ltype, lsize = iprot.readListBegin() + +            value = [read(iprot, types, spec[0], spec[1]) for i in range(lsize)] + +            iprot.readListEnd() +            return value +         +        elif ftype == TType.SET: +            ltype, lsize = iprot.readSetBegin() + +            value = set([read(iprot, types, spec[0], spec[1]) for i in range(lsize)]) + +            iprot.readSetEnd() +            return value + +        elif ftype == TType.MAP: +            key_type, key_spec = spec[0], spec[1] +            val_type, val_spec = spec[2], spec[3] + +            ktype, vtype, mlen = iprot.readMapBegin() +            res = dict() + +            for i in xrange(mlen): +                key = read(iprot, types, key_type, key_spec) +                res[key] = read(iprot, types, val_type, val_spec) + +            iprot.readMapEnd() +            return res + +        elif ftype == TType.STRUCT: +            return spec[0]().read(iprot) + + + + +def write(oprot, types, ftype, spec, value): +    try: +        types[ftype][1](value) +    except KeyError: +        if ftype == TType.LIST: +            oprot.writeListBegin(spec[0], len(value)) + +            for elem in value: +                write(oprot, types, spec[0], spec[1], elem) + +            oprot.writeListEnd() +        elif ftype == TType.SET: +            oprot.writeSetBegin(spec[0], len(value)) + +            for elem in value: +                write(oprot, types, spec[0], spec[1], elem) +                 +            oprot.writeSetEnd() +        elif ftype == TType.MAP: +            key_type, key_spec = spec[0], spec[1] +            val_type, val_spec = spec[2], spec[3] + +            oprot.writeMapBegin(key_type, val_type, len(value)) +            for key, val in value.iteritems(): +                write(oprot, types, key_type, key_spec, key) +                write(oprot, types, val_type, val_spec, val) + +            oprot.writeMapEnd() +        elif ftype == TType.STRUCT: +            value.write(oprot) + + +class TBase2(object): +    __slots__ = ("thrift_spec") + +    #subclasses provides this information +    thrift_spec = () + +    def __repr__(self): +        L = ['%s=%r' % (key, getattr(self, key)) +              for key in self.__slots__ ] +        return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + +    def __eq__(self, other): +        if not isinstance(other, self.__class__): +            return False +        for attr in self.__slots__: +            my_val = getattr(self, attr) +            other_val = getattr(other, attr) +            if my_val != other_val: +                return False +        return True + +    def __ne__(self, other): +        return not (self == other) + +    def read(self, iprot): +        if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: +            fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +            return + +        #local copies for faster access +        thrift_spec = self.thrift_spec +        setter = self.__setattr__ + +        iprot.readStructBegin() +        while True: +            (fname, ftype, fid) = iprot.readFieldBegin() +            if ftype == TType.STOP: +                break + +            try: +                specs = thrift_spec[fid] +                if not specs or specs[1] != ftype: +                    iprot.skip(ftype) + +                else: +                    pos, etype, ename, espec, unk = specs +                    value = read(iprot, iprot.primTypes, etype, espec) +                    setter(ename, value) +                     +            except IndexError: +                iprot.skip() + +            iprot.readFieldEnd() + +        iprot.readStructEnd() + +    def write(self, oprot): +        if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: +            oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +            return + +        #local copies for faster access +        oprot.writeStructBegin(self.__class__.__name__) +        getter = self.__getattribute__ + +        for spec in self.thrift_spec: +            if spec is None: continue +            # element attributes +            pos, etype, ename, espec, unk = spec +            value = getter(ename) +            if value is None: continue + +            oprot.writeFieldBegin(ename, etype, pos) +            write(oprot, oprot.primTypes, etype, espec, value) +            oprot.writeFieldEnd() + +        oprot.writeFieldStop() +        oprot.writeStructEnd() + +class TBase(object): +  __slots__ = ('thrift_spec',) + +  #provides by subclasses +  thrift_spec = () + +  def __repr__(self): +    L = ['%s=%r' % (key, getattr(self, key)) +              for key in self.__slots__ ] +    return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + +  def __eq__(self, other): +    if not isinstance(other, self.__class__): +      return False +    for attr in self.__slots__: +      my_val = getattr(self, attr) +      other_val = getattr(other, attr) +      if my_val != other_val: +        return False +    return True +     +  def __ne__(self, other): +    return not (self == other) +   +  def read(self, iprot): +    if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: +      fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) +      return +    iprot.readStruct(self, self.thrift_spec) + +  def write(self, oprot): +    if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: +      oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) +      return +    oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): +  # old style class so python2.4 can raise exceptions derived from this +  #  This can't inherit from TBase because of that limitation. +  __slots__ = [] +   +  __repr__ = TBase.__repr__.im_func +  __eq__ = TBase.__eq__.im_func +  __ne__ = TBase.__ne__.im_func +  read = TBase.read.im_func +  write = TBase.write.im_func + diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py index fbc156a8f..280b54f0f 100644 --- a/module/lib/thrift/protocol/TCompactProtocol.py +++ b/module/lib/thrift/protocol/TCompactProtocol.py @@ -52,8 +52,9 @@ def readVarint(trans):      shift += 7  class CompactType: -  TRUE = 1 -  FALSE = 2 +  STOP = 0x00 +  TRUE = 0x01 +  FALSE = 0x02    BYTE = 0x03    I16 = 0x04    I32 = 0x05 @@ -65,7 +66,8 @@ class CompactType:    MAP = 0x0B    STRUCT = 0x0C -CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection +CTYPES = {TType.STOP: CompactType.STOP, +          TType.BOOL: CompactType.TRUE, # used for collection            TType.BYTE: CompactType.BYTE,            TType.I16: CompactType.I16,            TType.I32: CompactType.I32, @@ -75,7 +77,7 @@ CTYPES = {TType.BOOL: CompactType.TRUE, # used for collection            TType.STRUCT: CompactType.STRUCT,            TType.LIST: CompactType.LIST,            TType.SET: CompactType.SET, -          TType.MAP: CompactType.MAP, +          TType.MAP: CompactType.MAP            }  TTYPES = {} @@ -196,11 +198,15 @@ class TCompactProtocol(TProtocolBase):    def writeBool(self, bool):      if self.state == BOOL_WRITE: -      self.__writeFieldHeader(types[bool], self.__bool_fid) +        if bool: +            ctype = CompactType.TRUE +        else: +            ctype = CompactType.FALSE +        self.__writeFieldHeader(ctype, self.__bool_fid)      elif self.state == CONTAINER_WRITE:        self.__writeByte(int(bool))      else: -      raise AssertetionError, "Invalid state in compact protocol" +      raise AssertionError, "Invalid state in compact protocol"    writeByte = writer(__writeByte)    writeI16 = writer(__writeI16) @@ -285,9 +291,8 @@ class TCompactProtocol(TProtocolBase):      return (name, type, seqid)    def readMessageEnd(self): -    assert self.state == VALUE_READ +    assert self.state == CLEAR      assert len(self.__structs) == 0 -    self.state = CLEAR    def readStructBegin(self):      assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py index be3cb1403..beb6bea16 100644 --- a/module/lib/thrift/protocol/TProtocol.py +++ b/module/lib/thrift/protocol/TProtocol.py @@ -200,6 +200,205 @@ class TProtocolBase:          self.skip(etype)        self.readListEnd() +  # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) +  _TTYPE_HANDLERS = ( +       (None, None, False), # 0 == TType,STOP +       (None, None, False), # 1 == TType.VOID # TODO: handle void? +       ('readBool', 'writeBool', False), # 2 == TType.BOOL +       ('readByte',  'writeByte', False), # 3 == TType.BYTE and I08 +       ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE +       (None, None, False), # 5, undefined +       ('readI16', 'writeI16', False), # 6 == TType.I16 +       (None, None, False), # 7, undefined +       ('readI32', 'writeI32', False), # 8 == TType.I32 +       (None, None, False), # 9, undefined +       ('readI64', 'writeI64', False), # 10 == TType.I64 +       ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 +       ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT +       ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP +       ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET +       ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST +       (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? +       (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? +      ) + +  def readFieldByTType(self, ttype, spec): +    try: +      (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] +    except IndexError: +      raise TProtocolException(type=TProtocolException.INVALID_DATA, +                               message='Invalid field type %d' % (ttype)) +    if r_handler is None: +      raise TProtocolException(type=TProtocolException.INVALID_DATA, +                               message='Invalid field type %d' % (ttype)) +    reader = getattr(self, r_handler) +    if not is_container: +      return reader() +    return reader(spec) + +  def readContainerList(self, spec): +    results = [] +    ttype, tspec = spec[0], spec[1] +    r_handler = self._TTYPE_HANDLERS[ttype][0] +    reader = getattr(self, r_handler) +    (list_type, list_len) = self.readListBegin() +    if tspec is None: +      # list values are simple types +      for idx in xrange(list_len): +        results.append(reader()) +    else: +      (elem_class, elem_spec) = tspec +      for idx in xrange(list_len): +        val = elem_class() +        val.read(self) +        results.append(val) +    self.readListEnd() +    return results + +  def readContainerSet(self, spec): +    results = set() +    ttype, tspec = spec[0], spec[1] +    r_handler = self._TTYPE_HANDLERS[ttype][0] +    reader = getattr(self, r_handler) +    (list_type, set_len) = self.readSetBegin() +    if tspec is None: +      # list values are simple types +      for idx in xrange(set_len): +        results.add(reader()) +    else: +      (elem_class, elem_spec) = tspec +      for idx in xrange(set_len): +        val = elem_class() +        val.read(self) +        results.add(val) +    self.readSetEnd() +    return results + +  def readContainerStruct(self, spec): +    (obj_class, obj_spec) = spec +    obj = obj_class() +    obj.read(self) +    return obj +   +  def readContainerMap(self, spec): +    results = dict() +    key_ttype, key_spec = spec[0], spec[1] +    val_ttype, val_spec = spec[2], spec[3] +    (map_ktype, map_vtype, map_len) = self.readMapBegin() +    # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree +    key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) +    val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) +    # list values are simple types +    for idx in xrange(map_len): +      if key_spec is None: +        k_val = key_reader() +      else: +        k_val = self.readFieldByTType(key_ttype, key_spec) +      if val_spec is None: +        v_val = val_reader() +      else: +        v_val = self.readFieldByTType(val_ttype, val_spec) +      # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails +      results[k_val] = v_val +    self.readMapEnd() +    return results + +  def readStruct(self, obj, thrift_spec): +    self.readStructBegin() +    while True: +      (fname, ftype, fid) = self.readFieldBegin() +      if ftype == TType.STOP: +        break +      try: +        field = thrift_spec[fid] +      except IndexError: +        self.skip(ftype) +      else: +        if field is not None and ftype == field[1]: +          fname = field[2] +          fspec = field[3] +          val = self.readFieldByTType(ftype, fspec) +          setattr(obj, fname, val) +        else: +          self.skip(ftype) +      self.readFieldEnd() +    self.readStructEnd() + +  def writeContainerStruct(self, val, spec): +    val.write(self) + +  def writeContainerList(self, val, spec): +    self.writeListBegin(spec[0], len(val)) +    r_handler, w_handler, is_container  = self._TTYPE_HANDLERS[spec[0]] +    e_writer = getattr(self, w_handler) +    if not is_container: +      for elem in val: +        e_writer(elem) +    else: +      for elem in val: +        e_writer(elem, spec) +    self.writeListEnd() + +  def writeContainerSet(self, val, spec): +    self.writeSetBegin(spec[0], len(val)) +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] +    e_writer = getattr(self, w_handler) +    if not is_container: +      for elem in val: +        e_writer(elem) +    else: +      for elem in val: +        e_writer(elem, spec) +    self.writeSetEnd() + +  def writeContainerMap(self, val, spec): +    k_type = spec[0] +    v_type = spec[2] +    ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] +    ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] +    k_writer = getattr(self, ktype_name) +    v_writer = getattr(self, vtype_name) +    self.writeMapBegin(k_type, v_type, len(val)) +    for m_key, m_val in val.iteritems(): +      if not k_is_container: +        k_writer(m_key) +      else: +        k_writer(m_key, spec[1]) +      if not v_is_container: +        v_writer(m_val) +      else: +        v_writer(m_val, spec[3]) +    self.writeMapEnd() + +  def writeStruct(self, obj, thrift_spec): +    self.writeStructBegin(obj.__class__.__name__) +    for field in thrift_spec: +      if field is None: +        continue +      fname = field[2] +      val = getattr(obj, fname) +      if val is None: +        # skip writing out unset fields +        continue +      fid = field[0] +      ftype = field[1] +      fspec = field[3] +      # get the writer method for this value +      self.writeFieldBegin(fname, ftype, fid) +      self.writeFieldByTType(ftype, val, fspec) +      self.writeFieldEnd() +    self.writeFieldStop() +    self.writeStructEnd() + +  def writeFieldByTType(self, ttype, val, spec): +    r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] +    writer = getattr(self, w_handler) +    if is_container: +      writer(val, spec) +    else: +      writer(val) +  class TProtocolFactory:    def getProtocol(self, trans):      pass + diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py index 01bfe18e5..d53359b28 100644 --- a/module/lib/thrift/protocol/__init__.py +++ b/module/lib/thrift/protocol/__init__.py @@ -17,4 +17,4 @@  # under the License.  # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] | 
