Skip to content

Commit 7689d68

Browse files
Merge pull request googleapis#281 from dhermes/100-percent-coverage-xsrfutil
100% test coverage for xsrfutil module.
2 parents 9025e23 + 4274129 commit 7689d68

File tree

2 files changed

+198
-22
lines changed

2 files changed

+198
-22
lines changed

oauth2client/xsrfutil.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616
"""Helper methods for creating & verifying XSRF tokens."""
1717

1818
import base64
19+
import binascii
1920
import hmac
21+
import six
2022
import time
2123

22-
import six
24+
from oauth2client._helpers import _to_bytes
2325
from oauth2client import util
2426

2527
__authors__ = [
@@ -31,20 +33,11 @@
3133
DELIMITER = b':'
3234

3335
# 1 hour in seconds
34-
DEFAULT_TIMEOUT_SECS = 1 * 60 * 60
35-
36-
37-
def _force_bytes(s):
38-
if isinstance(s, bytes):
39-
return s
40-
s = str(s)
41-
if isinstance(s, six.text_type):
42-
return s.encode('utf-8')
43-
return s
36+
DEFAULT_TIMEOUT_SECS = 60 * 60
4437

4538

4639
@util.positional(2)
47-
def generate_token(key, user_id, action_id="", when=None):
40+
def generate_token(key, user_id, action_id='', when=None):
4841
"""Generates a URL-safe token for the given user, action, time tuple.
4942
5043
Args:
@@ -58,12 +51,12 @@ def generate_token(key, user_id, action_id="", when=None):
5851
Returns:
5952
A string XSRF protection token.
6053
"""
61-
when = _force_bytes(when or int(time.time()))
62-
digester = hmac.new(_force_bytes(key))
63-
digester.update(_force_bytes(user_id))
54+
digester = hmac.new(_to_bytes(key, encoding='utf-8'))
55+
digester.update(_to_bytes(str(user_id), encoding='utf-8'))
6456
digester.update(DELIMITER)
65-
digester.update(_force_bytes(action_id))
57+
digester.update(_to_bytes(action_id, encoding='utf-8'))
6658
digester.update(DELIMITER)
59+
when = _to_bytes(str(when or int(time.time())), encoding='utf-8')
6760
digester.update(when)
6861
digest = digester.digest()
6962

@@ -94,7 +87,7 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
9487
try:
9588
decoded = base64.urlsafe_b64decode(token)
9689
token_time = int(decoded.split(DELIMITER)[-1])
97-
except (TypeError, ValueError):
90+
except (TypeError, ValueError, binascii.Error):
9891
return False
9992
if current_time is None:
10093
current_time = time.time()

tests/test_xsrfutil.py

Lines changed: 188 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,207 @@
1616
Unit tests for oauth2client.xsrfutil.
1717
"""
1818

19+
import base64
1920
import unittest
2021

22+
import mock
23+
24+
from oauth2client._helpers import _to_bytes
2125
from oauth2client import xsrfutil
2226

2327
# Jan 17 2008, 5:40PM
24-
TEST_KEY = 'test key'
28+
TEST_KEY = b'test key'
29+
# Jan. 17, 2008 22:40:32.081230 UTC
2530
TEST_TIME = 1200609642081230
2631
TEST_USER_ID_1 = 123832983
2732
TEST_USER_ID_2 = 938297432
28-
TEST_ACTION_ID_1 = 'some_action'
29-
TEST_ACTION_ID_2 = 'some_other_action'
30-
TEST_EXTRA_INFO_1 = 'extra_info_1'
31-
TEST_EXTRA_INFO_2 = 'more_extra_info'
33+
TEST_ACTION_ID_1 = b'some_action'
34+
TEST_ACTION_ID_2 = b'some_other_action'
35+
TEST_EXTRA_INFO_1 = b'extra_info_1'
36+
TEST_EXTRA_INFO_2 = b'more_extra_info'
3237

3338

3439
__author__ = '[email protected] (Joe Gregorio)'
3540

3641

42+
class Test_generate_token(unittest.TestCase):
43+
44+
def test_bad_positional(self):
45+
# Need 2 positional arguments.
46+
self.assertRaises(TypeError, xsrfutil.generate_token, None)
47+
# At most 2 positional arguments.
48+
self.assertRaises(TypeError, xsrfutil.generate_token, None, None, None)
49+
50+
def test_it(self):
51+
digest = b'foobar'
52+
curr_time = 1440449755.74
53+
digester = mock.MagicMock()
54+
digester.digest = mock.MagicMock(name='digest', return_value=digest)
55+
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
56+
hmac.new = mock.MagicMock(name='new', return_value=digester)
57+
token = xsrfutil.generate_token(TEST_KEY,
58+
TEST_USER_ID_1,
59+
action_id=TEST_ACTION_ID_1,
60+
when=TEST_TIME)
61+
hmac.new.assert_called_once_with(TEST_KEY)
62+
digester.digest.assert_called_once_with()
63+
64+
expected_digest_calls = [
65+
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
66+
mock.call.update(xsrfutil.DELIMITER),
67+
mock.call.update(TEST_ACTION_ID_1),
68+
mock.call.update(xsrfutil.DELIMITER),
69+
mock.call.update(_to_bytes(str(TEST_TIME))),
70+
]
71+
self.assertEqual(digester.method_calls, expected_digest_calls)
72+
73+
expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
74+
_to_bytes(str(TEST_TIME)))
75+
expected_token = base64.urlsafe_b64encode(
76+
expected_token_as_bytes)
77+
self.assertEqual(token, expected_token)
78+
79+
def test_with_system_time(self):
80+
digest = b'foobar'
81+
curr_time = 1440449755.74
82+
digester = mock.MagicMock()
83+
digester.digest = mock.MagicMock(name='digest', return_value=digest)
84+
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
85+
hmac.new = mock.MagicMock(name='new', return_value=digester)
86+
87+
with mock.patch('oauth2client.xsrfutil.time') as time:
88+
time.time = mock.MagicMock(name='time', return_value=curr_time)
89+
# when= is omitted
90+
token = xsrfutil.generate_token(TEST_KEY,
91+
TEST_USER_ID_1,
92+
action_id=TEST_ACTION_ID_1)
93+
94+
hmac.new.assert_called_once_with(TEST_KEY)
95+
time.time.assert_called_once_with()
96+
digester.digest.assert_called_once_with()
97+
98+
expected_digest_calls = [
99+
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
100+
mock.call.update(xsrfutil.DELIMITER),
101+
mock.call.update(TEST_ACTION_ID_1),
102+
mock.call.update(xsrfutil.DELIMITER),
103+
mock.call.update(_to_bytes(str(int(curr_time)))),
104+
]
105+
self.assertEqual(digester.method_calls, expected_digest_calls)
106+
107+
expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
108+
_to_bytes(str(int(curr_time))))
109+
expected_token = base64.urlsafe_b64encode(
110+
expected_token_as_bytes)
111+
self.assertEqual(token, expected_token)
112+
113+
114+
class Test_validate_token(unittest.TestCase):
115+
116+
def test_bad_positional(self):
117+
# Need 3 positional arguments.
118+
self.assertRaises(TypeError, xsrfutil.validate_token, None, None)
119+
# At most 3 positional arguments.
120+
self.assertRaises(TypeError, xsrfutil.validate_token,
121+
None, None, None, None)
122+
123+
def test_no_token(self):
124+
key = token = user_id = None
125+
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
126+
127+
def test_token_not_valid_base64(self):
128+
key = user_id = None
129+
token = b'a' # Bad padding
130+
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
131+
132+
def test_token_non_integer(self):
133+
key = user_id = None
134+
token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz')
135+
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
136+
137+
def test_token_too_old_implicit_current_time(self):
138+
token_time = 123456789
139+
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
140+
141+
key = user_id = None
142+
token = base64.b64encode(_to_bytes(str(token_time)))
143+
with mock.patch('oauth2client.xsrfutil.time') as time:
144+
time.time = mock.MagicMock(name='time', return_value=curr_time)
145+
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
146+
time.time.assert_called_once_with()
147+
148+
def test_token_too_old_explicit_current_time(self):
149+
token_time = 123456789
150+
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1
151+
152+
key = user_id = None
153+
token = base64.b64encode(_to_bytes(str(token_time)))
154+
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
155+
current_time=curr_time))
156+
157+
def test_token_length_differs_from_generated(self):
158+
token_time = 123456789
159+
# Make sure it isn't too old.
160+
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
161+
162+
key = object()
163+
user_id = object()
164+
action_id = object()
165+
token = base64.b64encode(_to_bytes(str(token_time)))
166+
generated_token = b'a'
167+
# Make sure the token length comparison will fail.
168+
self.assertNotEqual(len(token), len(generated_token))
169+
170+
with mock.patch('oauth2client.xsrfutil.generate_token',
171+
return_value=generated_token) as gen_tok:
172+
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
173+
current_time=curr_time,
174+
action_id=action_id))
175+
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
176+
when=token_time)
177+
178+
def test_token_differs_from_generated_but_same_length(self):
179+
token_time = 123456789
180+
# Make sure it isn't too old.
181+
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
182+
183+
key = object()
184+
user_id = object()
185+
action_id = object()
186+
token = base64.b64encode(_to_bytes(str(token_time)))
187+
# It is encoded as b'MTIzNDU2Nzg5', which has length 12.
188+
generated_token = b'M' * 12
189+
# Make sure the token length comparison will succeed, but the token
190+
# comparison will fail.
191+
self.assertEqual(len(token), len(generated_token))
192+
self.assertNotEqual(token, generated_token)
193+
194+
with mock.patch('oauth2client.xsrfutil.generate_token',
195+
return_value=generated_token) as gen_tok:
196+
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
197+
current_time=curr_time,
198+
action_id=action_id))
199+
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
200+
when=token_time)
201+
202+
def test_success(self):
203+
token_time = 123456789
204+
# Make sure it isn't too old.
205+
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1
206+
207+
key = object()
208+
user_id = object()
209+
action_id = object()
210+
token = base64.b64encode(_to_bytes(str(token_time)))
211+
with mock.patch('oauth2client.xsrfutil.generate_token',
212+
return_value=token) as gen_tok:
213+
self.assertTrue(xsrfutil.validate_token(key, token, user_id,
214+
current_time=curr_time,
215+
action_id=action_id))
216+
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
217+
when=token_time)
218+
219+
37220
class XsrfUtilTests(unittest.TestCase):
38221
"""Test xsrfutil functions."""
39222

0 commit comments

Comments
 (0)