From 899159508c24355903e37b31a009168a0129ca0d Mon Sep 17 00:00:00 2001
From: RaNaN <Mast3rRaNaN@hotmail.de>
Date: Fri, 11 Feb 2011 21:30:41 +0100
Subject: Thrift with SSL

---
 module/remote/RemoteManager.py            |  1 +
 module/remote/ThriftBackend.py            | 19 ++++++--
 module/remote/thriftbackend/Socket.py     | 78 +++++++++++++++++++++++++++++--
 module/remote/thriftbackend/ThriftTest.py | 10 ++--
 4 files changed, 96 insertions(+), 12 deletions(-)

(limited to 'module/remote')

diff --git a/module/remote/RemoteManager.py b/module/remote/RemoteManager.py
index 5edc2fbdf..b66ed75e5 100644
--- a/module/remote/RemoteManager.py
+++ b/module/remote/RemoteManager.py
@@ -69,6 +69,7 @@ class RemoteManager():
             else:
                 backend.start()
                 self.backends.append(backend)
+
     def checkAuth(self, user, password, remoteip=None):
         if self.core.config["remote"]["nolocalauth"] and remoteip == "127.0.0.1":
             return True
diff --git a/module/remote/ThriftBackend.py b/module/remote/ThriftBackend.py
index d7e59f7fa..ab262cf76 100644
--- a/module/remote/ThriftBackend.py
+++ b/module/remote/ThriftBackend.py
@@ -13,8 +13,10 @@
     You should have received a copy of the GNU General Public License
     along with this program; if not, see <http://www.gnu.org/licenses/>.
 
-    @author: mkaay
+    @author: mkaay, RaNaN
 """
+from os.path import exists
+
 from module.remote.RemoteManager import BackendBase
 
 from thriftbackend.Handler import Handler
@@ -23,19 +25,28 @@ from thriftbackend.Protocol import ProtocolFactory
 from thriftbackend.Socket import ServerSocket
 
 from thrift.transport import TTransport
-
 from thrift.server import TServer
 
 class ThriftBackend(BackendBase):
     def setup(self):
         handler = Handler(self)
         processor = Processor(handler)
-        transport = ServerSocket(7228)
+
+        key = None
+        cert = None
+
+        if self.core.config['ssl']['activated']:
+            if exists(self.core.config['ssl']['cert']) and exists(self.core.config['ssl']['key']):
+                self.core.log.info(_("Using SSL ThriftBackend"))
+                key = self.core.config['ssl']['key']
+                cert = self.core.config['ssl']['cert']
+
+        transport = ServerSocket(7228, self.core.config["remote"]["listenaddr"], key, cert)
 
         tfactory = TTransport.TBufferedTransportFactory()
         pfactory = ProtocolFactory()
         
-        self.server = TServer.TSimpleServer(processor, transport, tfactory, pfactory)
+        self.server = TServer.TThreadedServer(processor, transport, tfactory, pfactory)
         #self.server = TNonblockingServer.TNonblockingServer(processor, transport, tfactory, pfactory)
         
         #server = TServer.TThreadPoolServer(processor, transport, tfactory, pfactory)
diff --git a/module/remote/thriftbackend/Socket.py b/module/remote/thriftbackend/Socket.py
index 6ee850d07..cfb8b08c9 100644
--- a/module/remote/thriftbackend/Socket.py
+++ b/module/remote/thriftbackend/Socket.py
@@ -1,9 +1,27 @@
 # -*- coding: utf-8 -*-
 
+import sys
 import socket
+import errno
 
 from thrift.transport.TSocket import TSocket, TServerSocket, TTransportException
 
+class SecureSocketConnection:
+    def __init__(self, connection):
+        self.__dict__["connection"] = connection
+
+    def __getattr__(self, name):
+        return getattr(self.__dict__["connection"], name)
+
+    def __setattr__(self, name, value):
+        setattr(self.__dict__["connection"], name, value)
+
+    def shutdown(self, how=1):
+        self.__dict__["connection"].shutdown()
+
+    def accept(self):
+        connection, address = self.__dict__["connection"].accept()
+        return SecureSocketConnection(connection), address
 
 class Socket(TSocket):
     def __init__(self, host='localhost', port=7228, ssl=False):
@@ -11,21 +29,75 @@ class Socket(TSocket):
         self.ssl = ssl
 
     def open(self):
-        self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        if self.ssl:
+            SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+            ctx = SSL.Context(SSL.SSLv23_METHOD)
+            self.handle = SecureSocketConnection(SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM)))
+        else:
+            self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+        #errno 104 connection reset
+
         self.handle.settimeout(self._timeout)
         self.handle.connect((self.host, self.port))
 
+    def read(self, sz):
+        try:
+            buff = self.handle.recv(sz)
+        except socket.error, e:
+            if (e.args[0] == errno.ECONNRESET and
+                (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))):
+                # freebsd and Mach don't follow POSIX semantic of recv
+                # and fail with ECONNRESET if peer performed shutdown.
+                # See corresponding comment and code in TSocket::read()
+                # in lib/cpp/src/transport/TSocket.cpp.
+                self.close()
+                # Trigger the check to raise the END_OF_FILE exception below.
+                buff = ''
+            else:
+                raise
+        except Exception, e:
+            # SSL connection was closed
+            if e.args == (-1, 'Unexpected EOF'):
+                buff = ''
+            else:
+                raise
+            
+        if not len(buff):
+          raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes')
+        return buff
+
 
 class ServerSocket(TServerSocket, Socket):
     def __init__(self, port=7228, host="0.0.0.0", key="", cert=""):
         self.host = host
         self.port = port
+        self.key = key
+        self.cert = cert
         self.handle = None
 
     def listen(self):
-        self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        if self.cert and self.key:
+            SSL = __import__("OpenSSL", globals(), locals(), "SSL", -1).SSL
+            ctx = SSL.Context(SSL.SSLv23_METHOD)
+            ctx.use_privatekey_file(self.key)
+            ctx.use_certificate_file(self.cert)
+
+            tmpConnection = SSL.Connection(ctx, socket.socket(socket.AF_INET, socket.SOCK_STREAM))
+            self.handle = SecureSocketConnection(tmpConnection)
+
+        else:
+            self.handle = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+
         self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         if hasattr(self.handle, 'set_timeout'):
           self.handle.set_timeout(None)
         self.handle.bind((self.host, self.port))
-        self.handle.listen(128)
\ No newline at end of file
+        self.handle.listen(128)
+
+    def accept(self):
+        client, addr = self.handle.accept()
+        result = Socket()
+        result.setHandle(client)
+        return result
\ No newline at end of file
diff --git a/module/remote/thriftbackend/ThriftTest.py b/module/remote/thriftbackend/ThriftTest.py
index 587ca184a..8cfeb68d5 100644
--- a/module/remote/thriftbackend/ThriftTest.py
+++ b/module/remote/thriftbackend/ThriftTest.py
@@ -16,13 +16,13 @@ from thrift.transport import TTransport
 
 from Protocol import Protocol
 
-from time import sleep, time
+from time import time
 
 import xmlrpclib
 
 def bench(f, *args, **kwargs):
     s = time()
-    ret = [f(*args, **kwargs) for i in range(0,250)]
+    ret = [f(*args, **kwargs) for i in range(0,100)]
     e = time()
     try:
         print "%s: %f s" % (f._Method__name, e-s)
@@ -48,7 +48,7 @@ print
 try:
 
     # Make socket
-    transport = Socket('localhost', 7228)
+    transport = Socket('localhost', 7228, False)
 
     # Buffering is critical. Raw sockets are very slow
     transport = TTransport.TBufferedTransport(transport)
@@ -74,7 +74,7 @@ try:
     print client.getServerVersion()
     print client.statusServer()
     print client.statusDownloads()
-    q =  client.getQueue()
+    q = client.getQueue()
 
     for p in q:
       data = client.getPackageData(p.pid)
@@ -85,4 +85,4 @@ try:
     transport.close()
 
 except Thrift.TException, tx:
-    print 'ThriftExpection: %s' % (tx.message)
+    print 'ThriftExpection: %s' % tx.message
-- 
cgit v1.2.3