[Checkins] SVN: Products.PluggableAuthService/trunk/ Add 'getCSRFToken' and 'checkCSRFToken' helpers + 'CSRFToken' view.
Tres Seaver
cvs-admin at zope.org
Fri Nov 16 00:54:57 UTC 2012
Log message for revision 128301:
Add 'getCSRFToken' and 'checkCSRFToken' helpers + 'CSRFToken' view.
Changed:
_U Products.PluggableAuthService/trunk/
U Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml
U Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
U Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py
-=-
Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml 2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/configure.zcml 2012-11-16 00:54:57 UTC (rev 128301)
@@ -1,7 +1,14 @@
<configure
xmlns="http://namespaces.zope.org/zope"
- >
+ xmlns:browser="http://namespaces.zope.org/browser">
+ <browser:page
+ for="*"
+ name="csrf_token"
+ class=".utils.CSRFToken"
+ permission="zope.Public"
+ />
+
<include file="exportimport.zcml" />
<include file="events.zcml" />
Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py 2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/tests/test_utils.py 2012-11-16 00:54:57 UTC (rev 128301)
@@ -11,82 +11,197 @@
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
-
import unittest
-from Products.PluggableAuthService.utils import createViewName
-from Products.PluggableAuthService.utils import createKeywords
+class Test_createViewName(unittest.TestCase):
-class UtilityTests(unittest.TestCase):
+ def _callFUT(self, *args, **kw):
+ from Products.PluggableAuthService.utils import createViewName
+ return createViewName(*args, **kw)
- def test_createViewName(self):
- self.assertEqual(createViewName('foo', 'bar'), 'foo-bar')
+ def test_simple(self):
+ self.assertEqual(self._callFUT('foo', 'bar'), 'foo-bar')
- def test_createViewName_no_user_handle(self):
- self.assertEqual(createViewName('foo', None), 'foo')
+ def test_no_user_handle(self):
+ self.assertEqual(self._callFUT('foo', None), 'foo')
- def test_createViewName_latin1_umlaut_in_method(self):
- self.assertEqual(createViewName('f\366o'), 'f\366o')
+ def test_latin1_umlaut_in_method(self):
+ self.assertEqual(self._callFUT('f\366o'), 'f\366o')
- def test_createViewName_utf8_umlaut_in_method(self):
- self.assertEqual(createViewName('f\303\266o'), 'f\303\266o')
+ def test_utf8_umlaut_in_method(self):
+ self.assertEqual(self._callFUT('f\303\266o'), 'f\303\266o')
- def test_createViewName_unicode_umlaut_in_method(self):
- self.assertEqual(createViewName(u'f\366o'), 'f\303\266o')
+ def test_unicode_umlaut_in_method(self):
+ self.assertEqual(self._callFUT(u'f\366o'), 'f\303\266o')
- def test_createViewName_latin1_umlaut_in_handle(self):
- self.assertEqual(createViewName('foo', 'b\344r'), 'foo-b\344r')
+ def test_latin1_umlaut_in_handle(self):
+ self.assertEqual(self._callFUT('foo', 'b\344r'), 'foo-b\344r')
- def test_createViewName_utf8_umlaut_in_handle(self):
- self.assertEqual(createViewName('foo', 'b\303\244r'), 'foo-b\303\244r')
+ def test_utf8_umlaut_in_handle(self):
+ self.assertEqual(self._callFUT('foo', 'b\303\244r'), 'foo-b\303\244r')
- def test_createViewName_unicode_umlaut_in_handle(self):
- self.assertEqual(createViewName('foo', u'b\344r'), 'foo-b\303\244r')
+ def test_unicode_umlaut_in_handle(self):
+ self.assertEqual(self._callFUT('foo', u'b\344r'), 'foo-b\303\244r')
- def test_createKeywords(self):
+
+class Test_createKeywords(unittest.TestCase):
+
+ def _callFUT(self, *args, **kw):
+ from Products.PluggableAuthService.utils import createKeywords
+ return createKeywords(*args, **kw)
+
+ def test_simple(self):
_ITEMS = (('foo', 'bar'),)
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar'),
+ self.assertEqual(self._callFUT(foo='bar'),
{'keywords': hashed})
def test_createKeywords_multiple(self):
_ITEMS = (('foo', 'bar'), ('baz', 'peng'))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar', baz='peng'),
+ self.assertEqual(self._callFUT(foo='bar', baz='peng'),
{'keywords': hashed})
def test_createKeywords_latin1_umlaut(self):
_ITEMS = (('foo', 'bar'), ('baz', 'M\344dchen'))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar', baz='M\344dchen'),
+ self.assertEqual(self._callFUT(foo='bar', baz='M\344dchen'),
{'keywords': hashed})
def test_createKeywords_utf8_umlaut(self):
_ITEMS = (('foo', 'bar'), ('baz', 'M\303\244dchen'))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar', baz='M\303\244dchen'),
+ self.assertEqual(self._callFUT(foo='bar', baz='M\303\244dchen'),
{'keywords': hashed})
def test_createKeywords_unicode_umlaut(self):
_ITEMS = (('foo', 'bar'), ('baz', u'M\344dchen'))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar', baz=u'M\344dchen'),
+ self.assertEqual(self._callFUT(foo='bar', baz=u'M\344dchen'),
{'keywords': hashed})
def test_createKeywords_utf16_umlaut(self):
_ITEMS = (('foo', 'bar'), ('baz', u'M\344dchen'.encode('utf-16')))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar',
+ self.assertEqual(self._callFUT(foo='bar',
baz=u'M\344dchen'.encode('utf-16')),
{'keywords': hashed})
def test_createKeywords_unicode_chinese(self):
_ITEMS = (('foo', 'bar'), ('baz', u'\u03a4\u03b6'))
hashed = _createHashedValue(_ITEMS)
- self.assertEqual(createKeywords(foo='bar', baz=u'\u03a4\u03b6'),
+ self.assertEqual(self._callFUT(foo='bar', baz=u'\u03a4\u03b6'),
{'keywords': hashed})
+
+def _makeRequestWSession(**session):
+ class _Request(dict):
+ pass
+ request = _Request()
+ request.SESSION = session.copy()
+ request.form = {}
+ return request
+
+
+class Test_getCSRFToken(unittest.TestCase):
+
+ def _callFUT(self, *args, **kw):
+ from Products.PluggableAuthService.utils import getCSRFToken
+ return getCSRFToken(*args, **kw)
+
+ def test_wo_token_in_request(self):
+ request = _makeRequestWSession()
+ token = self._callFUT(request)
+ self.assertTrue(isinstance(token, str))
+ self.assertFalse(set(token) - set('0123456789abcdef'))
+
+ def test_w_token_in_request(self):
+ request = _makeRequestWSession()
+ request.SESSION['_csrft_'] = 'deadbeef'
+ token = self._callFUT(request)
+ self.assertEqual(token, 'deadbeef')
+
+
+class Test_checkCSRFToken(unittest.TestCase):
+
+ def _callFUT(self, *args, **kw):
+ from Products.PluggableAuthService.utils import checkCSRFToken
+ return checkCSRFToken(*args, **kw)
+
+ def test_wo_token_in_session_or_form_w_raises(self):
+ from ZPublisher import BadRequest
+ request = _makeRequestWSession()
+ self.assertRaises(BadRequest, self._callFUT, request)
+
+ def test_wo_token_in_session_or_form_wo_raises(self):
+ request = _makeRequestWSession()
+ self.assertFalse(self._callFUT(request, raises=False))
+
+ def test_wo_token_in_session_w_token_in_form_w_raises(self):
+ from ZPublisher import BadRequest
+ request = _makeRequestWSession()
+ request.form['csrf_token'] = 'deadbeef'
+ self.assertRaises(BadRequest, self._callFUT, request)
+
+ def test_wo_token_in_session_w_token_in_form_wo_raises(self):
+ request = _makeRequestWSession()
+ request.form['csrf_token'] = 'deadbeef'
+ self.assertFalse(self._callFUT(request, raises=False))
+
+ def test_w_token_in_session_wo_token_in_form_w_raises(self):
+ from ZPublisher import BadRequest
+ request = _makeRequestWSession(_csrft_='deadbeef')
+ self.assertRaises(BadRequest, self._callFUT, request)
+
+ def test_w_token_in_session_wo_token_in_form_wo_raises(self):
+ request = _makeRequestWSession(_csrft_='deadbeef')
+ self.assertFalse(self._callFUT(request, raises=False))
+
+ def test_w_token_in_session_w_token_in_form_miss_w_raises(self):
+ from ZPublisher import BadRequest
+ request = _makeRequestWSession(_csrft_='deadbeef')
+ request.form['csrf_token'] = 'bab3l0f'
+ self.assertRaises(BadRequest, self._callFUT, request)
+
+ def test_w_token_in_session_w_token_in_form_miss_wo_raises(self):
+ request = _makeRequestWSession(_csrft_='deadbeef')
+ request.form['csrf_token'] = 'bab3l0f'
+ self.assertFalse(self._callFUT(request, raises=False))
+
+ def test_w_token_in_session_w_token_in_form_hit(self):
+ request = _makeRequestWSession(_csrft_='deadbeef')
+ request.form['csrf_token'] = 'deadbeef'
+ self.assertTrue(self._callFUT(request))
+
+
+class CSRFTokenTests(unittest.TestCase):
+
+ def _getTargetClass(self):
+ from Products.PluggableAuthService.utils import CSRFToken
+ return CSRFToken
+
+ def _makeOne(self, context=None, request=None):
+ if context is None:
+ context = object()
+ if request is None:
+ request = _makeRequestWSession()
+ return self._getTargetClass()(context, request)
+
+ def test_wo_token_in_request(self):
+ request = _makeRequestWSession()
+ token = self._makeOne(request=request)
+ value = token()
+ self.assertTrue(isinstance(value, str))
+ self.assertFalse(set(value) - set('0123456789abcdef'))
+
+ def test_w_token_in_request(self):
+ request = _makeRequestWSession()
+ request.SESSION['_csrft_'] = 'deadbeef'
+ token = self._makeOne(request=request)
+ self.assertEqual(token(), 'deadbeef')
+
+
def _createHashedValue(items):
try:
from hashlib import sha1 as sha
@@ -107,8 +222,9 @@
def test_suite():
return unittest.TestSuite((
- unittest.makeSuite(UtilityTests),
+ unittest.makeSuite(Test_createViewName),
+ unittest.makeSuite(Test_createKeywords),
+ unittest.makeSuite(Test_getCSRFToken),
+ unittest.makeSuite(Test_checkCSRFToken),
+ unittest.makeSuite(CSRFTokenTests),
))
-
-if __name__ == '__main__':
- unittest.main(defaultTest='test_suite')
Modified: Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py
===================================================================
--- Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py 2012-11-16 00:54:56 UTC (rev 128300)
+++ Products.PluggableAuthService/trunk/Products/PluggableAuthService/utils.py 2012-11-16 00:54:57 UTC (rev 128301)
@@ -11,6 +11,7 @@
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
+import binascii
import os
import unittest
try:
@@ -19,7 +20,9 @@
from sha import new as sha
+from AccessControl import ClassSecurityInfo
from App.Common import package_home
+from ZPublisher import BadRequest
from zope import interface
@@ -187,3 +190,36 @@
return {'keywords': keywords.hexdigest()}
+def getCSRFToken(request):
+ session = request.SESSION
+ token = session.get('_csrft_', None)
+ if token is None:
+ token = session['_csrft_'] = binascii.hexlify(os.urandom(20))
+ return token
+
+def checkCSRFToken(request, token='csrf_token', raises=True):
+ """ Check CSRF token in session against token formdata.
+
+ If the values don't match, and 'raises' is True, raise a BadRequest.
+
+ If the values don't match, and 'raises' is False, return False.
+
+ If the values match, return True.
+ """
+ if request.form.get(token) != getCSRFToken(request):
+ if raises:
+ raise BadRequest('incorrect CSRF token')
+ return False
+ return True
+
+
+class CSRFToken(object):
+ """ View helper for rendering CSRF token in templates.
+ """
+ security = ClassSecurityInfo()
+ security.declareObjectPublic()
+ def __init__(self, context, request):
+ self.context = context
+ self.request = request
+ def __call__(self):
+ return getCSRFToken(self.request)
More information about the checkins
mailing list