Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve control over identity_id extraction in experimental TokenStorage #1055

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Fixed
~~~~~

.. rubric:: Experimental

- Fix the handling of Dependent Token and Refresh Token responses in
``TokenStorage`` and ``GlobusApp``'s internal ``ValidatingTokenStorage`` in
order to ensure that ``id_token`` is only parsed when appropriate. (:pr:`NUMBER`)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import typing as t

from globus_sdk import AuthClient, Scope
from globus_sdk import AuthClient, OAuthRefreshTokenResponse, OAuthTokenResponse, Scope
from globus_sdk.experimental.tokenstorage import TokenStorage, TokenStorageData
from globus_sdk.scopes.consents import ConsentForest

from ..._types import UUIDLike
from .errors import (
IdentityMismatchError,
MissingIdentityError,
Expand Down Expand Up @@ -81,7 +80,7 @@ def __init__(

super().__init__(namespace=token_storage.namespace)

def _lookup_stored_identity_id(self) -> UUIDLike | None:
def _lookup_stored_identity_id(self) -> str | None:
"""
Attempts to extract an identity id from stored token data using the internal
token storage.
Expand Down Expand Up @@ -276,3 +275,15 @@ def _poll_and_cache_consents(self) -> ConsentForest | None:
# Cache the consent forest first.
self._cached_consent_forest = forest
return forest

def _extract_identity_id(self, token_response: OAuthTokenResponse) -> str | None:
"""
Override determination of the identity_id for a token response.

When handling a refresh token, use the stored identity ID.
Otherwise, call the inner token storage's method of lookup.
"""
if isinstance(token_response, OAuthRefreshTokenResponse):
return self.identity_id
else:
return self.token_storage._extract_identity_id(token_response)
33 changes: 26 additions & 7 deletions src/globus_sdk/experimental/tokenstorage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sys
import typing as t

from globus_sdk.services.auth import OAuthTokenResponse
from globus_sdk.services.auth import OAuthDependentTokenResponse, OAuthTokenResponse

from ... import GlobusSDKUsageError
from ..._types import UUIDLike
Expand Down Expand Up @@ -88,12 +88,7 @@ def store_token_response(self, token_response: OAuthTokenResponse) -> None:
"""
token_data_by_resource_server = {}

# get identity_id from id_token if available
if token_response.get("id_token"):
decoded_id_token = token_response.decode_id_token()
identity_id = decoded_id_token["sub"]
else:
identity_id = None
identity_id = self._extract_identity_id(token_response)

for resource_server, token_dict in token_response.by_resource_server.items():
token_data_by_resource_server[resource_server] = TokenStorageData(
Expand All @@ -107,6 +102,30 @@ def store_token_response(self, token_response: OAuthTokenResponse) -> None:
)
self.store_token_data_by_resource_server(token_data_by_resource_server)

def _extract_identity_id(self, token_response: OAuthTokenResponse) -> str | None:
"""
Get identity_id from id_token if available.

.. note::

This method is private, but is used in ValidatingTokenStorage to
override the extraction of ``identity_id`` information.

Generalizing customization of ``identity_id`` extraction will require
implementation of a user-facing mechanism for controlling calls to
``decode_id_token()``.
"""
# dependent token responses cannot contain an `id_token` field, as the
# top-level data is an array
if isinstance(token_response, OAuthDependentTokenResponse):
return None

if token_response.get("id_token"):
decoded_id_token = token_response.decode_id_token()
return decoded_id_token["sub"] # type: ignore[no-any-return]
else:
return None


class FileTokenStorage(TokenStorage, metaclass=abc.ABCMeta):
"""
Expand Down
103 changes: 103 additions & 0 deletions tests/functional/tokenstorage_v2/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
import time
import uuid
from unittest import mock

import pytest

import globus_sdk
from globus_sdk._testing import RegisteredResponse
from globus_sdk.experimental.tokenstorage import TokenStorageData


@pytest.fixture
def id_token_sub():
return str(uuid.UUID(int=1))


@pytest.fixture
def cc_auth_client(no_retry_transport):
class CustomAuthClient(globus_sdk.ConfidentialAppAuthClient):
transport_class = no_retry_transport

return CustomAuthClient("dummy_id", "dummy_secret")


@pytest.fixture
def mock_token_data_by_resource_server():
expiration_time = int(time.time()) + 3600
Expand Down Expand Up @@ -57,3 +73,90 @@ def mock_response():
res.decode_id_token.return_value = {"sub": "user_id"}

return res


@pytest.fixture
def dependent_token_response(cc_auth_client):
expiration_time = int(time.time()) + 3600
RegisteredResponse(
service="auth",
path="/v2/oauth2/token",
method="POST",
json=[
{
"access_token": "access_token_1",
"expires_at_seconds": expiration_time,
"refresh_token": "refresh_token_1",
"resource_server": "resource_server_1",
"scope": "scope1",
"token_type": "Bearer",
},
{
"access_token": "access_token_2",
"expires_at_seconds": expiration_time,
"refresh_token": "refresh_token_2",
"resource_server": "resource_server_2",
"scope": "scope2 scope2:0 scope2:1",
"token_type": "Bearer",
},
],
).add()
return cc_auth_client.oauth2_get_dependent_tokens("dummy_tok")


@pytest.fixture
def authorization_code_response(cc_auth_client, id_token_sub):
cc_auth_client.oauth2_start_flow("https://example.com/redirect-uri", "dummy-scope")

expiration_time = int(time.time()) + 3600
RegisteredResponse(
service="auth",
path="/v2/oauth2/token",
method="POST",
json={
"access_token": "access_token_1",
"expires_at_seconds": expiration_time,
"refresh_token": "refresh_token_1",
"resource_server": "resource_server_1",
"scope": "scope1",
"token_type": "Bearer",
"id_token": "dummy_id_token",
"other_tokens": [
{
"access_token": "access_token_2",
"expires_at_seconds": expiration_time,
"refresh_token": "refresh_token_2",
"resource_server": "resource_server_2",
"scope": "scope2 scope2:0 scope2:1",
"token_type": "Bearer",
},
],
},
).add()

# because it's more difficult to mock the full decode_id_token() interaction in
# detail, directly mock the result of it to return the desired subject (identity_id)
# value
response = cc_auth_client.oauth2_exchange_code_for_tokens("dummy_code")
with mock.patch.object(response, "decode_id_token", lambda: {"sub": id_token_sub}):
yield response


@pytest.fixture
def refresh_token_response(cc_auth_client):
expiration_time = int(time.time()) + 3600
RegisteredResponse(
service="auth",
path="/v2/oauth2/token",
method="POST",
json={
"access_token": "access_token_1",
"expires_at_seconds": expiration_time,
"refresh_token": "refresh_token_1",
"resource_server": "resource_server_1",
"scope": "scope1",
"token_type": "Bearer",
"other_tokens": [],
},
).add()
return cc_auth_client.oauth2_refresh_token("dummy_token")
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest

from globus_sdk.experimental.tokenstorage import (
JSONTokenStorage,
MemoryTokenStorage,
SQLiteTokenStorage,
)


@pytest.fixture(params=["json", "sqlite", "memory"])
def storage(request, tmp_path):
if request.param == "json":
file = tmp_path / "mydata.json"
yield JSONTokenStorage(file)
elif request.param == "sqlite":
file = tmp_path / "mydata.db"
store = SQLiteTokenStorage(file)
yield store
store.close()
else:
yield MemoryTokenStorage()


def test_store_authorization_code_response(
storage, authorization_code_response, id_token_sub
):
storage.store_token_response(authorization_code_response)

tok_by_rs = authorization_code_response.by_resource_server

stored_data = storage.get_token_data_by_resource_server()

for resource_server in ["resource_server_1", "resource_server_2"]:
for fieldname in (
"resource_server",
"scope",
"access_token",
"refresh_token",
"expires_at_seconds",
"token_type",
):
assert tok_by_rs[resource_server][fieldname] == getattr(
stored_data[resource_server], fieldname
)
assert "identity_id" not in tok_by_rs[resource_server]
assert stored_data[resource_server].identity_id == id_token_sub


def test_store_dependent_token_response(storage, dependent_token_response):
"""
If a TokenStorage is asked to store dependent token data, it should work and
produce identity_id values of None (because there is no id_token to inspect)
"""
storage.store_token_response(dependent_token_response)

dep_tok_by_rs = dependent_token_response.by_resource_server

stored_data = storage.get_token_data_by_resource_server()

for resource_server in ["resource_server_1", "resource_server_2"]:
for fieldname in (
"resource_server",
"scope",
"access_token",
"refresh_token",
"expires_at_seconds",
"token_type",
):
assert dep_tok_by_rs[resource_server][fieldname] == getattr(
stored_data[resource_server], fieldname
)
assert stored_data[resource_server].identity_id is None
assert "identity_id" not in dep_tok_by_rs[resource_server]


def test_store_refresh_token_response(storage, refresh_token_response):
"""
If a TokenStorage is asked to store refresh token data, it should work and
produce identity_id values of None (because there is no id_token to inspect)
"""
storage.store_token_response(refresh_token_response)

refresh_tok_by_rs = refresh_token_response.by_resource_server

stored_data = storage.get_token_data_by_resource_server()

for fieldname in (
"resource_server",
"scope",
"access_token",
"refresh_token",
"expires_at_seconds",
"token_type",
):
assert refresh_tok_by_rs["resource_server_1"][fieldname] == getattr(
stored_data["resource_server_1"], fieldname
)
assert stored_data["resource_server_1"].identity_id is None
assert "identity_id" not in refresh_tok_by_rs["resource_server_1"]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
import pytest

import globus_sdk
from globus_sdk import MISSING, MissingType, OAuthTokenResponse, Scope
from globus_sdk import (
MISSING,
MissingType,
OAuthRefreshTokenResponse,
OAuthTokenResponse,
Scope,
)
from globus_sdk.experimental.globus_app import ValidatingTokenStorage
from globus_sdk.experimental.globus_app.errors import (
IdentityMismatchError,
Expand Down Expand Up @@ -117,6 +123,33 @@ def test_validating_token_storage_loads_identity_info_from_storage(
assert new_adapter.identity_id == identity_id


def test_validating_token_storage_stores_with_saved_identity_id_on_refresh_tokens(
make_token_response,
):
# Create an in memory storage adapter
storage = MemoryTokenStorage()
adapter = ValidatingTokenStorage(storage, {})

# Store an identifiable token response
identity_id = str(uuid.uuid4())
token_response = make_token_response(identity_id=identity_id)
adapter.store_token_response(token_response)

# now get and store a replacement token response, identified with a different user
# however, in this case make it a refresh token response
other_identity_id = str(uuid.uuid4())
refresh_token_response = make_token_response(
response_class=OAuthRefreshTokenResponse, identity_id=other_identity_id
)
adapter.store_token_response(refresh_token_response)

# read back the data, and verify that it contains tokens from the refresh, but the
# original identity_id
result = adapter.get_token_data("auth.globus.org")
assert result.access_token == refresh_token_response["access_token"]
assert result.identity_id == identity_id


def test_validating_token_storage_raises_error_when_no_token_data():
adapter = ValidatingTokenStorage(MemoryTokenStorage(), {})

Expand All @@ -129,6 +162,7 @@ def make_token_response(make_response):
def _make_token_response(
scopes: dict[str, str] | None = None,
identity_id: str | None | MissingType = MISSING,
response_class: type[OAuthTokenResponse] = OAuthTokenResponse,
):
"""
:param scopes: A dictionary of resource server to scope mappings to fill in
Expand Down Expand Up @@ -169,7 +203,7 @@ def _make_token_response(
# be a real JWT ID token.
data["id_token"] = _make_id_token()

response = make_response(response_class=OAuthTokenResponse, json_body=data)
response = make_response(response_class=response_class, json_body=data)

if identity_id is not None:
decoded_id_token = _decoded_id_token(identity_id)
Expand Down