[Zodb-checkins] SVN: ZODB/trunk/src/ZEO/ Refactored the zrpc implementation to:

Jim Fulton jim at zope.com
Thu Jan 28 16:49:35 EST 2010


Log message for revision 108624:
  Refactored the zrpc implementation to:
  
  - Most server methods now return data to clients more quickly by writing to
    client sockets immediately, rather than waiting for the asyncore
    select loop to get around to it.
  
  - More clearly define client and server responsibilities. Machinery
    needed for just clients or just servers has been moved to the
    corresponding connection subclasses.
  
  - Degeneralized "flags" argument to many methods. There's just one
    async flag.
  

Changed:
  U   ZODB/trunk/src/ZEO/StorageServer.py
  U   ZODB/trunk/src/ZEO/tests/servertesting.py
  U   ZODB/trunk/src/ZEO/tests/testZEO2.py
  U   ZODB/trunk/src/ZEO/zrpc/connection.py

-=-
Modified: ZODB/trunk/src/ZEO/StorageServer.py
===================================================================
--- ZODB/trunk/src/ZEO/StorageServer.py	2010-01-28 21:49:32 UTC (rev 108623)
+++ ZODB/trunk/src/ZEO/StorageServer.py	2010-01-28 21:49:34 UTC (rev 108624)
@@ -1340,10 +1340,10 @@
         self.rpc.callAsyncNoPoll('invalidateTransaction', tid, args)
 
     def serialnos(self, arg):
-        self.rpc.callAsync('serialnos', arg)
+        self.rpc.callAsyncNoPoll('serialnos', arg)
 
     def info(self, arg):
-        self.rpc.callAsync('info', arg)
+        self.rpc.callAsyncNoPoll('info', arg)
 
     def storeBlob(self, oid, serial, blobfilename):
 

Modified: ZODB/trunk/src/ZEO/tests/servertesting.py
===================================================================
--- ZODB/trunk/src/ZEO/tests/servertesting.py	2010-01-28 21:49:32 UTC (rev 108623)
+++ ZODB/trunk/src/ZEO/tests/servertesting.py	2010-01-28 21:49:34 UTC (rev 108624)
@@ -56,3 +56,5 @@
 
     def callAsync(self, meth, *args):
         print self.name, 'callAsync', meth, repr(args)
+
+    callAsyncNoPoll = callAsync

Modified: ZODB/trunk/src/ZEO/tests/testZEO2.py
===================================================================
--- ZODB/trunk/src/ZEO/tests/testZEO2.py	2010-01-28 21:49:32 UTC (rev 108623)
+++ ZODB/trunk/src/ZEO/tests/testZEO2.py	2010-01-28 21:49:34 UTC (rev 108624)
@@ -78,9 +78,10 @@
     >>> zs2.storeBlobEnd(oid, serial, data, '1')
     >>> delay = zs2.vote('1')
 
-    >>> def send_reply(id, reply):
-    ...     print 'reply', id, reply
-    >>> delay.set_sender(1, send_reply, None)
+    >>> class Sender:
+    ...     def send_reply(self, id, reply):
+    ...         print 'reply', id, reply
+    >>> delay.set_sender(1, Sender())
 
     >>> logger = logging.getLogger('ZEO')
     >>> handler = logging.StreamHandler(sys.stdout)

Modified: ZODB/trunk/src/ZEO/zrpc/connection.py
===================================================================
--- ZODB/trunk/src/ZEO/zrpc/connection.py	2010-01-28 21:49:32 UTC (rev 108623)
+++ ZODB/trunk/src/ZEO/zrpc/connection.py	2010-01-28 21:49:34 UTC (rev 108624)
@@ -30,7 +30,6 @@
 import ZODB.POSException
 
 REPLY = ".reply" # message name used for replies
-ASYNC = 1
 
 exception_type_type = type(Exception)
 
@@ -180,34 +179,33 @@
     the mainloop from sending a response.
     """
 
-    def set_sender(self, msgid, send_reply, return_error):
+    def set_sender(self, msgid, conn):
         self.msgid = msgid
-        self.send_reply = send_reply
-        self.return_error = return_error
+        self.conn = conn
 
     def reply(self, obj):
-        self.send_reply(self.msgid, obj)
+        self.conn.send_reply(self.msgid, obj)
 
     def error(self, exc_info):
         log("Error raised in delayed method", logging.ERROR, exc_info=True)
-        self.return_error(self.msgid, 0, *exc_info[:2])
+        self.conn.return_error(self.msgid, *exc_info[:2])
 
 class MTDelay(Delay):
 
     def __init__(self):
         self.ready = threading.Event()
 
-    def set_sender(self, msgid, send_reply, return_error):
-        Delay.set_sender(self, msgid, send_reply, return_error)
+    def set_sender(self, *args):
+        Delay.set_sender(self, *args)
         self.ready.set()
 
     def reply(self, obj):
         self.ready.wait()
-        Delay.reply(self, obj)
+        self.conn.call_from_thread(self.conn.send_reply, self.msgid, obj)
 
     def error(self, exc_info):
         self.ready.wait()
-        Delay.error(self, exc_info)
+        self.conn.call_from_thread(Delay.error, self, exc_info)
 
 # PROTOCOL NEGOTIATION
 #
@@ -304,9 +302,7 @@
     client for that particular call.
 
     The protocol also supports asynchronous calls.  The client does
-    not wait for a return value for an asynchronous call.  The only
-    defined flag is ASYNC.  If a method call message has the ASYNC
-    flag set, the server will raise an exception.
+    not wait for a return value for an asynchronous call.
 
     If a method call raises an Exception, the exception is propagated
     back to the client via the REPLY message.  The client side will
@@ -428,15 +424,6 @@
         # The singleton dict is a socket map containing only this object.
         self._singleton = {self._fileno: self}
 
-        # msgid_lock guards access to msgid
-        self.msgid = 0
-        self.msgid_lock = threading.Lock()
-
-        # replies_cond is used to block when a synchronous call is
-        # waiting for a response
-        self.replies_cond = threading.Condition()
-        self.replies = {}
-
         # waiting_for_reply is used internally to indicate whether
         # a call is in progress.  setting a session key is deferred
         # until after the call returns.
@@ -488,9 +475,6 @@
         self.closed = True
         self.__super_close()
         self.trigger.pull_trigger()
-        self.replies_cond.acquire()
-        self.replies_cond.notifyAll()
-        self.replies_cond.release()
 
     def register_object(self, obj):
         """Register obj as the true object to invoke methods on."""
@@ -537,29 +521,19 @@
         # will raise an exception.  The exception will ultimately
         # result in asycnore calling handle_error(), which will
         # close the connection.
-        msgid, flags, name, args = self.marshal.decode(message)
+        msgid, async, name, args = self.marshal.decode(message)
 
         if debug_zrpc:
-            self.log("recv msg: %s, %s, %s, %s" % (msgid, flags, name,
+            self.log("recv msg: %s, %s, %s, %s" % (msgid, async, name,
                                                    short_repr(args)),
                      level=TRACE)
         if name == REPLY:
-            self.handle_reply(msgid, flags, args)
+            assert not async
+            self.handle_reply(msgid, args)
         else:
-            self.handle_request(msgid, flags, name, args)
+            self.handle_request(msgid, async, name, args)
 
-    def handle_reply(self, msgid, flags, args):
-        if debug_zrpc:
-            self.log("recv reply: %s, %s, %s"
-                     % (msgid, flags, short_repr(args)), level=TRACE)
-        self.replies_cond.acquire()
-        try:
-            self.replies[msgid] = flags, args
-            self.replies_cond.notifyAll()
-        finally:
-            self.replies_cond.release()
-
-    def handle_request(self, msgid, flags, name, args):
+    def handle_request(self, msgid, async, name, args):
         obj = self.obj
 
         if name.startswith('_') or not hasattr(obj, name):
@@ -590,9 +564,14 @@
                 self.log("%s() raised exception: %s" % (name, msg),
                          logging.ERROR, exc_info=True)
             error = sys.exc_info()[:2]
-            return self.return_error(msgid, flags, *error)
+            if async:
+                self.log("Asynchronous call raised exception: %s" % self,
+                         level=logging.ERROR, exc_info=True)
+            else:
+                self.return_error(msgid, *error)
+            return
 
-        if flags & ASYNC:
+        if async:
             if ret is not None:
                 raise ZRPCError("async method %s returned value %s" %
                                 (name, short_repr(ret)))
@@ -601,43 +580,19 @@
                 self.log("%s returns %s" % (name, short_repr(ret)),
                          logging.DEBUG)
             if isinstance(ret, Delay):
-                ret.set_sender(msgid, self.send_reply, self.return_error)
+                ret.set_sender(msgid, self)
             else:
-                self.send_reply(msgid, ret)
+                self.send_reply(msgid, ret, not self.delay_sesskey)
 
         if self.delay_sesskey:
             self.__super_setSessionKey(self.delay_sesskey)
             self.delay_sesskey = None
 
-    def handle_error(self):
-        if sys.exc_info()[0] == SystemExit:
-            raise sys.exc_info()
-        self.log("Error caught in asyncore",
-                 level=logging.ERROR, exc_info=True)
-        self.close()
+    def return_error(self, msgid, err_type, err_value):
+        # Note that, ideally, this should be defined soley for
+        # servers, but a test arranges to get it called on
+        # a client. Too much trouble to fix it now. :/
 
-    def send_reply(self, msgid, ret):
-        # encode() can pass on a wide variety of exceptions from cPickle.
-        # While a bare `except` is generally poor practice, in this case
-        # it's acceptable -- we really do want to catch every exception
-        # cPickle may raise.
-        try:
-            msg = self.marshal.encode(msgid, 0, REPLY, ret)
-        except: # see above
-            try:
-                r = short_repr(ret)
-            except:
-                r = "<unreprable>"
-            err = ZRPCError("Couldn't pickle return %.100s" % r)
-            msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
-        self.message_output(msg)
-        self.poll()
-
-    def return_error(self, msgid, flags, err_type, err_value):
-        if flags & ASYNC:
-            self.log("Asynchronous call raised exception: %s" % self,
-                     level=logging.ERROR, exc_info=True)
-            return
         if not isinstance(err_value, Exception):
             err_value = err_type, err_value
 
@@ -657,79 +612,37 @@
         self.message_output(msg)
         self.poll()
 
+    def handle_error(self):
+        if sys.exc_info()[0] == SystemExit:
+            raise sys.exc_info()
+        self.log("Error caught in asyncore",
+                 level=logging.ERROR, exc_info=True)
+        self.close()
+
     def setSessionKey(self, key):
         if self.waiting_for_reply:
             self.delay_sesskey = key
         else:
             self.__super_setSessionKey(key)
 
-    # The next two public methods (call and callAsync) are used by
-    # clients to invoke methods on remote objects
+    def send_call(self, method, args, async=False):
+        # send a message and return its msgid
+        if async:
+            msgid = 0
+        else:
+            msgid = self._new_msgid()
 
-    def __new_msgid(self):
-        self.msgid_lock.acquire()
-        try:
-            msgid = self.msgid
-            self.msgid = self.msgid + 1
-            return msgid
-        finally:
-            self.msgid_lock.release()
-
-    def __call_message(self, method, args, flags):
-        # compute a message and return it
-        msgid = self.__new_msgid()
         if debug_zrpc:
-            self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
+            self.log("send msg: %d, %d, %s, ..." % (msgid, async, method),
                      level=TRACE)
-        return self.marshal.encode(msgid, flags, method, args)
-
-    def send_call(self, method, args, flags):
-        # send a message and return its msgid
-        msgid = self.__new_msgid()
-        if debug_zrpc:
-            self.log("send msg: %d, %d, %s, ..." % (msgid, flags, method),
-                     level=TRACE)
-        buf = self.marshal.encode(msgid, flags, method, args)
+        buf = self.marshal.encode(msgid, async, method, args)
         self.message_output(buf)
         return msgid
 
-    def call(self, method, *args):
-        if self.closed:
-            raise DisconnectedError()
-        msgid = self.send_call(method, args, 0)
-        r_flags, r_args = self.wait(msgid)
-        if (isinstance(r_args, tuple) and len(r_args) > 1
-            and type(r_args[0]) == exception_type_type
-            and issubclass(r_args[0], Exception)):
-            inst = r_args[1]
-            raise inst # error raised by server
-        else:
-            return r_args
-
-    # For testing purposes, it is useful to begin a synchronous call
-    # but not block waiting for its response.
-
-    def _deferred_call(self, method, *args):
-        if self.closed:
-            raise DisconnectedError()
-        msgid = self.send_call(method, args, 0)
-        self.trigger.pull_trigger()
-        return msgid
-
-    def _deferred_wait(self, msgid):
-        r_flags, r_args = self.wait(msgid)
-        if (isinstance(r_args, tuple)
-            and type(r_args[0]) == exception_type_type
-            and issubclass(r_args[0], Exception)):
-            inst = r_args[1]
-            raise inst # error raised by server
-        else:
-            return r_args
-
     def callAsync(self, method, *args):
         if self.closed:
             raise DisconnectedError()
-        self.send_call(method, args, ASYNC)
+        self.send_call(method, args, 1)
         self.poll()
 
     def callAsyncNoPoll(self, method, *args):
@@ -738,7 +651,7 @@
         # allowing any client to sneak in a load request.
         if self.closed:
             raise DisconnectedError()
-        self.send_call(method, args, ASYNC)
+        self.send_call(method, args, 1)
 
     def callAsyncIterator(self, iterator):
         """Queue a sequence of calls using an iterator
@@ -746,47 +659,12 @@
         The calls will not be interleaved with other calls from the same
         client.
         """
-        self.message_output(self.__outputIterator(iterator))
+        self.message_output(self.marshal.encode(0, 1, method, args)
+                            for method, args in iterator)
 
-    def __outputIterator(self, iterator):
-        for method, args in iterator:
-            yield self.__call_message(method, args, ASYNC)
+    def handle_reply(self, msgid, ret):
+        assert msgid == -1 and ret is None
 
-
-    def wait(self, msgid):
-        """Invoke asyncore mainloop and wait for reply."""
-        if debug_zrpc:
-            self.log("wait(%d)" % msgid, level=TRACE)
-
-        self.trigger.pull_trigger()
-
-        # Delay used when we call asyncore.poll() directly.
-        # Start with a 1 msec delay, double until 1 sec.
-        delay = 0.001
-
-        self.replies_cond.acquire()
-        try:
-            while 1:
-                if self.closed:
-                    raise DisconnectedError()
-                reply = self.replies.get(msgid)
-                if reply is not None:
-                    del self.replies[msgid]
-                    if debug_zrpc:
-                        self.log("wait(%d): reply=%s" %
-                                 (msgid, short_repr(reply)), level=TRACE)
-                    return reply
-                self.replies_cond.wait()
-        finally:
-            self.replies_cond.release()
-
-    def flush(self):
-        """Invoke poll() until the output buffer is empty."""
-        if debug_zrpc:
-            self.log("flush")
-        while self.writable():
-            self.poll()
-
     def poll(self):
         """Invoke asyncore mainloop to get pending message out."""
         if debug_zrpc:
@@ -794,7 +672,6 @@
         self.trigger.pull_trigger()
 
 
-
 class ManagedServerConnection(Connection):
     """Server-side Connection subclass."""
 
@@ -803,6 +680,7 @@
 
     # Servers use a shared server trigger that uses the asyncore socket map
     trigger = trigger()
+    call_from_thread = trigger.pull_trigger
 
     def __init__(self, sock, addr, obj, mgr):
         self.mgr = mgr
@@ -821,13 +699,33 @@
         self.obj.notifyDisconnected()
         Connection.close(self)
 
+    def send_reply(self, msgid, ret, immediately=True):
+        # encode() can pass on a wide variety of exceptions from cPickle.
+        # While a bare `except` is generally poor practice, in this case
+        # it's acceptable -- we really do want to catch every exception
+        # cPickle may raise.
+        try:
+            msg = self.marshal.encode(msgid, 0, REPLY, ret)
+        except: # see above
+            try:
+                r = short_repr(ret)
+            except:
+                r = "<unreprable>"
+            err = ZRPCError("Couldn't pickle return %.100s" % r)
+            msg = self.marshal.encode(msgid, 0, REPLY, (ZRPCError, err))
+        self.message_output(msg)
+        if immediately:
+            self.poll()
+
+    poll = smac.SizedMessageAsyncConnection.handle_write
+
 class ManagedClientConnection(Connection):
     """Client-side Connection subclass."""
     __super_init = Connection.__init__
-    __super_close = Connection.close
     base_message_output = Connection.message_output
 
     trigger = client_trigger
+    call_from_thread = trigger.pull_trigger
 
     def __init__(self, sock, addr, mgr):
         self.mgr = mgr
@@ -846,9 +744,24 @@
         self.queue_output = True
         self.queued_messages = []
 
+        # msgid_lock guards access to msgid
+        self.msgid = 0
+        self.msgid_lock = threading.Lock()
+
+        # replies_cond is used to block when a synchronous call is
+        # waiting for a response
+        self.replies_cond = threading.Condition()
+        self.replies = {}
+
         self.__super_init(sock, addr, None, tag='C', map=client_map)
         client_trigger.pull_trigger()
 
+    def close(self):
+        Connection.close(self)
+        self.replies_cond.acquire()
+        self.replies_cond.notifyAll()
+        self.replies_cond.release()
+
     # Our message_ouput() queues messages until recv_handshake() gets the
     # protocol handshake from the server.
     def message_output(self, message):
@@ -890,3 +803,88 @@
             self.queue_output = False
         finally:
             self.output_lock.release()
+
+    def _new_msgid(self):
+        self.msgid_lock.acquire()
+        try:
+            msgid = self.msgid
+            self.msgid = self.msgid + 1
+            return msgid
+        finally:
+            self.msgid_lock.release()
+
+    def call(self, method, *args):
+        if self.closed:
+            raise DisconnectedError()
+        msgid = self.send_call(method, args)
+        r_args = self.wait(msgid)
+        if (isinstance(r_args, tuple) and len(r_args) > 1
+            and type(r_args[0]) == exception_type_type
+            and issubclass(r_args[0], Exception)):
+            inst = r_args[1]
+            raise inst # error raised by server
+        else:
+            return r_args
+
+    def wait(self, msgid):
+        """Invoke asyncore mainloop and wait for reply."""
+        if debug_zrpc:
+            self.log("wait(%d)" % msgid, level=TRACE)
+
+        self.trigger.pull_trigger()
+
+        # Delay used when we call asyncore.poll() directly.
+        # Start with a 1 msec delay, double until 1 sec.
+        delay = 0.001
+
+        self.replies_cond.acquire()
+        try:
+            while 1:
+                if self.closed:
+                    raise DisconnectedError()
+                reply = self.replies.get(msgid, self)
+                if reply is not self:
+                    del self.replies[msgid]
+                    if debug_zrpc:
+                        self.log("wait(%d): reply=%s" %
+                                 (msgid, short_repr(reply)), level=TRACE)
+                    return reply
+                self.replies_cond.wait()
+        finally:
+            self.replies_cond.release()
+
+    # For testing purposes, it is useful to begin a synchronous call
+    # but not block waiting for its response.
+
+    def _deferred_call(self, method, *args):
+        if self.closed:
+            raise DisconnectedError()
+        msgid = self.send_call(method, args)
+        self.trigger.pull_trigger()
+        return msgid
+
+    def _deferred_wait(self, msgid):
+        r_args = self.wait(msgid)
+        if (isinstance(r_args, tuple)
+            and type(r_args[0]) == exception_type_type
+            and issubclass(r_args[0], Exception)):
+            inst = r_args[1]
+            raise inst # error raised by server
+        else:
+            return r_args
+
+    def handle_reply(self, msgid, args):
+        if debug_zrpc:
+            self.log("recv reply: %s, %s"
+                     % (msgid, short_repr(args)), level=TRACE)
+        self.replies_cond.acquire()
+        try:
+            self.replies[msgid] = args
+            self.replies_cond.notifyAll()
+        finally:
+            self.replies_cond.release()
+
+    def send_reply(self, msgid, ret):
+        # Whimper. Used to send heartbeat
+        assert msgid == -1 and ret is None
+        self.message_output('(J\xff\xff\xff\xffK\x00U\x06.replyNt.')



More information about the Zodb-checkins mailing list