[Checkins] SVN: persistent/trunk/ Test cPersistence handling of derived classes w/ slots.

Tres Seaver cvs-admin at zope.org
Thu Jun 28 18:50:18 UTC 2012


Log message for revision 127148:
  Test cPersistence handling of derived classes w/ slots.
  
  Match its behavior in pyPersistence.

Changed:
  _U  persistent/trunk/
  U   persistent/trunk/persistent/pyPersistence.py
  U   persistent/trunk/persistent/tests/test_pyPersistence.py

-=-
Modified: persistent/trunk/persistent/pyPersistence.py
===================================================================
--- persistent/trunk/persistent/pyPersistence.py	2012-06-28 18:50:11 UTC (rev 127147)
+++ persistent/trunk/persistent/pyPersistence.py	2012-06-28 18:50:15 UTC (rev 127148)
@@ -12,6 +12,7 @@
 #
 ##############################################################################
 from copy_reg import __newobj__
+from copy_reg import _slotnames
 import sys
 
 from zope.interface import implements
@@ -252,34 +253,51 @@
                     _OGA(self, '_p_register')()
         object.__delattr__(self, name)
 
+    def _slotnames(self):
+        slotnames = _slotnames(type(self))
+        return [x for x in slotnames
+                   if not x.startswith('_p_') and
+                      not x.startswith('_v_') and
+                      not x.startswith('_Persistent__') and
+                      x not in Persistent.__slots__]
+
     def __getstate__(self):
         """ See IPersistent.
         """
         idict = getattr(self, '__dict__', None)
+        slotnames = self._slotnames()
         if idict is not None:
-            return dict([x for x in idict.items()
-                            if not x[0].startswith('_p_') and
-                               not x[0].startswith('_v_')])
-        slots = getattr(type(self), '__slots__', None)
-        if slots is not None:
-            slots = [x for x in slots
-                            if not x.startswith('_p_') and
-                               not x.startswith('_v_') and
-                               x not in Persistent.__slots__]
-            if slots:
-                return None, dict([(x, getattr(self, x)) for x in slots])
-        return None
+            d = dict([x for x in idict.items()
+                         if not x[0].startswith('_p_') and
+                            not x[0].startswith('_v_')])
+        else:
+            d = None
+        if slotnames:
+            s = {}
+            for slotname in slotnames:
+                value = getattr(self, slotname, self)
+                if value is not self:
+                    s[slotname] = value
+            return d, s
+        return d
 
     def __setstate__(self, state):
         """ See IPersistent.
         """
+        try:
+            inst_dict, slots = state
+        except:
+            inst_dict, slots = state, ()
         idict = getattr(self, '__dict__', None)
-        if idict is not None:
+        if inst_dict is not None:
+            if idict is None:
+                raise TypeError('No instance dict')
             idict.clear()
-            idict.update(state)
-        else:
-            if state != None:
-                raise ValueError('No state allowed on base Persistent class')
+            idict.update(inst_dict)
+        slotnames = self._slotnames()
+        if slotnames:
+            for k, v in slots.items():
+                setattr(self, k, v)
 
     def __reduce__(self):
         """ See IPersistent.

Modified: persistent/trunk/persistent/tests/test_pyPersistence.py
===================================================================
--- persistent/trunk/persistent/tests/test_pyPersistence.py	2012-06-28 18:50:11 UTC (rev 127147)
+++ persistent/trunk/persistent/tests/test_pyPersistence.py	2012-06-28 18:50:15 UTC (rev 127148)
@@ -704,6 +704,30 @@
         inst._v_qux = 'spam'
         self.assertEqual(inst.__getstate__(), (None, {'foo': 'bar'}))
 
+    def test___getstate___derived_w_slots_in_base_and_derived(self):
+        class Base(self._getTargetClass()):
+            __slots__ = ('foo',)
+        class Derived(Base):
+            __slots__ = ('baz', 'qux',)
+        inst = Derived()
+        inst.foo = 'bar'
+        inst.baz = 'bam'
+        inst.qux = 'spam'
+        self.assertEqual(inst.__getstate__(),
+                         (None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
+
+    def test___getstate___derived_w_slots_in_base_but_not_derived(self):
+        class Base(self._getTargetClass()):
+            __slots__ = ('foo',)
+        class Derived(Base):
+            pass
+        inst = Derived()
+        inst.foo = 'bar'
+        inst.baz = 'bam'
+        inst.qux = 'spam'
+        self.assertEqual(inst.__getstate__(),
+                         ({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
+
     def test___setstate___empty(self):
         inst = self._makeOne()
         inst.__setstate__(None) # doesn't raise, but doesn't change anything
@@ -727,6 +751,35 @@
         inst.__setstate__({'baz': 'bam'})
         self.assertEqual(inst.__dict__, {'baz': 'bam'})
 
+    def test___setstate___derived_w_slots(self):
+        class Derived(self._getTargetClass()):
+            __slots__ = ('foo', '_p_baz', '_v_qux')
+        inst = Derived()
+        inst.__setstate__((None, {'foo': 'bar'}))
+        self.assertEqual(inst.foo, 'bar')
+
+    def test___setstate___derived_w_slots_in_base_classes(self):
+        class Base(self._getTargetClass()):
+            __slots__ = ('foo',)
+        class Derived(Base):
+            __slots__ = ('baz', 'qux',)
+        inst = Derived()
+        inst.__setstate__((None, {'foo': 'bar', 'baz': 'bam', 'qux': 'spam'}))
+        self.assertEqual(inst.foo, 'bar')
+        self.assertEqual(inst.baz, 'bam')
+        self.assertEqual(inst.qux, 'spam')
+
+    def test___setstate___derived_w_slots_in_base_but_not_derived(self):
+        class Base(self._getTargetClass()):
+            __slots__ = ('foo',)
+        class Derived(Base):
+            pass
+        inst = Derived()
+        inst.__setstate__(({'baz': 'bam', 'qux': 'spam'}, {'foo': 'bar'}))
+        self.assertEqual(inst.foo, 'bar')
+        self.assertEqual(inst.baz, 'bam')
+        self.assertEqual(inst.qux, 'spam')
+
     def test___reduce__(self):
         from copy_reg import __newobj__
         inst = self._makeOne()
@@ -1067,8 +1120,11 @@
         jar._cache._mru[:] = []
         
 
-import os
-if os.environ.get('run_C_tests'):
+try:
+    from persistent import cPersistence
+except ImportError:
+    pass
+else:
     class CPersistentTests(unittest.TestCase, _Persistent_Base):
 
         def _getTargetClass(self):



More information about the checkins mailing list