[Zope3-checkins] CVS: Zope3/lib/python/ZODB - BaseStorage.py:1.21 Connection.py:1.72 DB.py:1.50 ExportImport.py:1.16 TmpStore.py:1.9

Jeremy Hylton jeremy@zope.com
Wed, 24 Jul 2002 19:13:16 -0400


Update of /cvs-repository/Zope3/lib/python/ZODB
In directory cvs.zope.org:/tmp/cvs-serv3960/ZODB

Modified Files:
	BaseStorage.py Connection.py DB.py ExportImport.py TmpStore.py 
Log Message:
Implement new Transaction APIs.

Add ZODB.ZTransaction.Transaction, which extends the basic transaction
object with user, description, and extended metadata.

XXX The rollback() part of the savepoint() protocol doesn't work yet.
There's an XXX in the code that explains why, and the test of it is
not enabled.

Connect.py:

Also change the registerDB() protocol to omit the limit argument,
which was not used in any implementation.

Refactor the ExportImport hooks to use a single hook method --
importHook() to do all the work.  Remove the onCommitActions()
framework.

Delete all the code related to reset for persistence-based classes,
since it was complex and unused by the new Persistence.Class.  Also
remove exchange() method which was used for ZClass support.

Use Set(dict) to replace many of the small lists of modified objects.

Remove the cacheGC attribute on Connection.

DB.py:

Fixup AbortVersion, CommitVersion, and TransactionalUndo data
managers.  Each is now a subclass of the base SimpleDataManager.

ExportImport.py:

Changes for importHook() as mentioned above. 

Also, lots of little code cleanup.

TmpStore.py:

Refactor enough to support rollback() at the storage level.



=== Zope3/lib/python/ZODB/BaseStorage.py 1.20 => 1.21 ===
             if d < 255: return last[:-1]+chr(d+1)+'\0'*(8-len(last))
             else:       return self.new_oid(last[:-1])
 
-    def registerDB(self, db, limit):
+    def registerDB(self, db):
         pass # we don't care
 
     def isReadOnly(self):


=== Zope3/lib/python/ZODB/Connection.py 1.71 => 1.72 ===
 $Id$
 """
 
-from ZODB import ExportImport, TmpStore
+from ZODB import ExportImport
+from ZODB.TmpStore import TmpStore
 from ZODB.ConflictResolution import ResolvedSerial
 from ZODB.IConnection import IConnection
 from ZODB.POSException import ConflictError
 from ZODB.utils import U64
+from Transaction import get_transaction
 
 from Persistence.Cache import Cache
-from Transaction import get_transaction
 from zLOG import LOG, ERROR, BLATHER, INFO
 
 from cPickle import Unpickler, Pickler
@@ -59,23 +60,13 @@
 import time
 from types import StringType, ClassType, TupleType
 
-# XXX Does this need to be an actual timestamp or would a counter be
-# sufficient? 
-_cache_timestamp = 0
-
-def resetPersistentCaches():
-    """Call to reset caches for all Connections.
-    
-    This function should be called after changes are made to
-    persistence-based classes.  It causes all connection caches to be
-    re-created as the connections are reopened.
-    """
-    # WARNING: It is only safe to call this method if the application
-    # never keeps references to ZODB objects after the connection is
-    # closed.  If this rule is violated, ZODB will load inconsistent
-    # duplicates of those objects and you may see silent corruption.
-    global _cache_timestamp
-    _cache_timestamp = time.time()
+class Set(dict):
+    def add(self, o):
+        self[o] = 1
+        
+    def addmany(self, L):
+        for o in L:
+            self[o] = 1
 
 class Connection(ExportImport.ExportImport):
     """Object managers for individual object space.
@@ -106,14 +97,13 @@
         self.new_oid = db._storage.new_oid
         self._version = version
         self._cache = cache = Cache(cache_size, cache_deactivate_after)
-        self.cacheGC = cache.incrgc
 
         # _invalidated queues invalidate messages delivered from the DB
-        self._invalidated = d = {}
-        self._invalid = d.has_key
+        self._invalidated = Set()
         self._committed = []
         
-        self._cache_timestamp = _cache_timestamp # set from global
+        # track which objects are involved with a transaction
+        self._txns = {}
 
     def getVersion(self):
         return self._version
@@ -246,10 +236,10 @@
             # to avoid time-of-check to time-of-use race.
             p, serial = self._storage.load(oid, self._version)
 
-            if self._invalid(oid):
+            if oid in self._invalidated:
                 if not (hasattr(object, '_p_independent')
                         and object._p_independent()):
-                    get_transaction().register(self)
+                    get_transaction().join(self)
                     raise ConflictError(object=object)
                 invalid = 1
             else:
@@ -273,7 +263,7 @@
                     except KeyError:
                         pass
                 else:
-                    get_transaction().register(self)
+                    get_transaction().join(self)
                     raise ConflictError(object=object)
 
         except ConflictError:
@@ -287,24 +277,24 @@
             self._cache.setstate(oid, object)
 
     def register(self, object):
-        get_transaction().register(object)
+        txn = get_transaction()
+        L = self._txns.get(txn)
+        if L is None:
+            L = self._txns[txn] = []
+            txn.join(self)
+        L.append(object)
 
     def mtime(self, object):
         # required by the IPersistentDataManager interface, but unimplemented
         return None
 
     def reset(self):
-        if self._cache_timestamp != _cache_timestamp:
-            # New code is in place.  Start a new cache.
-            self._resetCache()
-        else:
-            # XXX race condition?
-            self._cache.invalidateMany(self._invalidated.iterkeys())
-            self._invalidated.clear()
+        self._cache.invalidateMany(self._invalidated)
+        self._invalidated.clear()
         self._opened = time.time()
 
     def close(self):
-        self.cacheGC() # This is a good time to do some GC
+        self._cache.incrgc()
         self.applyCloseCallbacks()
         self._opened = None
         # Return the connection to the pool.
@@ -331,31 +321,8 @@
                         error=sys.exc_info())
             self.__onCloseCallbacks = None
 
-    # the commit actions are used by ExportImport
-
-    __onCommitActions = None
-    
-    def onCommitAction(self, method_name, *args, **kw):
-        if self.__onCommitActions is None:
-            self.__onCommitActions = []
-        self.__onCommitActions.append((method_name, args, kw))
-        get_transaction().register(self)
-
     # some cache-related methods
 
-    def _resetCache(self):
-        "Creates a new cache, discarding the old."
-        # WARNING: This method is not part of the public API!
-        # If you call this method when anything besides the
-        # cache has references to ZODB objects, ZODB will start loading
-        # inconsistent duplicates of the objects, leading to silent
-        # corruption and gnashing of teeth.  And don't call this method
-        # while objects are pending in the transaction queue.  Ouch.
-        self._cache_timestamp = _cache_timestamp
-        self._invalidated.clear()
-        orig_cache = self._cache
-        self._cache = Cache(orig_cache._size, orig_cache._inactive)
-
     def cacheFullSweep(self, dt=0):
         self._cache.full_sweep(dt)
         
@@ -373,46 +340,21 @@
         transaction boundaries.
         """
         assert oid is not None
-        # XXX race condition?
-        self._invalidated[oid] = 1
-
-    ######################################################################
-    # Transaction.IDataManager
-    # requires the next 8 methods:
-    # abort(), tpc_begin(), commit(), tpc_vote(), tpc_finish(),
-    # tpc_abort(), abort_sub(), commit_sub()
-    
-    def abort(self, object, transaction):
-        """Invalidate the object (or all objects if None)."""
-        if object is self:
-            # XXX race condition?
-            self._cache.invalidateMany(self._invalidated.iterkeys())
-            self._invalidated.clear()
-        else:
-            self._cache.invalidate(object._p_oid)
+        self._invalidated.add(oid)
 
-    def commit(self, object, transaction):
-        if object is self:
-            # If None was registered and the transaction is
-            # committing, then execute a commit action.
-            if self.__onCommitActions is not None:
-                method_name, args, kw = self.__onCommitActions.pop(0)
-                getattr(self, method_name)(transaction, *args, **kw)
-            return
-        
+    def objcommit(self, object, transaction):
         oid = object._p_oid
+        
         if oid is None or object._p_jar is not self:
-            # new object
             oid = self.new_oid()
             object._p_jar = self
             object._p_oid = oid
-            self._created.append(oid)
+            self._created.add(oid)
         elif object._p_changed:
-            # XXX Is it kosher to raise a ConflictError on commit?
-            if (self._invalid(oid)
-                and not hasattr(object, '_p_resolveConflict')):
+            if (oid in self._invalidated and
+                not hasattr(object, '_p_resolveConflict')):
                 raise ConflictError(object=object)
-            self._modified.append(oid)
+            self._modified.add(oid)
         else:
             return # Nothing to do
 
@@ -430,16 +372,15 @@
         # Maybe just create new ones each time?  Except I'm not sure
         # how that interacts with the persistent_id attribute.
         oid = pobject._p_oid
-        serial = getattr(pobject, '_p_serial', '\0\0\0\0\0\0\0\0')
-        if serial == '\0\0\0\0\0\0\0\0':
-            # new object
-            self._created.append(oid)
+        serial = getattr(pobject, '_p_serial', None)
+        if serial is None:
+            self._created.add(oid)
         else:
-            # We should only get here for the original object passed to commit
+            # XXX this seems to duplicate code on objcommit()
             if (oid in self._invalidated and
                 not hasattr(pobject, '_p_resolveConflict')):
                 raise ConflictError(oid=oid)
-            self._modified.append(oid)
+            self._modified.add(oid)
 
         klass = pobject.__class__
 
@@ -467,45 +408,42 @@
         self._cache[oid] = pobject
         self._handle_serial(s, oid)
 
-    def commit_sub(self, t):
+    def commit_sub(self, txn):
         """Commit all work done in subtransactions"""
-        if self._tmp is None:
-            return
+        assert self._tmp is not None
         
-        src = self._storage
-        tmp = self._storage = self._tmp
+        tmp = self._storage
+        self._storage = self._tmp
         self._tmp = None
 
-        tmp.tpc_begin(t)
+        self._storage.tpc_begin(txn)
         
-        oids = src._index.keys()
         # Copy invalidating and creating info from temporary storage:
-        self._modified.extend(oids)
-        self._created.extend(src._created)
+        self._modified.addmany(tmp._index)
+        self._created.addmany(tmp._created)
         
-        for oid in oids:
-            data, serial = src.load(oid, src)
-            s = tmp.store(oid, serial, data, self._version, t)
+        for oid in tmp._index:
+            data, serial = tmp.load(oid, tmp._bver)
+            s = self._storage.store(oid, serial, data, self._version, txn)
             self._handle_serial(s, oid, change=0)
 
-    def abort_sub(self, t):
+    def abort_sub(self):
         """Abort work done in subtransactions"""
-        if self._tmp is None:
-            return
-        src = self._storage
+        assert self._tmp is not None
+
+        tmp = self._storage
         self._storage = self._tmp
         self._tmp = None
 
-        self._cache.invalidateMany(src._index.iterkeys())
-        src._index.clear()
-        self._invalidate_created(src._created)
+        self._cache.invalidateMany(tmp._index)
+        self._invalidate_created(tmp._created)
 
     def _invalidate_created(self, created=None):
         """Dissown any objects newly saved in an uncommitted transaction.
         """
         if created is None:
             created = self._created
-            self._created = []
+            self._created = Set()
 
         for oid in created:
             o = self._cache.get(oid)
@@ -514,6 +452,9 @@
                 del o._p_oid
                 del self._cache[oid]
 
+    ######################################################################
+    # Transaction.IDataManager
+    
     def oldstate(self, object, serial):
         """Return the state of an object as of serial.
 
@@ -522,43 +463,60 @@
         p = self._storage.loadSerial(object._p_oid, serial)
         return self._unpickle_object(p)
 
-    def tpc_abort(self, transaction):
-        if self.__onCommitActions is not None:
-            del self.__onCommitActions
-        self._storage.tpc_abort(transaction)
-        # XXX race condition?
-        self._cache.invalidateMany(self._invalidated.iterkeys())
-        self._invalidated.clear()
+    def savepoint(self, txn):
+        if self._tmp is None:
+            tmp = TmpStore(self._version)
+            self._tmp = self._storage
+            self._storage = tmp
+            tmp.registerDB(self._db)
+        self._modified = Set()
+        self._created = Set()
+        self._storage.tpc_begin(txn)
+        
+        for obj in self._txns.get(txn, ()):
+            self.objcommit(obj, txn)
+        self.importHook(txn) # hook for ExportImport
+
+        undo = self._storage.tpc_finish(txn)
+        self._storage._created = self._created
+        self._created = Set()
+        return Rollback(self, undo)
+
+    def abort(self, txn):
+        if self._tmp is not None:
+            self.abort_sub()
+        self._storage.tpc_abort(txn)
+
+        objs = self._txns.get(txn)
+        if objs is not None:
+            self._cache.invalidateMany([obj._p_oid for obj in objs])
+            del self._txns[txn]
+        self._cache.invalidateMany(self._invalidated)
         self._cache.invalidateMany(self._modified)
-        del self._modified[:]
         self._invalidate_created()
 
-    def tpc_begin(self, transaction, sub=None):
-        # _modified is a list of the oids of the objects modified
-        # by this transaction.
-        self._modified = []
-        self._created = []
-
-        if sub:
-            # Sub-transaction!
-            tmp = self._tmp
-            if tmp is None:
-                tmp = TmpStore.TmpStore(self._version)
-                self._tmp = self._storage
-                self._storage = tmp
-                tmp.registerDB(self._db, 0)
-
-        self._storage.tpc_begin(transaction)
-
-    def tpc_vote(self, transaction):
-        if self.__onCommitActions is not None:
-            del self.__onCommitActions
+        self._invalidated.clear()
+        self._modified.clear()
+
+    def prepare(self, txn):
+        self._modified = Set()
+        self._created = Set()
+        if self._tmp is not None:
+            # commit_sub() will call tpc_begin() on the real storage
+            self.commit_sub(txn)
+        else:
+            self._storage.tpc_begin(txn)
+
+        for obj in self._txns.get(txn, ()):
+            self.objcommit(obj, txn)
+
         try:
-            vote = self._storage.tpc_vote
-        except AttributeError:
-            return
-        s = vote(transaction)
-        self._handle_serial(s)
+            s = self._storage.tpc_vote(txn)
+            self._handle_serial(s)
+        except Exception, err:
+            print "Error during tpc_vote", err
+            return False
+        return True
 
     def _handle_serial(self, store_return, oid=None, change=1):
         """Handle the returns from store() and tpc_vote() calls."""
@@ -602,29 +560,23 @@
                         obj._p_changed = 0
                     obj._p_serial = serial
 
-
-    def tpc_finish(self, transaction):
+    def commit(self, txn):
         # It's important that the storage call the function we pass
         # (self._invalidate_modified) while it still has it's
         # lock.  We don't want another thread to be able to read any
         # updated data until we've had a chance to send an
         # invalidation message to all of the other connections!
 
-        if self._tmp is not None:
-            # Commiting a subtransaction!
-            # There is no need to invalidate anything.
-            self._storage.tpc_finish(transaction)
-            self._storage._created[:0] = self._created
-            del self._created[:]
-        else:
-            self._db.begin_invalidation()
-            self._storage.tpc_finish(transaction,
-                                     self._invalidate_modified)
+        self._db.begin_invalidation()
+        self._storage.tpc_finish(txn, self._invalidate_modified)
+        try:
+            del self._txns[txn]
+        except KeyError:
+            pass
 
-        # XXX race condition?
-        self._cache.invalidateMany(self._invalidated.iterkeys())
+        self._cache.invalidateMany(self._invalidated)
         self._invalidated.clear()
-        self.cacheGC() # This is a good time to do some GC
+        self._cache.incrgc() 
 
     def _invalidate_modified(self):
         for oid in self._modified:
@@ -637,20 +589,9 @@
         if sync is not None:
             sync()
         # XXX race condition?
-        self._cache.invalidateMany(self._invalidated.iterkeys())
+        self._cache.invalidateMany(self._invalidated)
         self._invalidated.clear()
-        self.cacheGC() # This is a good time to do some GC
-
-    def exchange(self, old, new):
-        # Replace an existing object with a new one.
-        # This is used by ZClasses to support some deep
-        # hacking to allow base classes to change.
-        oid = old._p_oid
-        new._p_oid = oid
-        new._p_jar = self
-        new._p_changed = 1
-        get_transaction().register(new)
-        self._cache[oid] = new
+        self._cache.incrgc() # This is a good time to do some GC
         
 def new_persistent_id(self, stack):
     # XXX need a doc string.  not sure if the one for persistent_id()
@@ -703,3 +644,23 @@
         return oid, klass
     
     return persistent_id
+
+class Rollback:
+    """Rollback changes associated with savepoint"""
+
+    # XXX This doesn't work yet.
+
+    # It needs to invalidate objects modified after the previous
+    # savepoint or the start of the transaction if it is the first
+    # savepoint.
+
+    def __init__(self, conn, tmp_undo):
+        self._conn = conn
+        self._tmp_undo = tmp_undo # undo info from the storage
+
+    def rollback(self):
+        if not self._tmp_undo.current(self._conn._storage):
+            # need better error
+            raise RuntimeError, "savepoint has already been committed"
+        self._tmp_undo.rollback()
+        


=== Zope3/lib/python/ZODB/DB.py 1.49 => 1.50 ===
 import POSException
 from Connection import Connection
 from threading import Lock
-from Transaction import Transaction
 from referencesf import referencesf
 from time import time, ctime
 from zLOG import LOG, ERROR
-
+from ZODB.ZTransaction import Transaction
 from Transaction import get_transaction
 
+from Transaction.IDataManager import IDataManager
+
 from types import StringType
 
 class DB:
@@ -72,7 +73,7 @@
 
         # Setup storage
         self._storage = storage
-        storage.registerDB(self, None)
+        storage.registerDB(self)
         try:
             storage.load('\0\0\0\0\0\0\0\0', '')
         except KeyError:
@@ -539,60 +540,66 @@
     def versionEmpty(self, version):
         return self._storage.versionEmpty(version)
 
-class CommitVersion:
-    """An object that will see to version commit
+class SimpleDataManager:
 
-    in cooperation with a transaction manager.
-    """
-    def __init__(self, db, version, dest=''):
-        self._db=db
-        s=db._storage
-        self._version=version
-        self._dest=dest
-        self.tpc_abort=s.tpc_abort
-        self.tpc_begin=s.tpc_begin
-        self.tpc_vote=s.tpc_vote
-        self.tpc_finish=s.tpc_finish
-        get_transaction().register(self)
+    __implements__ = IDataManager
+    
+    def __init__(self, db):
+        self._db = db
+        self._storage = db._storage
+        get_transaction().join(self)
+
+    def prepare(self, txn):
+        self._storage.tpc_begin(txn)
+        if self._storage.tpc_vote(txn):
+            return True
+        else:
+            return False
 
-    def abort(self, reallyme, t): pass
+    def abort(self, txn):
+        pass
 
-    def commit(self, reallyme, t):
-        db=self._db
-        dest=self._dest
-        oids=db._storage.commitVersion(self._version, dest, t)
-        for oid in oids: db.invalidate(oid, version=dest)
-        if dest:
-            # the code above just invalidated the dest version.
-            # now we need to invalidate the source!
-            for oid in oids: db.invalidate(oid, version=self._version)
-    
-class AbortVersion(CommitVersion):
-    """An object that will see to version abortion
+    def commit(self, txn):
+        pass
 
-    in cooperation with a transaction manager.
-    """
+class CommitVersion(SimpleDataManager):
+    """An object that will see to version commit."""
 
-    def commit(self, reallyme, t):
-        db=self._db
-        version=self._version
-        oids = db._storage.abortVersion(version, t)
-        for oid in oids:
-            db.invalidate(oid, version=version)
+    def __init__(self, db, version, dest=''):
+        super(CommitVersion, self).__init__(db)
+        self._version = version
+        self._dest = dest
 
+    def commit(self, txn):
+        oids = db._storage.commitVersion(self._version, dest, txn)
+        for oid in oids:
+            self._db.invalidate(oid, version=self._dest)
+        if self._dest:
+            # the code above just invalidated the dest version.
+            # now we need to invalidate the source!
+            for oid in oids:
+                self._db.invalidate(oid, version=self._version)
+                
+class AbortVersion(SimpleDataManager):
+    """An object that will see to version abortion."""
 
-class TransactionalUndo(CommitVersion):
-    """An object that will see to transactional undo
+    def __init__(self, db, version):
+        super(CommitVersion, self).__init__(db)
+        self._version = version
+        
+    def commit(self, txn):
+        oids = self._db._storage.abortVersion(version, txn)
+        for oid in oids:
+            self._db.invalidate(oid, version=self._version)
 
-    in cooperation with a transaction manager.
-    """
+class TransactionalUndo(SimpleDataManager):
+    """An object that will see to transactional undo."""
     
-    # I'm lazy. I'm reusing __init__ and abort and reusing the
-    # version attr for the transavtion id. There's such a strong
-    # similarity of rythm, that I think it's justified.
-
+    def __init__(self, db, tid):
+        super(CommitVersion, self).__init__(db)
+        self._tid = tid
+        
     def commit(self, reallyme, t):
-        db=self._db
-        oids=db._storage.transactionalUndo(self._version, t)
+        oids = self._db._storage.transactionalUndo(self._tid, t)
         for oid in oids:
-            db.invalidate(oid)
+            self._db.invalidate(oid)


=== Zope3/lib/python/ZODB/ExportImport.py 1.15 => 1.16 ===
 # FOR A PARTICULAR PURPOSE.
 # 
 ##############################################################################
-"""Support for database export and import.
-"""
+"""Support for database export and import."""
 
-from Transaction import get_transaction
 from ZODB import POSException
+from ZODB.utils import p64, u64
+from ZODB.referencesf import referencesf
+from Transaction import get_transaction
 
-from utils import p64, u64
-from referencesf import referencesf
 from cStringIO import StringIO
 from cPickle import Pickler, Unpickler
+from tempfile import TemporaryFile
 from types import StringType, TupleType
 
 class ExportImport:
+    # a mixin for use with ZODB.Connection.Connection
 
-    def exportFile(self, oid, file=None):
+    __hooks = None
 
+    def exportFile(self, oid, file=None):
         if file is None:
             file = TemporaryFile()
         elif isinstance(file, StringType):
             file = open(file, 'w+b')
         file.write('ZEXP')
-        version=self._version
-        ref=referencesf
-        oids=[oid]
-        done_oids={}
-        done=done_oids.has_key
-        load=self._storage.load
+        oids = [oid]
+        done_oids = {}
         while oids:
-            oid=oids[0]
-            del oids[0]
-            if done(oid): continue
-            done_oids[oid]=1
-            try: p, serial = load(oid, version)
-            except: pass # Ick, a broken reference
+            oid = oids.pop(0)
+            if oid in done_oids:
+                continue
+            done_oids[oid] = 1
+            try:
+                p, serial = self._storage.load(oid, self._version)
+            except:
+                # XXX what exception is expected?
+                pass # Ick, a broken reference
             else:
-                ref(p, oids)
+                referencesf(p, oids)
                 file.write(oid)
                 file.write(p64(len(p)))
                 file.write(p)
         file.write(export_end_marker)
         return file
 
-    def importFile(self, file, clue='', customImporters=None):
+    def importFile(self, file, clue=None, customImporters=None):
         # This is tricky, because we need to work in a transaction!
+        # XXX I think this needs to work in a transaction, because it
+        # needs to write pickles into the storage, which only allows
+        # store() calls between tpc_begin() and tpc_vote().
 
         if isinstance(file, StringType):
             file = open(file,'rb')
-        read = file.read
-
-        magic = read(4)
+        magic = file.read(4)
 
         if magic != 'ZEXP':
-            if customImporters and customImporters.has_key(magic):
+            if customImporters is not None and customImporters.has_key(magic):
                 file.seek(0)
                 return customImporters[magic](self, file, clue)
             raise POSException.ExportError, 'Invalid export header'
 
         t = get_transaction()
-        if clue:
+        if clue is not None:
             t.note(clue)
 
-        return_oid_list = []
-        self.onCommitAction('_importDuringCommit', file, return_oid_list)
-        t.commit(1)
+        L = []
+        if self.__hooks is None:
+            self.__hooks = []
+        self.__hooks.append((file, L))
+        t.join(self)
+        t.savepoint()
         # Return the root imported object.
-        if return_oid_list:
-            return self[return_oid_list[0]]
+        if L:
+            return self[L[0]]
         else:
             return None
 
+    def importHook(self, txn):
+        if self.__hooks is None:
+            return
+        for file, L in self.__hooks:
+            self._importDuringCommit(txn, file, L)
+        del self.__hooks
+
     def _importDuringCommit(self, transaction, file, return_oid_list):
         """Invoked by the transaction manager mid commit.
         
         Appends one item, the OID of the first object created,
         to return_oid_list.
         """
-
         oids = {}
-        storage = self._storage
-        new_oid = storage.new_oid
-        store = storage.store
-        read = file.read
-
-        def persistent_load(ooid,
-                            Ghost=Ghost, 
-                            oids=oids, wrote_oid=oids.has_key,
-                            new_oid=storage.new_oid):
-        
+
+        def persistent_load(ooid):
             "Remap a persistent id to a new ID and create a ghost for it."
 
             if isinstance(ooid, TupleType):
@@ -106,65 +109,64 @@
             else:
                 klass = None
 
-            if wrote_oid(ooid): oid=oids[ooid]
-            else:
-                if klass is None: oid=new_oid()
-                else: oid=new_oid(), klass
-                oids[ooid]=oid
-
-            Ghost=Ghost()
-            Ghost.oid=oid
-            return Ghost
+            oid = oids.get(ooid)
+            if oid is None:
+                if klass is None:
+                    oid = self._storage.new_oid()
+                    self._created.add(oid)
+                else:
+                    oid = self._storage.new_oid(), klass
+                    self._created.add(oid[0])
+                oids[ooid] = oid
+
+            g = Placeholder()
+            g.oid = oid
+            return g
 
         version = self._version
 
         while 1:
-            h=read(16)
-            if h==export_end_marker: break
+            h = file.read(16)
+            if h == export_end_marker:
+                break
             if len(h) != 16:
                 raise POSException.ExportError, 'Truncated export file'
-            l=u64(h[8:16])
-            p=read(l)
+            l = u64(h[8:16])
+            p = file.read(l)
             if len(p) != l:
                 raise POSException.ExportError, 'Truncated export file'
 
-            ooid=h[:8]
+            # XXX what does the tuple in oids mean?
+            ooid = h[:8]
             if oids:
-                oid=oids[ooid]
+                oid = oids[ooid]
                 if isinstance(oid, TupleType):
                     oid = oid[0]
             else:
-                oids[ooid] = oid = storage.new_oid()
+                oids[ooid] = oid = self._storage.new_oid()
                 return_oid_list.append(oid)
+                self._created.add(oid)
 
-            pfile=StringIO(p)
-            unpickler=Unpickler(pfile)
-            unpickler.persistent_load=persistent_load
-
-            newp=StringIO()
-            pickler=Pickler(newp,1)
-            pickler.persistent_id=persistent_id
+            pfile = StringIO(p)
+            unpickler = Unpickler(pfile)
+            unpickler.persistent_load = persistent_load
+
+            newp = StringIO()
+            pickler = Pickler(newp, 1)
+            pickler.persistent_id = persistent_id
 
             pickler.dump(unpickler.load())
             pickler.dump(unpickler.load())
-            p=newp.getvalue()
-
-            store(oid, None, p, version, transaction)
-
+            p = newp.getvalue()
 
-def TemporaryFile():
-    # This is sneaky suicide
-    global TemporaryFile
-    import tempfile
-    TemporaryFile=tempfile.TemporaryFile
-    return TemporaryFile()
+            self._storage.store(oid, None, p, version, transaction)
 
-export_end_marker='\377'*16
+export_end_marker = '\377' * 16
 
-class Ghost:
+class Placeholder(object):
     pass
 
-def persistent_id(object, Ghost=Ghost):
-    if getattr(object, '__class__', None) is Ghost:
+def persistent_id(object):
+    if isinstance(object, Placeholder):
         return object.oid
 


=== Zope3/lib/python/ZODB/TmpStore.py 1.8 => 1.9 ===
         # _tindex: map oid to pos for new updates
         self._tindex = {}
         self._db = None
-        # XXX what is this for?
-        self._created = []
 
     def close(self):
         self._file.close()
@@ -68,7 +66,7 @@
     def new_oid(self):
         return self._db._storage.new_oid()
 
-    def registerDB(self, db, limit):
+    def registerDB(self, db):
         self._db = db
         self._storage = db._storage
 
@@ -107,9 +105,11 @@
             return
         if f is not None:
             f()
+        undo = UndoInfo(self, self._tpos, self._index.copy())
         self._index.update(self._tindex)
         self._tindex.clear()
         self._tpos = self._pos
+        return undo
 
     def undoLog(self, first, last, filter=None):
         return ()
@@ -118,3 +118,25 @@
         # XXX what is this supposed to do?
         if version == self._bver:
             return len(self._index)
+
+    def rollback(self, pos, index):
+        if not (pos < self._tpos <= self._pos):
+            # XXX need to make this pos exception
+            raise RuntimeError("transaction rolled back to early point")
+        self._tpos = self._pos = pos
+        self._index = index
+        self._tindex.clear()
+
+class UndoInfo:
+
+    def __init__(self, store, pos, index):
+        self._store = store
+        self._pos = pos
+        self._index = index
+
+    def current(self, cur_store):
+        """Return true if the UndoInfo is for cur_store."""
+        return self._store is cur_store
+        
+    def rollback(self):
+        self._store.rollback(self._pos, self._index)