diff options
author | mkaay <mkaay@mkaay.de> | 2011-02-15 18:06:39 +0100 |
---|---|---|
committer | mkaay <mkaay@mkaay.de> | 2011-02-15 18:06:39 +0100 |
commit | 8b45dddde21dd2fbe24b492823085e1870c735bc (patch) | |
tree | 36d68d27246e8298129b16a40f7ca3230a378663 /module/remote | |
parent | removed debug (diff) | |
download | pyload-8b45dddde21dd2fbe24b492823085e1870c735bc.tar.xz |
thrift ssl connection fix (closes #242)
Diffstat (limited to 'module/remote')
-rw-r--r-- | module/remote/thriftbackend/Socket.py | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/module/remote/thriftbackend/Socket.py b/module/remote/thriftbackend/Socket.py index 33daab4c0..a492be59e 100644 --- a/module/remote/thriftbackend/Socket.py +++ b/module/remote/thriftbackend/Socket.py @@ -6,6 +6,8 @@ import errno from thrift.transport.TSocket import TSocket, TServerSocket, TTransportException +WantReadError = Exception #overwritten when ssl is used + class SecureSocketConnection: def __init__(self, connection): self.__dict__["connection"] = connection @@ -22,6 +24,18 @@ class SecureSocketConnection: def accept(self): connection, address = self.__dict__["connection"].accept() return SecureSocketConnection(connection), address + + def send(self, buff): + try: + return self.__dict__["connection"].send(buff) + except WantReadError: + return self.send(buff) + + def recv(self, buff): + try: + return self.__dict__["connection"].recv(buff) + except WantReadError: + return self.recv(buff) class Socket(TSocket): def __init__(self, host='localhost', port=7228, ssl=False): @@ -31,8 +45,11 @@ class Socket(TSocket): def open(self): if self.ssl: SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL + WantReadError = SSL.WantReadError ctx = SSL.Context(SSL.SSLv23_METHOD) - self.handle = SecureSocketConnection(SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))) + c = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) + c.set_connect_state() + self.handle = SecureSocketConnection(c) else: self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -67,7 +84,7 @@ class Socket(TSocket): raise if not len(buff): - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') + raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') return buff @@ -82,11 +99,13 @@ class ServerSocket(TServerSocket, Socket): def listen(self): if self.cert and self.key: SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL + WantReadError = SSL.WantReadError ctx = SSL.Context(SSL.SSLv23_METHOD) ctx.use_privatekey_file(self.key) ctx.use_certificate_file(self.cert) tmpConnection = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)) + tmpConnection.set_accept_state() self.handle = SecureSocketConnection(tmpConnection) else: @@ -95,7 +114,7 @@ class ServerSocket(TServerSocket, Socket): self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if hasattr(self.handle, 'set_timeout'): - self.handle.set_timeout(None) + self.handle.set_timeout(None) self.handle.bind((self.host, self.port)) self.handle.listen(128) |