diff options
Diffstat (limited to 'pyload/lib/thrift')
24 files changed, 5553 insertions, 0 deletions
diff --git a/pyload/lib/thrift/TSCons.py b/pyload/lib/thrift/TSCons.py new file mode 100644 index 000000000..da8d2833b --- /dev/null +++ b/pyload/lib/thrift/TSCons.py @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder + + +def scons_env(env, add=''): + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) + + +def gen_cpp(env, dir, file): + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/pyload/lib/thrift/TSerialization.py b/pyload/lib/thrift/TSerialization.py new file mode 100644 index 000000000..8a58d89df --- /dev/null +++ b/pyload/lib/thrift/TSerialization.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from protocol import TBinaryProtocol +from transport import TTransport + + +def serialize(thrift_object, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): + transport = TTransport.TMemoryBuffer() + protocol = protocol_factory.getProtocol(transport) + thrift_object.write(protocol) + return transport.getvalue() + + +def deserialize(base, + buf, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): + transport = TTransport.TMemoryBuffer(buf) + protocol = protocol_factory.getProtocol(transport) + base.read(protocol) + return base diff --git a/pyload/lib/thrift/TTornado.py b/pyload/lib/thrift/TTornado.py new file mode 100644 index 000000000..af309c3d9 --- /dev/null +++ b/pyload/lib/thrift/TTornado.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +import logging +import socket +import struct + +from thrift.transport import TTransport +from thrift.transport.TTransport import TTransportException + +from tornado import gen +from tornado import iostream +from tornado import netutil + + +class TTornadoStreamTransport(TTransport.TTransportBase): + """a framed, buffered transport over a Tornado stream""" + def __init__(self, host, port, stream=None): + self.host = host + self.port = port + self.is_queuing_reads = False + self.read_queue = [] + self.__wbuf = StringIO() + + # servers provide a ready-to-go stream + self.stream = stream + if self.stream is not None: + self._set_close_callback() + + # not the same number of parameters as TTransportBase.open + def open(self, callback): + logging.debug('socket connecting') + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) + self.stream = iostream.IOStream(sock) + + def on_close_in_connect(*_): + message = 'could not connect to {}:{}'.format(self.host, self.port) + raise TTransportException( + type=TTransportException.NOT_OPEN, + message=message) + self.stream.set_close_callback(on_close_in_connect) + + def finish(*_): + self._set_close_callback() + callback() + + self.stream.connect((self.host, self.port), callback=finish) + + def _set_close_callback(self): + def on_close(): + raise TTransportException( + type=TTransportException.END_OF_FILE, + message='socket closed') + self.stream.set_close_callback(self.close) + + def close(self): + # don't raise if we intend to close + self.stream.set_close_callback(None) + self.stream.close() + + def read(self, _): + # The generated code for Tornado shouldn't do individual reads -- only + # frames at a time + assert "you're doing it wrong" is True + + @gen.engine + def readFrame(self, callback): + self.read_queue.append(callback) + logging.debug('read queue: %s', self.read_queue) + + if self.is_queuing_reads: + # If a read is already in flight, then the while loop below should + # pull it from self.read_queue + return + + self.is_queuing_reads = True + while self.read_queue: + next_callback = self.read_queue.pop() + result = yield gen.Task(self._readFrameFromStream) + next_callback(result) + self.is_queuing_reads = False + + @gen.engine + def _readFrameFromStream(self, callback): + logging.debug('_readFrameFromStream') + frame_header = yield gen.Task(self.stream.read_bytes, 4) + frame_length, = struct.unpack('!i', frame_header) + logging.debug('received frame header, frame length = %i', frame_length) + frame = yield gen.Task(self.stream.read_bytes, frame_length) + logging.debug('received frame payload') + callback(frame) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self, callback=None): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = struct.pack("!i", wsz) + wout + + logging.debug('writing frame length = %i', wsz) + self.stream.write(buf, callback) + + +class TTornadoServer(netutil.TCPServer): + def __init__(self, processor, iprot_factory, oprot_factory=None, + *args, **kwargs): + super(TTornadoServer, self).__init__(*args, **kwargs) + + self._processor = processor + self._iprot_factory = iprot_factory + self._oprot_factory = (oprot_factory if oprot_factory is not None + else iprot_factory) + + def handle_stream(self, stream, address): + try: + host, port = address + trans = TTornadoStreamTransport(host=host, port=port, stream=stream) + oprot = self._oprot_factory.getProtocol(trans) + + def next_pass(): + if not trans.stream.closed(): + self._processor.process(trans, self._iprot_factory, oprot, + callback=next_pass) + + next_pass() + + except Exception: + logging.exception('thrift exception in handle_stream') + trans.close() diff --git a/pyload/lib/thrift/Thrift.py b/pyload/lib/thrift/Thrift.py new file mode 100644 index 000000000..9890af7e1 --- /dev/null +++ b/pyload/lib/thrift/Thrift.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys + + +class TType: + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + + _VALUES_TO_NAMES = ('STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16') + + +class TMessageType: + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + + +class TProcessor: + """Base class for procsessor, which works on two streams.""" + + def process(iprot, oprot): + pass + + +class TException(Exception): + """Base class for all thrift exceptions.""" + + # BaseException.message is deprecated in Python v[2.6,3.0) + if (2, 6, 0) <= sys.version_info < (3, 0): + def _get_message(self): + return self._message + + def _set_message(self, message): + self._message = message + message = property(_get_message, _set_message) + + def __init__(self, message=None): + Exception.__init__(self, message) + self.message = message + + +class TApplicationException(TException): + """Application level thrift exceptions.""" + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + def __str__(self): + if self.message: + return self.message + elif self.type == self.UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == self.INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == self.WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == self.BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == self.MISSING_RESULT: + return 'Missing result' + elif self.type == self.INTERNAL_ERROR: + return 'Internal error' + elif self.type == self.PROTOCOL_ERROR: + return 'Protocol error' + elif self.type == self.INVALID_TRANSFORM: + return 'Invalid transform' + elif self.type == self.INVALID_PROTOCOL: + return 'Invalid protocol' + elif self.type == self.UNSUPPORTED_CLIENT_TYPE: + return 'Unsupported client type' + else: + return 'Default (unknown) TApplicationException' + + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString() + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32() + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message is not None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type is not None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() diff --git a/pyload/lib/thrift/__init__.py b/pyload/lib/thrift/__init__.py new file mode 100644 index 000000000..48d659c40 --- /dev/null +++ b/pyload/lib/thrift/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/pyload/lib/thrift/protocol/TBase.py b/pyload/lib/thrift/protocol/TBase.py new file mode 100644 index 000000000..6cbd5f39a --- /dev/null +++ b/pyload/lib/thrift/protocol/TBase.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: + from thrift.protocol import fastbinary +except: + fastbinary = None + + +class TBase(object): + __slots__ = [] + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) + for key in self.__slots__] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None and + fastbinary is not None): + fastbinary.decode_binary(self, + iprot.trans, + (self.__class__, self.thrift_spec)) + return + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + self.thrift_spec is not None and + fastbinary is not None): + oprot.trans.write( + fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + return + oprot.writeStruct(self, self.thrift_spec) + + +class TExceptionBase(Exception): + # old style class so python2.4 can raise exceptions derived from this + # This can't inherit from TBase because of that limitation. + __slots__ = [] + + __repr__ = TBase.__repr__.im_func + __eq__ = TBase.__eq__.im_func + __ne__ = TBase.__ne__.im_func + read = TBase.read.im_func + write = TBase.write.im_func diff --git a/pyload/lib/thrift/protocol/TBinaryProtocol.py b/pyload/lib/thrift/protocol/TBinaryProtocol.py new file mode 100644 index 000000000..6fdd08c26 --- /dev/null +++ b/pyload/lib/thrift/protocol/TBinaryProtocol.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + + +class TBinaryProtocol(TProtocolBase): + """Binary implementation of the Thrift protocol driver.""" + + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. + + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 + + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 + + TYPE_MASK = 0x000000ff + + def __init__(self, trans, strictRead=False, strictWrite=True): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + self.writeByte(TType.STOP) + + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) + + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) + + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) + + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) + + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) + + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) + + def writeString(self, str): + self.writeI32(len(str)) + self.trans.write(str) + + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) + + def readFieldEnd(self): + pass + + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + return (ktype, vtype, size) + + def readMapEnd(self): + pass + + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readListEnd(self): + pass + + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readSetEnd(self): + pass + + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True + + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val + + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val + + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val + + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readString(self): + len = self.readI32() + str = self.trans.readAll(len) + return str + + +class TBinaryProtocolFactory: + def __init__(self, strictRead=False, strictWrite=True): + self.strictRead = strictRead + self.strictWrite = strictWrite + + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) + return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + pass + + +class TBinaryProtocolAcceleratedFactory: + def getProtocol(self, trans): + return TBinaryProtocolAccelerated(trans) diff --git a/pyload/lib/thrift/protocol/TCompactProtocol.py b/pyload/lib/thrift/protocol/TCompactProtocol.py new file mode 100644 index 000000000..cdec60773 --- /dev/null +++ b/pyload/lib/thrift/protocol/TCompactProtocol.py @@ -0,0 +1,403 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] + +CLEAR = 0 +FIELD_WRITE = 1 +VALUE_WRITE = 2 +CONTAINER_WRITE = 3 +BOOL_WRITE = 4 +FIELD_READ = 5 +CONTAINER_READ = 6 +VALUE_READ = 7 +BOOL_READ = 8 + + +def make_helper(v_from, container): + def helper(func): + def nested(self, *args, **kwargs): + assert self.state in (v_from, container), (self.state, v_from, container) + return func(self, *args, **kwargs) + return nested + return helper +writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) +reader = make_helper(VALUE_READ, CONTAINER_READ) + + +def makeZigZag(n, bits): + return (n << 1) ^ (n >> (bits - 1)) + + +def fromZigZag(n): + return (n >> 1) ^ -(n & 1) + + +def writeVarint(trans, n): + out = [] + while True: + if n & ~0x7f == 0: + out.append(n) + break + else: + out.append((n & 0xff) | 0x80) + n = n >> 7 + trans.write(''.join(map(chr, out))) + + +def readVarint(trans): + result = 0 + shift = 0 + while True: + x = trans.readAll(1) + byte = ord(x) + result |= (byte & 0x7f) << shift + if byte >> 7 == 0: + return result + shift += 7 + + +class CompactType: + STOP = 0x00 + TRUE = 0x01 + FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C + +CTYPES = {TType.STOP: CompactType.STOP, + TType.BOOL: CompactType.TRUE, # used for collection + TType.BYTE: CompactType.BYTE, + TType.I16: CompactType.I16, + TType.I32: CompactType.I32, + TType.I64: CompactType.I64, + TType.DOUBLE: CompactType.DOUBLE, + TType.STRING: CompactType.BINARY, + TType.STRUCT: CompactType.STRUCT, + TType.LIST: CompactType.LIST, + TType.SET: CompactType.SET, + TType.MAP: CompactType.MAP + } + +TTYPES = {} +for k, v in CTYPES.items(): + TTYPES[v] = k +TTYPES[CompactType.FALSE] = TType.BOOL +del k +del v + + +class TCompactProtocol(TProtocolBase): + """Compact implementation of the Thrift protocol driver.""" + + PROTOCOL_ID = 0x82 + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xe0 + TYPE_SHIFT_AMOUNT = 5 + + def __init__(self, trans): + TProtocolBase.__init__(self, trans) + self.state = CLEAR + self.__last_fid = 0 + self.__bool_fid = None + self.__bool_value = None + self.__structs = [] + self.__containers = [] + + def __writeVarint(self, n): + writeVarint(self.trans, n) + + def writeMessageBegin(self, name, type, seqid): + assert self.state == CLEAR + self.__writeUByte(self.PROTOCOL_ID) + self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) + self.__writeVarint(seqid) + self.__writeString(name) + self.state = VALUE_WRITE + + def writeMessageEnd(self): + assert self.state == VALUE_WRITE + self.state = CLEAR + + def writeStructBegin(self, name): + assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_WRITE + self.__last_fid = 0 + + def writeStructEnd(self): + assert self.state == FIELD_WRITE + self.state, self.__last_fid = self.__structs.pop() + + def writeFieldStop(self): + self.__writeByte(0) + + def __writeFieldHeader(self, type, fid): + delta = fid - self.__last_fid + if 0 < delta <= 15: + self.__writeUByte(delta << 4 | type) + else: + self.__writeByte(type) + self.__writeI16(fid) + self.__last_fid = fid + + def writeFieldBegin(self, name, type, fid): + assert self.state == FIELD_WRITE, self.state + if type == TType.BOOL: + self.state = BOOL_WRITE + self.__bool_fid = fid + else: + self.state = VALUE_WRITE + self.__writeFieldHeader(CTYPES[type], fid) + + def writeFieldEnd(self): + assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state + self.state = FIELD_WRITE + + def __writeUByte(self, byte): + self.trans.write(pack('!B', byte)) + + def __writeByte(self, byte): + self.trans.write(pack('!b', byte)) + + def __writeI16(self, i16): + self.__writeVarint(makeZigZag(i16, 16)) + + def __writeSize(self, i32): + self.__writeVarint(i32) + + def writeCollectionBegin(self, etype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size <= 14: + self.__writeUByte(size << 4 | CTYPES[etype]) + else: + self.__writeUByte(0xf0 | CTYPES[etype]) + self.__writeSize(size) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + writeSetBegin = writeCollectionBegin + writeListBegin = writeCollectionBegin + + def writeMapBegin(self, ktype, vtype, size): + assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state + if size == 0: + self.__writeByte(0) + else: + self.__writeSize(size) + self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) + self.__containers.append(self.state) + self.state = CONTAINER_WRITE + + def writeCollectionEnd(self): + assert self.state == CONTAINER_WRITE, self.state + self.state = self.__containers.pop() + writeMapEnd = writeCollectionEnd + writeSetEnd = writeCollectionEnd + writeListEnd = writeCollectionEnd + + def writeBool(self, bool): + if self.state == BOOL_WRITE: + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) + elif self.state == CONTAINER_WRITE: + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) + else: + raise AssertionError("Invalid state in compact protocol") + + writeByte = writer(__writeByte) + writeI16 = writer(__writeI16) + + @writer + def writeI32(self, i32): + self.__writeVarint(makeZigZag(i32, 32)) + + @writer + def writeI64(self, i64): + self.__writeVarint(makeZigZag(i64, 64)) + + @writer + def writeDouble(self, dub): + self.trans.write(pack('!d', dub)) + + def __writeString(self, s): + self.__writeSize(len(s)) + self.trans.write(s) + writeString = writer(__writeString) + + def readFieldBegin(self): + assert self.state == FIELD_READ, self.state + type = self.__readUByte() + if type & 0x0f == TType.STOP: + return (None, 0, 0) + delta = type >> 4 + if delta == 0: + fid = self.__readI16() + else: + fid = self.__last_fid + delta + self.__last_fid = fid + type = type & 0x0f + if type == CompactType.TRUE: + self.state = BOOL_READ + self.__bool_value = True + elif type == CompactType.FALSE: + self.state = BOOL_READ + self.__bool_value = False + else: + self.state = VALUE_READ + return (None, self.__getTType(type), fid) + + def readFieldEnd(self): + assert self.state in (VALUE_READ, BOOL_READ), self.state + self.state = FIELD_READ + + def __readUByte(self): + result, = unpack('!B', self.trans.readAll(1)) + return result + + def __readByte(self): + result, = unpack('!b', self.trans.readAll(1)) + return result + + def __readVarint(self): + return readVarint(self.trans) + + def __readZigZag(self): + return fromZigZag(self.__readVarint()) + + def __readSize(self): + result = self.__readVarint() + if result < 0: + raise TException("Length < 0") + return result + + def readMessageBegin(self): + assert self.state == CLEAR + proto_id = self.__readUByte() + if proto_id != self.PROTOCOL_ID: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad protocol id in the message: %d' % proto_id) + ver_type = self.__readUByte() + type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT + version = ver_type & self.VERSION_MASK + if version != self.VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + 'Bad version: %d (expect %d)' % (version, self.VERSION)) + seqid = self.__readVarint() + name = self.__readString() + return (name, type, seqid) + + def readMessageEnd(self): + assert self.state == CLEAR + assert len(self.__structs) == 0 + + def readStructBegin(self): + assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state + self.__structs.append((self.state, self.__last_fid)) + self.state = FIELD_READ + self.__last_fid = 0 + + def readStructEnd(self): + assert self.state == FIELD_READ + self.state, self.__last_fid = self.__structs.pop() + + def readCollectionBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size_type = self.__readUByte() + size = size_type >> 4 + type = self.__getTType(size_type) + if size == 15: + size = self.__readSize() + self.__containers.append(self.state) + self.state = CONTAINER_READ + return type, size + readSetBegin = readCollectionBegin + readListBegin = readCollectionBegin + + def readMapBegin(self): + assert self.state in (VALUE_READ, CONTAINER_READ), self.state + size = self.__readSize() + types = 0 + if size > 0: + types = self.__readUByte() + vtype = self.__getTType(types) + ktype = self.__getTType(types >> 4) + self.__containers.append(self.state) + self.state = CONTAINER_READ + return (ktype, vtype, size) + + def readCollectionEnd(self): + assert self.state == CONTAINER_READ, self.state + self.state = self.__containers.pop() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + readMapEnd = readCollectionEnd + + def readBool(self): + if self.state == BOOL_READ: + return self.__bool_value == CompactType.TRUE + elif self.state == CONTAINER_READ: + return self.__readByte() == CompactType.TRUE + else: + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) + + readByte = reader(__readByte) + __readI16 = __readZigZag + readI16 = reader(__readZigZag) + readI32 = reader(__readZigZag) + readI64 = reader(__readZigZag) + + @reader + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def __readString(self): + len = self.__readSize() + return self.trans.readAll(len) + readString = reader(__readString) + + def __getTType(self, byte): + return TTYPES[byte & 0x0f] + + +class TCompactProtocolFactory: + def __init__(self): + pass + + def getProtocol(self, trans): + return TCompactProtocol(trans) diff --git a/pyload/lib/thrift/protocol/TJSONProtocol.py b/pyload/lib/thrift/protocol/TJSONProtocol.py new file mode 100644 index 000000000..3048197d4 --- /dev/null +++ b/pyload/lib/thrift/protocol/TJSONProtocol.py @@ -0,0 +1,550 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import TType, TProtocolBase, TProtocolException +import base64 +import json +import math + +__all__ = ['TJSONProtocol', + 'TJSONProtocolFactory', + 'TSimpleJSONProtocol', + 'TSimpleJSONProtocolFactory'] + +VERSION = 1 + +COMMA = ',' +COLON = ':' +LBRACE = '{' +RBRACE = '}' +LBRACKET = '[' +RBRACKET = ']' +QUOTE = '"' +BACKSLASH = '\\' +ZERO = '0' + +ESCSEQ = '\\u00' +ESCAPE_CHAR = '"\\bfnrt' +ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t'] +NUMERIC_CHAR = '+-.0123456789Ee' + +CTYPES = {TType.BOOL: 'tf', + TType.BYTE: 'i8', + TType.I16: 'i16', + TType.I32: 'i32', + TType.I64: 'i64', + TType.DOUBLE: 'dbl', + TType.STRING: 'str', + TType.STRUCT: 'rec', + TType.LIST: 'lst', + TType.SET: 'set', + TType.MAP: 'map'} + +JTYPES = {} +for key in CTYPES.keys(): + JTYPES[CTYPES[key]] = key + + +class JSONBaseContext(object): + + def __init__(self, protocol): + self.protocol = protocol + self.first = True + + def doIO(self, function): + pass + + def write(self): + pass + + def read(self): + pass + + def escapeNum(self): + return False + + def __str__(self): + return self.__class__.__name__ + + +class JSONListContext(JSONBaseContext): + + def doIO(self, function): + if self.first is True: + self.first = False + else: + function(COMMA) + + def write(self): + self.doIO(self.protocol.trans.write) + + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) + + +class JSONPairContext(JSONBaseContext): + + def __init__(self, protocol): + super(JSONPairContext, self).__init__(protocol) + self.colon = True + + def doIO(self, function): + if self.first: + self.first = False + self.colon = True + else: + function(COLON if self.colon else COMMA) + self.colon = not self.colon + + def write(self): + self.doIO(self.protocol.trans.write) + + def read(self): + self.doIO(self.protocol.readJSONSyntaxChar) + + def escapeNum(self): + return self.colon + + def __str__(self): + return '%s, colon=%s' % (self.__class__.__name__, self.colon) + + +class LookaheadReader(): + hasData = False + data = '' + + def __init__(self, protocol): + self.protocol = protocol + + def read(self): + if self.hasData is True: + self.hasData = False + else: + self.data = self.protocol.trans.read(1) + return self.data + + def peek(self): + if self.hasData is False: + self.data = self.protocol.trans.read(1) + self.hasData = True + return self.data + +class TJSONProtocolBase(TProtocolBase): + + def __init__(self, trans): + TProtocolBase.__init__(self, trans) + self.resetWriteContext() + self.resetReadContext() + + def resetWriteContext(self): + self.context = JSONBaseContext(self) + self.contextStack = [self.context] + + def resetReadContext(self): + self.resetWriteContext() + self.reader = LookaheadReader(self) + + def pushContext(self, ctx): + self.contextStack.append(ctx) + self.context = ctx + + def popContext(self): + self.contextStack.pop() + if self.contextStack: + self.context = self.contextStack[-1] + else: + self.context = JSONBaseContext(self) + + def writeJSONString(self, string): + self.context.write() + self.trans.write(json.dumps(string)) + + def writeJSONNumber(self, number): + self.context.write() + jsNumber = str(number) + if self.context.escapeNum(): + jsNumber = "%s%s%s" % (QUOTE, jsNumber, QUOTE) + self.trans.write(jsNumber) + + def writeJSONBase64(self, binary): + self.context.write() + self.trans.write(QUOTE) + self.trans.write(base64.b64encode(binary)) + self.trans.write(QUOTE) + + def writeJSONObjectStart(self): + self.context.write() + self.trans.write(LBRACE) + self.pushContext(JSONPairContext(self)) + + def writeJSONObjectEnd(self): + self.popContext() + self.trans.write(RBRACE) + + def writeJSONArrayStart(self): + self.context.write() + self.trans.write(LBRACKET) + self.pushContext(JSONListContext(self)) + + def writeJSONArrayEnd(self): + self.popContext() + self.trans.write(RBRACKET) + + def readJSONSyntaxChar(self, character): + current = self.reader.read() + if character != current: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Unexpected character: %s" % current) + + def readJSONString(self, skipContext): + string = [] + if skipContext is False: + self.context.read() + self.readJSONSyntaxChar(QUOTE) + while True: + character = self.reader.read() + if character == QUOTE: + break + if character == ESCSEQ[0]: + character = self.reader.read() + if character == ESCSEQ[1]: + self.readJSONSyntaxChar(ZERO) + self.readJSONSyntaxChar(ZERO) + character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) + else: + off = ESCAPE_CHAR.find(character) + if off == -1: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Expected control char") + character = ESCAPE_CHAR_VALS[off] + string.append(character) + return ''.join(string) + + def isJSONNumeric(self, character): + return (True if NUMERIC_CHAR.find(character) != - 1 else False) + + def readJSONQuotes(self): + if (self.context.escapeNum()): + self.readJSONSyntaxChar(QUOTE) + + def readJSONNumericChars(self): + numeric = [] + while True: + character = self.reader.peek() + if self.isJSONNumeric(character) is False: + break + numeric.append(self.reader.read()) + return ''.join(numeric) + + def readJSONInteger(self): + self.context.read() + self.readJSONQuotes() + numeric = self.readJSONNumericChars() + self.readJSONQuotes() + try: + return int(numeric) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONDouble(self): + self.context.read() + if self.reader.peek() == QUOTE: + string = self.readJSONString(True) + try: + double = float(string) + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): + raise TProtocolException(TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted") + return double + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + else: + if self.context.escapeNum() is True: + self.readJSONSyntaxChar(QUOTE) + try: + return float(self.readJSONNumericChars()) + except ValueError: + raise TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data") + + def readJSONBase64(self): + string = self.readJSONString(False) + return base64.b64decode(string) + + def readJSONObjectStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACE) + self.pushContext(JSONPairContext(self)) + + def readJSONObjectEnd(self): + self.readJSONSyntaxChar(RBRACE) + self.popContext() + + def readJSONArrayStart(self): + self.context.read() + self.readJSONSyntaxChar(LBRACKET) + self.pushContext(JSONListContext(self)) + + def readJSONArrayEnd(self): + self.readJSONSyntaxChar(RBRACKET) + self.popContext() + + +class TJSONProtocol(TJSONProtocolBase): + + def readMessageBegin(self): + self.resetReadContext() + self.readJSONArrayStart() + if self.readJSONInteger() != VERSION: + raise TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version.") + name = self.readJSONString(False) + typen = self.readJSONInteger() + seqid = self.readJSONInteger() + return (name, typen, seqid) + + def readMessageEnd(self): + self.readJSONArrayEnd() + + def readStructBegin(self): + self.readJSONObjectStart() + + def readStructEnd(self): + self.readJSONObjectEnd() + + def readFieldBegin(self): + character = self.reader.peek() + ttype = 0 + id = 0 + if character == RBRACE: + ttype = TType.STOP + else: + id = self.readJSONInteger() + self.readJSONObjectStart() + ttype = JTYPES[self.readJSONString(False)] + return (None, ttype, id) + + def readFieldEnd(self): + self.readJSONObjectEnd() + + def readMapBegin(self): + self.readJSONArrayStart() + keyType = JTYPES[self.readJSONString(False)] + valueType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + self.readJSONObjectStart() + return (keyType, valueType, size) + + def readMapEnd(self): + self.readJSONObjectEnd() + self.readJSONArrayEnd() + + def readCollectionBegin(self): + self.readJSONArrayStart() + elemType = JTYPES[self.readJSONString(False)] + size = self.readJSONInteger() + return (elemType, size) + readListBegin = readCollectionBegin + readSetBegin = readCollectionBegin + + def readCollectionEnd(self): + self.readJSONArrayEnd() + readSetEnd = readCollectionEnd + readListEnd = readCollectionEnd + + def readBool(self): + return (False if self.readJSONInteger() == 0 else True) + + def readNumber(self): + return self.readJSONInteger() + readByte = readNumber + readI16 = readNumber + readI32 = readNumber + readI64 = readNumber + + def readDouble(self): + return self.readJSONDouble() + + def readString(self): + return self.readJSONString(False) + + def readBinary(self): + return self.readJSONBase64() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + self.writeJSONArrayStart() + self.writeJSONNumber(VERSION) + self.writeJSONString(name) + self.writeJSONNumber(request_type) + self.writeJSONNumber(seqid) + + def writeMessageEnd(self): + self.writeJSONArrayEnd() + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, id): + self.writeJSONNumber(id) + self.writeJSONObjectStart() + self.writeJSONString(CTYPES[ttype]) + + def writeFieldEnd(self): + self.writeJSONObjectEnd() + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[ktype]) + self.writeJSONString(CTYPES[vtype]) + self.writeJSONNumber(size) + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + self.writeJSONArrayEnd() + + def writeListBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeListEnd(self): + self.writeJSONArrayEnd() + + def writeSetBegin(self, etype, size): + self.writeJSONArrayStart() + self.writeJSONString(CTYPES[etype]) + self.writeJSONNumber(size) + + def writeSetEnd(self): + self.writeJSONArrayEnd() + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeInteger(self, integer): + self.writeJSONNumber(integer) + writeByte = writeInteger + writeI16 = writeInteger + writeI32 = writeInteger + writeI64 = writeInteger + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TJSONProtocolFactory: + + def getProtocol(self, trans): + return TJSONProtocol(trans) + + +class TSimpleJSONProtocol(TJSONProtocolBase): + """Simple, readable, write-only JSON protocol. + + Useful for interacting with scripting languages. + """ + + def readMessageBegin(self): + raise NotImplementedError() + + def readMessageEnd(self): + raise NotImplementedError() + + def readStructBegin(self): + raise NotImplementedError() + + def readStructEnd(self): + raise NotImplementedError() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, fid): + self.writeJSONString(name) + + def writeFieldEnd(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + + def _writeCollectionBegin(self, etype, size): + self.writeJSONArrayStart() + + def _writeCollectionEnd(self): + self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin + writeListEnd = _writeCollectionEnd + writeSetBegin = _writeCollectionBegin + writeSetEnd = _writeCollectionEnd + + def writeInteger(self, integer): + self.writeJSONNumber(integer) + writeByte = writeInteger + writeI16 = writeInteger + writeI32 = writeInteger + writeI64 = writeInteger + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(object): + + def getProtocol(self, trans): + return TSimpleJSONProtocol(trans) diff --git a/pyload/lib/thrift/protocol/TProtocol.py b/pyload/lib/thrift/protocol/TProtocol.py new file mode 100644 index 000000000..dc2b095de --- /dev/null +++ b/pyload/lib/thrift/protocol/TProtocol.py @@ -0,0 +1,406 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * + + +class TProtocolException(TException): + """Custom Protocol Exception class""" + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + +class TProtocolBase: + """Base class for Thrift protocol driver.""" + + def __init__(self, trans): + self.trans = trans + + def writeMessageBegin(self, name, ttype, seqid): + pass + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, ttype, fid): + pass + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + pass + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + pass + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + pass + + def writeSetEnd(self): + pass + + def writeBool(self, bool_val): + pass + + def writeByte(self, byte): + pass + + def writeI16(self, i16): + pass + + def writeI32(self, i32): + pass + + def writeI64(self, i64): + pass + + def writeDouble(self, dub): + pass + + def writeString(self, str_val): + pass + + def readMessageBegin(self): + pass + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + pass + + def readFieldEnd(self): + pass + + def readMapBegin(self): + pass + + def readMapEnd(self): + pass + + def readListBegin(self): + pass + + def readListEnd(self): + pass + + def readSetBegin(self): + pass + + def readSetEnd(self): + pass + + def readBool(self): + pass + + def readByte(self): + pass + + def readI16(self): + pass + + def readI32(self): + pass + + def readI64(self): + pass + + def readDouble(self): + pass + + def readString(self): + pass + + def skip(self, ttype): + if ttype == TType.STOP: + return + elif ttype == TType.BOOL: + self.readBool() + elif ttype == TType.BYTE: + self.readByte() + elif ttype == TType.I16: + self.readI16() + elif ttype == TType.I32: + self.readI32() + elif ttype == TType.I64: + self.readI64() + elif ttype == TType.DOUBLE: + self.readDouble() + elif ttype == TType.STRING: + self.readString() + elif ttype == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, ttype, id) = self.readFieldBegin() + if ttype == TType.STOP: + break + self.skip(ttype) + self.readFieldEnd() + self.readStructEnd() + elif ttype == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in xrange(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif ttype == TType.SET: + (etype, size) = self.readSetBegin() + for i in xrange(size): + self.skip(etype) + self.readSetEnd() + elif ttype == TType.LIST: + (etype, size) = self.readListBegin() + for i in xrange(size): + self.skip(etype) + self.readListEnd() + + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? + ) + + def readFieldByTType(self, ttype, spec): + try: + (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] + except IndexError: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + if r_handler is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + reader = getattr(self, r_handler) + if not is_container: + return reader() + return reader(spec) + + def readContainerList(self, spec): + results = [] + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (list_type, list_len) = self.readListBegin() + if tspec is None: + # list values are simple types + for idx in xrange(list_len): + results.append(reader()) + else: + # this is like an inlined readFieldByTType + container_reader = self._TTYPE_HANDLERS[list_type][0] + val_reader = getattr(self, container_reader) + for idx in xrange(list_len): + val = val_reader(tspec) + results.append(val) + self.readListEnd() + return results + + def readContainerSet(self, spec): + results = set() + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (set_type, set_len) = self.readSetBegin() + if tspec is None: + # set members are simple types + for idx in xrange(set_len): + results.add(reader()) + else: + container_reader = self._TTYPE_HANDLERS[set_type][0] + val_reader = getattr(self, container_reader) + for idx in xrange(set_len): + results.add(val_reader(tspec)) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + results = dict() + key_ttype, key_spec = spec[0], spec[1] + val_ttype, val_spec = spec[2], spec[3] + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree + key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) + val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) + # list values are simple types + for idx in xrange(map_len): + if key_spec is None: + k_val = key_reader() + else: + k_val = self.readFieldByTType(key_ttype, key_spec) + if val_spec is None: + v_val = val_reader() + else: + v_val = self.readFieldByTType(val_ttype, val_spec) + # this raises a TypeError with unhashable keys types + # i.e. this fails: d=dict(); d[[0,1]] = 2 + results[k_val] = v_val + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec): + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + self.writeListBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec[1]) + self.writeListEnd() + + def writeContainerSet(self, val, spec): + self.writeSetBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec[1]) + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + k_type = spec[0] + v_type = spec[2] + ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] + ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] + k_writer = getattr(self, ktype_name) + v_writer = getattr(self, vtype_name) + self.writeMapBegin(k_type, v_type, len(val)) + for m_key, m_val in val.iteritems(): + if not k_is_container: + k_writer(m_key) + else: + k_writer(m_key, spec[1]) + if not v_is_container: + v_writer(m_val) + else: + v_writer(m_val, spec[3]) + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + # get the writer method for this value + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def writeFieldByTType(self, ttype, val, spec): + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] + writer = getattr(self, w_handler) + if is_container: + writer(val, spec) + else: + writer(val) + + +class TProtocolFactory: + def getProtocol(self, trans): + pass diff --git a/pyload/lib/thrift/protocol/__init__.py b/pyload/lib/thrift/protocol/__init__.py new file mode 100644 index 000000000..7eefb458a --- /dev/null +++ b/pyload/lib/thrift/protocol/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['fastbinary', 'TBase', 'TBinaryProtocol', 'TCompactProtocol', 'TJSONProtocol', 'TProtocol'] diff --git a/pyload/lib/thrift/protocol/fastbinary.c b/pyload/lib/thrift/protocol/fastbinary.c new file mode 100644 index 000000000..2ce56603c --- /dev/null +++ b/pyload/lib/thrift/protocol/fastbinary.c @@ -0,0 +1,1219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <Python.h> +#include "cStringIO.h" +#include <stdint.h> +#ifndef _WIN32 +# include <stdbool.h> +# include <netinet/in.h> +#else +# include <WinSock2.h> +# pragma comment (lib, "ws2_32.lib") +# define BIG_ENDIAN (4321) +# define LITTLE_ENDIAN (1234) +# define BYTE_ORDER LITTLE_ENDIAN +# if defined(_MSC_VER) && _MSC_VER < 1600 + typedef int _Bool; +# define bool _Bool +# define false 0 +# define true 1 +# endif +# define inline __inline +#endif + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) + #define __i386__ + #endif + + #ifndef BIG_ENDIAN + #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN + #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) + #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +# include <byteswap.h> +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType element_type; + PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + PyObject* klass; + PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { + PyObject* stringiobuf; + PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, INT32_MAX)) { + PyErr_SetString(PyExc_OverflowError, "string size out of range"); + return false; + } + return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = (int32_t) val; + return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + return false; + } + + dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 4) { + PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + return false; + } + + dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + + return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); + return false; + } + + dest->klass = PyTuple_GET_ITEM(typeargs, 0); + dest->spec = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + return false; + } + + dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { + int8_t net = val; + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { + int16_t net = (int16_t)htons(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { + int32_t net = (int32_t)htonl(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { + int64_t net = (int64_t)htonll(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * output_val assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + + writeByte(output, (int8_t) v); + break; + } + case T_I08: { + int32_t val; + + if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { + return false; + } + + writeByte(output, (int8_t) val); + break; + } + case T_I16: { + int32_t val; + + if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { + return false; + } + + writeI16(output, (int16_t) val); + break; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { + return false; + } + + writeI32(output, val); + break; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + writeI64(output, nval); + break; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + writeDouble(output, nval); + break; + } + + case T_STRING: { + Py_ssize_t len = PyString_Size(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeI32(output, (int32_t) len); + PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); + break; + } + + case T_LIST: + case T_SET: { + Py_ssize_t len; + SetListTypeArgs parsedargs; + PyObject *item; + PyObject *iterator; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + len = PyObject_Length(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeByte(output, parsedargs.element_type); + writeI32(output, (int32_t) len); + + iterator = PyObject_GetIter(value); + if (iterator == NULL) { + return false; + } + + while ((item = PyIter_Next(iterator))) { + if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { + Py_DECREF(item); + Py_DECREF(iterator); + return false; + } + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + return false; + } + + break; + } + + case T_MAP: { + PyObject *k, *v; + Py_ssize_t pos = 0; + Py_ssize_t len; + + MapTypeArgs parsedargs; + + len = PyDict_Size(value); + if (!check_ssize_t_32(len)) { + return false; + } + + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + writeByte(output, parsedargs.ktag); + writeByte(output, parsedargs.vtag); + writeI32(output, len); + + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + // TODO(dreiss): Think hard about whether these INCREFs actually + // turn any unsafe scenarios into safe scenarios. + Py_INCREF(k); + Py_INCREF(v); + + if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) + || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { + Py_DECREF(k); + Py_DECREF(v); + return false; + } + Py_DECREF(k); + Py_DECREF(v); + } + break; + } + + // TODO(dreiss): Consider breaking this out as a function + // the way we did for decode_struct. + case T_STRUCT: { + StructTypeArgs parsedargs; + Py_ssize_t nspec; + Py_ssize_t i; + + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + nspec = PyTuple_Size(parsedargs.spec); + + if (nspec == -1) { + return false; + } + + for (i = 0; i < nspec; i++) { + StructItemSpec parsedspec; + PyObject* spec_tuple; + PyObject* instval = NULL; + + spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { + return false; + } + + instval = PyObject_GetAttr(value, parsedspec.attrname); + + if (!instval) { + return false; + } + + if (instval == Py_None) { + Py_DECREF(instval); + continue; + } + + writeByte(output, (int8_t) parsedspec.type); + writeI16(output, parsedspec.tag); + + if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { + Py_DECREF(instval); + return false; + } + + Py_DECREF(instval); + } + + writeByte(output, (int8_t)T_STOP); + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { + PyObject* enc_obj; + PyObject* type_args; + PyObject* buf; + PyObject* ret = NULL; + + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + + buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); + if (output_val(buf, enc_obj, T_STRUCT, type_args)) { + ret = PycStringIO->cgetvalue(buf); + } + + Py_DECREF(buf); + return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { + Py_XDECREF(d->stringiobuf); + Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { + dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); + if (!dest->stringiobuf) { + return false; + } + + if (!PycStringIO_InputCheck(dest->stringiobuf)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting stringio input"); + return false; + } + + dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + + if(!dest->refill_callable) { + free_decodebuf(dest); + return false; + } + + if (!PyCallable_Check(dest->refill_callable)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { + int read; + + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + PyObject* newiobuf; + + // using building functions as this is a rare codepath + newiobuf = PyObject_CallFunction( + input->refill_callable, "s#i", *output, read, len, NULL); + if (newiobuf == NULL) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + Py_CLEAR(input->stringiobuf); + input->stringiobuf = newiobuf; + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, + "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +static int8_t readByte(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int8_t))) { + return -1; + } + + return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int16_t))) { + return -1; + } + + return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int32_t))) { + return -1; + } + return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int64_t))) { + return -1; + } + + return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { + union { + int64_t f; + double t; + } transfer; + + transfer.f = readI64(input); + if (transfer.f == -1) { + return -1; + } + return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { + TType got = readByte(input); + if (INT_CONV_ERROR_OCCURRED(got)) { + return false; + } + + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(input, &dummy_buf, (n))) { \ + return false; \ + } \ + } while(0) + + char* dummy_buf; + + switch (type) { + + case T_BOOL: + case T_I08: SKIPBYTES(1); break; + case T_I16: SKIPBYTES(2); break; + case T_I32: SKIPBYTES(4); break; + case T_I64: + case T_DOUBLE: SKIPBYTES(8); break; + + case T_STRING: { + // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. + int len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + SKIPBYTES(len); + break; + } + + case T_LIST: + case T_SET: { + TType etype; + int len, i; + + etype = readByte(input); + if (etype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!skip(input, etype)) { + return false; + } + } + break; + } + + case T_MAP: { + TType ktype, vtype; + int len, i; + + ktype = readByte(input); + if (ktype == -1) { + return false; + } + + vtype = readByte(input); + if (vtype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!(skip(input, ktype) && skip(input, vtype))) { + return false; + } + } + break; + } + + case T_STRUCT: { + while (true) { + TType type; + + type = readByte(input); + if (type == -1) { + return false; + } + + if (type == T_STOP) + break; + + SKIPBYTES(2); // tag + if (!skip(input, type)) { + return false; + } + } + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + if (spec_seq_len == -1) { + return false; + } + + while (true) { + TType type; + int16_t tag; + PyObject* item_spec; + PyObject* fieldval = NULL; + StructItemSpec parsedspec; + + type = readByte(input); + if (type == -1) { + return false; + } + if (type == T_STOP) { + break; + } + tag = readI16(input); + if (INT_CONV_ERROR_OCCURRED(tag)) { + return false; + } + if (tag >= 0 && tag < spec_seq_len) { + item_spec = PyTuple_GET_ITEM(spec_seq, tag); + } else { + item_spec = Py_None; + } + + if (item_spec == Py_None) { + if (!skip(input, type)) { + return false; + } else { + continue; + } + } + + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return false; + } + if (parsedspec.type != type) { + if (!skip(input, type)) { + PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); + return false; + } else { + continue; + } + } + + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + if (fieldval == NULL) { + return false; + } + + if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + Py_DECREF(fieldval); + return false; + } + Py_DECREF(fieldval); + } + return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + switch (v) { + case 0: Py_RETURN_FALSE; + case 1: Py_RETURN_TRUE; + // Don't laugh. This is a potentially serious issue. + default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; + } + break; + } + case T_I08: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = readI16(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = readI32(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = readI64(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long) v); + } + + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = readDouble(input); + if (v == -1.0 && PyErr_Occurred()) { + return false; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + Py_ssize_t len = readI32(input); + char* buf; + if (!readBytes(input, &buf, len)) { + return NULL; + } + + return PyString_FromStringAndSize(buf, len); + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + int32_t len; + PyObject* ret = NULL; + int i; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.element_type)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return NULL; + } + + ret = PyList_New(len); + if (!ret) { + return NULL; + } + + for (i = 0; i < len; i++) { + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + if (!item) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, item); + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) + // hack needed for older versions + setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else + // official version + setret = PySet_New(ret); +#endif + Py_DECREF(ret); + return setret; + } + return ret; + } + + case T_MAP: { + int32_t len; + int i; + MapTypeArgs parsedargs; + PyObject* ret = NULL; + + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.ktag)) { + return NULL; + } + if (!checkTypeByte(input, parsedargs.vtag)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + ret = PyDict_New(); + if (!ret) { + goto error; + } + + for (i = 0; i < len; i++) { + PyObject* k = NULL; + PyObject* v = NULL; + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + if (k == NULL) { + goto loop_error; + } + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + if (v == NULL) { + goto loop_error; + } + if (PyDict_SetItem(ret, k, v) == -1) { + goto loop_error; + } + + Py_DECREF(k); + Py_DECREF(v); + continue; + + // Yuck! Destructors, anyone? + loop_error: + Py_XDECREF(k); + Py_XDECREF(v); + goto error; + } + + return ret; + + error: + Py_XDECREF(ret); + return NULL; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + PyObject* ret; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + ret = PyObject_CallObject(parsedargs.klass, NULL); + if (!ret) { + return NULL; + } + + if (!decode_struct(input, ret, parsedargs.spec)) { + Py_DECREF(ret); + return NULL; + } + + return ret; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return NULL; + } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { + PyObject* output_obj = NULL; + PyObject* transport = NULL; + PyObject* typeargs = NULL; + StructTypeArgs parsedargs; + DecodeBuffer input = {0, 0}; + + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + return NULL; + } + + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!decode_buffer_from_obj(&input, transport)) { + return NULL; + } + + if (!decode_struct(&input, output_obj, parsedargs.spec)) { + free_decodebuf(&input); + return NULL; + } + + free_decodebuf(&input); + + Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if(!INTERN_STRING(value)) return; \ + } while(0) + + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + + PycString_IMPORT; + if (PycStringIO == NULL) return; + + (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/pyload/lib/thrift/server/THttpServer.py b/pyload/lib/thrift/server/THttpServer.py new file mode 100644 index 000000000..be54bab94 --- /dev/null +++ b/pyload/lib/thrift/server/THttpServer.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import BaseHTTPServer + +from thrift.server import TServer +from thrift.transport import TTransport + + +class ResponseException(Exception): + """Allows handlers to override the HTTP response + + Normally, THttpServer always sends a 200 response. If a handler wants + to override this behavior (e.g., to simulate a misconfigured or + overloaded web server during testing), it can raise a ResponseException. + The function passed to the constructor will be called with the + RequestHandler as its only argument. + """ + def __init__(self, handler): + self.handler = handler + + +class THttpServer(TServer.TServer): + """A simple HTTP-based Thrift server + + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint. + """ + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer): + """Set up protocol factories and HTTP server. + + See BaseHTTPServer for server_address. + See TServer for protocol factories. + """ + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory + + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) + + thttpserver = self + + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + itrans = TTransport.TFileObjectTransport(self.rfile) + otrans = TTransport.TFileObjectTransport(self.wfile) + itrans = TTransport.TBufferedTransport( + itrans, int(self.headers['Content-Length'])) + otrans = TTransport.TMemoryBuffer() + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + try: + thttpserver.processor.process(iprot, oprot) + except ResponseException, exn: + exn.handler(self) + else: + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + self.wfile.write(otrans.getvalue()) + + self.httpd = server_class(server_address, RequestHander) + + def serve(self): + self.httpd.serve_forever() diff --git a/pyload/lib/thrift/server/TNonblockingServer.py b/pyload/lib/thrift/server/TNonblockingServer.py new file mode 100644 index 000000000..fa478d01f --- /dev/null +++ b/pyload/lib/thrift/server/TNonblockingServer.py @@ -0,0 +1,346 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is to receive and send requests +only from the main thread. + +The thread poool should be sized for concurrent tasks, not +maximum connections +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + + +class Worker(threading.Thread): + """Worker is a small helper to process incoming connection.""" + + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + + def run(self): + """Process queries from task queue, stop if processor is None.""" + while True: + try: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break + processor.process(iprot, oprot) + callback(True, otrans.getvalue()) + except Exception: + logging.exception("Exception while processing request") + callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + + +def locked(func): + """Decorator which locks self.lock.""" + def nested(self, *args, **kwargs): + self.lock.acquire() + try: + return func(self, *args, **kwargs) + finally: + self.lock.release() + return nested + + +def socket_exception(func): + """Decorator close object on socket.error.""" + def read(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except socket.error: + self.close() + return read + + +class Connection: + """Basic class is represented connection. + + It can be in state: + WAIT_LEN --- connection is reading request len. + WAIT_MESSAGE --- connection is reading request. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. + SEND_ANSWER --- connection is sending answer string (including length + of answer). + CLOSED --- socket was closed and connection should be deleted. + """ + def __init__(self, new_socket, wake_up): + self.socket = new_socket + self.socket.setblocking(False) + self.status = WAIT_LEN + self.len = 0 + self.message = '' + self.lock = threading.Lock() + self.wake_up = wake_up + + def _read_len(self): + """Reads length of request. + + It's a safer alternative to self.socket.recv(4) + """ + read = self.socket.recv(4 - len(self.message)) + if len(read) == 0: + # if we read 0 bytes and self.message is empty, then + # the client closed the connection + if len(self.message) != 0: + logging.error("can't read frame size from socket") + self.close() + return + self.message += read + if len(self.message) == 4: + self.len, = struct.unpack('!i', self.message) + if self.len < 0: + logging.error("negative frame size, it seems client " + "doesn't use FramedTransport") + self.close() + elif self.len == 0: + logging.error("empty frame, it's really strange") + self.close() + else: + self.message = '' + self.status = WAIT_MESSAGE + + @socket_exception + def read(self): + """Reads data from stream and switch state.""" + assert self.status in (WAIT_LEN, WAIT_MESSAGE) + if self.status == WAIT_LEN: + self._read_len() + # go back to the main loop here for simplicity instead of + # falling through, even though there is a good chance that + # the message is already available + elif self.status == WAIT_MESSAGE: + read = self.socket.recv(self.len - len(self.message)) + if len(read) == 0: + logging.error("can't read frame from socket (get %d of " + "%d bytes)" % (len(self.message), self.len)) + self.close() + return + self.message += read + if len(self.message) == self.len: + self.status = WAIT_PROCESS + + @socket_exception + def write(self): + """Writes data from socket and switch state.""" + assert self.status == SEND_ANSWER + sent = self.socket.send(self.message) + if sent == len(self.message): + self.status = WAIT_LEN + self.message = '' + self.len = 0 + else: + self.message = self.message[sent:] + + @locked + def ready(self, all_ok, message): + """Callback function for switching state and waking up main thread. + + This function is the only function witch can be called asynchronous. + + The ready can switch Connection to three states: + WAIT_LEN if request was oneway. + SEND_ANSWER if request was processed in normal way. + CLOSED if request throws unexpected exception. + + The one wakes up main thread. + """ + assert self.status == WAIT_PROCESS + if not all_ok: + self.close() + self.wake_up() + return + self.len = '' + if len(message) == 0: + # it was a oneway request, do not write answer + self.message = '' + self.status = WAIT_LEN + else: + self.message = struct.pack('!i', len(message)) + message + self.status = SEND_ANSWER + self.wake_up() + + @locked + def is_writeable(self): + """Return True if connection should be added to write list of select""" + return self.status == SEND_ANSWER + + # it's not necessary, but... + @locked + def is_readable(self): + """Return True if connection should be added to read list of select""" + return self.status in (WAIT_LEN, WAIT_MESSAGE) + + @locked + def is_closed(self): + """Returns True if connection is closed.""" + return self.status == CLOSED + + def fileno(self): + """Returns the file descriptor of the associated socket.""" + return self.socket.fileno() + + def close(self): + """Closes connection""" + self.status = CLOSED + self.socket.close() + + +class TNonblockingServer: + """Non-blocking server.""" + + def __init__(self, + processor, + lsocket, + inputProtocolFactory=None, + outputProtocolFactory=None, + threads=10): + self.processor = processor + self.socket = lsocket + self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() + self.out_protocol = outputProtocolFactory or self.in_protocol + self.threads = int(threads) + self.clients = {} + self.tasks = Queue.Queue() + self._read, self._write = socket.socketpair() + self.prepared = False + self._stop = False + + def setNumThreads(self, num): + """Set the number of worker threads that should be created.""" + # implement ThreadPool interface + assert not self.prepared, "Can't change number of threads after start" + self.threads = num + + def prepare(self): + """Prepares server for serve requests.""" + if self.prepared: + return + self.socket.listen() + for _ in xrange(self.threads): + thread = Worker(self.tasks) + thread.setDaemon(True) + thread.start() + self.prepared = True + + def wake_up(self): + """Wake up main thread. + + The server usualy waits in select call in we should terminate one. + The simplest way is using socketpair. + + Select always wait to read from the first socket of socketpair. + + In this case, we can just write anything to the second socket from + socketpair. + """ + self._write.send('1') + + def stop(self): + """Stop the server. + + This method causes the serve() method to return. stop() may be invoked + from within your handler, or from another thread. + + After stop() is called, serve() will return but the server will still + be listening on the socket. serve() may then be called again to resume + processing requests. Alternatively, close() may be called after + serve() returns to close the server socket and shutdown all worker + threads. + """ + self._stop = True + self.wake_up() + + def _select(self): + """Does select on open connections.""" + readable = [self.socket.handle.fileno(), self._read.fileno()] + writable = [] + for i, connection in self.clients.items(): + if connection.is_readable(): + readable.append(connection.fileno()) + if connection.is_writeable(): + writable.append(connection.fileno()) + if connection.is_closed(): + del self.clients[i] + return select.select(readable, writable, readable) + + def handle(self): + """Handle requests. + + WARNING! You must call prepare() BEFORE calling handle() + """ + assert self.prepared, "You have to call prepare before handle" + rset, wset, xset = self._select() + for readable in rset: + if readable == self._read.fileno(): + # don't care i just need to clean readable flag + self._read.recv(1024) + elif readable == self.socket.handle.fileno(): + client = self.socket.accept().handle + self.clients[client.fileno()] = Connection(client, + self.wake_up) + else: + connection = self.clients[readable] + connection.read() + if connection.status == WAIT_PROCESS: + itransport = TTransport.TMemoryBuffer(connection.message) + otransport = TTransport.TMemoryBuffer() + iprot = self.in_protocol.getProtocol(itransport) + oprot = self.out_protocol.getProtocol(otransport) + self.tasks.put([self.processor, iprot, oprot, + otransport, connection.ready]) + for writeable in wset: + self.clients[writeable].write() + for oob in xset: + self.clients[oob].close() + del self.clients[oob] + + def close(self): + """Closes the server.""" + for _ in xrange(self.threads): + self.tasks.put([None, None, None, None, None]) + self.socket.close() + self.prepared = False + + def serve(self): + """Serve requests. + + Serve requests forever, or until stop() is called. + """ + self._stop = False + self.prepare() + while not self._stop: + self.handle() diff --git a/pyload/lib/thrift/server/TProcessPoolServer.py b/pyload/lib/thrift/server/TProcessPoolServer.py new file mode 100644 index 000000000..7a695a883 --- /dev/null +++ b/pyload/lib/thrift/server/TProcessPoolServer.py @@ -0,0 +1,118 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import logging +from multiprocessing import Process, Value, Condition, reduction + +from TServer import TServer +from thrift.transport.TTransport import TTransportException + + +class TProcessPoolServer(TServer): + """Server with a fixed size pool of worker subprocesses to service requests + + Note that if you need shared state between the handlers - it's up to you! + Written by Dvir Volk, doat.com + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.numWorkers = 10 + self.workers = [] + self.isRunning = Value('b', False) + self.stopCondition = Condition() + self.postForkCallback = None + + def setPostForkCallback(self, callback): + if not callable(callback): + raise TypeError("This is not a callback!") + self.postForkCallback = callback + + def setNumWorkers(self, num): + """Set the number of worker threads that should be created""" + self.numWorkers = num + + def workerProcess(self): + """Loop getting clients from the shared queue and process them""" + if self.postForkCallback: + self.postForkCallback() + + while self.isRunning.value: + try: + client = self.serverTransport.accept() + self.serveClient(client) + except (KeyboardInterrupt, SystemExit): + return 0 + except Exception, x: + logging.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + try: + while True: + self.processor.process(iprot, oprot) + except TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start workers and put into queue""" + # this is a shared state that can tell the workers to exit when False + self.isRunning.value = True + + # first bind and listen to the port + self.serverTransport.listen() + + # fork the children + for i in range(self.numWorkers): + try: + w = Process(target=self.workerProcess) + w.daemon = True + w.start() + self.workers.append(w) + except Exception, x: + logging.exception(x) + + # wait until the condition is set by stop() + while True: + self.stopCondition.acquire() + try: + self.stopCondition.wait() + break + except (SystemExit, KeyboardInterrupt): + break + except Exception, x: + logging.exception(x) + + self.isRunning.value = False + + def stop(self): + self.isRunning.value = False + self.stopCondition.acquire() + self.stopCondition.notify() + self.stopCondition.release() diff --git a/pyload/lib/thrift/server/TServer.py b/pyload/lib/thrift/server/TServer.py new file mode 100644 index 000000000..2f24842c4 --- /dev/null +++ b/pyload/lib/thrift/server/TServer.py @@ -0,0 +1,269 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import Queue +import logging +import os +import sys +import threading +import traceback + +from thrift.Thrift import TProcessor +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + + +class TServer: + """Base interface for a server, which must have a serve() method. + + Three constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory) + """ + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + def serve(self): + pass + + +class TSimpleServer(TServer): + """Simple single-threaded server that just pumps around one transport.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + +class TThreadedServer(TServer): + """Threaded server that spawns a new thread per each connection.""" + + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.daemon = kwargs.get("daemon", False) + + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + t = threading.Thread(target=self.handle, args=(client,)) + t.setDaemon(self.daemon) + t.start() + except KeyboardInterrupt: + raise + except Exception, x: + logging.exception(x) + + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + +class TThreadPoolServer(TServer): + """Server with a fixed size pool of threads which service requests.""" + + def __init__(self, *args, **kwargs): + TServer.__init__(self, *args) + self.clients = Queue.Queue() + self.threads = 10 + self.daemon = kwargs.get("daemon", False) + + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num + + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: + try: + client = self.clients.get() + self.serveClient(client) + except Exception, x: + logging.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target=self.serveThread) + t.setDaemon(self.daemon) + t.start() + except Exception, x: + logging.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + self.clients.put(client) + except Exception, x: + logging.exception(x) + + +class TForkingServer(TServer): + """A Thrift server that forks a new process for each request + + This is more scalable than the threaded server as it does not cause + GIL contention. + + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. + + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] + + def serve(self): + def try_close(file): + try: + file.close() + except IOError, e: + logging.warning(e, exc_info=True) + + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + try: + pid = os.fork() + + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() + + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, e: + logging.exception(e) + ecode = 1 + finally: + try_close(itrans) + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break diff --git a/pyload/lib/thrift/server/__init__.py b/pyload/lib/thrift/server/__init__.py new file mode 100644 index 000000000..1bf6e254e --- /dev/null +++ b/pyload/lib/thrift/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/pyload/lib/thrift/transport/THttpClient.py b/pyload/lib/thrift/transport/THttpClient.py new file mode 100644 index 000000000..ea80a1ae8 --- /dev/null +++ b/pyload/lib/thrift/transport/THttpClient.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import httplib +import os +import socket +import sys +import urllib +import urlparse +import warnings + +from cStringIO import StringIO + +from TTransport import * + + +class THttpClient(TTransportBase): + """Http implementation of TTransport base.""" + + def __init__(self, uri_or_host, port=None, path=None): + """THttpClient supports two different types constructor parameters. + + THttpClient(host, port, path) - deprecated + THttpClient(uri) + + Only the second supports https. + """ + if port is not None: + warnings.warn( + "Please use the THttpClient('http://host:port/path') syntax", + DeprecationWarning, + stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urlparse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or httplib.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or httplib.HTTPS_PORT + self.host = parsed.hostname + self.path = parsed.path + if parsed.query: + self.path += '?%s' % parsed.query + self.__wbuf = StringIO() + self.__http = None + self.__timeout = None + self.__custom_headers = None + + def open(self): + if self.scheme == 'http': + self.__http = httplib.HTTP(self.host, self.port) + else: + self.__http = httplib.HTTPS(self.host, self.port) + + def close(self): + self.__http.close() + self.__http = None + + def isOpen(self): + return self.__http is not None + + def setTimeout(self, ms): + if not hasattr(socket, 'getdefaulttimeout'): + raise NotImplementedError + + if ms is None: + self.__timeout = None + else: + self.__timeout = ms / 1000.0 + + def setCustomHeaders(self, headers): + self.__custom_headers = headers + + def read(self, sz): + return self.__http.file.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def __withTimeout(f): + def _f(*args, **kwargs): + orig_timeout = socket.getdefaulttimeout() + socket.setdefaulttimeout(args[0].__timeout) + result = f(*args, **kwargs) + socket.setdefaulttimeout(orig_timeout) + return result + return _f + + def flush(self): + if self.isOpen(): + self.close() + self.open() + + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = StringIO() + + # HTTP request + self.__http.putrequest('POST', self.path) + + # Write headers + self.__http.putheader('Host', self.host) + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + + if not self.__custom_headers or 'User-Agent' not in self.__custom_headers: + user_agent = 'Python/THttpClient' + script = os.path.basename(sys.argv[0]) + if script: + user_agent = '%s (%s)' % (user_agent, urllib.quote(script)) + self.__http.putheader('User-Agent', user_agent) + + if self.__custom_headers: + for key, val in self.__custom_headers.iteritems(): + self.__http.putheader(key, val) + + self.__http.endheaders() + + # Write payload + self.__http.send(data) + + # Get reply to flush the request + self.code, self.message, self.headers = self.__http.getreply() + + # Decorate if we know how to timeout + if hasattr(socket, 'getdefaulttimeout'): + flush = __withTimeout(flush) diff --git a/pyload/lib/thrift/transport/TSSLSocket.py b/pyload/lib/thrift/transport/TSSLSocket.py new file mode 100644 index 000000000..81e098426 --- /dev/null +++ b/pyload/lib/thrift/transport/TSSLSocket.py @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +import socket +import ssl + +from thrift.transport import TSocket +from thrift.transport.TTransport import TTransportException + + +class TSSLSocket(TSocket.TSocket): + """ + SSL implementation of client-side TSocket + + This class creates outbound sockets wrapped using the + python standard ssl module for encrypted connections. + + The protocol used is set using the class variable + SSL_VERSION, which must be one of ssl.PROTOCOL_* and + defaults to ssl.PROTOCOL_TLSv1 for greatest security. + """ + SSL_VERSION = ssl.PROTOCOL_TLSv1 + + def __init__(self, + host='localhost', + port=9090, + validate=True, + ca_certs=None, + keyfile=None, + certfile=None, + unix_socket=None): + """Create SSL TSocket + + @param validate: Set to False to disable SSL certificate validation + @type validate: bool + @param ca_certs: Filename to the Certificate Authority pem file, possibly a + file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to + the ssl_wrap function as the 'ca_certs' parameter. + @type ca_certs: str + @param keyfile: The private key + @type keyfile: str + @param certfile: The cert file + @type certfile: str + + Raises an IOError exception if validate is True and the ca_certs file is + None, not present or unreadable. + """ + self.validate = validate + self.is_valid = False + self.peercert = None + if not validate: + self.cert_reqs = ssl.CERT_NONE + else: + self.cert_reqs = ssl.CERT_REQUIRED + self.ca_certs = ca_certs + self.keyfile = keyfile + self.certfile = certfile + if validate: + if ca_certs is None or not os.access(ca_certs, os.R_OK): + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (ca_certs)) + TSocket.TSocket.__init__(self, host, port, unix_socket) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + sock_family, sock_type = res[0:2] + ip_port = res[4] + plain_sock = socket.socket(sock_family, sock_type) + self.handle = ssl.wrap_socket(plain_sock, + ssl_version=self.SSL_VERSION, + do_handshake_on_connect=True, + ca_certs=self.ca_certs, + keyfile=self.keyfile, + certfile=self.certfile, + cert_reqs=self.cert_reqs) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(ip_port) + except socket.error, e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error, e: + if self._unix_socket: + message = 'Could not connect to secure socket %s: %s' \ + % (self._unix_socket, e) + else: + message = 'Could not connect to %s:%d: %s' % (self.host, self.port, e) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) + if self.validate: + self._validate_cert() + + def _validate_cert(self): + """internal method to validate the peer's SSL certificate, and to check the + commonName of the certificate to ensure it matches the hostname we + used to make this connection. Does not support subjectAltName records + in certificates. + + raises TTransportException if the certificate fails validation. + """ + cert = self.handle.getpeercert() + self.peercert = cert + if 'subject' not in cert: + raise TTransportException( + type=TTransportException.NOT_OPEN, + message='No SSL certificate found from %s:%s' % (self.host, self.port)) + fields = cert['subject'] + for field in fields: + # ensure structure we get back is what we expect + if not isinstance(field, tuple): + continue + cert_pair = field[0] + if len(cert_pair) < 2: + continue + cert_key, cert_value = cert_pair[0:2] + if cert_key != 'commonName': + continue + certhost = cert_value + # this check should be performed by some sort of Access Manager + if certhost == self.host: + # success, cert commonName matches desired hostname + self.is_valid = True + return + else: + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (self.host, certhost)) + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Could not validate SSL certificate from ' + 'host "%s". Cert=%s' % (self.host, cert)) + + +class TSSLServerSocket(TSocket.TServerSocket): + """SSL implementation of TServerSocket + + This uses the ssl module's wrap_socket() method to provide SSL + negotiated encryption. + """ + SSL_VERSION = ssl.PROTOCOL_TLSv1 + + def __init__(self, + host=None, + port=9090, + certfile='cert.pem', + unix_socket=None): + """Initialize a TSSLServerSocket + + @param certfile: filename of the server certificate, defaults to cert.pem + @type certfile: str + @param host: The hostname or IP to bind the listen socket to, + i.e. 'localhost' for only allowing local network connections. + Pass None to bind to all interfaces. + @type host: str + @param port: The port to listen on for inbound connections. + @type port: int + """ + self.setCertfile(certfile) + TSocket.TServerSocket.__init__(self, host, port) + + def setCertfile(self, certfile): + """Set or change the server certificate file used to wrap new connections. + + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' + @type certfile: str + + Raises an IOError exception if the certfile is not present or unreadable. + """ + if not os.access(certfile, os.R_OK): + raise IOError('No such certfile found: %s' % (certfile)) + self.certfile = certfile + + def accept(self): + plain_client, addr = self.handle.accept() + try: + client = ssl.wrap_socket(plain_client, certfile=self.certfile, + server_side=True, ssl_version=self.SSL_VERSION) + except ssl.SSLError, ssl_exc: + # failed handshake/ssl wrap, close socket to client + plain_client.close() + # raise ssl_exc + # We can't raise the exception, because it kills most TServer derived + # serve() methods. + # Instead, return None, and let the TServer instance deal with it in + # other exception handling. (but TSimpleServer dies anyway) + return None + result = TSocket.TSocket() + result.setHandle(client) + return result diff --git a/pyload/lib/thrift/transport/TSocket.py b/pyload/lib/thrift/transport/TSocket.py new file mode 100644 index 000000000..9e2b3849b --- /dev/null +++ b/pyload/lib/thrift/transport/TSocket.py @@ -0,0 +1,176 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import errno +import os +import socket +import sys + +from TTransport import * + + +class TSocketBase(TTransportBase): + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] + else: + return socket.getaddrinfo(self.host, + self.port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + + def close(self): + if self.handle: + self.handle.close() + self.handle = None + + +class TSocket(TSocketBase): + """Socket implementation of TTransport base.""" + + def __init__(self, host='localhost', port=9090, unix_socket=None): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + """ + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + + def setHandle(self, h): + self.handle = h + + def isOpen(self): + return self.handle is not None + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms / 1000.0 + + if self.handle is not None: + self.handle.settimeout(self._timeout) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + self.handle = socket.socket(res[0], res[1]) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(res[4]) + except socket.error, e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error, e: + if self._unix_socket: + message = 'Could not connect to socket %s' % self._unix_socket + else: + message = 'Could not connect to %s:%d' % (self.host, self.port) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) + + def read(self, sz): + try: + buff = self.handle.recv(sz) + except socket.error, e: + if (e.args[0] == errno.ECONNRESET and + (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): + # freebsd and Mach don't follow POSIX semantic of recv + # and fail with ECONNRESET if peer performed shutdown. + # See corresponding comment and code in TSocket::read() + # in lib/cpp/src/transport/TSocket.cpp. + self.close() + # Trigger the check to raise the END_OF_FILE exception below. + buff = '' + else: + raise + if len(buff) == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') + return buff + + def write(self, buff): + if not self.handle: + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') + sent = 0 + have = len(buff) + while sent < have: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + + def flush(self): + pass + + +class TServerSocket(TSocketBase, TServerTransportBase): + """Socket implementation of TServerTransport base.""" + + def __init__(self, host=None, port=9090, unix_socket=None): + self.host = host + self.port = port + self._unix_socket = unix_socket + self.handle = None + + def listen(self): + res0 = self._resolveAddr() + for res in res0: + if res[0] is socket.AF_INET6 or res is res0[-1]: + break + + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error, err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) + + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'settimeout'): + self.handle.settimeout(None) + self.handle.bind(res[4]) + self.handle.listen(128) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/pyload/lib/thrift/transport/TTransport.py b/pyload/lib/thrift/transport/TTransport.py new file mode 100644 index 000000000..4481371a6 --- /dev/null +++ b/pyload/lib/thrift/transport/TTransport.py @@ -0,0 +1,330 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +from struct import pack, unpack +from thrift.Thrift import TException + + +class TTransportException(TException): + """Custom Transport Exception class""" + + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + +class TTransportBase: + """Base class for Thrift transport layer.""" + + def isOpen(self): + pass + + def open(self): + pass + + def close(self): + pass + + def read(self, sz): + pass + + def readAll(self, sz): + buff = '' + have = 0 + while (have < sz): + chunk = self.read(sz - have) + have += len(chunk) + buff += chunk + + if len(chunk) == 0: + raise EOFError() + + return buff + + def write(self, buf): + pass + + def flush(self): + pass + + +# This class should be thought of as an interface. +class CReadableTransport: + """base class for transports that are readable from C""" + + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. + + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass + + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. + + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. + + If reqlen bytes can't be read, throw EOFError. + """ + pass + + +class TServerTransportBase: + """Base class for Thrift server transports.""" + + def listen(self): + pass + + def accept(self): + pass + + def close(self): + pass + + +class TTransportFactoryBase: + """Base class for a Transport Factory""" + + def getTransport(self, trans): + return trans + + +class TBufferedTransportFactory: + """Factory transport that builds buffered transports""" + + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered + + +class TBufferedTransport(TTransportBase, CReadableTransport): + """Class that wraps another transport and buffers its I/O. + + The implementation uses a (configurable) fixed-size read buffer + but buffers all writes until a flush is performed. + """ + DEFAULT_BUFFER = 4096 + + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): + self.__trans = trans + self.__wbuf = StringIO() + self.__rbuf = StringIO("") + self.__rbuf_size = rbuf_size + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.__rbuf = StringIO(self.__trans.read(max(sz, self.__rbuf_size))) + return self.__rbuf.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + self.__trans.write(out) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.__rbuf_size: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.__rbuf_size) + + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) + + self.__rbuf = StringIO(retstring) + return self.__rbuf + + +class TMemoryBuffer(TTransportBase, CReadableTransport): + """Wraps a cStringIO object as a TTransport. + + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ + + def __init__(self, value=None): + """value -- a value to read from for stringio + + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = StringIO(value) + else: + self._buffer = StringIO() + + def isOpen(self): + return not self._buffer.closed + + def open(self): + pass + + def close(self): + self._buffer.close() + + def read(self, sz): + return self._buffer.read(sz) + + def write(self, buf): + self._buffer.write(buf) + + def flush(self): + pass + + def getvalue(self): + return self._buffer.getvalue() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer + + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() + + +class TFramedTransportFactory: + """Factory transport that builds framed transports""" + + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + """Class that wraps another transport and frames its I/O when writing.""" + + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = StringIO() + self.__wbuf = StringIO() + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.readFrame() + return self.__rbuf.read(sz) + + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = StringIO(self.__trans.readAll(sz)) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + self.readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = StringIO(prefix) + return self.__rbuf + + +class TFileObjectTransport(TTransportBase): + """Wraps a file-like object to make it work as a Thrift transport.""" + + def __init__(self, fileobj): + self.fileobj = fileobj + + def isOpen(self): + return True + + def close(self): + self.fileobj.close() + + def read(self, sz): + return self.fileobj.read(sz) + + def write(self, buf): + self.fileobj.write(buf) + + def flush(self): + self.fileobj.flush() diff --git a/pyload/lib/thrift/transport/TTwisted.py b/pyload/lib/thrift/transport/TTwisted.py new file mode 100644 index 000000000..3ce3eb220 --- /dev/null +++ b/pyload/lib/thrift/transport/TTwisted.py @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO + +from zope.interface import implements, Interface, Attribute +from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ + connectionDone +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log +from twisted.web import server, resource, http + +from thrift.transport import TTransport + + +class TMessageSenderTransport(TTransport.TTransportBase): + + def __init__(self): + self.__wbuf = StringIO() + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + msg = self.__wbuf.getvalue() + self.__wbuf = StringIO() + self.sendMessage(msg) + + def sendMessage(self, message): + raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + + def __init__(self, func): + TMessageSenderTransport.__init__(self) + self.func = func + + def sendMessage(self, message): + self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + + MAX_LENGTH = 2 ** 31 - 1 + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self._client_class = client_class + self._iprot_factory = iprot_factory + if oprot_factory is None: + self._oprot_factory = iprot_factory + else: + self._oprot_factory = oprot_factory + + self.recv_map = {} + self.started = defer.Deferred() + + def dispatch(self, msg): + self.sendString(msg) + + def connectionMade(self): + tmo = TCallbackTransport(self.dispatch) + self.client = self._client_class(tmo, self._oprot_factory) + self.started.callback(self.client) + + def connectionLost(self, reason=connectionDone): + for k, v in self.client._reqs.iteritems(): + tex = TTransport.TTransportException( + type=TTransport.TTransportException.END_OF_FILE, + message='Connection closed') + v.errback(tex) + + def stringReceived(self, frame): + tr = TTransport.TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + (fname, mtype, rseqid) = iprot.readMessageBegin() + + try: + method = self.recv_map[fname] + except KeyError: + method = getattr(self.client, 'recv_' + fname) + self.recv_map[fname] = method + + method(iprot, mtype, rseqid) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + + MAX_LENGTH = 2 ** 31 - 1 + + def dispatch(self, msg): + self.sendString(msg) + + def processError(self, error): + self.transport.loseConnection() + + def processOk(self, _, tmo): + msg = tmo.getvalue() + + if len(msg) > 0: + self.dispatch(msg) + + def stringReceived(self, frame): + tmi = TTransport.TMemoryBuffer(frame) + tmo = TTransport.TMemoryBuffer() + + iprot = self.factory.iprot_factory.getProtocol(tmi) + oprot = self.factory.oprot_factory.getProtocol(tmo) + + d = self.factory.processor.process(iprot, oprot) + d.addCallbacks(self.processOk, self.processError, + callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + + processor = Attribute("Thrift processor") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + + client_class = Attribute("Thrift client class") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class ThriftServerFactory(ServerFactory): + + implements(IThriftServerFactory) + + protocol = ThriftServerProtocol + + def __init__(self, processor, iprot_factory, oprot_factory=None): + self.processor = processor + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + +class ThriftClientFactory(ClientFactory): + + implements(IThriftClientFactory) + + protocol = ThriftClientProtocol + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self.client_class = client_class + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + def buildProtocol(self, addr): + p = self.protocol(self.client_class, self.iprot_factory, + self.oprot_factory) + p.factory = self + return p + + +class ThriftResource(resource.Resource): + + allowedMethods = ('POST',) + + def __init__(self, processor, inputProtocolFactory, + outputProtocolFactory=None): + resource.Resource.__init__(self) + self.inputProtocolFactory = inputProtocolFactory + if outputProtocolFactory is None: + self.outputProtocolFactory = inputProtocolFactory + else: + self.outputProtocolFactory = outputProtocolFactory + self.processor = processor + + def getChild(self, path, request): + return self + + def _cbProcess(self, _, request, tmo): + msg = tmo.getvalue() + request.setResponseCode(http.OK) + request.setHeader("content-type", "application/x-thrift") + request.write(msg) + request.finish() + + def render_POST(self, request): + request.content.seek(0, 0) + data = request.content.read() + tmi = TTransport.TMemoryBuffer(data) + tmo = TTransport.TMemoryBuffer() + + iprot = self.inputProtocolFactory.getProtocol(tmi) + oprot = self.outputProtocolFactory.getProtocol(tmo) + + d = self.processor.process(iprot, oprot) + d.addCallback(self._cbProcess, request, tmo) + return server.NOT_DONE_YET diff --git a/pyload/lib/thrift/transport/TZlibTransport.py b/pyload/lib/thrift/transport/TZlibTransport.py new file mode 100644 index 000000000..a2f42a5d2 --- /dev/null +++ b/pyload/lib/thrift/transport/TZlibTransport.py @@ -0,0 +1,248 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +"""TZlibTransport provides a compressed transport and transport factory +class, using the python standard library zlib module to implement +data compression. +""" + +from __future__ import division +import zlib +from cStringIO import StringIO +from TTransport import TTransportBase, CReadableTransport + + +class TZlibTransportFactory(object): + """Factory transport that builds zlib compressed transports. + + This factory caches the last single client/transport that it was passed + and returns the same TZlibTransport object that was created. + + This caching means the TServer class will get the _same_ transport + object for both input and output transports from this factory. + (For non-threaded scenarios only, since the cache only holds one object) + + The purpose of this caching is to allocate only one TZlibTransport where + only one is really needed (since it must have separate read/write buffers), + and makes the statistics from getCompSavings() and getCompRatio() + easier to understand. + """ + # class scoped cache of last transport given and zlibtransport returned + _last_trans = None + _last_z = None + + def getTransport(self, trans, compresslevel=9): + """Wrap a transport, trans, with the TZlibTransport + compressed transport class, returning a new + transport to the caller. + + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Defaults to 9. + @type compresslevel: int + + This method returns a TZlibTransport which wraps the + passed C{trans} TTransport derived instance. + """ + if trans == self._last_trans: + return self._last_z + ztrans = TZlibTransport(trans, compresslevel) + self._last_trans = trans + self._last_z = ztrans + return ztrans + + +class TZlibTransport(TTransportBase, CReadableTransport): + """Class that wraps a transport with zlib, compressing writes + and decompresses reads, using the python standard + library zlib module. + """ + # Read buffer size for the python fastbinary C extension, + # the TBinaryProtocolAccelerated class. + DEFAULT_BUFFSIZE = 4096 + + def __init__(self, trans, compresslevel=9): + """Create a new TZlibTransport, wrapping C{trans}, another + TTransport derived object. + + @param trans: A thrift transport object, i.e. a TSocket() object. + @type trans: TTransport + @param compresslevel: The zlib compression level, ranging + from 0 (no compression) to 9 (best compression). Default is 9. + @type compresslevel: int + """ + self.__trans = trans + self.compresslevel = compresslevel + self.__rbuf = StringIO() + self.__wbuf = StringIO() + self._init_zlib() + self._init_stats() + + def _reinit_buffers(self): + """Internal method to initialize/reset the internal StringIO objects + for read and write buffers. + """ + self.__rbuf = StringIO() + self.__wbuf = StringIO() + + def _init_stats(self): + """Internal method to reset the internal statistics counters + for compression ratios and bandwidth savings. + """ + self.bytes_in = 0 + self.bytes_out = 0 + self.bytes_in_comp = 0 + self.bytes_out_comp = 0 + + def _init_zlib(self): + """Internal method for setting up the zlib compression and + decompression objects. + """ + self._zcomp_read = zlib.decompressobj() + self._zcomp_write = zlib.compressobj(self.compresslevel) + + def getCompRatio(self): + """Get the current measured compression ratios (in,out) from + this transport. + + Returns a tuple of: + (inbound_compression_ratio, outbound_compression_ratio) + + The compression ratios are computed as: + compressed / uncompressed + + E.g., data that compresses by 10x will have a ratio of: 0.10 + and data that compresses to half of ts original size will + have a ratio of 0.5 + + None is returned if no bytes have yet been processed in + a particular direction. + """ + r_percent, w_percent = (None, None) + if self.bytes_in > 0: + r_percent = self.bytes_in_comp / self.bytes_in + if self.bytes_out > 0: + w_percent = self.bytes_out_comp / self.bytes_out + return (r_percent, w_percent) + + def getCompSavings(self): + """Get the current count of saved bytes due to data + compression. + + Returns a tuple of: + (inbound_saved_bytes, outbound_saved_bytes) + + Note: if compression is actually expanding your + data (only likely with very tiny thrift objects), then + the values returned will be negative. + """ + r_saved = self.bytes_in - self.bytes_in_comp + w_saved = self.bytes_out - self.bytes_out_comp + return (r_saved, w_saved) + + def isOpen(self): + """Return the underlying transport's open status""" + return self.__trans.isOpen() + + def open(self): + """Open the underlying transport""" + self._init_stats() + return self.__trans.open() + + def listen(self): + """Invoke the underlying transport's listen() method""" + self.__trans.listen() + + def accept(self): + """Accept connections on the underlying transport""" + return self.__trans.accept() + + def close(self): + """Close the underlying transport,""" + self._reinit_buffers() + self._init_zlib() + return self.__trans.close() + + def read(self, sz): + """Read up to sz bytes from the decompressed bytes buffer, and + read from the underlying transport if the decompression + buffer is empty. + """ + ret = self.__rbuf.read(sz) + if len(ret) > 0: + return ret + # keep reading from transport until something comes back + while True: + if self.readComp(sz): + break + ret = self.__rbuf.read(sz) + return ret + + def readComp(self, sz): + """Read compressed data from the underlying transport, then + decompress it and append it to the internal StringIO read buffer + """ + zbuf = self.__trans.read(sz) + zbuf = self._zcomp_read.unconsumed_tail + zbuf + buf = self._zcomp_read.decompress(zbuf) + self.bytes_in += len(zbuf) + self.bytes_in_comp += len(buf) + old = self.__rbuf.read() + self.__rbuf = StringIO(old + buf) + if len(old) + len(buf) == 0: + return False + return True + + def write(self, buf): + """Write some bytes, putting them into the internal write + buffer for eventual compression. + """ + self.__wbuf.write(buf) + + def flush(self): + """Flush any queued up data in the write buffer and ensure the + compression buffer is flushed out to the underlying transport + """ + wout = self.__wbuf.getvalue() + if len(wout) > 0: + zbuf = self._zcomp_write.compress(wout) + self.bytes_out += len(wout) + self.bytes_out_comp += len(zbuf) + else: + zbuf = '' + ztail = self._zcomp_write.flush(zlib.Z_SYNC_FLUSH) + self.bytes_out_comp += len(ztail) + if (len(zbuf) + len(ztail)) > 0: + self.__wbuf = StringIO() + self.__trans.write(zbuf + ztail) + self.__trans.flush() + + @property + def cstringio_buf(self): + """Implement the CReadableTransport interface""" + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + """Implement the CReadableTransport interface for refill""" + retstring = partialread + if reqlen < self.DEFAULT_BUFFSIZE: + retstring += self.read(self.DEFAULT_BUFFSIZE) + while len(retstring) < reqlen: + retstring += self.read(reqlen - len(retstring)) + self.__rbuf = StringIO(retstring) + return self.__rbuf diff --git a/pyload/lib/thrift/transport/__init__.py b/pyload/lib/thrift/transport/__init__.py new file mode 100644 index 000000000..c9596d9a6 --- /dev/null +++ b/pyload/lib/thrift/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport'] |