Skip to content

Commit

Permalink
Merge pull request #579 from Wykiki/feat/refresh-user
Browse files Browse the repository at this point in the history
feat: Implement `OAuthenticator.refresh_user`
  • Loading branch information
minrk authored Nov 27, 2024
2 parents 9caf311 + f2a8566 commit 3669734
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 54 deletions.
179 changes: 139 additions & 40 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class OAuthenticator(Authenticator):
- Override the constant `user_auth_state_key`
- Override various config's default values, such as
`authorize_url`, `token_url`, `userdata_url`, and `login_service`.
- Override various methods called by the `authenticate` method, which
- Override various methods called by :meth:`authenticate`, which
subclasses should not override.
- Override handler classes such as `login_handler`, `callback_handler`, and
`logout_handler`.
Expand Down Expand Up @@ -919,7 +919,8 @@ def get_handlers(self, app):
def build_userdata_request_headers(self, access_token, token_type):
"""
Builds and returns the headers to be used in the userdata request.
Called by the :meth:`oauthenticator.OAuthenticator.token_to_user`
Called by :meth:`.token_to_user`.
"""

# token_type is case-insensitive, but the headers are case-sensitive
Expand All @@ -937,7 +938,8 @@ def build_userdata_request_headers(self, access_token, token_type):
def build_token_info_request_headers(self):
"""
Builds and returns the headers to be used in the access token request.
Called by the :meth:`oauthenticator.OAuthenticator.get_token_info`.
Called by :meth:`.get_token_info`.
The Content-Type header is specified by the OAuth 2.0 RFC in
https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3. utf-8 is also
Expand Down Expand Up @@ -971,7 +973,7 @@ def user_info_to_username(self, user_info):
Returns:
user_info["self.username_claim"] or raises an error if such value isn't found.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

if callable(self.username_claim):
Expand All @@ -987,25 +989,12 @@ def user_info_to_username(self, user_info):

return username

# Originally a GoogleOAuthenticator only feature
async def get_prev_refresh_token(self, handler, username):
"""
Retrieves the `refresh_token` from previous encrypted auth state.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
user = handler.find_user(username)
if not user:
return None
auth_state = await user.get_auth_state()
if not auth_state:
return None
return auth_state.get("refresh_token", None)

def build_access_tokens_request_params(self, handler, data=None):
"""
Builds the parameters that should be passed to the URL request
that exchanges the OAuth code for the Access Token.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`.
Called by :meth:`.authenticate`.
"""
code = handler.get_argument("code")
if not code:
Expand Down Expand Up @@ -1042,14 +1031,36 @@ def build_access_tokens_request_params(self, handler, data=None):

return params

def build_refresh_token_request_params(self, refresh_token):
"""
Builds the parameters that should be passed to the URL request
to renew the Access Token based on the Refresh Token
Called by :meth:`.refresh_user`.
"""
params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
}

# the client_id and client_secret should not be included in the access token request params
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
if not self.basic_auth:
params["client_id"] = self.client_id
params["client_secret"] = self.client_secret

return params

async def get_token_info(self, handler, params):
"""
Makes a "POST" request to `self.token_url`, with the parameters received as argument.
Returns:
the JSON response to the `token_url` the request.
the JSON response to the `token_url` the request as described in
https://www.rfc-editor.org/rfc/rfc6749#section-5.1
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

token_info = await self.httpfetch(
Expand All @@ -1073,9 +1084,9 @@ async def get_token_info(self, handler, params):
async def token_to_user(self, token_info):
"""
Determines who the logged-in user by sending a "GET" request to
:data:`oauthenticator.OAuthenticator.userdata_url` using the `access_token`.
:attr:`.userdata_url` using the `access_token`.
If :data:`oauthenticator.OAuthenticator.userdata_from_id_token` is set then
If :attr:`.userdata_from_id_token` is set then
extracts the corresponding info from an `id_token` instead.
Args:
Expand All @@ -1084,7 +1095,7 @@ async def token_to_user(self, token_info):
Returns:
the JSON response to the `userdata_url` request.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""
if self.userdata_from_id_token:
# Use id token instead of exchanging access token with userinfo endpoint.
Expand Down Expand Up @@ -1134,7 +1145,7 @@ async def token_to_user(self, token_info):

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Builds the `auth_state` dict that will be returned by a successful `authenticate` method call.
May be async (requires oauthenticator >= 17.0).
Args:
Expand All @@ -1150,13 +1161,13 @@ def build_auth_state_dict(self, token_info, user_info):
- "token_response": the full token_info response
- self.user_auth_state_key: the full user_info response
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
.. versionchanged:: 17.0
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
# We know for sure the `access_token` key exists, otherwise we would have errored out already
access_token = token_info["access_token"]

refresh_token = token_info.get("refresh_token", None)
Expand Down Expand Up @@ -1221,9 +1232,9 @@ async def update_auth_model(self, auth_model):
- `admin`: the admin status (True/False/None), where None means it
should be unchanged.
- `auth_state`: the auth state dictionary,
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
returned by :meth:`.build_auth_state_dict`
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""
# NOTE: this base implementation should _not_ be updated to do anything
# subclasses should have full control without calling super()
Expand Down Expand Up @@ -1276,24 +1287,112 @@ async def authenticate(self, handler, data=None, **kwargs):
"""
# build the parameters to be used in the request exchanging the oauth code for the access token
access_token_params = self.build_access_tokens_request_params(handler, data)
# exchange the oauth code for an access token and get the JSON with info about it
token_info = await self.get_token_info(handler, access_token_params)
# call the oauth endpoints
return await self._token_to_auth_model(token_info)

async def refresh_user(self, user, handler=None, **kwargs):
"""
Refresh user authentication
If auth_state is enabled, constructs a fresh user model
(the same as `authenticate`)
using the access_token in auth_state.
If requests with the access token fail
(e.g. because the token has expired)
and a refresh token is found, attempts to exchange
the refresh token for a new access token to store in auth_state.
If the access token still fails after refresh,
return False to require the user to login via oauth again.
Set `Authenticator.auth_refresh_age = 0` to disable.
Returns
-------
True:
If auth info is up-to-date and needs no changes
(always if `enable_auth_state` is False)
False:
If the user needs to login again
(e.g. tokens in `auth_state` unavailable or expired)
auth_model: dict
The same dict as `authenticate`, updating any fields that should change.
Can include things like group membership,
but in OAuthenticator this mainly refreshes
the token fields in `auth_state`.
"""
if not self.enable_auth_state:
# auth state not enabled, can't refresh
return True
auth_state = await user.get_auth_state()
if not auth_state:
self.log.info(
f"No auth_state found for user {user.name} refresh, need full authentication",
)
return False

token_info = auth_state.get("token_response")
auth_model = None
try:
auth_model = await self._token_to_auth_model(token_info)
except HTTPClientError as e:
# assume any client error means an expired token
# most likely 401 or 403 for well-behaved providers
if 400 <= e.code < 500:
self.log.info(
f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible."
)
else:
raise
refresh_token = auth_state.get("refresh_token", None)
if refresh_token and not auth_model:
self.log.info(f"Refreshing oauth access token for {user.name}")
# access_token expired, try refreshing with refresh_token
refresh_token_params = self.build_refresh_token_request_params(
refresh_token
)
try:
token_info = await self.get_token_info(handler, refresh_token_params)
except Exception as e:
self.log.info(
f"Error using refresh_token for {user.name}: {e}. Requiring fresh login."
)
return False
else:
self.log.debug(
f"Received fresh access_token for {user.name} via refresh_token"
)
# refresh_token may not be returned when refreshing a token
# in which case, keep the current one
if not token_info.get("refresh_token"):
token_info["refresh_token"] = refresh_token
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# this means we were issued a fresh access token,
# but it didn't work! Fail harder?
self.log.error(
f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login."
)
return False

# return False if auth_model is None for "needs new login"
return auth_model or False

async def _token_to_auth_model(self, token_info):
"""
Turn a token into the user's `auth_model` to be returned by :meth:`.authenticate`.
Common logic shared by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

# use the access_token to get userdata info
user_info = await self.token_to_user(token_info)
# extract the username out of the user_info dict and normalize it
username = self.user_info_to_username(user_info)
username = self.normalize_username(username)

# check if there any refresh_token in the token_info dict
refresh_token = token_info.get("refresh_token", None)
if self.enable_auth_state and not refresh_token:
self.log.debug(
"Refresh token was empty, will try to pull refresh_token from previous auth_state"
)
refresh_token = await self.get_prev_refresh_token(handler, username)
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if isawaitable(auth_state):
auth_state = await auth_state
Expand Down
52 changes: 41 additions & 11 deletions oauthenticator/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def setup_oauth_mock(
user_path=None,
token_type='Bearer',
token_request_style='post',
enable_refresh_tokens=False,
scope="",
):
"""setup the mock client for OAuth
Expand Down Expand Up @@ -134,6 +135,8 @@ def setup_oauth_mock(

client.oauth_codes = oauth_codes = {}
client.access_tokens = access_tokens = {}
client.refresh_tokens = refresh_tokens = {}
client.enable_refresh_tokens = enable_refresh_tokens

def access_token(request):
"""Handler for access token endpoint
Expand All @@ -146,26 +149,53 @@ def access_token(request):
if not query:
query = request.body.decode('utf8')
query = parse_qs(query)
if 'code' not in query:
grant_type = query.get("grant_type", [""])[0]
if grant_type == 'authorization_code':
if 'code' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
)
user = oauth_codes.pop(code)
elif grant_type == 'refresh_token':
if 'refresh_token' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No refresh_token in access token request: url={request.url}, body={request.body}",
)
refresh_token = query['refresh_token'][0]
if refresh_token not in refresh_token:
return HTTPResponse(
request=request,
code=403,
reason=f"No such refresh_toekn: {refresh_token}",
)
user = refresh_tokens[refresh_token]
else:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}",
)

# consume code, allocate token
token = uuid.uuid4().hex
user = oauth_codes.pop(code)
access_tokens[token] = user
access_token = uuid.uuid4().hex
access_tokens[access_token] = user
model = {
'access_token': token,
'access_token': access_token,
'token_type': token_type,
}
if client.enable_refresh_tokens:
refresh_token = uuid.uuid4().hex
refresh_tokens[refresh_token] = user
model['refresh_token'] = refresh_token
if scope:
model['scope'] = scope
if 'id_token' in user:
Expand Down
Loading

0 comments on commit 3669734

Please sign in to comment.