summaryrefslogtreecommitdiffstats
path: root/module
diff options
context:
space:
mode:
Diffstat (limited to 'module')
-rw-r--r--module/lib/thrift/TSCons.py8
-rw-r--r--module/lib/thrift/TSerialization.py10
-rw-r--r--module/lib/thrift/TTornado.py153
-rw-r--r--module/lib/thrift/Thrift.py56
-rw-r--r--module/lib/thrift/protocol/TBase.py27
-rw-r--r--module/lib/thrift/protocol/TBinaryProtocol.py13
-rw-r--r--module/lib/thrift/protocol/TCompactProtocol.py34
-rw-r--r--module/lib/thrift/protocol/TJSONProtocol.py550
-rw-r--r--module/lib/thrift/protocol/TProtocol.py102
-rw-r--r--module/lib/thrift/protocol/__init__.py2
-rw-r--r--module/lib/thrift/protocol/fastbinary.c1219
-rw-r--r--module/lib/thrift/server/THttpServer.py21
-rw-r--r--module/lib/thrift/server/TNonblockingServer.py120
-rw-r--r--module/lib/thrift/server/TProcessPoolServer.py29
-rw-r--r--module/lib/thrift/server/TServer.py37
-rw-r--r--module/lib/thrift/transport/THttpClient.py47
-rw-r--r--module/lib/thrift/transport/TSSLSocket.py214
-rw-r--r--module/lib/thrift/transport/TSocket.py33
-rw-r--r--module/lib/thrift/transport/TTransport.py25
-rw-r--r--module/lib/thrift/transport/TTwisted.py6
-rw-r--r--module/lib/thrift/transport/TZlibTransport.py109
-rw-r--r--module/lib/thrift/transport/__init__.py2
22 files changed, 2524 insertions, 293 deletions
diff --git a/module/lib/thrift/TSCons.py b/module/lib/thrift/TSCons.py
index 24046256c..da8d2833b 100644
--- a/module/lib/thrift/TSCons.py
+++ b/module/lib/thrift/TSCons.py
@@ -20,14 +20,16 @@
from os import path
from SCons.Builder import Builder
+
def scons_env(env, add=''):
opath = path.dirname(path.abspath('$TARGET'))
lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE'
- cppbuild = Builder(action = lstr)
- env.Append(BUILDERS = {'ThriftCpp' : cppbuild})
+ cppbuild = Builder(action=lstr)
+ env.Append(BUILDERS={'ThriftCpp': cppbuild})
+
def gen_cpp(env, dir, file):
scons_env(env)
suffixes = ['_types.h', '_types.cpp']
targets = map(lambda s: 'gen-cpp/' + file + s, suffixes)
- return env.ThriftCpp(targets, dir+file+'.thrift')
+ return env.ThriftCpp(targets, dir + file + '.thrift')
diff --git a/module/lib/thrift/TSerialization.py b/module/lib/thrift/TSerialization.py
index b19f98aa8..8a58d89df 100644
--- a/module/lib/thrift/TSerialization.py
+++ b/module/lib/thrift/TSerialization.py
@@ -20,15 +20,19 @@
from protocol import TBinaryProtocol
from transport import TTransport
-def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()):
+
+def serialize(thrift_object,
+ protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer()
protocol = protocol_factory.getProtocol(transport)
thrift_object.write(protocol)
return transport.getvalue()
-def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()):
+
+def deserialize(base,
+ buf,
+ protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()):
transport = TTransport.TMemoryBuffer(buf)
protocol = protocol_factory.getProtocol(transport)
base.read(protocol)
return base
-
diff --git a/module/lib/thrift/TTornado.py b/module/lib/thrift/TTornado.py
new file mode 100644
index 000000000..af309c3d9
--- /dev/null
+++ b/module/lib/thrift/TTornado.py
@@ -0,0 +1,153 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from cStringIO import StringIO
+import logging
+import socket
+import struct
+
+from thrift.transport import TTransport
+from thrift.transport.TTransport import TTransportException
+
+from tornado import gen
+from tornado import iostream
+from tornado import netutil
+
+
+class TTornadoStreamTransport(TTransport.TTransportBase):
+ """a framed, buffered transport over a Tornado stream"""
+ def __init__(self, host, port, stream=None):
+ self.host = host
+ self.port = port
+ self.is_queuing_reads = False
+ self.read_queue = []
+ self.__wbuf = StringIO()
+
+ # servers provide a ready-to-go stream
+ self.stream = stream
+ if self.stream is not None:
+ self._set_close_callback()
+
+ # not the same number of parameters as TTransportBase.open
+ def open(self, callback):
+ logging.debug('socket connecting')
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
+ self.stream = iostream.IOStream(sock)
+
+ def on_close_in_connect(*_):
+ message = 'could not connect to {}:{}'.format(self.host, self.port)
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message=message)
+ self.stream.set_close_callback(on_close_in_connect)
+
+ def finish(*_):
+ self._set_close_callback()
+ callback()
+
+ self.stream.connect((self.host, self.port), callback=finish)
+
+ def _set_close_callback(self):
+ def on_close():
+ raise TTransportException(
+ type=TTransportException.END_OF_FILE,
+ message='socket closed')
+ self.stream.set_close_callback(self.close)
+
+ def close(self):
+ # don't raise if we intend to close
+ self.stream.set_close_callback(None)
+ self.stream.close()
+
+ def read(self, _):
+ # The generated code for Tornado shouldn't do individual reads -- only
+ # frames at a time
+ assert "you're doing it wrong" is True
+
+ @gen.engine
+ def readFrame(self, callback):
+ self.read_queue.append(callback)
+ logging.debug('read queue: %s', self.read_queue)
+
+ if self.is_queuing_reads:
+ # If a read is already in flight, then the while loop below should
+ # pull it from self.read_queue
+ return
+
+ self.is_queuing_reads = True
+ while self.read_queue:
+ next_callback = self.read_queue.pop()
+ result = yield gen.Task(self._readFrameFromStream)
+ next_callback(result)
+ self.is_queuing_reads = False
+
+ @gen.engine
+ def _readFrameFromStream(self, callback):
+ logging.debug('_readFrameFromStream')
+ frame_header = yield gen.Task(self.stream.read_bytes, 4)
+ frame_length, = struct.unpack('!i', frame_header)
+ logging.debug('received frame header, frame length = %i', frame_length)
+ frame = yield gen.Task(self.stream.read_bytes, frame_length)
+ logging.debug('received frame payload')
+ callback(frame)
+
+ def write(self, buf):
+ self.__wbuf.write(buf)
+
+ def flush(self, callback=None):
+ wout = self.__wbuf.getvalue()
+ wsz = len(wout)
+ # reset wbuf before write/flush to preserve state on underlying failure
+ self.__wbuf = StringIO()
+ # N.B.: Doing this string concatenation is WAY cheaper than making
+ # two separate calls to the underlying socket object. Socket writes in
+ # Python turn out to be REALLY expensive, but it seems to do a pretty
+ # good job of managing string buffer operations without excessive copies
+ buf = struct.pack("!i", wsz) + wout
+
+ logging.debug('writing frame length = %i', wsz)
+ self.stream.write(buf, callback)
+
+
+class TTornadoServer(netutil.TCPServer):
+ def __init__(self, processor, iprot_factory, oprot_factory=None,
+ *args, **kwargs):
+ super(TTornadoServer, self).__init__(*args, **kwargs)
+
+ self._processor = processor
+ self._iprot_factory = iprot_factory
+ self._oprot_factory = (oprot_factory if oprot_factory is not None
+ else iprot_factory)
+
+ def handle_stream(self, stream, address):
+ try:
+ host, port = address
+ trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
+ oprot = self._oprot_factory.getProtocol(trans)
+
+ def next_pass():
+ if not trans.stream.closed():
+ self._processor.process(trans, self._iprot_factory, oprot,
+ callback=next_pass)
+
+ next_pass()
+
+ except Exception:
+ logging.exception('thrift exception in handle_stream')
+ trans.close()
diff --git a/module/lib/thrift/Thrift.py b/module/lib/thrift/Thrift.py
index 1d271fcff..9890af7e1 100644
--- a/module/lib/thrift/Thrift.py
+++ b/module/lib/thrift/Thrift.py
@@ -19,6 +19,7 @@
import sys
+
class TType:
STOP = 0
VOID = 1
@@ -38,7 +39,7 @@ class TType:
UTF8 = 16
UTF16 = 17
- _VALUES_TO_NAMES = ( 'STOP',
+ _VALUES_TO_NAMES = ('STOP',
'VOID',
'BOOL',
'BYTE',
@@ -48,46 +49,48 @@ class TType:
None,
'I32',
None,
- 'I64',
- 'STRING',
- 'STRUCT',
- 'MAP',
- 'SET',
- 'LIST',
- 'UTF8',
- 'UTF16' )
+ 'I64',
+ 'STRING',
+ 'STRUCT',
+ 'MAP',
+ 'SET',
+ 'LIST',
+ 'UTF8',
+ 'UTF16')
+
class TMessageType:
- CALL = 1
+ CALL = 1
REPLY = 2
EXCEPTION = 3
ONEWAY = 4
-class TProcessor:
+class TProcessor:
"""Base class for procsessor, which works on two streams."""
def process(iprot, oprot):
pass
-class TException(Exception):
+class TException(Exception):
"""Base class for all thrift exceptions."""
# BaseException.message is deprecated in Python v[2.6,3.0)
- if (2,6,0) <= sys.version_info < (3,0):
+ if (2, 6, 0) <= sys.version_info < (3, 0):
def _get_message(self):
- return self._message
+ return self._message
+
def _set_message(self, message):
- self._message = message
+ self._message = message
message = property(_get_message, _set_message)
def __init__(self, message=None):
Exception.__init__(self, message)
self.message = message
-class TApplicationException(TException):
+class TApplicationException(TException):
"""Application level thrift exceptions."""
UNKNOWN = 0
@@ -98,6 +101,9 @@ class TApplicationException(TException):
MISSING_RESULT = 5
INTERNAL_ERROR = 6
PROTOCOL_ERROR = 7
+ INVALID_TRANSFORM = 8
+ INVALID_PROTOCOL = 9
+ UNSUPPORTED_CLIENT_TYPE = 10
def __init__(self, type=UNKNOWN, message=None):
TException.__init__(self, message)
@@ -116,6 +122,16 @@ class TApplicationException(TException):
return 'Bad sequence ID'
elif self.type == self.MISSING_RESULT:
return 'Missing result'
+ elif self.type == self.INTERNAL_ERROR:
+ return 'Internal error'
+ elif self.type == self.PROTOCOL_ERROR:
+ return 'Protocol error'
+ elif self.type == self.INVALID_TRANSFORM:
+ return 'Invalid transform'
+ elif self.type == self.INVALID_PROTOCOL:
+ return 'Invalid protocol'
+ elif self.type == self.UNSUPPORTED_CLIENT_TYPE:
+ return 'Unsupported client type'
else:
return 'Default (unknown) TApplicationException'
@@ -127,12 +143,12 @@ class TApplicationException(TException):
break
if fid == 1:
if ftype == TType.STRING:
- self.message = iprot.readString();
+ self.message = iprot.readString()
else:
iprot.skip(ftype)
elif fid == 2:
if ftype == TType.I32:
- self.type = iprot.readI32();
+ self.type = iprot.readI32()
else:
iprot.skip(ftype)
else:
@@ -142,11 +158,11 @@ class TApplicationException(TException):
def write(self, oprot):
oprot.writeStructBegin('TApplicationException')
- if self.message != None:
+ if self.message is not None:
oprot.writeFieldBegin('message', TType.STRING, 1)
oprot.writeString(self.message)
oprot.writeFieldEnd()
- if self.type != None:
+ if self.type is not None:
oprot.writeFieldBegin('type', TType.I32, 2)
oprot.writeI32(self.type)
oprot.writeFieldEnd()
diff --git a/module/lib/thrift/protocol/TBase.py b/module/lib/thrift/protocol/TBase.py
index e675c7dc0..6cbd5f39a 100644
--- a/module/lib/thrift/protocol/TBase.py
+++ b/module/lib/thrift/protocol/TBase.py
@@ -26,12 +26,13 @@ try:
except:
fastbinary = None
+
class TBase(object):
__slots__ = []
def __repr__(self):
L = ['%s=%r' % (key, getattr(self, key))
- for key in self.__slots__ ]
+ for key in self.__slots__]
return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
def __eq__(self, other):
@@ -43,30 +44,38 @@ class TBase(object):
if my_val != other_val:
return False
return True
-
+
def __ne__(self, other):
return not (self == other)
-
+
def read(self, iprot):
- if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None:
- fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))
+ if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ isinstance(iprot.trans, TTransport.CReadableTransport) and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ fastbinary.decode_binary(self,
+ iprot.trans,
+ (self.__class__, self.thrift_spec))
return
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot):
- if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None:
- oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
+ if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and
+ self.thrift_spec is not None and
+ fastbinary is not None):
+ oprot.trans.write(
+ fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
return
oprot.writeStruct(self, self.thrift_spec)
+
class TExceptionBase(Exception):
# old style class so python2.4 can raise exceptions derived from this
# This can't inherit from TBase because of that limitation.
__slots__ = []
-
+
__repr__ = TBase.__repr__.im_func
__eq__ = TBase.__eq__.im_func
__ne__ = TBase.__ne__.im_func
read = TBase.read.im_func
write = TBase.write.im_func
-
diff --git a/module/lib/thrift/protocol/TBinaryProtocol.py b/module/lib/thrift/protocol/TBinaryProtocol.py
index 50c6aa896..6fdd08c26 100644
--- a/module/lib/thrift/protocol/TBinaryProtocol.py
+++ b/module/lib/thrift/protocol/TBinaryProtocol.py
@@ -20,8 +20,8 @@
from TProtocol import *
from struct import pack, unpack
-class TBinaryProtocol(TProtocolBase):
+class TBinaryProtocol(TProtocolBase):
"""Binary implementation of the Thrift protocol driver."""
# NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be
@@ -68,7 +68,7 @@ class TBinaryProtocol(TProtocolBase):
pass
def writeFieldStop(self):
- self.writeByte(TType.STOP);
+ self.writeByte(TType.STOP)
def writeMapBegin(self, ktype, vtype, size):
self.writeByte(ktype)
@@ -127,13 +127,16 @@ class TBinaryProtocol(TProtocolBase):
if sz < 0:
version = sz & TBinaryProtocol.VERSION_MASK
if version != TBinaryProtocol.VERSION_1:
- raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz))
+ raise TProtocolException(
+ type=TProtocolException.BAD_VERSION,
+ message='Bad version in readMessageBegin: %d' % (sz))
type = sz & TBinaryProtocol.TYPE_MASK
name = self.readString()
seqid = self.readI32()
else:
if self.strictRead:
- raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header')
+ raise TProtocolException(type=TProtocolException.BAD_VERSION,
+ message='No protocol version header')
name = self.trans.readAll(sz)
type = self.readByte()
seqid = self.readI32()
@@ -231,7 +234,6 @@ class TBinaryProtocolFactory:
class TBinaryProtocolAccelerated(TBinaryProtocol):
-
"""C-Accelerated version of TBinaryProtocol.
This class does not override any of TBinaryProtocol's methods,
@@ -250,7 +252,6 @@ class TBinaryProtocolAccelerated(TBinaryProtocol):
Please feel free to report bugs and/or success stories
to the public mailing list.
"""
-
pass
diff --git a/module/lib/thrift/protocol/TCompactProtocol.py b/module/lib/thrift/protocol/TCompactProtocol.py
index 016a33171..cdec60773 100644
--- a/module/lib/thrift/protocol/TCompactProtocol.py
+++ b/module/lib/thrift/protocol/TCompactProtocol.py
@@ -32,6 +32,7 @@ CONTAINER_READ = 6
VALUE_READ = 7
BOOL_READ = 8
+
def make_helper(v_from, container):
def helper(func):
def nested(self, *args, **kwargs):
@@ -42,12 +43,15 @@ def make_helper(v_from, container):
writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
reader = make_helper(VALUE_READ, CONTAINER_READ)
+
def makeZigZag(n, bits):
return (n << 1) ^ (n >> (bits - 1))
+
def fromZigZag(n):
return (n >> 1) ^ -(n & 1)
+
def writeVarint(trans, n):
out = []
while True:
@@ -59,6 +63,7 @@ def writeVarint(trans, n):
n = n >> 7
trans.write(''.join(map(chr, out)))
+
def readVarint(trans):
result = 0
shift = 0
@@ -70,6 +75,7 @@ def readVarint(trans):
return result
shift += 7
+
class CompactType:
STOP = 0x00
TRUE = 0x01
@@ -86,7 +92,7 @@ class CompactType:
STRUCT = 0x0C
CTYPES = {TType.STOP: CompactType.STOP,
- TType.BOOL: CompactType.TRUE, # used for collection
+ TType.BOOL: CompactType.TRUE, # used for collection
TType.BYTE: CompactType.BYTE,
TType.I16: CompactType.I16,
TType.I32: CompactType.I32,
@@ -106,8 +112,9 @@ TTYPES[CompactType.FALSE] = TType.BOOL
del k
del v
+
class TCompactProtocol(TProtocolBase):
- "Compact implementation of the Thrift protocol driver."
+ """Compact implementation of the Thrift protocol driver."""
PROTOCOL_ID = 0x82
VERSION = 1
@@ -217,18 +224,18 @@ class TCompactProtocol(TProtocolBase):
def writeBool(self, bool):
if self.state == BOOL_WRITE:
- if bool:
- ctype = CompactType.TRUE
- else:
- ctype = CompactType.FALSE
- self.__writeFieldHeader(ctype, self.__bool_fid)
+ if bool:
+ ctype = CompactType.TRUE
+ else:
+ ctype = CompactType.FALSE
+ self.__writeFieldHeader(ctype, self.__bool_fid)
elif self.state == CONTAINER_WRITE:
- if bool:
- self.__writeByte(CompactType.TRUE)
- else:
- self.__writeByte(CompactType.FALSE)
+ if bool:
+ self.__writeByte(CompactType.TRUE)
+ else:
+ self.__writeByte(CompactType.FALSE)
else:
- raise AssertionError, "Invalid state in compact protocol"
+ raise AssertionError("Invalid state in compact protocol")
writeByte = writer(__writeByte)
writeI16 = writer(__writeI16)
@@ -364,7 +371,8 @@ class TCompactProtocol(TProtocolBase):
elif self.state == CONTAINER_READ:
return self.__readByte() == CompactType.TRUE
else:
- raise AssertionError, "Invalid state in compact protocol: %d" % self.state
+ raise AssertionError("Invalid state in compact protocol: %d" %
+ self.state)
readByte = reader(__readByte)
__readI16 = __readZigZag
diff --git a/module/lib/thrift/protocol/TJSONProtocol.py b/module/lib/thrift/protocol/TJSONProtocol.py
new file mode 100644
index 000000000..3048197d4
--- /dev/null
+++ b/module/lib/thrift/protocol/TJSONProtocol.py
@@ -0,0 +1,550 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from TProtocol import TType, TProtocolBase, TProtocolException
+import base64
+import json
+import math
+
+__all__ = ['TJSONProtocol',
+ 'TJSONProtocolFactory',
+ 'TSimpleJSONProtocol',
+ 'TSimpleJSONProtocolFactory']
+
+VERSION = 1
+
+COMMA = ','
+COLON = ':'
+LBRACE = '{'
+RBRACE = '}'
+LBRACKET = '['
+RBRACKET = ']'
+QUOTE = '"'
+BACKSLASH = '\\'
+ZERO = '0'
+
+ESCSEQ = '\\u00'
+ESCAPE_CHAR = '"\\bfnrt'
+ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t']
+NUMERIC_CHAR = '+-.0123456789Ee'
+
+CTYPES = {TType.BOOL: 'tf',
+ TType.BYTE: 'i8',
+ TType.I16: 'i16',
+ TType.I32: 'i32',
+ TType.I64: 'i64',
+ TType.DOUBLE: 'dbl',
+ TType.STRING: 'str',
+ TType.STRUCT: 'rec',
+ TType.LIST: 'lst',
+ TType.SET: 'set',
+ TType.MAP: 'map'}
+
+JTYPES = {}
+for key in CTYPES.keys():
+ JTYPES[CTYPES[key]] = key
+
+
+class JSONBaseContext(object):
+
+ def __init__(self, protocol):
+ self.protocol = protocol
+ self.first = True
+
+ def doIO(self, function):
+ pass
+
+ def write(self):
+ pass
+
+ def read(self):
+ pass
+
+ def escapeNum(self):
+ return False
+
+ def __str__(self):
+ return self.__class__.__name__
+
+
+class JSONListContext(JSONBaseContext):
+
+ def doIO(self, function):
+ if self.first is True:
+ self.first = False
+ else:
+ function(COMMA)
+
+ def write(self):
+ self.doIO(self.protocol.trans.write)
+
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
+
+
+class JSONPairContext(JSONBaseContext):
+
+ def __init__(self, protocol):
+ super(JSONPairContext, self).__init__(protocol)
+ self.colon = True
+
+ def doIO(self, function):
+ if self.first:
+ self.first = False
+ self.colon = True
+ else:
+ function(COLON if self.colon else COMMA)
+ self.colon = not self.colon
+
+ def write(self):
+ self.doIO(self.protocol.trans.write)
+
+ def read(self):
+ self.doIO(self.protocol.readJSONSyntaxChar)
+
+ def escapeNum(self):
+ return self.colon
+
+ def __str__(self):
+ return '%s, colon=%s' % (self.__class__.__name__, self.colon)
+
+
+class LookaheadReader():
+ hasData = False
+ data = ''
+
+ def __init__(self, protocol):
+ self.protocol = protocol
+
+ def read(self):
+ if self.hasData is True:
+ self.hasData = False
+ else:
+ self.data = self.protocol.trans.read(1)
+ return self.data
+
+ def peek(self):
+ if self.hasData is False:
+ self.data = self.protocol.trans.read(1)
+ self.hasData = True
+ return self.data
+
+class TJSONProtocolBase(TProtocolBase):
+
+ def __init__(self, trans):
+ TProtocolBase.__init__(self, trans)
+ self.resetWriteContext()
+ self.resetReadContext()
+
+ def resetWriteContext(self):
+ self.context = JSONBaseContext(self)
+ self.contextStack = [self.context]
+
+ def resetReadContext(self):
+ self.resetWriteContext()
+ self.reader = LookaheadReader(self)
+
+ def pushContext(self, ctx):
+ self.contextStack.append(ctx)
+ self.context = ctx
+
+ def popContext(self):
+ self.contextStack.pop()
+ if self.contextStack:
+ self.context = self.contextStack[-1]
+ else:
+ self.context = JSONBaseContext(self)
+
+ def writeJSONString(self, string):
+ self.context.write()
+ self.trans.write(json.dumps(string))
+
+ def writeJSONNumber(self, number):
+ self.context.write()
+ jsNumber = str(number)
+ if self.context.escapeNum():
+ jsNumber = "%s%s%s" % (QUOTE, jsNumber, QUOTE)
+ self.trans.write(jsNumber)
+
+ def writeJSONBase64(self, binary):
+ self.context.write()
+ self.trans.write(QUOTE)
+ self.trans.write(base64.b64encode(binary))
+ self.trans.write(QUOTE)
+
+ def writeJSONObjectStart(self):
+ self.context.write()
+ self.trans.write(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def writeJSONObjectEnd(self):
+ self.popContext()
+ self.trans.write(RBRACE)
+
+ def writeJSONArrayStart(self):
+ self.context.write()
+ self.trans.write(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def writeJSONArrayEnd(self):
+ self.popContext()
+ self.trans.write(RBRACKET)
+
+ def readJSONSyntaxChar(self, character):
+ current = self.reader.read()
+ if character != current:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Unexpected character: %s" % current)
+
+ def readJSONString(self, skipContext):
+ string = []
+ if skipContext is False:
+ self.context.read()
+ self.readJSONSyntaxChar(QUOTE)
+ while True:
+ character = self.reader.read()
+ if character == QUOTE:
+ break
+ if character == ESCSEQ[0]:
+ character = self.reader.read()
+ if character == ESCSEQ[1]:
+ self.readJSONSyntaxChar(ZERO)
+ self.readJSONSyntaxChar(ZERO)
+ character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2))
+ else:
+ off = ESCAPE_CHAR.find(character)
+ if off == -1:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Expected control char")
+ character = ESCAPE_CHAR_VALS[off]
+ string.append(character)
+ return ''.join(string)
+
+ def isJSONNumeric(self, character):
+ return (True if NUMERIC_CHAR.find(character) != - 1 else False)
+
+ def readJSONQuotes(self):
+ if (self.context.escapeNum()):
+ self.readJSONSyntaxChar(QUOTE)
+
+ def readJSONNumericChars(self):
+ numeric = []
+ while True:
+ character = self.reader.peek()
+ if self.isJSONNumeric(character) is False:
+ break
+ numeric.append(self.reader.read())
+ return ''.join(numeric)
+
+ def readJSONInteger(self):
+ self.context.read()
+ self.readJSONQuotes()
+ numeric = self.readJSONNumericChars()
+ self.readJSONQuotes()
+ try:
+ return int(numeric)
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONDouble(self):
+ self.context.read()
+ if self.reader.peek() == QUOTE:
+ string = self.readJSONString(True)
+ try:
+ double = float(string)
+ if (self.context.escapeNum is False and
+ not math.isinf(double) and
+ not math.isnan(double)):
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Numeric data unexpectedly quoted")
+ return double
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+ else:
+ if self.context.escapeNum() is True:
+ self.readJSONSyntaxChar(QUOTE)
+ try:
+ return float(self.readJSONNumericChars())
+ except ValueError:
+ raise TProtocolException(TProtocolException.INVALID_DATA,
+ "Bad data encounted in numeric data")
+
+ def readJSONBase64(self):
+ string = self.readJSONString(False)
+ return base64.b64decode(string)
+
+ def readJSONObjectStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACE)
+ self.pushContext(JSONPairContext(self))
+
+ def readJSONObjectEnd(self):
+ self.readJSONSyntaxChar(RBRACE)
+ self.popContext()
+
+ def readJSONArrayStart(self):
+ self.context.read()
+ self.readJSONSyntaxChar(LBRACKET)
+ self.pushContext(JSONListContext(self))
+
+ def readJSONArrayEnd(self):
+ self.readJSONSyntaxChar(RBRACKET)
+ self.popContext()
+
+
+class TJSONProtocol(TJSONProtocolBase):
+
+ def readMessageBegin(self):
+ self.resetReadContext()
+ self.readJSONArrayStart()
+ if self.readJSONInteger() != VERSION:
+ raise TProtocolException(TProtocolException.BAD_VERSION,
+ "Message contained bad version.")
+ name = self.readJSONString(False)
+ typen = self.readJSONInteger()
+ seqid = self.readJSONInteger()
+ return (name, typen, seqid)
+
+ def readMessageEnd(self):
+ self.readJSONArrayEnd()
+
+ def readStructBegin(self):
+ self.readJSONObjectStart()
+
+ def readStructEnd(self):
+ self.readJSONObjectEnd()
+
+ def readFieldBegin(self):
+ character = self.reader.peek()
+ ttype = 0
+ id = 0
+ if character == RBRACE:
+ ttype = TType.STOP
+ else:
+ id = self.readJSONInteger()
+ self.readJSONObjectStart()
+ ttype = JTYPES[self.readJSONString(False)]
+ return (None, ttype, id)
+
+ def readFieldEnd(self):
+ self.readJSONObjectEnd()
+
+ def readMapBegin(self):
+ self.readJSONArrayStart()
+ keyType = JTYPES[self.readJSONString(False)]
+ valueType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ self.readJSONObjectStart()
+ return (keyType, valueType, size)
+
+ def readMapEnd(self):
+ self.readJSONObjectEnd()
+ self.readJSONArrayEnd()
+
+ def readCollectionBegin(self):
+ self.readJSONArrayStart()
+ elemType = JTYPES[self.readJSONString(False)]
+ size = self.readJSONInteger()
+ return (elemType, size)
+ readListBegin = readCollectionBegin
+ readSetBegin = readCollectionBegin
+
+ def readCollectionEnd(self):
+ self.readJSONArrayEnd()
+ readSetEnd = readCollectionEnd
+ readListEnd = readCollectionEnd
+
+ def readBool(self):
+ return (False if self.readJSONInteger() == 0 else True)
+
+ def readNumber(self):
+ return self.readJSONInteger()
+ readByte = readNumber
+ readI16 = readNumber
+ readI32 = readNumber
+ readI64 = readNumber
+
+ def readDouble(self):
+ return self.readJSONDouble()
+
+ def readString(self):
+ return self.readJSONString(False)
+
+ def readBinary(self):
+ return self.readJSONBase64()
+
+ def writeMessageBegin(self, name, request_type, seqid):
+ self.resetWriteContext()
+ self.writeJSONArrayStart()
+ self.writeJSONNumber(VERSION)
+ self.writeJSONString(name)
+ self.writeJSONNumber(request_type)
+ self.writeJSONNumber(seqid)
+
+ def writeMessageEnd(self):
+ self.writeJSONArrayEnd()
+
+ def writeStructBegin(self, name):
+ self.writeJSONObjectStart()
+
+ def writeStructEnd(self):
+ self.writeJSONObjectEnd()
+
+ def writeFieldBegin(self, name, ttype, id):
+ self.writeJSONNumber(id)
+ self.writeJSONObjectStart()
+ self.writeJSONString(CTYPES[ttype])
+
+ def writeFieldEnd(self):
+ self.writeJSONObjectEnd()
+
+ def writeFieldStop(self):
+ pass
+
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[ktype])
+ self.writeJSONString(CTYPES[vtype])
+ self.writeJSONNumber(size)
+ self.writeJSONObjectStart()
+
+ def writeMapEnd(self):
+ self.writeJSONObjectEnd()
+ self.writeJSONArrayEnd()
+
+ def writeListBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
+
+ def writeListEnd(self):
+ self.writeJSONArrayEnd()
+
+ def writeSetBegin(self, etype, size):
+ self.writeJSONArrayStart()
+ self.writeJSONString(CTYPES[etype])
+ self.writeJSONNumber(size)
+
+ def writeSetEnd(self):
+ self.writeJSONArrayEnd()
+
+ def writeBool(self, boolean):
+ self.writeJSONNumber(1 if boolean is True else 0)
+
+ def writeInteger(self, integer):
+ self.writeJSONNumber(integer)
+ writeByte = writeInteger
+ writeI16 = writeInteger
+ writeI32 = writeInteger
+ writeI64 = writeInteger
+
+ def writeDouble(self, dbl):
+ self.writeJSONNumber(dbl)
+
+ def writeString(self, string):
+ self.writeJSONString(string)
+
+ def writeBinary(self, binary):
+ self.writeJSONBase64(binary)
+
+
+class TJSONProtocolFactory:
+
+ def getProtocol(self, trans):
+ return TJSONProtocol(trans)
+
+
+class TSimpleJSONProtocol(TJSONProtocolBase):
+ """Simple, readable, write-only JSON protocol.
+
+ Useful for interacting with scripting languages.
+ """
+
+ def readMessageBegin(self):
+ raise NotImplementedError()
+
+ def readMessageEnd(self):
+ raise NotImplementedError()
+
+ def readStructBegin(self):
+ raise NotImplementedError()
+
+ def readStructEnd(self):
+ raise NotImplementedError()
+
+ def writeMessageBegin(self, name, request_type, seqid):
+ self.resetWriteContext()
+
+ def writeMessageEnd(self):
+ pass
+
+ def writeStructBegin(self, name):
+ self.writeJSONObjectStart()
+
+ def writeStructEnd(self):
+ self.writeJSONObjectEnd()
+
+ def writeFieldBegin(self, name, ttype, fid):
+ self.writeJSONString(name)
+
+ def writeFieldEnd(self):
+ pass
+
+ def writeMapBegin(self, ktype, vtype, size):
+ self.writeJSONObjectStart()
+
+ def writeMapEnd(self):
+ self.writeJSONObjectEnd()
+
+ def _writeCollectionBegin(self, etype, size):
+ self.writeJSONArrayStart()
+
+ def _writeCollectionEnd(self):
+ self.writeJSONArrayEnd()
+ writeListBegin = _writeCollectionBegin
+ writeListEnd = _writeCollectionEnd
+ writeSetBegin = _writeCollectionBegin
+ writeSetEnd = _writeCollectionEnd
+
+ def writeInteger(self, integer):
+ self.writeJSONNumber(integer)
+ writeByte = writeInteger
+ writeI16 = writeInteger
+ writeI32 = writeInteger
+ writeI64 = writeInteger
+
+ def writeBool(self, boolean):
+ self.writeJSONNumber(1 if boolean is True else 0)
+
+ def writeDouble(self, dbl):
+ self.writeJSONNumber(dbl)
+
+ def writeString(self, string):
+ self.writeJSONString(string)
+
+ def writeBinary(self, binary):
+ self.writeJSONBase64(binary)
+
+
+class TSimpleJSONProtocolFactory(object):
+
+ def getProtocol(self, trans):
+ return TSimpleJSONProtocol(trans)
diff --git a/module/lib/thrift/protocol/TProtocol.py b/module/lib/thrift/protocol/TProtocol.py
index 7338ff68a..dc2b095de 100644
--- a/module/lib/thrift/protocol/TProtocol.py
+++ b/module/lib/thrift/protocol/TProtocol.py
@@ -19,8 +19,8 @@
from thrift.Thrift import *
-class TProtocolException(TException):
+class TProtocolException(TException):
"""Custom Protocol Exception class"""
UNKNOWN = 0
@@ -33,14 +33,14 @@ class TProtocolException(TException):
TException.__init__(self, message)
self.type = type
-class TProtocolBase:
+class TProtocolBase:
"""Base class for Thrift protocol driver."""
def __init__(self, trans):
self.trans = trans
- def writeMessageBegin(self, name, type, seqid):
+ def writeMessageBegin(self, name, ttype, seqid):
pass
def writeMessageEnd(self):
@@ -52,7 +52,7 @@ class TProtocolBase:
def writeStructEnd(self):
pass
- def writeFieldBegin(self, name, type, id):
+ def writeFieldBegin(self, name, ttype, fid):
pass
def writeFieldEnd(self):
@@ -79,7 +79,7 @@ class TProtocolBase:
def writeSetEnd(self):
pass
- def writeBool(self, bool):
+ def writeBool(self, bool_val):
pass
def writeByte(self, byte):
@@ -97,7 +97,7 @@ class TProtocolBase:
def writeDouble(self, dub):
pass
- def writeString(self, str):
+ def writeString(self, str_val):
pass
def readMessageBegin(self):
@@ -157,69 +157,69 @@ class TProtocolBase:
def readString(self):
pass
- def skip(self, type):
- if type == TType.STOP:
+ def skip(self, ttype):
+ if ttype == TType.STOP:
return
- elif type == TType.BOOL:
+ elif ttype == TType.BOOL:
self.readBool()
- elif type == TType.BYTE:
+ elif ttype == TType.BYTE:
self.readByte()
- elif type == TType.I16:
+ elif ttype == TType.I16:
self.readI16()
- elif type == TType.I32:
+ elif ttype == TType.I32:
self.readI32()
- elif type == TType.I64:
+ elif ttype == TType.I64:
self.readI64()
- elif type == TType.DOUBLE:
+ elif ttype == TType.DOUBLE:
self.readDouble()
- elif type == TType.STRING:
+ elif ttype == TType.STRING:
self.readString()
- elif type == TType.STRUCT:
+ elif ttype == TType.STRUCT:
name = self.readStructBegin()
while True:
- (name, type, id) = self.readFieldBegin()
- if type == TType.STOP:
+ (name, ttype, id) = self.readFieldBegin()
+ if ttype == TType.STOP:
break
- self.skip(type)
+ self.skip(ttype)
self.readFieldEnd()
self.readStructEnd()
- elif type == TType.MAP:
+ elif ttype == TType.MAP:
(ktype, vtype, size) = self.readMapBegin()
- for i in range(size):
+ for i in xrange(size):
self.skip(ktype)
self.skip(vtype)
self.readMapEnd()
- elif type == TType.SET:
+ elif ttype == TType.SET:
(etype, size) = self.readSetBegin()
- for i in range(size):
+ for i in xrange(size):
self.skip(etype)
self.readSetEnd()
- elif type == TType.LIST:
+ elif ttype == TType.LIST:
(etype, size) = self.readListBegin()
- for i in range(size):
+ for i in xrange(size):
self.skip(etype)
self.readListEnd()
- # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name )
+ # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
_TTYPE_HANDLERS = (
- (None, None, False), # 0 == TType,STOP
- (None, None, False), # 1 == TType.VOID # TODO: handle void?
- ('readBool', 'writeBool', False), # 2 == TType.BOOL
- ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08
- ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE
- (None, None, False), # 5, undefined
- ('readI16', 'writeI16', False), # 6 == TType.I16
- (None, None, False), # 7, undefined
- ('readI32', 'writeI32', False), # 8 == TType.I32
- (None, None, False), # 9, undefined
- ('readI64', 'writeI64', False), # 10 == TType.I64
- ('readString', 'writeString', False), # 11 == TType.STRING and UTF7
- ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT
- ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP
- ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET
- ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST
- (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types?
- (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types?
+ (None, None, False), # 0 TType.STOP
+ (None, None, False), # 1 TType.VOID # TODO: handle void?
+ ('readBool', 'writeBool', False), # 2 TType.BOOL
+ ('readByte', 'writeByte', False), # 3 TType.BYTE and I08
+ ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE
+ (None, None, False), # 5 undefined
+ ('readI16', 'writeI16', False), # 6 TType.I16
+ (None, None, False), # 7 undefined
+ ('readI32', 'writeI32', False), # 8 TType.I32
+ (None, None, False), # 9 undefined
+ ('readI64', 'writeI64', False), # 10 TType.I64
+ ('readString', 'writeString', False), # 11 TType.STRING and UTF7
+ ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT
+ ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP
+ ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET
+ ('readContainerList', 'writeContainerList', True), # 15 TType.LIST
+ (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types?
+ (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types?
)
def readFieldByTType(self, ttype, spec):
@@ -270,7 +270,7 @@ class TProtocolBase:
container_reader = self._TTYPE_HANDLERS[set_type][0]
val_reader = getattr(self, container_reader)
for idx in xrange(set_len):
- results.add(val_reader(tspec))
+ results.add(val_reader(tspec))
self.readSetEnd()
return results
@@ -279,13 +279,14 @@ class TProtocolBase:
obj = obj_class()
obj.read(self)
return obj
-
+
def readContainerMap(self, spec):
results = dict()
key_ttype, key_spec = spec[0], spec[1]
val_ttype, val_spec = spec[2], spec[3]
(map_ktype, map_vtype, map_len) = self.readMapBegin()
- # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree
+ # TODO: compare types we just decoded with thrift_spec and
+ # abort/skip if types disagree
key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
# list values are simple types
@@ -298,7 +299,8 @@ class TProtocolBase:
v_val = val_reader()
else:
v_val = self.readFieldByTType(val_ttype, val_spec)
- # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails
+ # this raises a TypeError with unhashable keys types
+ # i.e. this fails: d=dict(); d[[0,1]] = 2
results[k_val] = v_val
self.readMapEnd()
return results
@@ -329,7 +331,7 @@ class TProtocolBase:
def writeContainerList(self, val, spec):
self.writeListBegin(spec[0], len(val))
- r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
e_writer = getattr(self, w_handler)
if not is_container:
for elem in val:
@@ -398,7 +400,7 @@ class TProtocolBase:
else:
writer(val)
+
class TProtocolFactory:
def getProtocol(self, trans):
pass
-
diff --git a/module/lib/thrift/protocol/__init__.py b/module/lib/thrift/protocol/__init__.py
index d53359b28..7eefb458a 100644
--- a/module/lib/thrift/protocol/__init__.py
+++ b/module/lib/thrift/protocol/__init__.py
@@ -17,4 +17,4 @@
# under the License.
#
-__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
+__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol']
diff --git a/module/lib/thrift/protocol/fastbinary.c b/module/lib/thrift/protocol/fastbinary.c
new file mode 100644
index 000000000..2ce56603c
--- /dev/null
+++ b/module/lib/thrift/protocol/fastbinary.c
@@ -0,0 +1,1219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <Python.h>
+#include "cStringIO.h"
+#include <stdint.h>
+#ifndef _WIN32
+# include <stdbool.h>
+# include <netinet/in.h>
+#else
+# include <WinSock2.h>
+# pragma comment (lib, "ws2_32.lib")
+# define BIG_ENDIAN (4321)
+# define LITTLE_ENDIAN (1234)
+# define BYTE_ORDER LITTLE_ENDIAN
+# if defined(_MSC_VER) && _MSC_VER < 1600
+ typedef int _Bool;
+# define bool _Bool
+# define false 0
+# define true 1
+# endif
+# define inline __inline
+#endif
+
+/* Fix endianness issues on Solaris */
+#if defined (__SVR4) && defined (__sun)
+ #if defined(__i386) && !defined(__i386__)
+ #define __i386__
+ #endif
+
+ #ifndef BIG_ENDIAN
+ #define BIG_ENDIAN (4321)
+ #endif
+ #ifndef LITTLE_ENDIAN
+ #define LITTLE_ENDIAN (1234)
+ #endif
+
+ /* I386 is LE, even on Solaris */
+ #if !defined(BYTE_ORDER) && defined(__i386__)
+ #define BYTE_ORDER LITTLE_ENDIAN
+ #endif
+#endif
+
+// TODO(dreiss): defval appears to be unused. Look into removing it.
+// TODO(dreiss): Make parse_spec_args recursive, and cache the output
+// permanently in the object. (Malloc and orphan.)
+// TODO(dreiss): Why do we need cStringIO for reading, why not just char*?
+// Can cStringIO let us work with a BufferedTransport?
+// TODO(dreiss): Don't ignore the rv from cwrite (maybe).
+
+/* ====== BEGIN UTILITIES ====== */
+
+#define INIT_OUTBUF_SIZE 128
+
+// Stolen out of TProtocol.h.
+// It would be a huge pain to have both get this from one place.
+typedef enum TType {
+ T_STOP = 0,
+ T_VOID = 1,
+ T_BOOL = 2,
+ T_BYTE = 3,
+ T_I08 = 3,
+ T_I16 = 6,
+ T_I32 = 8,
+ T_U64 = 9,
+ T_I64 = 10,
+ T_DOUBLE = 4,
+ T_STRING = 11,
+ T_UTF7 = 11,
+ T_STRUCT = 12,
+ T_MAP = 13,
+ T_SET = 14,
+ T_LIST = 15,
+ T_UTF8 = 16,
+ T_UTF16 = 17
+} TType;
+
+#ifndef __BYTE_ORDER
+# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN)
+# define __BYTE_ORDER BYTE_ORDER
+# define __LITTLE_ENDIAN LITTLE_ENDIAN
+# define __BIG_ENDIAN BIG_ENDIAN
+# else
+# error "Cannot determine endianness"
+# endif
+#endif
+
+// Same comment as the enum. Sorry.
+#if __BYTE_ORDER == __BIG_ENDIAN
+# define ntohll(n) (n)
+# define htonll(n) (n)
+#elif __BYTE_ORDER == __LITTLE_ENDIAN
+# if defined(__GNUC__) && defined(__GLIBC__)
+# include <byteswap.h>
+# define ntohll(n) bswap_64(n)
+# define htonll(n) bswap_64(n)
+# else /* GNUC & GLIBC */
+# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) )
+# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) )
+# endif /* GNUC & GLIBC */
+#else /* __BYTE_ORDER */
+# error "Can't define htonll or ntohll!"
+#endif
+
+// Doing a benchmark shows that interning actually makes a difference, amazingly.
+#define INTERN_STRING(value) _intern_ ## value
+
+#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() )
+#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) )
+
+// Py_ssize_t was not defined before Python 2.5
+#if (PY_VERSION_HEX < 0x02050000)
+typedef int Py_ssize_t;
+#endif
+
+/**
+ * A cache of the spec_args for a set or list,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ TType element_type;
+ PyObject* typeargs;
+} SetListTypeArgs;
+
+/**
+ * A cache of the spec_args for a map,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ TType ktag;
+ TType vtag;
+ PyObject* ktypeargs;
+ PyObject* vtypeargs;
+} MapTypeArgs;
+
+/**
+ * A cache of the spec_args for a struct,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ PyObject* klass;
+ PyObject* spec;
+} StructTypeArgs;
+
+/**
+ * A cache of the item spec from a struct specification,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ int tag;
+ TType type;
+ PyObject* attrname;
+ PyObject* typeargs;
+ PyObject* defval;
+} StructItemSpec;
+
+/**
+ * A cache of the two key attributes of a CReadableTransport,
+ * so we don't have to keep calling PyObject_GetAttr.
+ */
+typedef struct {
+ PyObject* stringiobuf;
+ PyObject* refill_callable;
+} DecodeBuffer;
+
+/** Pointer to interned string to speed up attribute lookup. */
+static PyObject* INTERN_STRING(cstringio_buf);
+/** Pointer to interned string to speed up attribute lookup. */
+static PyObject* INTERN_STRING(cstringio_refill);
+
+static inline bool
+check_ssize_t_32(Py_ssize_t len) {
+ // error from getting the int
+ if (INT_CONV_ERROR_OCCURRED(len)) {
+ return false;
+ }
+ if (!CHECK_RANGE(len, 0, INT32_MAX)) {
+ PyErr_SetString(PyExc_OverflowError, "string size out of range");
+ return false;
+ }
+ return true;
+}
+
+static inline bool
+parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) {
+ long val = PyInt_AsLong(o);
+
+ if (INT_CONV_ERROR_OCCURRED(val)) {
+ return false;
+ }
+ if (!CHECK_RANGE(val, min, max)) {
+ PyErr_SetString(PyExc_OverflowError, "int out of range");
+ return false;
+ }
+
+ *ret = (int32_t) val;
+ return true;
+}
+
+
+/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */
+
+static bool
+parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 2) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args");
+ return false;
+ }
+
+ dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->element_type)) {
+ return false;
+ }
+
+ dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
+
+ return true;
+}
+
+static bool
+parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 4) {
+ PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map");
+ return false;
+ }
+
+ dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->ktag)) {
+ return false;
+ }
+
+ dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2));
+ if (INT_CONV_ERROR_OCCURRED(dest->vtag)) {
+ return false;
+ }
+
+ dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
+ dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
+
+ return true;
+}
+
+static bool
+parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 2) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
+ return false;
+ }
+
+ dest->klass = PyTuple_GET_ITEM(typeargs, 0);
+ dest->spec = PyTuple_GET_ITEM(typeargs, 1);
+
+ return true;
+}
+
+static int
+parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) {
+
+ // i'd like to use ParseArgs here, but it seems to be a bottleneck.
+ if (PyTuple_Size(spec_tuple) != 5) {
+ PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple");
+ return false;
+ }
+
+ dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->tag)) {
+ return false;
+ }
+
+ dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1));
+ if (INT_CONV_ERROR_OCCURRED(dest->type)) {
+ return false;
+ }
+
+ dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2);
+ dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3);
+ dest->defval = PyTuple_GET_ITEM(spec_tuple, 4);
+ return true;
+}
+
+/* ====== END UTILITIES ====== */
+
+
+/* ====== BEGIN WRITING FUNCTIONS ====== */
+
+/* --- LOW-LEVEL WRITING FUNCTIONS --- */
+
+static void writeByte(PyObject* outbuf, int8_t val) {
+ int8_t net = val;
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t));
+}
+
+static void writeI16(PyObject* outbuf, int16_t val) {
+ int16_t net = (int16_t)htons(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t));
+}
+
+static void writeI32(PyObject* outbuf, int32_t val) {
+ int32_t net = (int32_t)htonl(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t));
+}
+
+static void writeI64(PyObject* outbuf, int64_t val) {
+ int64_t net = (int64_t)htonll(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t));
+}
+
+static void writeDouble(PyObject* outbuf, double dub) {
+ // Unfortunately, bitwise_cast doesn't work in C. Bad C!
+ union {
+ double f;
+ int64_t t;
+ } transfer;
+ transfer.f = dub;
+ writeI64(outbuf, transfer.t);
+}
+
+
+/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */
+
+static int
+output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) {
+ /*
+ * Refcounting Strategy:
+ *
+ * We assume that elements of the thrift_spec tuple are not going to be
+ * mutated, so we don't ref count those at all. Other than that, we try to
+ * keep a reference to all the user-created objects while we work with them.
+ * output_val assumes that a reference is already held. The *caller* is
+ * responsible for handling references
+ */
+
+ switch (type) {
+
+ case T_BOOL: {
+ int v = PyObject_IsTrue(value);
+ if (v == -1) {
+ return false;
+ }
+
+ writeByte(output, (int8_t) v);
+ break;
+ }
+ case T_I08: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) {
+ return false;
+ }
+
+ writeByte(output, (int8_t) val);
+ break;
+ }
+ case T_I16: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) {
+ return false;
+ }
+
+ writeI16(output, (int16_t) val);
+ break;
+ }
+ case T_I32: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) {
+ return false;
+ }
+
+ writeI32(output, val);
+ break;
+ }
+ case T_I64: {
+ int64_t nval = PyLong_AsLongLong(value);
+
+ if (INT_CONV_ERROR_OCCURRED(nval)) {
+ return false;
+ }
+
+ if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) {
+ PyErr_SetString(PyExc_OverflowError, "int out of range");
+ return false;
+ }
+
+ writeI64(output, nval);
+ break;
+ }
+
+ case T_DOUBLE: {
+ double nval = PyFloat_AsDouble(value);
+ if (nval == -1.0 && PyErr_Occurred()) {
+ return false;
+ }
+
+ writeDouble(output, nval);
+ break;
+ }
+
+ case T_STRING: {
+ Py_ssize_t len = PyString_Size(value);
+
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ writeI32(output, (int32_t) len);
+ PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len);
+ break;
+ }
+
+ case T_LIST:
+ case T_SET: {
+ Py_ssize_t len;
+ SetListTypeArgs parsedargs;
+ PyObject *item;
+ PyObject *iterator;
+
+ if (!parse_set_list_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ len = PyObject_Length(value);
+
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ writeByte(output, parsedargs.element_type);
+ writeI32(output, (int32_t) len);
+
+ iterator = PyObject_GetIter(value);
+ if (iterator == NULL) {
+ return false;
+ }
+
+ while ((item = PyIter_Next(iterator))) {
+ if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) {
+ Py_DECREF(item);
+ Py_DECREF(iterator);
+ return false;
+ }
+ Py_DECREF(item);
+ }
+
+ Py_DECREF(iterator);
+
+ if (PyErr_Occurred()) {
+ return false;
+ }
+
+ break;
+ }
+
+ case T_MAP: {
+ PyObject *k, *v;
+ Py_ssize_t pos = 0;
+ Py_ssize_t len;
+
+ MapTypeArgs parsedargs;
+
+ len = PyDict_Size(value);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ if (!parse_map_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ writeByte(output, parsedargs.ktag);
+ writeByte(output, parsedargs.vtag);
+ writeI32(output, len);
+
+ // TODO(bmaurer): should support any mapping, not just dicts
+ while (PyDict_Next(value, &pos, &k, &v)) {
+ // TODO(dreiss): Think hard about whether these INCREFs actually
+ // turn any unsafe scenarios into safe scenarios.
+ Py_INCREF(k);
+ Py_INCREF(v);
+
+ if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs)
+ || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) {
+ Py_DECREF(k);
+ Py_DECREF(v);
+ return false;
+ }
+ Py_DECREF(k);
+ Py_DECREF(v);
+ }
+ break;
+ }
+
+ // TODO(dreiss): Consider breaking this out as a function
+ // the way we did for decode_struct.
+ case T_STRUCT: {
+ StructTypeArgs parsedargs;
+ Py_ssize_t nspec;
+ Py_ssize_t i;
+
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ nspec = PyTuple_Size(parsedargs.spec);
+
+ if (nspec == -1) {
+ return false;
+ }
+
+ for (i = 0; i < nspec; i++) {
+ StructItemSpec parsedspec;
+ PyObject* spec_tuple;
+ PyObject* instval = NULL;
+
+ spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i);
+ if (spec_tuple == Py_None) {
+ continue;
+ }
+
+ if (!parse_struct_item_spec (&parsedspec, spec_tuple)) {
+ return false;
+ }
+
+ instval = PyObject_GetAttr(value, parsedspec.attrname);
+
+ if (!instval) {
+ return false;
+ }
+
+ if (instval == Py_None) {
+ Py_DECREF(instval);
+ continue;
+ }
+
+ writeByte(output, (int8_t) parsedspec.type);
+ writeI16(output, parsedspec.tag);
+
+ if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) {
+ Py_DECREF(instval);
+ return false;
+ }
+
+ Py_DECREF(instval);
+ }
+
+ writeByte(output, (int8_t)T_STOP);
+ break;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return false;
+
+ }
+
+ return true;
+}
+
+
+/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */
+
+static PyObject *
+encode_binary(PyObject *self, PyObject *args) {
+ PyObject* enc_obj;
+ PyObject* type_args;
+ PyObject* buf;
+ PyObject* ret = NULL;
+
+ if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) {
+ return NULL;
+ }
+
+ buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE);
+ if (output_val(buf, enc_obj, T_STRUCT, type_args)) {
+ ret = PycStringIO->cgetvalue(buf);
+ }
+
+ Py_DECREF(buf);
+ return ret;
+}
+
+/* ====== END WRITING FUNCTIONS ====== */
+
+
+/* ====== BEGIN READING FUNCTIONS ====== */
+
+/* --- LOW-LEVEL READING FUNCTIONS --- */
+
+static void
+free_decodebuf(DecodeBuffer* d) {
+ Py_XDECREF(d->stringiobuf);
+ Py_XDECREF(d->refill_callable);
+}
+
+static bool
+decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) {
+ dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf));
+ if (!dest->stringiobuf) {
+ return false;
+ }
+
+ if (!PycStringIO_InputCheck(dest->stringiobuf)) {
+ free_decodebuf(dest);
+ PyErr_SetString(PyExc_TypeError, "expecting stringio input");
+ return false;
+ }
+
+ dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill));
+
+ if(!dest->refill_callable) {
+ free_decodebuf(dest);
+ return false;
+ }
+
+ if (!PyCallable_Check(dest->refill_callable)) {
+ free_decodebuf(dest);
+ PyErr_SetString(PyExc_TypeError, "expecting callable");
+ return false;
+ }
+
+ return true;
+}
+
+static bool readBytes(DecodeBuffer* input, char** output, int len) {
+ int read;
+
+ // TODO(dreiss): Don't fear the malloc. Think about taking a copy of
+ // the partial read instead of forcing the transport
+ // to prepend it to its buffer.
+
+ read = PycStringIO->cread(input->stringiobuf, output, len);
+
+ if (read == len) {
+ return true;
+ } else if (read == -1) {
+ return false;
+ } else {
+ PyObject* newiobuf;
+
+ // using building functions as this is a rare codepath
+ newiobuf = PyObject_CallFunction(
+ input->refill_callable, "s#i", *output, read, len, NULL);
+ if (newiobuf == NULL) {
+ return false;
+ }
+
+ // must do this *AFTER* the call so that we don't deref the io buffer
+ Py_CLEAR(input->stringiobuf);
+ input->stringiobuf = newiobuf;
+
+ read = PycStringIO->cread(input->stringiobuf, output, len);
+
+ if (read == len) {
+ return true;
+ } else if (read == -1) {
+ return false;
+ } else {
+ // TODO(dreiss): This could be a valid code path for big binary blobs.
+ PyErr_SetString(PyExc_TypeError,
+ "refill claimed to have refilled the buffer, but didn't!!");
+ return false;
+ }
+ }
+}
+
+static int8_t readByte(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int8_t))) {
+ return -1;
+ }
+
+ return *(int8_t*) buf;
+}
+
+static int16_t readI16(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int16_t))) {
+ return -1;
+ }
+
+ return (int16_t) ntohs(*(int16_t*) buf);
+}
+
+static int32_t readI32(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int32_t))) {
+ return -1;
+ }
+ return (int32_t) ntohl(*(int32_t*) buf);
+}
+
+
+static int64_t readI64(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int64_t))) {
+ return -1;
+ }
+
+ return (int64_t) ntohll(*(int64_t*) buf);
+}
+
+static double readDouble(DecodeBuffer* input) {
+ union {
+ int64_t f;
+ double t;
+ } transfer;
+
+ transfer.f = readI64(input);
+ if (transfer.f == -1) {
+ return -1;
+ }
+ return transfer.t;
+}
+
+static bool
+checkTypeByte(DecodeBuffer* input, TType expected) {
+ TType got = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(got)) {
+ return false;
+ }
+
+ if (expected != got) {
+ PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
+ return false;
+ }
+ return true;
+}
+
+static bool
+skip(DecodeBuffer* input, TType type) {
+#define SKIPBYTES(n) \
+ do { \
+ if (!readBytes(input, &dummy_buf, (n))) { \
+ return false; \
+ } \
+ } while(0)
+
+ char* dummy_buf;
+
+ switch (type) {
+
+ case T_BOOL:
+ case T_I08: SKIPBYTES(1); break;
+ case T_I16: SKIPBYTES(2); break;
+ case T_I32: SKIPBYTES(4); break;
+ case T_I64:
+ case T_DOUBLE: SKIPBYTES(8); break;
+
+ case T_STRING: {
+ // TODO(dreiss): Find out if these check_ssize_t32s are really necessary.
+ int len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+ SKIPBYTES(len);
+ break;
+ }
+
+ case T_LIST:
+ case T_SET: {
+ TType etype;
+ int len, i;
+
+ etype = readByte(input);
+ if (etype == -1) {
+ return false;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ for (i = 0; i < len; i++) {
+ if (!skip(input, etype)) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_MAP: {
+ TType ktype, vtype;
+ int len, i;
+
+ ktype = readByte(input);
+ if (ktype == -1) {
+ return false;
+ }
+
+ vtype = readByte(input);
+ if (vtype == -1) {
+ return false;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ for (i = 0; i < len; i++) {
+ if (!(skip(input, ktype) && skip(input, vtype))) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_STRUCT: {
+ while (true) {
+ TType type;
+
+ type = readByte(input);
+ if (type == -1) {
+ return false;
+ }
+
+ if (type == T_STOP)
+ break;
+
+ SKIPBYTES(2); // tag
+ if (!skip(input, type)) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return false;
+
+ }
+
+ return true;
+
+#undef SKIPBYTES
+}
+
+
+/* --- HELPER FUNCTION FOR DECODE_VAL --- */
+
+static PyObject*
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
+
+static bool
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) {
+ int spec_seq_len = PyTuple_Size(spec_seq);
+ if (spec_seq_len == -1) {
+ return false;
+ }
+
+ while (true) {
+ TType type;
+ int16_t tag;
+ PyObject* item_spec;
+ PyObject* fieldval = NULL;
+ StructItemSpec parsedspec;
+
+ type = readByte(input);
+ if (type == -1) {
+ return false;
+ }
+ if (type == T_STOP) {
+ break;
+ }
+ tag = readI16(input);
+ if (INT_CONV_ERROR_OCCURRED(tag)) {
+ return false;
+ }
+ if (tag >= 0 && tag < spec_seq_len) {
+ item_spec = PyTuple_GET_ITEM(spec_seq, tag);
+ } else {
+ item_spec = Py_None;
+ }
+
+ if (item_spec == Py_None) {
+ if (!skip(input, type)) {
+ return false;
+ } else {
+ continue;
+ }
+ }
+
+ if (!parse_struct_item_spec(&parsedspec, item_spec)) {
+ return false;
+ }
+ if (parsedspec.type != type) {
+ if (!skip(input, type)) {
+ PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped");
+ return false;
+ } else {
+ continue;
+ }
+ }
+
+ fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
+ if (fieldval == NULL) {
+ return false;
+ }
+
+ if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) {
+ Py_DECREF(fieldval);
+ return false;
+ }
+ Py_DECREF(fieldval);
+ }
+ return true;
+}
+
+
+/* --- MAIN RECURSIVE INPUT FUCNTION --- */
+
+// Returns a new reference.
+static PyObject*
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
+ switch (type) {
+
+ case T_BOOL: {
+ int8_t v = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+
+ switch (v) {
+ case 0: Py_RETURN_FALSE;
+ case 1: Py_RETURN_TRUE;
+ // Don't laugh. This is a potentially serious issue.
+ default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL;
+ }
+ break;
+ }
+ case T_I08: {
+ int8_t v = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+
+ return PyInt_FromLong(v);
+ }
+ case T_I16: {
+ int16_t v = readI16(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ return PyInt_FromLong(v);
+ }
+ case T_I32: {
+ int32_t v = readI32(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ return PyInt_FromLong(v);
+ }
+
+ case T_I64: {
+ int64_t v = readI64(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ // TODO(dreiss): Find out if we can take this fastpath always when
+ // sizeof(long) == sizeof(long long).
+ if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) {
+ return PyInt_FromLong((long) v);
+ }
+
+ return PyLong_FromLongLong(v);
+ }
+
+ case T_DOUBLE: {
+ double v = readDouble(input);
+ if (v == -1.0 && PyErr_Occurred()) {
+ return false;
+ }
+ return PyFloat_FromDouble(v);
+ }
+
+ case T_STRING: {
+ Py_ssize_t len = readI32(input);
+ char* buf;
+ if (!readBytes(input, &buf, len)) {
+ return NULL;
+ }
+
+ return PyString_FromStringAndSize(buf, len);
+ }
+
+ case T_LIST:
+ case T_SET: {
+ SetListTypeArgs parsedargs;
+ int32_t len;
+ PyObject* ret = NULL;
+ int i;
+
+ if (!parse_set_list_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!checkTypeByte(input, parsedargs.element_type)) {
+ return NULL;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return NULL;
+ }
+
+ ret = PyList_New(len);
+ if (!ret) {
+ return NULL;
+ }
+
+ for (i = 0; i < len; i++) {
+ PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs);
+ if (!item) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+ PyList_SET_ITEM(ret, i, item);
+ }
+
+ // TODO(dreiss): Consider biting the bullet and making two separate cases
+ // for list and set, avoiding this post facto conversion.
+ if (type == T_SET) {
+ PyObject* setret;
+#if (PY_VERSION_HEX < 0x02050000)
+ // hack needed for older versions
+ setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL);
+#else
+ // official version
+ setret = PySet_New(ret);
+#endif
+ Py_DECREF(ret);
+ return setret;
+ }
+ return ret;
+ }
+
+ case T_MAP: {
+ int32_t len;
+ int i;
+ MapTypeArgs parsedargs;
+ PyObject* ret = NULL;
+
+ if (!parse_map_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!checkTypeByte(input, parsedargs.ktag)) {
+ return NULL;
+ }
+ if (!checkTypeByte(input, parsedargs.vtag)) {
+ return NULL;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ ret = PyDict_New();
+ if (!ret) {
+ goto error;
+ }
+
+ for (i = 0; i < len; i++) {
+ PyObject* k = NULL;
+ PyObject* v = NULL;
+ k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs);
+ if (k == NULL) {
+ goto loop_error;
+ }
+ v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs);
+ if (v == NULL) {
+ goto loop_error;
+ }
+ if (PyDict_SetItem(ret, k, v) == -1) {
+ goto loop_error;
+ }
+
+ Py_DECREF(k);
+ Py_DECREF(v);
+ continue;
+
+ // Yuck! Destructors, anyone?
+ loop_error:
+ Py_XDECREF(k);
+ Py_XDECREF(v);
+ goto error;
+ }
+
+ return ret;
+
+ error:
+ Py_XDECREF(ret);
+ return NULL;
+ }
+
+ case T_STRUCT: {
+ StructTypeArgs parsedargs;
+ PyObject* ret;
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ ret = PyObject_CallObject(parsedargs.klass, NULL);
+ if (!ret) {
+ return NULL;
+ }
+
+ if (!decode_struct(input, ret, parsedargs.spec)) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ return ret;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return NULL;
+ }
+}
+
+
+/* --- TOP-LEVEL WRAPPER FOR INPUT -- */
+
+static PyObject*
+decode_binary(PyObject *self, PyObject *args) {
+ PyObject* output_obj = NULL;
+ PyObject* transport = NULL;
+ PyObject* typeargs = NULL;
+ StructTypeArgs parsedargs;
+ DecodeBuffer input = {0, 0};
+
+ if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
+ return NULL;
+ }
+
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!decode_buffer_from_obj(&input, transport)) {
+ return NULL;
+ }
+
+ if (!decode_struct(&input, output_obj, parsedargs.spec)) {
+ free_decodebuf(&input);
+ return NULL;
+ }
+
+ free_decodebuf(&input);
+
+ Py_RETURN_NONE;
+}
+
+/* ====== END READING FUNCTIONS ====== */
+
+
+/* -- PYTHON MODULE SETUP STUFF --- */
+
+static PyMethodDef ThriftFastBinaryMethods[] = {
+
+ {"encode_binary", encode_binary, METH_VARARGS, ""},
+ {"decode_binary", decode_binary, METH_VARARGS, ""},
+
+ {NULL, NULL, 0, NULL} /* Sentinel */
+};
+
+PyMODINIT_FUNC
+initfastbinary(void) {
+#define INIT_INTERN_STRING(value) \
+ do { \
+ INTERN_STRING(value) = PyString_InternFromString(#value); \
+ if(!INTERN_STRING(value)) return; \
+ } while(0)
+
+ INIT_INTERN_STRING(cstringio_buf);
+ INIT_INTERN_STRING(cstringio_refill);
+#undef INIT_INTERN_STRING
+
+ PycString_IMPORT;
+ if (PycStringIO == NULL) return;
+
+ (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods);
+}
diff --git a/module/lib/thrift/server/THttpServer.py b/module/lib/thrift/server/THttpServer.py
index 3047d9c00..be54bab94 100644
--- a/module/lib/thrift/server/THttpServer.py
+++ b/module/lib/thrift/server/THttpServer.py
@@ -22,6 +22,7 @@ import BaseHTTPServer
from thrift.server import TServer
from thrift.transport import TTransport
+
class ResponseException(Exception):
"""Allows handlers to override the HTTP response
@@ -39,16 +40,19 @@ class THttpServer(TServer.TServer):
"""A simple HTTP-based Thrift server
This class is not very performant, but it is useful (for example) for
- acting as a mock version of an Apache-based PHP Thrift endpoint."""
-
- def __init__(self, processor, server_address,
- inputProtocolFactory, outputProtocolFactory = None,
- server_class = BaseHTTPServer.HTTPServer):
+ acting as a mock version of an Apache-based PHP Thrift endpoint.
+ """
+ def __init__(self,
+ processor,
+ server_address,
+ inputProtocolFactory,
+ outputProtocolFactory=None,
+ server_class=BaseHTTPServer.HTTPServer):
"""Set up protocol factories and HTTP server.
See BaseHTTPServer for server_address.
- See TServer for protocol factories."""
-
+ See TServer for protocol factories.
+ """
if outputProtocolFactory is None:
outputProtocolFactory = inputProtocolFactory
@@ -62,7 +66,8 @@ class THttpServer(TServer.TServer):
# Don't care about the request path.
itrans = TTransport.TFileObjectTransport(self.rfile)
otrans = TTransport.TFileObjectTransport(self.wfile)
- itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length']))
+ itrans = TTransport.TBufferedTransport(
+ itrans, int(self.headers['Content-Length']))
otrans = TTransport.TMemoryBuffer()
iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
diff --git a/module/lib/thrift/server/TNonblockingServer.py b/module/lib/thrift/server/TNonblockingServer.py
index ea348a0b6..fa478d01f 100644
--- a/module/lib/thrift/server/TNonblockingServer.py
+++ b/module/lib/thrift/server/TNonblockingServer.py
@@ -18,10 +18,11 @@
#
"""Implementation of non-blocking server.
-The main idea of the server is reciving and sending requests
-only from main thread.
+The main idea of the server is to receive and send requests
+only from the main thread.
-It also makes thread pool server in tasks terms, not connections.
+The thread poool should be sized for concurrent tasks, not
+maximum connections
"""
import threading
import socket
@@ -35,8 +36,10 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
__all__ = ['TNonblockingServer']
+
class Worker(threading.Thread):
"""Worker is a small helper to process incoming connection."""
+
def __init__(self, queue):
threading.Thread.__init__(self)
self.queue = queue
@@ -60,8 +63,9 @@ WAIT_PROCESS = 2
SEND_ANSWER = 3
CLOSED = 4
+
def locked(func):
- "Decorator which locks self.lock."
+ """Decorator which locks self.lock."""
def nested(self, *args, **kwargs):
self.lock.acquire()
try:
@@ -70,8 +74,9 @@ def locked(func):
self.lock.release()
return nested
+
def socket_exception(func):
- "Decorator close object on socket.error."
+ """Decorator close object on socket.error."""
def read(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
@@ -79,16 +84,17 @@ def socket_exception(func):
self.close()
return read
+
class Connection:
"""Basic class is represented connection.
-
+
It can be in state:
WAIT_LEN --- connection is reading request len.
WAIT_MESSAGE --- connection is reading request.
- WAIT_PROCESS --- connection has just read whole request and
- waits for call ready routine.
+ WAIT_PROCESS --- connection has just read whole request and
+ waits for call ready routine.
SEND_ANSWER --- connection is sending answer string (including length
- of answer).
+ of answer).
CLOSED --- socket was closed and connection should be deleted.
"""
def __init__(self, new_socket, wake_up):
@@ -102,13 +108,13 @@ class Connection:
def _read_len(self):
"""Reads length of request.
-
- It's really paranoic routine and it may be replaced by
- self.socket.recv(4)."""
+
+ It's a safer alternative to self.socket.recv(4)
+ """
read = self.socket.recv(4 - len(self.message))
if len(read) == 0:
- # if we read 0 bytes and self.message is empty, it means client close
- # connection
+ # if we read 0 bytes and self.message is empty, then
+ # the client closed the connection
if len(self.message) != 0:
logging.error("can't read frame size from socket")
self.close()
@@ -117,8 +123,8 @@ class Connection:
if len(self.message) == 4:
self.len, = struct.unpack('!i', self.message)
if self.len < 0:
- logging.error("negative frame size, it seems client"\
- " doesn't use FramedTransport")
+ logging.error("negative frame size, it seems client "
+ "doesn't use FramedTransport")
self.close()
elif self.len == 0:
logging.error("empty frame, it's really strange")
@@ -139,8 +145,8 @@ class Connection:
elif self.status == WAIT_MESSAGE:
read = self.socket.recv(self.len - len(self.message))
if len(read) == 0:
- logging.error("can't read frame from socket (get %d of %d bytes)" %
- (len(self.message), self.len))
+ logging.error("can't read frame from socket (get %d of "
+ "%d bytes)" % (len(self.message), self.len))
self.close()
return
self.message += read
@@ -162,14 +168,14 @@ class Connection:
@locked
def ready(self, all_ok, message):
"""Callback function for switching state and waking up main thread.
-
+
This function is the only function witch can be called asynchronous.
-
+
The ready can switch Connection to three states:
WAIT_LEN if request was oneway.
SEND_ANSWER if request was processed in normal way.
CLOSED if request throws unexpected exception.
-
+
The one wakes up main thread.
"""
assert self.status == WAIT_PROCESS
@@ -189,33 +195,39 @@ class Connection:
@locked
def is_writeable(self):
- "Returns True if connection should be added to write list of select."
+ """Return True if connection should be added to write list of select"""
return self.status == SEND_ANSWER
# it's not necessary, but...
@locked
def is_readable(self):
- "Returns True if connection should be added to read list of select."
+ """Return True if connection should be added to read list of select"""
return self.status in (WAIT_LEN, WAIT_MESSAGE)
@locked
def is_closed(self):
- "Returns True if connection is closed."
+ """Returns True if connection is closed."""
return self.status == CLOSED
def fileno(self):
- "Returns the file descriptor of the associated socket."
+ """Returns the file descriptor of the associated socket."""
return self.socket.fileno()
def close(self):
- "Closes connection"
+ """Closes connection"""
self.status = CLOSED
self.socket.close()
+
class TNonblockingServer:
"""Non-blocking server."""
- def __init__(self, processor, lsocket, inputProtocolFactory=None,
- outputProtocolFactory=None, threads=10):
+
+ def __init__(self,
+ processor,
+ lsocket,
+ inputProtocolFactory=None,
+ outputProtocolFactory=None,
+ threads=10):
self.processor = processor
self.socket = lsocket
self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
@@ -225,15 +237,18 @@ class TNonblockingServer:
self.tasks = Queue.Queue()
self._read, self._write = socket.socketpair()
self.prepared = False
+ self._stop = False
def setNumThreads(self, num):
"""Set the number of worker threads that should be created."""
# implement ThreadPool interface
- assert not self.prepared, "You can't change number of threads for working server"
+ assert not self.prepared, "Can't change number of threads after start"
self.threads = num
def prepare(self):
"""Prepares server for serve requests."""
+ if self.prepared:
+ return
self.socket.listen()
for _ in xrange(self.threads):
thread = Worker(self.tasks)
@@ -243,16 +258,32 @@ class TNonblockingServer:
def wake_up(self):
"""Wake up main thread.
-
+
The server usualy waits in select call in we should terminate one.
The simplest way is using socketpair.
-
+
Select always wait to read from the first socket of socketpair.
-
+
In this case, we can just write anything to the second socket from
- socketpair."""
+ socketpair.
+ """
self._write.send('1')
+ def stop(self):
+ """Stop the server.
+
+ This method causes the serve() method to return. stop() may be invoked
+ from within your handler, or from another thread.
+
+ After stop() is called, serve() will return but the server will still
+ be listening on the socket. serve() may then be called again to resume
+ processing requests. Alternatively, close() may be called after
+ serve() returns to close the server socket and shutdown all worker
+ threads.
+ """
+ self._stop = True
+ self.wake_up()
+
def _select(self):
"""Does select on open connections."""
readable = [self.socket.handle.fileno(), self._read.fileno()]
@@ -265,21 +296,22 @@ class TNonblockingServer:
if connection.is_closed():
del self.clients[i]
return select.select(readable, writable, readable)
-
+
def handle(self):
"""Handle requests.
-
- WARNING! You must call prepare BEFORE calling handle.
+
+ WARNING! You must call prepare() BEFORE calling handle()
"""
assert self.prepared, "You have to call prepare before handle"
rset, wset, xset = self._select()
for readable in rset:
if readable == self._read.fileno():
# don't care i just need to clean readable flag
- self._read.recv(1024)
+ self._read.recv(1024)
elif readable == self.socket.handle.fileno():
client = self.socket.accept().handle
- self.clients[client.fileno()] = Connection(client, self.wake_up)
+ self.clients[client.fileno()] = Connection(client,
+ self.wake_up)
else:
connection = self.clients[readable]
connection.read()
@@ -288,7 +320,7 @@ class TNonblockingServer:
otransport = TTransport.TMemoryBuffer()
iprot = self.in_protocol.getProtocol(itransport)
oprot = self.out_protocol.getProtocol(otransport)
- self.tasks.put([self.processor, iprot, oprot,
+ self.tasks.put([self.processor, iprot, oprot,
otransport, connection.ready])
for writeable in wset:
self.clients[writeable].write()
@@ -302,9 +334,13 @@ class TNonblockingServer:
self.tasks.put([None, None, None, None, None])
self.socket.close()
self.prepared = False
-
+
def serve(self):
- """Serve forever."""
+ """Serve requests.
+
+ Serve requests forever, or until stop() is called.
+ """
+ self._stop = False
self.prepare()
- while True:
+ while not self._stop:
self.handle()
diff --git a/module/lib/thrift/server/TProcessPoolServer.py b/module/lib/thrift/server/TProcessPoolServer.py
index 7ed814a88..7a695a883 100644
--- a/module/lib/thrift/server/TProcessPoolServer.py
+++ b/module/lib/thrift/server/TProcessPoolServer.py
@@ -24,15 +24,14 @@ from multiprocessing import Process, Value, Condition, reduction
from TServer import TServer
from thrift.transport.TTransport import TTransportException
+
class TProcessPoolServer(TServer):
+ """Server with a fixed size pool of worker subprocesses to service requests
- """
- Server with a fixed size pool of worker subprocesses which service requests.
Note that if you need shared state between the handlers - it's up to you!
Written by Dvir Volk, doat.com
"""
-
- def __init__(self, * args):
+ def __init__(self, *args):
TServer.__init__(self, *args)
self.numWorkers = 10
self.workers = []
@@ -50,12 +49,11 @@ class TProcessPoolServer(TServer):
self.numWorkers = num
def workerProcess(self):
- """Loop around getting clients from the shared queue and process them."""
-
+ """Loop getting clients from the shared queue and process them"""
if self.postForkCallback:
self.postForkCallback()
- while self.isRunning.value == True:
+ while self.isRunning.value:
try:
client = self.serverTransport.accept()
self.serveClient(client)
@@ -82,17 +80,15 @@ class TProcessPoolServer(TServer):
itrans.close()
otrans.close()
-
def serve(self):
- """Start a fixed number of worker threads and put client into a queue"""
-
- #this is a shared state that can tell the workers to exit when set as false
+ """Start workers and put into queue"""
+ # this is a shared state that can tell the workers to exit when False
self.isRunning.value = True
- #first bind and listen to the port
+ # first bind and listen to the port
self.serverTransport.listen()
- #fork the children
+ # fork the children
for i in range(self.numWorkers):
try:
w = Process(target=self.workerProcess)
@@ -102,16 +98,14 @@ class TProcessPoolServer(TServer):
except Exception, x:
logging.exception(x)
- #wait until the condition is set by stop()
-
+ # wait until the condition is set by stop()
while True:
-
self.stopCondition.acquire()
try:
self.stopCondition.wait()
break
except (SystemExit, KeyboardInterrupt):
- break
+ break
except Exception, x:
logging.exception(x)
@@ -122,4 +116,3 @@ class TProcessPoolServer(TServer):
self.stopCondition.acquire()
self.stopCondition.notify()
self.stopCondition.release()
-
diff --git a/module/lib/thrift/server/TServer.py b/module/lib/thrift/server/TServer.py
index 8456e2d40..2f24842c4 100644
--- a/module/lib/thrift/server/TServer.py
+++ b/module/lib/thrift/server/TServer.py
@@ -17,27 +17,28 @@
# under the License.
#
+import Queue
import logging
-import sys
import os
-import traceback
+import sys
import threading
-import Queue
+import traceback
from thrift.Thrift import TProcessor
-from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol
+from thrift.transport import TTransport
-class TServer:
- """Base interface for a server, which must have a serve method."""
+class TServer:
+ """Base interface for a server, which must have a serve() method.
- """ 3 constructors for all servers:
+ Three constructors for all servers:
1) (processor, serverTransport)
2) (processor, serverTransport, transportFactory, protocolFactory)
3) (processor, serverTransport,
inputTransportFactory, outputTransportFactory,
- inputProtocolFactory, outputProtocolFactory)"""
+ inputProtocolFactory, outputProtocolFactory)
+ """
def __init__(self, *args):
if (len(args) == 2):
self.__initArgs__(args[0], args[1],
@@ -63,8 +64,8 @@ class TServer:
def serve(self):
pass
-class TSimpleServer(TServer):
+class TSimpleServer(TServer):
"""Simple single-threaded server that just pumps around one transport."""
def __init__(self, *args):
@@ -89,8 +90,8 @@ class TSimpleServer(TServer):
itrans.close()
otrans.close()
-class TThreadedServer(TServer):
+class TThreadedServer(TServer):
"""Threaded server that spawns a new thread per each connection."""
def __init__(self, *args, **kwargs):
@@ -102,7 +103,7 @@ class TThreadedServer(TServer):
while True:
try:
client = self.serverTransport.accept()
- t = threading.Thread(target = self.handle, args=(client,))
+ t = threading.Thread(target=self.handle, args=(client,))
t.setDaemon(self.daemon)
t.start()
except KeyboardInterrupt:
@@ -126,8 +127,8 @@ class TThreadedServer(TServer):
itrans.close()
otrans.close()
-class TThreadPoolServer(TServer):
+class TThreadPoolServer(TServer):
"""Server with a fixed size pool of threads which service requests."""
def __init__(self, *args, **kwargs):
@@ -170,7 +171,7 @@ class TThreadPoolServer(TServer):
"""Start a fixed number of worker threads and put client into a queue"""
for i in range(self.threads):
try:
- t = threading.Thread(target = self.serveThread)
+ t = threading.Thread(target=self.serveThread)
t.setDaemon(self.daemon)
t.start()
except Exception, x:
@@ -187,9 +188,8 @@ class TThreadPoolServer(TServer):
class TForkingServer(TServer):
+ """A Thrift server that forks a new process for each request
- """A Thrift server that forks a new process for each request"""
- """
This is more scalable than the threaded server as it does not cause
GIL contention.
@@ -200,7 +200,6 @@ class TForkingServer(TServer):
This code is heavily inspired by SocketServer.ForkingMixIn in the
Python stdlib.
"""
-
def __init__(self, *args):
TServer.__init__(self, *args)
self.children = []
@@ -212,14 +211,13 @@ class TForkingServer(TServer):
except IOError, e:
logging.warning(e, exc_info=True)
-
self.serverTransport.listen()
while True:
client = self.serverTransport.accept()
try:
pid = os.fork()
- if pid: # parent
+ if pid: # parent
# add before collect, otherwise you race w/ waitpid
self.children.append(pid)
self.collect_children()
@@ -258,7 +256,6 @@ class TForkingServer(TServer):
except Exception, x:
logging.exception(x)
-
def collect_children(self):
while self.children:
try:
@@ -270,5 +267,3 @@ class TForkingServer(TServer):
self.children.remove(pid)
else:
break
-
-
diff --git a/module/lib/thrift/transport/THttpClient.py b/module/lib/thrift/transport/THttpClient.py
index 50269785c..ea80a1ae8 100644
--- a/module/lib/thrift/transport/THttpClient.py
+++ b/module/lib/thrift/transport/THttpClient.py
@@ -17,16 +17,20 @@
# under the License.
#
-from TTransport import *
-from cStringIO import StringIO
-
-import urlparse
import httplib
-import warnings
+import os
import socket
+import sys
+import urllib
+import urlparse
+import warnings
-class THttpClient(TTransportBase):
+from cStringIO import StringIO
+from TTransport import *
+
+
+class THttpClient(TTransportBase):
"""Http implementation of TTransport base."""
def __init__(self, uri_or_host, port=None, path=None):
@@ -35,10 +39,13 @@ class THttpClient(TTransportBase):
THttpClient(host, port, path) - deprecated
THttpClient(uri)
- Only the second supports https."""
-
+ Only the second supports https.
+ """
if port is not None:
- warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2)
+ warnings.warn(
+ "Please use the THttpClient('http://host:port/path') syntax",
+ DeprecationWarning,
+ stacklevel=2)
self.host = uri_or_host
self.port = port
assert path
@@ -59,6 +66,7 @@ class THttpClient(TTransportBase):
self.__wbuf = StringIO()
self.__http = None
self.__timeout = None
+ self.__custom_headers = None
def open(self):
if self.scheme == 'http':
@@ -71,7 +79,7 @@ class THttpClient(TTransportBase):
self.__http = None
def isOpen(self):
- return self.__http != None
+ return self.__http is not None
def setTimeout(self, ms):
if not hasattr(socket, 'getdefaulttimeout'):
@@ -80,7 +88,10 @@ class THttpClient(TTransportBase):
if ms is None:
self.__timeout = None
else:
- self.__timeout = ms/1000.0
+ self.__timeout = ms / 1000.0
+
+ def setCustomHeaders(self, headers):
+ self.__custom_headers = headers
def read(self, sz):
return self.__http.file.read(sz)
@@ -100,7 +111,7 @@ class THttpClient(TTransportBase):
def flush(self):
if self.isOpen():
self.close()
- self.open();
+ self.open()
# Pull data out of buffer
data = self.__wbuf.getvalue()
@@ -113,6 +124,18 @@ class THttpClient(TTransportBase):
self.__http.putheader('Host', self.host)
self.__http.putheader('Content-Type', 'application/x-thrift')
self.__http.putheader('Content-Length', str(len(data)))
+
+ if not self.__custom_headers or 'User-Agent' not in self.__custom_headers:
+ user_agent = 'Python/THttpClient'
+ script = os.path.basename(sys.argv[0])
+ if script:
+ user_agent = '%s (%s)' % (user_agent, urllib.quote(script))
+ self.__http.putheader('User-Agent', user_agent)
+
+ if self.__custom_headers:
+ for key, val in self.__custom_headers.iteritems():
+ self.__http.putheader(key, val)
+
self.__http.endheaders()
# Write payload
diff --git a/module/lib/thrift/transport/TSSLSocket.py b/module/lib/thrift/transport/TSSLSocket.py
new file mode 100644
index 000000000..81e098426
--- /dev/null
+++ b/module/lib/thrift/transport/TSSLSocket.py
@@ -0,0 +1,214 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+import socket
+import ssl
+
+from thrift.transport import TSocket
+from thrift.transport.TTransport import TTransportException
+
+
+class TSSLSocket(TSocket.TSocket):
+ """
+ SSL implementation of client-side TSocket
+
+ This class creates outbound sockets wrapped using the
+ python standard ssl module for encrypted connections.
+
+ The protocol used is set using the class variable
+ SSL_VERSION, which must be one of ssl.PROTOCOL_* and
+ defaults to ssl.PROTOCOL_TLSv1 for greatest security.
+ """
+ SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+ def __init__(self,
+ host='localhost',
+ port=9090,
+ validate=True,
+ ca_certs=None,
+ keyfile=None,
+ certfile=None,
+ unix_socket=None):
+ """Create SSL TSocket
+
+ @param validate: Set to False to disable SSL certificate validation
+ @type validate: bool
+ @param ca_certs: Filename to the Certificate Authority pem file, possibly a
+ file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to
+ the ssl_wrap function as the 'ca_certs' parameter.
+ @type ca_certs: str
+ @param keyfile: The private key
+ @type keyfile: str
+ @param certfile: The cert file
+ @type certfile: str
+
+ Raises an IOError exception if validate is True and the ca_certs file is
+ None, not present or unreadable.
+ """
+ self.validate = validate
+ self.is_valid = False
+ self.peercert = None
+ if not validate:
+ self.cert_reqs = ssl.CERT_NONE
+ else:
+ self.cert_reqs = ssl.CERT_REQUIRED
+ self.ca_certs = ca_certs
+ self.keyfile = keyfile
+ self.certfile = certfile
+ if validate:
+ if ca_certs is None or not os.access(ca_certs, os.R_OK):
+ raise IOError('Certificate Authority ca_certs file "%s" '
+ 'is not readable, cannot validate SSL '
+ 'certificates.' % (ca_certs))
+ TSocket.TSocket.__init__(self, host, port, unix_socket)
+
+ def open(self):
+ try:
+ res0 = self._resolveAddr()
+ for res in res0:
+ sock_family, sock_type = res[0:2]
+ ip_port = res[4]
+ plain_sock = socket.socket(sock_family, sock_type)
+ self.handle = ssl.wrap_socket(plain_sock,
+ ssl_version=self.SSL_VERSION,
+ do_handshake_on_connect=True,
+ ca_certs=self.ca_certs,
+ keyfile=self.keyfile,
+ certfile=self.certfile,
+ cert_reqs=self.cert_reqs)
+ self.handle.settimeout(self._timeout)
+ try:
+ self.handle.connect(ip_port)
+ except socket.error, e:
+ if res is not res0[-1]:
+ continue
+ else:
+ raise e
+ break
+ except socket.error, e:
+ if self._unix_socket:
+ message = 'Could not connect to secure socket %s: %s' \
+ % (self._unix_socket, e)
+ else:
+ message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
+ if self.validate:
+ self._validate_cert()
+
+ def _validate_cert(self):
+ """internal method to validate the peer's SSL certificate, and to check the
+ commonName of the certificate to ensure it matches the hostname we
+ used to make this connection. Does not support subjectAltName records
+ in certificates.
+
+ raises TTransportException if the certificate fails validation.
+ """
+ cert = self.handle.getpeercert()
+ self.peercert = cert
+ if 'subject' not in cert:
+ raise TTransportException(
+ type=TTransportException.NOT_OPEN,
+ message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ fields = cert['subject']
+ for field in fields:
+ # ensure structure we get back is what we expect
+ if not isinstance(field, tuple):
+ continue
+ cert_pair = field[0]
+ if len(cert_pair) < 2:
+ continue
+ cert_key, cert_value = cert_pair[0:2]
+ if cert_key != 'commonName':
+ continue
+ certhost = cert_value
+ # this check should be performed by some sort of Access Manager
+ if certhost == self.host:
+ # success, cert commonName matches desired hostname
+ self.is_valid = True
+ return
+ else:
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message='Hostname we connected to "%s" doesn\'t match certificate '
+ 'provided commonName "%s"' % (self.host, certhost))
+ raise TTransportException(
+ type=TTransportException.UNKNOWN,
+ message='Could not validate SSL certificate from '
+ 'host "%s". Cert=%s' % (self.host, cert))
+
+
+class TSSLServerSocket(TSocket.TServerSocket):
+ """SSL implementation of TServerSocket
+
+ This uses the ssl module's wrap_socket() method to provide SSL
+ negotiated encryption.
+ """
+ SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+ def __init__(self,
+ host=None,
+ port=9090,
+ certfile='cert.pem',
+ unix_socket=None):
+ """Initialize a TSSLServerSocket
+
+ @param certfile: filename of the server certificate, defaults to cert.pem
+ @type certfile: str
+ @param host: The hostname or IP to bind the listen socket to,
+ i.e. 'localhost' for only allowing local network connections.
+ Pass None to bind to all interfaces.
+ @type host: str
+ @param port: The port to listen on for inbound connections.
+ @type port: int
+ """
+ self.setCertfile(certfile)
+ TSocket.TServerSocket.__init__(self, host, port)
+
+ def setCertfile(self, certfile):
+ """Set or change the server certificate file used to wrap new connections.
+
+ @param certfile: The filename of the server certificate,
+ i.e. '/etc/certs/server.pem'
+ @type certfile: str
+
+ Raises an IOError exception if the certfile is not present or unreadable.
+ """
+ if not os.access(certfile, os.R_OK):
+ raise IOError('No such certfile found: %s' % (certfile))
+ self.certfile = certfile
+
+ def accept(self):
+ plain_client, addr = self.handle.accept()
+ try:
+ client = ssl.wrap_socket(plain_client, certfile=self.certfile,
+ server_side=True, ssl_version=self.SSL_VERSION)
+ except ssl.SSLError, ssl_exc:
+ # failed handshake/ssl wrap, close socket to client
+ plain_client.close()
+ # raise ssl_exc
+ # We can't raise the exception, because it kills most TServer derived
+ # serve() methods.
+ # Instead, return None, and let the TServer instance deal with it in
+ # other exception handling. (but TSimpleServer dies anyway)
+ return None
+ result = TSocket.TSocket()
+ result.setHandle(client)
+ return result
diff --git a/module/lib/thrift/transport/TSocket.py b/module/lib/thrift/transport/TSocket.py
index 4e0e1874f..9e2b3849b 100644
--- a/module/lib/thrift/transport/TSocket.py
+++ b/module/lib/thrift/transport/TSocket.py
@@ -17,24 +17,33 @@
# under the License.
#
-from TTransport import *
-import os
import errno
+import os
import socket
import sys
+from TTransport import *
+
+
class TSocketBase(TTransportBase):
def _resolveAddr(self):
if self._unix_socket is not None:
- return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)]
+ return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None,
+ self._unix_socket)]
else:
- return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
+ return socket.getaddrinfo(self.host,
+ self.port,
+ socket.AF_UNSPEC,
+ socket.SOCK_STREAM,
+ 0,
+ socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
def close(self):
if self.handle:
self.handle.close()
self.handle = None
+
class TSocket(TSocketBase):
"""Socket implementation of TTransport base."""
@@ -46,7 +55,6 @@ class TSocket(TSocketBase):
@param unix_socket(str) The filename of a unix socket to connect to.
(host and port will be ignored.)
"""
-
self.host = host
self.port = port
self.handle = None
@@ -63,7 +71,7 @@ class TSocket(TSocketBase):
if ms is None:
self._timeout = None
else:
- self._timeout = ms/1000.0
+ self._timeout = ms / 1000.0
if self.handle is not None:
self.handle.settimeout(self._timeout)
@@ -87,7 +95,8 @@ class TSocket(TSocketBase):
message = 'Could not connect to socket %s' % self._unix_socket
else:
message = 'Could not connect to %s:%d' % (self.host, self.port)
- raise TTransportException(type=TTransportException.NOT_OPEN, message=message)
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message=message)
def read(self, sz):
try:
@@ -105,24 +114,28 @@ class TSocket(TSocketBase):
else:
raise
if len(buff) == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket read 0 bytes')
return buff
def write(self, buff):
if not self.handle:
- raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open')
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message='Transport not open')
sent = 0
have = len(buff)
while sent < have:
plus = self.handle.send(buff)
if plus == 0:
- raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes')
+ raise TTransportException(type=TTransportException.END_OF_FILE,
+ message='TSocket sent 0 bytes')
sent += plus
buff = buff[plus:]
def flush(self):
pass
+
class TServerSocket(TSocketBase, TServerTransportBase):
"""Socket implementation of TServerTransport base."""
diff --git a/module/lib/thrift/transport/TTransport.py b/module/lib/thrift/transport/TTransport.py
index 12e51a9bf..4481371a6 100644
--- a/module/lib/thrift/transport/TTransport.py
+++ b/module/lib/thrift/transport/TTransport.py
@@ -18,11 +18,11 @@
#
from cStringIO import StringIO
-from struct import pack,unpack
+from struct import pack, unpack
from thrift.Thrift import TException
-class TTransportException(TException):
+class TTransportException(TException):
"""Custom Transport Exception class"""
UNKNOWN = 0
@@ -35,8 +35,8 @@ class TTransportException(TException):
TException.__init__(self, message)
self.type = type
-class TTransportBase:
+class TTransportBase:
"""Base class for Thrift transport layer."""
def isOpen(self):
@@ -55,7 +55,7 @@ class TTransportBase:
buff = ''
have = 0
while (have < sz):
- chunk = self.read(sz-have)
+ chunk = self.read(sz - have)
have += len(chunk)
buff += chunk
@@ -70,6 +70,7 @@ class TTransportBase:
def flush(self):
pass
+
# This class should be thought of as an interface.
class CReadableTransport:
"""base class for transports that are readable from C"""
@@ -98,8 +99,8 @@ class CReadableTransport:
"""
pass
-class TServerTransportBase:
+class TServerTransportBase:
"""Base class for Thrift server transports."""
def listen(self):
@@ -111,15 +112,15 @@ class TServerTransportBase:
def close(self):
pass
-class TTransportFactoryBase:
+class TTransportFactoryBase:
"""Base class for a Transport Factory"""
def getTransport(self, trans):
return trans
-class TBufferedTransportFactory:
+class TBufferedTransportFactory:
"""Factory transport that builds buffered transports"""
def getTransport(self, trans):
@@ -127,17 +128,15 @@ class TBufferedTransportFactory:
return buffered
-class TBufferedTransport(TTransportBase,CReadableTransport):
-
+class TBufferedTransport(TTransportBase, CReadableTransport):
"""Class that wraps another transport and buffers its I/O.
The implementation uses a (configurable) fixed-size read buffer
but buffers all writes until a flush is performed.
"""
-
DEFAULT_BUFFER = 4096
- def __init__(self, trans, rbuf_size = DEFAULT_BUFFER):
+ def __init__(self, trans, rbuf_size=DEFAULT_BUFFER):
self.__trans = trans
self.__wbuf = StringIO()
self.__rbuf = StringIO("")
@@ -188,6 +187,7 @@ class TBufferedTransport(TTransportBase,CReadableTransport):
self.__rbuf = StringIO(retstring)
return self.__rbuf
+
class TMemoryBuffer(TTransportBase, CReadableTransport):
"""Wraps a cStringIO object as a TTransport.
@@ -237,8 +237,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport):
# only one shot at reading...
raise EOFError()
-class TFramedTransportFactory:
+class TFramedTransportFactory:
"""Factory transport that builds framed transports"""
def getTransport(self, trans):
@@ -247,7 +247,6 @@ class TFramedTransportFactory:
class TFramedTransport(TTransportBase, CReadableTransport):
-
"""Class that wraps another transport and frames its I/O when writing."""
def __init__(self, trans,):
diff --git a/module/lib/thrift/transport/TTwisted.py b/module/lib/thrift/transport/TTwisted.py
index b6dcb4e0b..3ce3eb220 100644
--- a/module/lib/thrift/transport/TTwisted.py
+++ b/module/lib/thrift/transport/TTwisted.py
@@ -16,6 +16,9 @@
# specific language governing permissions and limitations
# under the License.
#
+
+from cStringIO import StringIO
+
from zope.interface import implements, Interface, Attribute
from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
connectionDone
@@ -25,7 +28,6 @@ from twisted.python import log
from twisted.web import server, resource, http
from thrift.transport import TTransport
-from cStringIO import StringIO
class TMessageSenderTransport(TTransport.TTransportBase):
@@ -79,7 +81,7 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
self.started.callback(self.client)
def connectionLost(self, reason=connectionDone):
- for k,v in self.client._reqs.iteritems():
+ for k, v in self.client._reqs.iteritems():
tex = TTransport.TTransportException(
type=TTransport.TTransportException.END_OF_FILE,
message='Connection closed')
diff --git a/module/lib/thrift/transport/TZlibTransport.py b/module/lib/thrift/transport/TZlibTransport.py
index 784d4e1e0..a2f42a5d2 100644
--- a/module/lib/thrift/transport/TZlibTransport.py
+++ b/module/lib/thrift/transport/TZlibTransport.py
@@ -16,50 +16,49 @@
# specific language governing permissions and limitations
# under the License.
#
-'''
-TZlibTransport provides a compressed transport and transport factory
+
+"""TZlibTransport provides a compressed transport and transport factory
class, using the python standard library zlib module to implement
data compression.
-'''
+"""
from __future__ import division
import zlib
from cStringIO import StringIO
from TTransport import TTransportBase, CReadableTransport
+
class TZlibTransportFactory(object):
- '''
- Factory transport that builds zlib compressed transports.
-
+ """Factory transport that builds zlib compressed transports.
+
This factory caches the last single client/transport that it was passed
and returns the same TZlibTransport object that was created.
-
+
This caching means the TServer class will get the _same_ transport
object for both input and output transports from this factory.
(For non-threaded scenarios only, since the cache only holds one object)
-
+
The purpose of this caching is to allocate only one TZlibTransport where
only one is really needed (since it must have separate read/write buffers),
and makes the statistics from getCompSavings() and getCompRatio()
easier to understand.
- '''
-
+ """
# class scoped cache of last transport given and zlibtransport returned
_last_trans = None
_last_z = None
def getTransport(self, trans, compresslevel=9):
- '''Wrap a transport , trans, with the TZlibTransport
+ """Wrap a transport, trans, with the TZlibTransport
compressed transport class, returning a new
transport to the caller.
-
+
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Defaults to 9.
@type compresslevel: int
-
+
This method returns a TZlibTransport which wraps the
passed C{trans} TTransport derived instance.
- '''
+ """
if trans == self._last_trans:
return self._last_z
ztrans = TZlibTransport(trans, compresslevel)
@@ -69,27 +68,24 @@ class TZlibTransportFactory(object):
class TZlibTransport(TTransportBase, CReadableTransport):
- '''
- Class that wraps a transport with zlib, compressing writes
+ """Class that wraps a transport with zlib, compressing writes
and decompresses reads, using the python standard
library zlib module.
- '''
-
+ """
# Read buffer size for the python fastbinary C extension,
# the TBinaryProtocolAccelerated class.
DEFAULT_BUFFSIZE = 4096
def __init__(self, trans, compresslevel=9):
- '''
- Create a new TZlibTransport, wrapping C{trans}, another
+ """Create a new TZlibTransport, wrapping C{trans}, another
TTransport derived object.
-
+
@param trans: A thrift transport object, i.e. a TSocket() object.
@type trans: TTransport
@param compresslevel: The zlib compression level, ranging
from 0 (no compression) to 9 (best compression). Default is 9.
@type compresslevel: int
- '''
+ """
self.__trans = trans
self.compresslevel = compresslevel
self.__rbuf = StringIO()
@@ -98,49 +94,45 @@ class TZlibTransport(TTransportBase, CReadableTransport):
self._init_stats()
def _reinit_buffers(self):
- '''
- Internal method to initialize/reset the internal StringIO objects
+ """Internal method to initialize/reset the internal StringIO objects
for read and write buffers.
- '''
+ """
self.__rbuf = StringIO()
self.__wbuf = StringIO()
def _init_stats(self):
- '''
- Internal method to reset the internal statistics counters
+ """Internal method to reset the internal statistics counters
for compression ratios and bandwidth savings.
- '''
+ """
self.bytes_in = 0
self.bytes_out = 0
self.bytes_in_comp = 0
self.bytes_out_comp = 0
def _init_zlib(self):
- '''
- Internal method for setting up the zlib compression and
+ """Internal method for setting up the zlib compression and
decompression objects.
- '''
+ """
self._zcomp_read = zlib.decompressobj()
self._zcomp_write = zlib.compressobj(self.compresslevel)
def getCompRatio(self):
- '''
- Get the current measured compression ratios (in,out) from
+ """Get the current measured compression ratios (in,out) from
this transport.
-
- Returns a tuple of:
+
+ Returns a tuple of:
(inbound_compression_ratio, outbound_compression_ratio)
-
+
The compression ratios are computed as:
compressed / uncompressed
E.g., data that compresses by 10x will have a ratio of: 0.10
and data that compresses to half of ts original size will
have a ratio of 0.5
-
+
None is returned if no bytes have yet been processed in
a particular direction.
- '''
+ """
r_percent, w_percent = (None, None)
if self.bytes_in > 0:
r_percent = self.bytes_in_comp / self.bytes_in
@@ -149,23 +141,22 @@ class TZlibTransport(TTransportBase, CReadableTransport):
return (r_percent, w_percent)
def getCompSavings(self):
- '''
- Get the current count of saved bytes due to data
+ """Get the current count of saved bytes due to data
compression.
-
+
Returns a tuple of:
(inbound_saved_bytes, outbound_saved_bytes)
-
+
Note: if compression is actually expanding your
data (only likely with very tiny thrift objects), then
the values returned will be negative.
- '''
+ """
r_saved = self.bytes_in - self.bytes_in_comp
w_saved = self.bytes_out - self.bytes_out_comp
return (r_saved, w_saved)
def isOpen(self):
- '''Return the underlying transport's open status'''
+ """Return the underlying transport's open status"""
return self.__trans.isOpen()
def open(self):
@@ -174,25 +165,24 @@ class TZlibTransport(TTransportBase, CReadableTransport):
return self.__trans.open()
def listen(self):
- '''Invoke the underlying transport's listen() method'''
+ """Invoke the underlying transport's listen() method"""
self.__trans.listen()
def accept(self):
- '''Accept connections on the underlying transport'''
+ """Accept connections on the underlying transport"""
return self.__trans.accept()
def close(self):
- '''Close the underlying transport,'''
+ """Close the underlying transport,"""
self._reinit_buffers()
self._init_zlib()
return self.__trans.close()
def read(self, sz):
- '''
- Read up to sz bytes from the decompressed bytes buffer, and
+ """Read up to sz bytes from the decompressed bytes buffer, and
read from the underlying transport if the decompression
buffer is empty.
- '''
+ """
ret = self.__rbuf.read(sz)
if len(ret) > 0:
return ret
@@ -204,10 +194,9 @@ class TZlibTransport(TTransportBase, CReadableTransport):
return ret
def readComp(self, sz):
- '''
- Read compressed data from the underlying transport, then
+ """Read compressed data from the underlying transport, then
decompress it and append it to the internal StringIO read buffer
- '''
+ """
zbuf = self.__trans.read(sz)
zbuf = self._zcomp_read.unconsumed_tail + zbuf
buf = self._zcomp_read.decompress(zbuf)
@@ -220,17 +209,15 @@ class TZlibTransport(TTransportBase, CReadableTransport):
return True
def write(self, buf):
- '''
- Write some bytes, putting them into the internal write
+ """Write some bytes, putting them into the internal write
buffer for eventual compression.
- '''
+ """
self.__wbuf.write(buf)
def flush(self):
- '''
- Flush any queued up data in the write buffer and ensure the
+ """Flush any queued up data in the write buffer and ensure the
compression buffer is flushed out to the underlying transport
- '''
+ """
wout = self.__wbuf.getvalue()
if len(wout) > 0:
zbuf = self._zcomp_write.compress(wout)
@@ -247,11 +234,11 @@ class TZlibTransport(TTransportBase, CReadableTransport):
@property
def cstringio_buf(self):
- '''Implement the CReadableTransport interface'''
+ """Implement the CReadableTransport interface"""
return self.__rbuf
def cstringio_refill(self, partialread, reqlen):
- '''Implement the CReadableTransport interface for refill'''
+ """Implement the CReadableTransport interface for refill"""
retstring = partialread
if reqlen < self.DEFAULT_BUFFSIZE:
retstring += self.read(self.DEFAULT_BUFFSIZE)
diff --git a/module/lib/thrift/transport/__init__.py b/module/lib/thrift/transport/__init__.py
index 46e54fe6b..c9596d9a6 100644
--- a/module/lib/thrift/transport/__init__.py
+++ b/module/lib/thrift/transport/__init__.py
@@ -17,4 +17,4 @@
# under the License.
#
-__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport']
+__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport']