From 8b45dddde21dd2fbe24b492823085e1870c735bc Mon Sep 17 00:00:00 2001
From: mkaay <mkaay@mkaay.de>
Date: Tue, 15 Feb 2011 18:06:39 +0100
Subject: thrift ssl connection fix (closes #242)

---
 module/remote/thriftbackend/Socket.py | 25 ++++++++++++++++++++++---
 1 file changed, 22 insertions(+), 3 deletions(-)

(limited to 'module/remote/thriftbackend')

diff --git a/module/remote/thriftbackend/Socket.py b/module/remote/thriftbackend/Socket.py
index 33daab4c0..a492be59e 100644
--- a/module/remote/thriftbackend/Socket.py
+++ b/module/remote/thriftbackend/Socket.py
@@ -6,6 +6,8 @@ import errno
 
 from thrift.transport.TSocket import TSocket, TServerSocket, TTransportException
 
+WantReadError = Exception #overwritten when ssl is used
+
 class SecureSocketConnection:
     def __init__(self, connection):
         self.__dict__["connection"] = connection
@@ -22,6 +24,18 @@ class SecureSocketConnection:
     def accept(self):
         connection, address = self.__dict__["connection"].accept()
         return SecureSocketConnection(connection), address
+    
+    def send(self, buff):
+        try:
+            return self.__dict__["connection"].send(buff)
+        except WantReadError:
+            return self.send(buff)
+    
+    def recv(self, buff):
+        try:
+            return self.__dict__["connection"].recv(buff)
+        except WantReadError:
+            return self.recv(buff)
 
 class Socket(TSocket):
     def __init__(self, host='localhost', port=7228, ssl=False):
@@ -31,8 +45,11 @@ class Socket(TSocket):
     def open(self):
         if self.ssl:
             SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+            WantReadError = SSL.WantReadError
             ctx = SSL.Context(SSL.SSLv23_METHOD)
-            self.handle = SecureSocketConnection(SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)))
+            c = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))
+            c.set_connect_state()
+            self.handle = SecureSocketConnection(c)
         else:
             self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
@@ -67,7 +84,7 @@ class Socket(TSocket):
                 raise
             
         if not len(buff):
-          raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
+            raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
         return buff
 
 
@@ -82,11 +99,13 @@ class ServerSocket(TServerSocket, Socket):
     def listen(self):
         if self.cert and self.key:
             SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+            WantReadError = SSL.WantReadError
             ctx = SSL.Context(SSL.SSLv23_METHOD)
             ctx.use_privatekey_file(self.key)
             ctx.use_certificate_file(self.cert)
 
             tmpConnection = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))
+            tmpConnection.set_accept_state()
             self.handle = SecureSocketConnection(tmpConnection)
 
         else:
@@ -95,7 +114,7 @@ class ServerSocket(TServerSocket, Socket):
 
         self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         if hasattr(self.handle, 'set_timeout'):
-          self.handle.set_timeout(None)
+            self.handle.set_timeout(None)
         self.handle.bind((self.host, self.port))
         self.handle.listen(128)
 
-- 
cgit v1.2.3