[Zodb-checkins] CVS: ZODB4/src/transaction - manager.py:1.4

Jeremy Hylton jeremy@zope.com
Thu, 6 Mar 2003 19:17:55 -0500


Update of /cvs-repository/ZODB4/src/transaction
In directory cvs.zope.org:/tmp/cvs-serv3400/transaction

Modified Files:
	manager.py 
Log Message:
A bunch of small fixes.

Make txn_factory an attribute of the base class.

Raise an exception when prepare() returns False, rather than
automatically aborting.

Pass transaction object to Rollback() so that rollback() method can
check state of transaction.

Add IllegalStateError calls to prevent assertions from failing.
XXX Should the manager duplicate these checks?

Add suspend() and resume() to non-threaded txn manager.

Fix bug that caused threaded suspend() to fail with KeyError for
thread with no current transaction.


=== ZODB4/src/transaction/manager.py 1.3 => 1.4 ===
--- ZODB4/src/transaction/manager.py:1.3	Wed Mar  5 17:12:38 2003
+++ ZODB4/src/transaction/manager.py	Thu Mar  6 19:17:55 2003
@@ -11,6 +11,12 @@
 class AbstractTransactionManager(object):
     # base class to provide commit logic
     # concrete class must provide logger attribute
+
+    txn_factory = Transaction
+
+    # XXX the methods below use assertions, but perhaps they should
+    # check errors.  on the other hand, the transaction instances
+    # do raise exceptions.
     
     def commit(self, txn):
         # commit calls _finishCommit() or abort()
@@ -21,16 +27,14 @@
         try:
             for r in txn._resources:
                 if prepare_ok and not r.prepare(txn):
-                    prepare_ok = False
+                    raise AbortError(r)
         except:
             txn._status = Status.FAILED
             raise
         txn._status = Status.PREPARED
         # XXX An error below is intolerable.  What state to use?
-        if prepare_ok:
-            self._finishCommit(txn)
-        else:
-            self.abort(txn)
+        # Need code to handle this case.
+        self._finishCommit(txn)
 
     def _finishCommit(self, txn):
         self.logger.debug("%s: commit", txn)
@@ -48,18 +52,18 @@
         txn._status = Status.ABORTED
 
     def savepoint(self, txn):
+        assert txn._status == Status.ACTIVE
         self.logger.debug("%s: savepoint", txn)
-        return Rollback([r.savepoint(txn) for r in txn._resources])
+        return Rollback(txn, [r.savepoint(txn) for r in txn._resources])
 
 class TransactionManager(AbstractTransactionManager):
 
-    txn_factory = Transaction
-
     __implements__ = ITransactionManager
 
     def __init__(self):
         self.logger = logging.getLogger("txn")
         self._current = None
+        self._suspended = Set()
 
     def get(self):
         if self._current is None:
@@ -67,9 +71,11 @@
         return self._current
 
     def begin(self):
-        txn = self.txn_factory(self)
-        self.logger.debug("%s: begin", txn)
-        return txn
+        if self._current is not None:
+            self._current.abort()
+        self._current = self.txn_factory(self)
+        self.logger.debug("%s: begin", self._current)
+        return self._current
 
     def commit(self, txn):
         super(TransactionManager, self).commit(txn)
@@ -79,16 +85,31 @@
         super(TransactionManager, self).abort(txn)
         self._current = None
 
-    # XXX need suspend and resume
+    def suspend(self, txn):
+        if self._current != txn:
+            raise TransactionError("Can't suspend transaction because "
+                                   "it is not active")
+        self._suspended.add(txn)
+        self._current = None
 
+    def resume(self, txn):
+        if self._current is not None:
+            raise TransactionError("Can't resume while other "
+                                   "transaction is active")
+        self._suspended.remove(txn)
+        self._current = txn
+            
 class Rollback(object):
 
     __implements__ = IRollback
 
-    def __init__(self, resources):
+    def __init__(self, txn, resources):
+        self._txn = txn
         self._resources = resources
 
     def rollback(self):
+        if self._txn.status() != Status.ACTIVE:
+            raise IllegalStateError("rollback", self._txn.status())
         for r in self._resources:
             r.rollback()
 
@@ -150,7 +171,7 @@
 
     def suspend(self, txn):
         tid = thread.get_ident()
-        if self._pool[tid] is txn:
+        if self._pool.get(tid) is txn:
             self._suspend.add(txn)
             del self._pool[tid]
         else:
@@ -164,5 +185,5 @@
                                    tid)
         if txn not in self._suspend:
             raise TransactionError("unknown transaction: %s" % txn)
-        del self._suspend[txn]
+        self._suspend.remove(txn)
         self._pool[tid] = txn