[Checkins] SVN: Products.PluggableAuthService/trunk/ Add 'getCSRFToken' and 'checkCSRFToken' helpers + 'CSRFToken' view.

Tres Seaver cvs-admin at zope.org
Fri Nov 16 00:54:57 UTC 2012


Log message for revision 128301:
  Add 'getCSRFToken' and 'checkCSRFToken' helpers + 'CSRFToken' view.

Changed:
  _U  Products.PluggableAuthService/trunk/
  U   Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml
  U   Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
  U   Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py

-=-
Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml	2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml	2012-11-16 00:54:57 UTC (rev 128301)
@@ -1,7 +1,14 @@
 <configure
     xmlns="http://namespaces.zope.org/zope"
-    >
+    xmlns:browser="http://namespaces.zope.org/browser">
 
+  <browser:page
+      for="*"
+      name="csrf_token"
+      class=".utils.CSRFToken"
+      permission="zope.Public"
+      />
+
   <include file="exportimport.zcml" />
 
   <include file="events.zcml" />

Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py	2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py	2012-11-16 00:54:57 UTC (rev 128301)
@@ -11,82 +11,197 @@
 # FOR A PARTICULAR PURPOSE.
 #
 ##############################################################################
-
 import unittest
 
-from Products.PluggableAuthService.utils import createViewName
-from Products.PluggableAuthService.utils import createKeywords
 
+class Test_createViewName(unittest.TestCase):
 
-class UtilityTests(unittest.TestCase):
+    def _callFUT(self, *args, **kw):
+        from Products.PluggableAuthService.utils import createViewName
+        return createViewName(*args, **kw)
 
-    def test_createViewName(self):
-        self.assertEqual(createViewName('foo', 'bar'), 'foo-bar')
+    def test_simple(self):
+        self.assertEqual(self._callFUT('foo', 'bar'), 'foo-bar')
 
-    def test_createViewName_no_user_handle(self):
-        self.assertEqual(createViewName('foo', None), 'foo')
+    def test_no_user_handle(self):
+        self.assertEqual(self._callFUT('foo', None), 'foo')
 
-    def test_createViewName_latin1_umlaut_in_method(self):
-        self.assertEqual(createViewName('f\366o'), 'f\366o')
+    def test_latin1_umlaut_in_method(self):
+        self.assertEqual(self._callFUT('f\366o'), 'f\366o')
 
-    def test_createViewName_utf8_umlaut_in_method(self):
-        self.assertEqual(createViewName('f\303\266o'), 'f\303\266o')
+    def test_utf8_umlaut_in_method(self):
+        self.assertEqual(self._callFUT('f\303\266o'), 'f\303\266o')
 
-    def test_createViewName_unicode_umlaut_in_method(self):
-        self.assertEqual(createViewName(u'f\366o'), 'f\303\266o')
+    def test_unicode_umlaut_in_method(self):
+        self.assertEqual(self._callFUT(u'f\366o'), 'f\303\266o')
 
-    def test_createViewName_latin1_umlaut_in_handle(self):
-        self.assertEqual(createViewName('foo', 'b\344r'), 'foo-b\344r')
+    def test_latin1_umlaut_in_handle(self):
+        self.assertEqual(self._callFUT('foo', 'b\344r'), 'foo-b\344r')
 
-    def test_createViewName_utf8_umlaut_in_handle(self):
-        self.assertEqual(createViewName('foo', 'b\303\244r'), 'foo-b\303\244r')
+    def test_utf8_umlaut_in_handle(self):
+        self.assertEqual(self._callFUT('foo', 'b\303\244r'), 'foo-b\303\244r')
 
-    def test_createViewName_unicode_umlaut_in_handle(self):
-        self.assertEqual(createViewName('foo', u'b\344r'), 'foo-b\303\244r')
+    def test_unicode_umlaut_in_handle(self):
+        self.assertEqual(self._callFUT('foo', u'b\344r'), 'foo-b\303\244r')
 
-    def test_createKeywords(self):
+
+class Test_createKeywords(unittest.TestCase):
+
+    def _callFUT(self, *args, **kw):
+        from Products.PluggableAuthService.utils import createKeywords
+        return createKeywords(*args, **kw)
+
+    def test_simple(self):
         _ITEMS = (('foo', 'bar'),)
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar'),
+        self.assertEqual(self._callFUT(foo='bar'),
                          {'keywords': hashed})
 
     def test_createKeywords_multiple(self):
         _ITEMS = (('foo', 'bar'), ('baz', 'peng'))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar', baz='peng'),
+        self.assertEqual(self._callFUT(foo='bar', baz='peng'),
                          {'keywords': hashed})
 
     def test_createKeywords_latin1_umlaut(self):
         _ITEMS = (('foo', 'bar'), ('baz', 'M\344dchen'))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar', baz='M\344dchen'),
+        self.assertEqual(self._callFUT(foo='bar', baz='M\344dchen'),
                          {'keywords': hashed})
 
     def test_createKeywords_utf8_umlaut(self):
         _ITEMS = (('foo', 'bar'), ('baz', 'M\303\244dchen'))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar', baz='M\303\244dchen'),
+        self.assertEqual(self._callFUT(foo='bar', baz='M\303\244dchen'),
                          {'keywords': hashed})
 
     def test_createKeywords_unicode_umlaut(self):
         _ITEMS = (('foo', 'bar'), ('baz', u'M\344dchen'))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar', baz=u'M\344dchen'),
+        self.assertEqual(self._callFUT(foo='bar', baz=u'M\344dchen'),
                          {'keywords': hashed})
 
     def test_createKeywords_utf16_umlaut(self):
         _ITEMS = (('foo', 'bar'), ('baz', u'M\344dchen'.encode('utf-16')))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar',
+        self.assertEqual(self._callFUT(foo='bar',
                                         baz=u'M\344dchen'.encode('utf-16')),
                          {'keywords': hashed})
 
     def test_createKeywords_unicode_chinese(self):
         _ITEMS = (('foo', 'bar'), ('baz', u'\u03a4\u03b6'))
         hashed = _createHashedValue(_ITEMS)
-        self.assertEqual(createKeywords(foo='bar', baz=u'\u03a4\u03b6'),
+        self.assertEqual(self._callFUT(foo='bar', baz=u'\u03a4\u03b6'),
                 {'keywords': hashed})
 
+
+def _makeRequestWSession(**session):
+    class _Request(dict):
+        pass
+    request = _Request()
+    request.SESSION = session.copy()
+    request.form = {}
+    return request
+
+
+class Test_getCSRFToken(unittest.TestCase):
+
+    def _callFUT(self, *args, **kw):
+        from Products.PluggableAuthService.utils import getCSRFToken
+        return getCSRFToken(*args, **kw)
+
+    def test_wo_token_in_request(self):
+        request = _makeRequestWSession()
+        token = self._callFUT(request)
+        self.assertTrue(isinstance(token, str))
+        self.assertFalse(set(token) - set('0123456789abcdef'))
+
+    def test_w_token_in_request(self):
+        request = _makeRequestWSession()
+        request.SESSION['_csrft_'] = 'deadbeef'
+        token = self._callFUT(request)
+        self.assertEqual(token, 'deadbeef')
+
+
+class Test_checkCSRFToken(unittest.TestCase):
+
+    def _callFUT(self, *args, **kw):
+        from Products.PluggableAuthService.utils import checkCSRFToken
+        return checkCSRFToken(*args, **kw)
+
+    def test_wo_token_in_session_or_form_w_raises(self):
+        from ZPublisher import BadRequest
+        request = _makeRequestWSession()
+        self.assertRaises(BadRequest, self._callFUT, request)
+
+    def test_wo_token_in_session_or_form_wo_raises(self):
+        request = _makeRequestWSession()
+        self.assertFalse(self._callFUT(request, raises=False))
+
+    def test_wo_token_in_session_w_token_in_form_w_raises(self):
+        from ZPublisher import BadRequest
+        request = _makeRequestWSession()
+        request.form['csrf_token'] = 'deadbeef'
+        self.assertRaises(BadRequest, self._callFUT, request)
+
+    def test_wo_token_in_session_w_token_in_form_wo_raises(self):
+        request = _makeRequestWSession()
+        request.form['csrf_token'] = 'deadbeef'
+        self.assertFalse(self._callFUT(request, raises=False))
+
+    def test_w_token_in_session_wo_token_in_form_w_raises(self):
+        from ZPublisher import BadRequest
+        request = _makeRequestWSession(_csrft_='deadbeef')
+        self.assertRaises(BadRequest, self._callFUT, request)
+
+    def test_w_token_in_session_wo_token_in_form_wo_raises(self):
+        request = _makeRequestWSession(_csrft_='deadbeef')
+        self.assertFalse(self._callFUT(request, raises=False))
+
+    def test_w_token_in_session_w_token_in_form_miss_w_raises(self):
+        from ZPublisher import BadRequest
+        request = _makeRequestWSession(_csrft_='deadbeef')
+        request.form['csrf_token'] = 'bab3l0f'
+        self.assertRaises(BadRequest, self._callFUT, request)
+
+    def test_w_token_in_session_w_token_in_form_miss_wo_raises(self):
+        request = _makeRequestWSession(_csrft_='deadbeef')
+        request.form['csrf_token'] = 'bab3l0f'
+        self.assertFalse(self._callFUT(request, raises=False))
+
+    def test_w_token_in_session_w_token_in_form_hit(self):
+        request = _makeRequestWSession(_csrft_='deadbeef')
+        request.form['csrf_token'] = 'deadbeef'
+        self.assertTrue(self._callFUT(request))
+
+
+class CSRFTokenTests(unittest.TestCase):
+
+    def _getTargetClass(self):
+        from Products.PluggableAuthService.utils import CSRFToken
+        return CSRFToken
+
+    def _makeOne(self, context=None, request=None):
+        if context is None:
+            context = object()
+        if request is None:
+            request = _makeRequestWSession()
+        return self._getTargetClass()(context, request)
+
+    def test_wo_token_in_request(self):
+        request = _makeRequestWSession()
+        token = self._makeOne(request=request)
+        value = token()
+        self.assertTrue(isinstance(value, str))
+        self.assertFalse(set(value) - set('0123456789abcdef'))
+
+    def test_w_token_in_request(self):
+        request = _makeRequestWSession()
+        request.SESSION['_csrft_'] = 'deadbeef'
+        token = self._makeOne(request=request)
+        self.assertEqual(token(), 'deadbeef')
+
+
 def _createHashedValue(items):
     try:
         from hashlib import sha1 as sha
@@ -107,8 +222,9 @@
 
 def test_suite():
     return unittest.TestSuite((
-        unittest.makeSuite(UtilityTests),
+        unittest.makeSuite(Test_createViewName),
+        unittest.makeSuite(Test_createKeywords),
+        unittest.makeSuite(Test_getCSRFToken),
+        unittest.makeSuite(Test_checkCSRFToken),
+        unittest.makeSuite(CSRFTokenTests),
     ))
-
-if __name__ == '__main__':
-    unittest.main(defaultTest='test_suite')

Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py	2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py	2012-11-16 00:54:57 UTC (rev 128301)
@@ -11,6 +11,7 @@
 # FOR A PARTICULAR PURPOSE.
 #
 ##############################################################################
+import binascii
 import os
 import unittest
 try:
@@ -19,7 +20,9 @@
     from sha import new as sha
 
 
+from AccessControl import ClassSecurityInfo
 from App.Common import package_home
+from ZPublisher import BadRequest
 
 
 from zope import interface
@@ -187,3 +190,36 @@
 
     return {'keywords': keywords.hexdigest()}
 
+def getCSRFToken(request):
+    session = request.SESSION
+    token = session.get('_csrft_', None)
+    if token is None:
+        token = session['_csrft_'] = binascii.hexlify(os.urandom(20))
+    return token
+
+def checkCSRFToken(request, token='csrf_token', raises=True):
+    """ Check CSRF token in session against token formdata.
+    
+    If the values don't match, and 'raises' is True, raise a BadRequest.
+    
+    If the values don't match, and 'raises' is False, return False.
+    
+    If the values match, return True.
+    """
+    if request.form.get(token) != getCSRFToken(request):
+        if raises:
+            raise BadRequest('incorrect CSRF token')
+        return False
+    return True
+
+
+class CSRFToken(object):
+    """ View helper for rendering CSRF token in templates.
+    """
+    security = ClassSecurityInfo()
+    security.declareObjectPublic()
+    def __init__(self, context, request):
+        self.context = context
+        self.request = request
+    def __call__(self):
+        return getCSRFToken(self.request)



More information about the checkins mailing list