Skip to content
This repository has been archived by the owner on Oct 26, 2021. It is now read-only.

Commit

Permalink
Documentation updates and unittest for 'aud' claims with endpoints
Browse files Browse the repository at this point in the history
(refs #1)
  • Loading branch information
vimalloc committed Sep 19, 2017
1 parent 0bec806 commit 5dacfab
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ You can change many options for how this extension works via
such as ``RS*`` or ``ES*``. PEM format expected.
``JWT_IDENTITY_CLAIM`` Which claim the `get_jwt_identity()` function will use to get
the identity out of a JWT. Defaults to ``'sub'``.
``JWT_DECODE_AUDIENCE`` The audience expected to be set in the JWT token when it is decoded.
``JWT_DECODE_AUDIENCE`` The audience you expect in a JWT when decoding it. Defaults
to ``None``. If this option differs from the 'aud' claim
in a JWT, the ``invalid_token_callback`` is invoked.
================================= =========================================
86 changes: 80 additions & 6 deletions tests/test_protected_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def app(request):
JWTManager(app)

@app.route('/jwt', methods=['POST'])
def fresh_access_jwt():
def create_token_endpoint():
access_token = create_jwt('username')
return jsonify(jwt=access_token)

Expand All @@ -88,9 +88,9 @@ def protected():
@jwt_optional
def optional():
if get_jwt_identity():
return jsonify(foo='baz')
else:
return jsonify(foo='bar')
else:
return jsonify(foo='baz')

return app

Expand Down Expand Up @@ -139,7 +139,7 @@ def test_optional_without_jwt(app):
json_data = json.loads(response.get_data(as_text=True))

assert response.status_code == 200
assert json_data == {'foo': 'bar'}
assert json_data == {'foo': 'baz'}


def test_optional_with_jwt(app):
Expand All @@ -149,7 +149,7 @@ def test_optional_with_jwt(app):
json_data = json.loads(response.get_data(as_text=True))

assert response.status_code == 200
assert json_data == {'foo': 'baz'}
assert json_data == {'foo': 'bar'}


@pytest.mark.parametrize("header_name", ['Authorization', 'Foo'])
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_with_bad_header(app, endpoint, header_type):
expected_results = (
(422, {'msg': "Bad Authorization header. Expected value '<JWT>'"}),
(422, {'msg': "Bad Authorization header. Expected value 'Foo <JWT>'"}),
(200, {'foo': "bar"}) # Returns this if unauthorized in jwt_optional test endpoint
(200, {'foo': "baz"}) # Returns this if unauthorized in jwt_optional test endpoint
)
assert (response.status_code, json_data) in expected_results

Expand Down Expand Up @@ -227,3 +227,77 @@ def test_expired_token(app, endpoint):

assert json_data == {'msg': 'Token has expired'}
assert response.status_code == 401


@pytest.mark.parametrize("endpoint", [
'/protected',
'/optional',
])
def test_valid_aud(app, endpoint):
app.config['JWT_DECODE_AUDIENCE'] = 'foo'
jwt_manager = app.extensions['flask-jwt-simple']

@jwt_manager.jwt_data_loader
def change_claims(identity):
now = datetime.datetime.utcnow()
identity_claim = app.config['JWT_IDENTITY_CLAIM']
return {
'exp': now + app.config['JWT_EXPIRES'],
'iat': now,
'nbf': now,
identity_claim: identity,
'aud': 'foo'
}

test_client = app.test_client()
jwt = _get_jwt(test_client)
response = _make_jwt_request(test_client, jwt, endpoint)
json_data = json.loads(response.get_data(as_text=True))

assert response.status_code == 200
assert json_data == {'foo': 'bar'}


@pytest.mark.parametrize("endpoint", [
'/protected',
'/optional',
])
def test_invalid_aud(app, endpoint):
app.config['JWT_DECODE_AUDIENCE'] = 'bar'
jwt_manager = app.extensions['flask-jwt-simple']

@jwt_manager.jwt_data_loader
def change_claims(identity):
now = datetime.datetime.utcnow()
identity_claim = app.config['JWT_IDENTITY_CLAIM']
return {
'exp': now + app.config['JWT_EXPIRES'],
'iat': now,
'nbf': now,
identity_claim: identity,
'aud': 'foo'
}

test_client = app.test_client()
jwt = _get_jwt(test_client)
response = _make_jwt_request(test_client, jwt, endpoint)
json_data = json.loads(response.get_data(as_text=True))

assert response.status_code == 422
assert json_data == {'msg': 'Invalid audience'}


@pytest.mark.parametrize("endpoint", [
'/protected',
'/optional',
])
def test_missing_aud(app, endpoint):
app.config['JWT_DECODE_AUDIENCE'] = 'bar'

test_client = app.test_client()
jwt = _get_jwt(test_client)
response = _make_jwt_request(test_client, jwt, endpoint)
json_data = json.loads(response.get_data(as_text=True))

assert response.status_code == 422
assert json_data == {'msg': 'Token is missing the "aud" claim'}

0 comments on commit 5dacfab

Please sign in to comment.