Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Getting to 100% coverage for xsrf_util module.
Browse files Browse the repository at this point in the history
  • Loading branch information
dhermes committed Aug 24, 2015
1 parent b821f6e commit ad048c7
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 22 deletions.
27 changes: 10 additions & 17 deletions oauth2client/xsrfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
"""Helper methods for creating & verifying XSRF tokens."""

import base64
import binascii
import hmac
import six
import time

import six
from oauth2client._helpers import _to_bytes
from oauth2client import util

__authors__ = [
Expand All @@ -31,20 +33,11 @@
DELIMITER = b':'

# 1 hour in seconds
DEFAULT_TIMEOUT_SECS = 1 * 60 * 60


def _force_bytes(s):
if isinstance(s, bytes):
return s
s = str(s)
if isinstance(s, six.text_type):
return s.encode('utf-8')
return s
DEFAULT_TIMEOUT_SECS = 60 * 60


@util.positional(2)
def generate_token(key, user_id, action_id="", when=None):
def generate_token(key, user_id, action_id='', when=None):
"""Generates a URL-safe token for the given user, action, time tuple.
Args:
Expand All @@ -58,12 +51,12 @@ def generate_token(key, user_id, action_id="", when=None):
Returns:
A string XSRF protection token.
"""
when = _force_bytes(when or int(time.time()))
digester = hmac.new(_force_bytes(key))
digester.update(_force_bytes(user_id))
digester = hmac.new(_to_bytes(key, encoding='utf-8'))
digester.update(_to_bytes(str(user_id), encoding='utf-8'))
digester.update(DELIMITER)
digester.update(_force_bytes(action_id))
digester.update(_to_bytes(action_id, encoding='utf-8'))
digester.update(DELIMITER)
when = _to_bytes(str(when or int(time.time())), encoding='utf-8')
digester.update(when)
digest = digester.digest()

Expand Down Expand Up @@ -94,7 +87,7 @@ def validate_token(key, token, user_id, action_id="", current_time=None):
try:
decoded = base64.urlsafe_b64decode(token)
token_time = int(decoded.split(DELIMITER)[-1])
except (TypeError, ValueError):
except (TypeError, ValueError, binascii.Error):
return False
if current_time is None:
current_time = time.time()
Expand Down
194 changes: 189 additions & 5 deletions tests/test_xsrfutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,208 @@
Unit tests for oauth2client.xsrfutil.
"""

import base64
import unittest

import mock

from oauth2client._helpers import _to_bytes
from oauth2client import xsrfutil

# Jan 17 2008, 5:40PM
TEST_KEY = 'test key'
TEST_KEY = b'test key'
# Jan. 17, 2008 22:40:32.081230 UTC
TEST_TIME = 1200609642081230
TEST_USER_ID_1 = 123832983
TEST_USER_ID_2 = 938297432
TEST_ACTION_ID_1 = 'some_action'
TEST_ACTION_ID_2 = 'some_other_action'
TEST_EXTRA_INFO_1 = 'extra_info_1'
TEST_EXTRA_INFO_2 = 'more_extra_info'
TEST_ACTION_ID_1 = b'some_action'
TEST_ACTION_ID_2 = b'some_other_action'
TEST_EXTRA_INFO_1 = b'extra_info_1'
TEST_EXTRA_INFO_2 = b'more_extra_info'


__author__ = '[email protected] (Joe Gregorio)'


class Test_generate_token(unittest.TestCase):

def test_bad_positional(self):
# Need 2 positional arguments.
self.assertRaises(TypeError, xsrfutil.generate_token, None)
# At most 2 positional arguments.
self.assertRaises(TypeError, xsrfutil.generate_token, None, None, None)

def test_it(self):
digest = b'foobar'
curr_time = 1440449755.74
digester = mock.MagicMock()
digester.digest = mock.MagicMock(name='digest', return_value=digest)
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
hmac.new = mock.MagicMock(name='new', return_value=digester)
token = xsrfutil.generate_token(TEST_KEY,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1,
when=TEST_TIME)
hmac.new.assert_called_once_with(TEST_KEY)
digester.digest.assert_called_once_with()

expected_digest_calls = [
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
mock.call.update(xsrfutil.DELIMITER),
mock.call.update(TEST_ACTION_ID_1),
mock.call.update(xsrfutil.DELIMITER),
mock.call.update(_to_bytes(str(TEST_TIME))),
]
self.assertEqual(digester.method_calls, expected_digest_calls)

expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
_to_bytes(str(TEST_TIME)))
expected_token = base64.urlsafe_b64encode(
expected_token_as_bytes)
self.assertEqual(token, expected_token)

def test_with_system_time(self):
digest = b'foobar'
curr_time = 1440449755.74
digester = mock.MagicMock()
digester.digest = mock.MagicMock(name='digest', return_value=digest)
with mock.patch('oauth2client.xsrfutil.hmac') as hmac:
hmac.new = mock.MagicMock(name='new', return_value=digester)

with mock.patch('oauth2client.xsrfutil.time') as time:
time.time = mock.MagicMock(name='time', return_value=curr_time)
# when= is omitted
token = xsrfutil.generate_token(TEST_KEY,
TEST_USER_ID_1,
action_id=TEST_ACTION_ID_1)

hmac.new.assert_called_once_with(TEST_KEY)
time.time.assert_called_once_with()
digester.digest.assert_called_once_with()

expected_digest_calls = [
mock.call.update(_to_bytes(str(TEST_USER_ID_1))),
mock.call.update(xsrfutil.DELIMITER),
mock.call.update(TEST_ACTION_ID_1),
mock.call.update(xsrfutil.DELIMITER),
mock.call.update(_to_bytes(str(int(curr_time)))),
]
self.assertEqual(digester.method_calls, expected_digest_calls)

expected_token_as_bytes = (digest + xsrfutil.DELIMITER +
_to_bytes(str(int(curr_time))))
expected_token = base64.urlsafe_b64encode(
expected_token_as_bytes)
self.assertEqual(token, expected_token)


class Test_validate_token(unittest.TestCase):

def test_bad_positional(self):
# Need 3 positional arguments.
self.assertRaises(TypeError, xsrfutil.validate_token, None, None)
# At most 3 positional arguments.
self.assertRaises(TypeError, xsrfutil.validate_token,
None, None, None, None)

def test_no_token(self):
key = token = user_id = None
self.assertFalse(xsrfutil.validate_token(key, token, user_id))

def test_token_not_valid_base64(self):
key = user_id = None
token = b'a' # Bad padding
self.assertFalse(xsrfutil.validate_token(key, token, user_id))

def test_token_non_integer(self):
key = user_id = None
token = base64.b64encode(b'abc' + xsrfutil.DELIMITER + b'xyz')
self.assertFalse(xsrfutil.validate_token(key, token, user_id))

def test_token_too_old_implicit_current_time(self):
token_time = 123456789
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1

key = user_id = None
token = base64.b64encode(_to_bytes(str(token_time)))
with mock.patch('oauth2client.xsrfutil.time') as time:
time.time = mock.MagicMock(name='time', return_value=curr_time)
self.assertFalse(xsrfutil.validate_token(key, token, user_id))
time.time.assert_called_once_with()

def test_token_too_old_explicit_current_time(self):
token_time = 123456789
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS + 1

key = user_id = None
token = base64.b64encode(_to_bytes(str(token_time)))
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
current_time=curr_time))

def test_token_length_differs_from_generated(self):
token_time = 123456789
# Make sure it isn't too old.
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1

key = object()
user_id = object()
action_id = object()
token = base64.b64encode(_to_bytes(str(token_time)))
generated_token = b'a'
# Make sure the token length comparison will fail.
self.assertNotEqual(len(token), len(generated_token))

with mock.patch('oauth2client.xsrfutil.generate_token',
return_value=generated_token) as gen_tok:
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
current_time=curr_time,
action_id=action_id))
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
when=token_time)

def test_token_differs_from_generated_but_same_length(self):
token_time = 123456789
# Make sure it isn't too old.
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1

key = object()
user_id = object()
action_id = object()
token = base64.b64encode(_to_bytes(str(token_time)))
# It is encoded as b'MTIzNDU2Nzg5', which has length 12.
# b'MMMMMMMMMMMM'
generated_token = b'M' * 12
# Make sure the token length comparison will succeed, but the token
# comparison will fail.
self.assertEqual(len(token), len(generated_token))
self.assertNotEqual(token, generated_token)

with mock.patch('oauth2client.xsrfutil.generate_token',
return_value=generated_token) as gen_tok:
self.assertFalse(xsrfutil.validate_token(key, token, user_id,
current_time=curr_time,
action_id=action_id))
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
when=token_time)

def test_success(self):
token_time = 123456789
# Make sure it isn't too old.
curr_time = token_time + xsrfutil.DEFAULT_TIMEOUT_SECS - 1

key = object()
user_id = object()
action_id = object()
token = base64.b64encode(_to_bytes(str(token_time)))
with mock.patch('oauth2client.xsrfutil.generate_token',
return_value=token) as gen_tok:
self.assertTrue(xsrfutil.validate_token(key, token, user_id,
current_time=curr_time,
action_id=action_id))
gen_tok.assert_called_once_with(key, user_id, action_id=action_id,
when=token_time)


class XsrfUtilTests(unittest.TestCase):
"""Test xsrfutil functions."""

Expand Down

0 comments on commit ad048c7

Please sign in to comment.