diff --git a/notifications_python_client/authentication.py b/notifications_python_client/authentication.py index 15bc5f9..0fb63e0 100644 --- a/notifications_python_client/authentication.py +++ b/notifications_python_client/authentication.py @@ -4,7 +4,11 @@ import jwt from notifications_python_client.errors import ( - TokenDecodeError, TokenExpiredError) + TokenDecodeError, + TokenExpiredError, + TokenIssuerError, + TokenIssuedAtError +) __algorithm__ = "HS256" __type__ = "JWT" @@ -52,14 +56,18 @@ def get_token_issuer(token): Does not check validity of the token :param token: signed JWT token :return issuer: iss field of the JWT token - :raises AssertionError: is iss field not present + :raises TokenIssuerError: if iss field not present + :raises TokenDecodeError: if token does not conform to JWT spec """ try: unverified = decode_token(token) - assert 'iss' in unverified - return unverified['iss'] + + if 'iss' not in unverified: + raise TokenIssuerError + + return unverified.get('iss') except jwt.DecodeError: - raise TokenDecodeError("Invalid token") + raise TokenDecodeError def decode_jwt_token(token, secret): @@ -72,7 +80,8 @@ def decode_jwt_token(token, secret): :param token: jwt token :param secret: client specific secret :return boolean: True if valid token, False otherwise - :raises AssertionError: If any required fields are not present + :raises TokenIssuerError: if iss field not present + :raises TokenIssuedAtError: if iat field not present :raises jwt.DecodeError: If signature validation fails """ try: @@ -84,10 +93,11 @@ def decode_jwt_token(token, secret): algorithms=[__algorithm__], leeway=__bound__ ) - # token has all the required fields - assert 'iss' in decoded_token, 'Missing iss field in token' - assert 'iat' in decoded_token, 'Missing iat field in token' + if 'iss' not in decoded_token: + raise TokenIssuerError + if 'iat' not in decoded_token: + raise TokenIssuedAtError # check iat time is within bounds now = epoch_seconds() @@ -100,7 +110,7 @@ def decode_jwt_token(token, secret): except jwt.InvalidIssuedAtError: raise TokenExpiredError("Token has invalid iat field", decode_token(token)) except jwt.DecodeError: - raise TokenDecodeError("Invalid token") + raise TokenDecodeError def decode_token(token): diff --git a/notifications_python_client/base.py b/notifications_python_client/base.py index 2b7bbc4..1ae5a16 100644 --- a/notifications_python_client/base.py +++ b/notifications_python_client/base.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -from monotonic import monotonic +from time import monotonic from notifications_python_client.errors import HTTPError, InvalidResponse from notifications_python_client.authentication import create_jwt_token from notifications_python_client.version import __version__ diff --git a/notifications_python_client/errors.py b/notifications_python_client/errors.py index da1c3b3..1033fb4 100644 --- a/notifications_python_client/errors.py +++ b/notifications_python_client/errors.py @@ -10,13 +10,24 @@ def __init__(self, message, token=None): class TokenDecodeError(TokenError): - pass + def __init__(self, message=None): + super().__init__(message or 'Invalid token: signature') class TokenExpiredError(TokenError): pass +class TokenIssuerError(TokenDecodeError): + def __init__(self): + super().__init__('Invalid token: iss field not provided') + + +class TokenIssuedAtError(TokenDecodeError): + def __init__(self): + super().__init__('Invalid token: iat field not provided') + + class APIError(Exception): def __init__(self, response=None, message=None): self.response = response diff --git a/notifications_python_client/version.py b/notifications_python_client/version.py index 96e3ce8..afced14 100644 --- a/notifications_python_client/version.py +++ b/notifications_python_client/version.py @@ -1 +1 @@ -__version__ = '1.4.0' +__version__ = '2.0.0' diff --git a/requirements.txt b/requirements.txt index 066154b..e672575 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -monotonic==0.3 requests==2.7.0 PyJWT==1.4.0 -docopt==0.6.2 \ No newline at end of file +docopt==0.6.2 diff --git a/tests/notifications_python_client/test_authentication.py b/tests/notifications_python_client/test_authentication.py index aa96ab8..70bf910 100644 --- a/tests/notifications_python_client/test_authentication.py +++ b/tests/notifications_python_client/test_authentication.py @@ -9,7 +9,7 @@ from notifications_python_client.authentication import ( create_jwt_token, decode_jwt_token, get_token_issuer) from notifications_python_client.errors import ( - TokenExpiredError, TokenDecodeError) + TokenExpiredError, TokenDecodeError, TokenIssuerError, TokenIssuedAtError) # helper method to directly decode token @@ -65,7 +65,7 @@ def test_should_reject_token_with_invalid_key(): with pytest.raises(TokenDecodeError) as e: decode_jwt_token(token=token, secret="wrong-key") - assert e.value.message == "Invalid token" + assert e.value.message == "Invalid token: signature" def test_should_reject_token_that_is_too_old(): @@ -112,14 +112,43 @@ def test_should_handle_random_inputs(): with pytest.raises(TokenDecodeError) as e: decode_jwt_token("token", "key") - assert e.value.message == "Invalid token" + assert e.value.message == "Invalid token: signature" def test_should_handle_invalid_token_for_issuer_lookup(): with pytest.raises(TokenDecodeError) as e: get_token_issuer("token") - assert e.value.message == "Invalid token" + assert e.value.message == "Invalid token: signature" + + +def test_get_token_issuer_should_handle_invalid_token_with_no_iss(): + token = create_jwt_token("key", "client_id") + token = jwt.encode( + payload={'iat': 1234}, + key='1234', + headers={'typ': 'JWT', 'alg': 'HS256'} + ).decode() + + with pytest.raises(TokenIssuerError): + get_token_issuer(token) + + +@pytest.mark.parametrize('missing_field,exc_class', [ + ('iss', TokenIssuerError), + ('iat', TokenIssuedAtError), +]) +def test_decode_should_handle_invalid_token_with_missing_field(missing_field, exc_class): + payload = {'iss': '1234', 'iat': '1234'} + payload.pop(missing_field) + token = jwt.encode( + payload=payload, + key='bar', + headers={'typ': 'JWT', 'alg': 'HS256'} + ) + + with pytest.raises(exc_class): + decode_jwt_token(token, 'bar') def test_should_return_issuer_from_token():