[Checkins] SVN: Products.PluggableAuthService/trunk/ Add 'csrf_only' decorator for post-handling methods.

Tres Seaver cvs-admin at zope.org
Fri Nov 16 00:54:58 UTC 2012


Log message for revision 128302:
  Add 'csrf_only' decorator for post-handling methods.

Changed:
  _U  Products.PluggableAuthService/trunk/
  U   Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
  U   Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py

-=-
Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py	2012-11-16 00:54:57 UTC (rev 128301)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py	2012-11-16 00:54:57 UTC (rev 128302)
@@ -202,6 +202,48 @@
         self.assertEqual(token(), 'deadbeef')
 
 
+class Test_csrf_only(unittest.TestCase):
+
+    def _callFUT(self, *args, **kw):
+        from Products.PluggableAuthService.utils import csrf_only
+        return csrf_only(*args, **kw)
+
+    def test_w_function_no_REQUEST(self):
+        def no_request(foo, bar, **kw):
+            "I haz no REQUEST"
+        self.assertRaises(ValueError, self._callFUT, no_request)
+
+    def test_w_function_w_positional_REQUEST(self):
+        from ZPublisher import BadRequest
+        def w_positional_request(foo, bar, REQUEST):
+            "I haz REQUEST as positional arg"
+            return 42
+        wrapped = self._callFUT(w_positional_request)
+        self.assertEqual(wrapped.__name__, w_positional_request.__name__)
+        self.assertEqual(wrapped.__module__, w_positional_request.__module__)
+        self.assertEqual(wrapped.__doc__, w_positional_request.__doc__)
+        self.assertRaises(BadRequest, wrapped, foo=None, bar=None,
+                          REQUEST=_makeRequestWSession())
+        req = _makeRequestWSession(_csrft_='deadbeef')
+        req.form['csrf_token'] = 'deadbeef'
+        self.assertEqual(wrapped(foo=None, bar=None, REQUEST=req), 42)
+
+    def test_w_function_w_optional_REQUEST(self):
+        from ZPublisher import BadRequest
+        def w_optional_request(foo, bar, REQUEST=None):
+            "I haz REQUEST as kw arg"
+            return 42
+        wrapped = self._callFUT(w_optional_request)
+        self.assertEqual(wrapped.__name__, w_optional_request.__name__)
+        self.assertEqual(wrapped.__module__, w_optional_request.__module__)
+        self.assertEqual(wrapped.__doc__, w_optional_request.__doc__)
+        self.assertRaises(BadRequest,
+                         wrapped, foo=None, bar=None,
+                                  REQUEST=_makeRequestWSession())
+        req = _makeRequestWSession(_csrft_='deadbeef')
+        req.form['csrf_token'] = 'deadbeef'
+        self.assertEqual(wrapped(foo=None, bar=None, REQUEST=req), 42)
+
 def _createHashedValue(items):
     try:
         from hashlib import sha1 as sha
@@ -227,4 +269,5 @@
         unittest.makeSuite(Test_getCSRFToken),
         unittest.makeSuite(Test_checkCSRFToken),
         unittest.makeSuite(CSRFTokenTests),
+        unittest.makeSuite(Test_csrf_only),
     ))

Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py	2012-11-16 00:54:57 UTC (rev 128301)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py	2012-11-16 00:54:57 UTC (rev 128302)
@@ -12,8 +12,9 @@
 #
 ##############################################################################
 import binascii
+import functools
+import inspect
 import os
-import unittest
 try:
     from hashlib import sha1 as sha
 except:
@@ -124,6 +125,7 @@
     """
         Retrieve a TestSuite from 'file'.
     """
+    import unittest
     module_name = module_name_from_path( file )
     loader = unittest.defaultTestLoader
     try:
@@ -141,6 +143,7 @@
     """
         Walk the product and build a unittest.TestSuite aggregating tests.
     """
+    import unittest
     os.path.walk( from_dir, remove_stale_bytecode, None )
     test_files = find_unit_test_files( from_dir, test_prefix )
     test_files.sort()
@@ -215,6 +218,11 @@
 
 class CSRFToken(object):
     """ View helper for rendering CSRF token in templates.
+
+    E.g., in every protected form, add this::
+
+      <input type="hidden" name="csrf_token"
+             tal:attributes="value context/@@csrf_token" />
     """
     security = ClassSecurityInfo()
     security.declareObjectPublic()
@@ -223,3 +231,15 @@
         self.request = request
     def __call__(self):
         return getCSRFToken(self.request)
+
+
+def csrf_only(wrapped):
+    args, varargs, kwargs, defaults = inspect.getargspec(wrapped)
+    if 'REQUEST' in args:
+        def wrapper(REQUEST, *a, **kw):
+            checkCSRFToken(REQUEST)
+            return wrapped(REQUEST=REQUEST, *a, **kw)
+    else:
+        raise ValueError("Method doesn't name request")
+    functools.update_wrapper(wrapper, wrapped)
+    return wrapper



More information about the checkins mailing list