summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGravatar mkaay <mkaay@mkaay.de> 2011-02-15 18:06:39 +0100
committerGravatar mkaay <mkaay@mkaay.de> 2011-02-15 18:06:39 +0100
commit8b45dddde21dd2fbe24b492823085e1870c735bc (patch)
tree36d68d27246e8298129b16a40f7ca3230a378663
parentremoved debug (diff)
downloadpyload-8b45dddde21dd2fbe24b492823085e1870c735bc.tar.xz
thrift ssl connection fix (closes #242)
-rw-r--r--module/remote/thriftbackend/Socket.py25
1 files changed, 22 insertions, 3 deletions
diff --git a/module/remote/thriftbackend/Socket.py b/module/remote/thriftbackend/Socket.py
index 33daab4c0..a492be59e 100644
--- a/module/remote/thriftbackend/Socket.py
+++ b/module/remote/thriftbackend/Socket.py
@@ -6,6 +6,8 @@ import errno
from thrift.transport.TSocket import TSocket, TServerSocket, TTransportException
+WantReadError = Exception #overwritten when ssl is used
+
class SecureSocketConnection:
def __init__(self, connection):
self.__dict__["connection"] = connection
@@ -22,6 +24,18 @@ class SecureSocketConnection:
def accept(self):
connection, address = self.__dict__["connection"].accept()
return SecureSocketConnection(connection), address
+
+ def send(self, buff):
+ try:
+ return self.__dict__["connection"].send(buff)
+ except WantReadError:
+ return self.send(buff)
+
+ def recv(self, buff):
+ try:
+ return self.__dict__["connection"].recv(buff)
+ except WantReadError:
+ return self.recv(buff)
class Socket(TSocket):
def __init__(self, host='localhost', port=7228, ssl=False):
@@ -31,8 +45,11 @@ class Socket(TSocket):
def open(self):
if self.ssl:
SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+ WantReadError = SSL.WantReadError
ctx = SSL.Context(SSL.SSLv23_METHOD)
- self.handle = SecureSocketConnection(SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)))
+ c = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))
+ c.set_connect_state()
+ self.handle = SecureSocketConnection(c)
else:
self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -67,7 +84,7 @@ class Socket(TSocket):
raise
if not len(buff):
- raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
+ raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
return buff
@@ -82,11 +99,13 @@ class ServerSocket(TServerSocket, Socket):
def listen(self):
if self.cert and self.key:
SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+ WantReadError = SSL.WantReadError
ctx = SSL.Context(SSL.SSLv23_METHOD)
ctx.use_privatekey_file(self.key)
ctx.use_certificate_file(self.cert)
tmpConnection = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))
+ tmpConnection.set_accept_state()
self.handle = SecureSocketConnection(tmpConnection)
else:
@@ -95,7 +114,7 @@ class ServerSocket(TServerSocket, Socket):
self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(self.handle, 'set_timeout'):
- self.handle.set_timeout(None)
+ self.handle.set_timeout(None)
self.handle.bind((self.host, self.port))
self.handle.listen(128)