Skip to content

Commit

Permalink
Added automatic access_token refresh when they expire.
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven committed Nov 21, 2024
1 parent a75df53 commit 8e90738
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 50 deletions.
16 changes: 1 addition & 15 deletions gazu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,7 @@ def log_out(client=raw.default_client):


def refresh_token(client=raw.default_client):
headers = {"User-Agent": "CGWire Gazu %s" % __version__}
if "refresh_token" in client.tokens:
headers["Authorization"] = "Bearer %s" % client.tokens["refresh_token"]

response = client.session.get(
raw.get_full_url("auth/refresh-token", client=client),
headers=headers,
)
raw.check_status(response, "auth/refresh-token")

tokens = response.json()

client.tokens["access_token"] = tokens["access_token"]

return tokens
return client.refresh_authentication_tokens()


def get_event_host(client=raw.default_client):
Expand Down
77 changes: 60 additions & 17 deletions gazu/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import shutil
import urllib
import os
import jwt
from datetime import datetime

from .encoder import CustomJSONEncoder

Expand Down Expand Up @@ -47,6 +49,60 @@ def __init__(
self.automatic_refresh_token = automatic_refresh_token
self.callback_not_authenticated = callback_not_authenticated

@property
def access_token(self):
return self.tokens.get("access_token", None)

@access_token.setter
def access_token(self, token):
self.tokens["access_token"] = token

@property
def refresh_token(self):
return self.tokens.get("refresh_token", None)

@property
def access_token_has_expired(self):
""" Returns: Whether this client's access token needs to be refreshed. """
if not self.access_token:
# No access token is present, refresh only when able with a refresh token.
return True if self.refresh_token else False

# Decode the access token.
decoded_token = jwt.decode(jwt=self.access_token, options={'verify_signature': False})
expiration_datetime = datetime.fromtimestamp(decoded_token['exp'])

# NOTE - Due to shenanigans caused by possible timezone differences
# between server and client, the access token is considered
# stale when less than 24 hours remain to its expiration.

return bool((expiration_datetime - datetime.now()).days < 1)

def refresh_authentication_tokens(self):
""" Refresh access tokens for this client."""
response = self.session.get(
get_full_url("auth/refresh-token", client=self),
headers={"User-Agent": "CGWire Gazu " + __version__,
"Authorization": "Bearer " + self.refresh_token})
check_status(response, "auth/refresh-token")
tokens = response.json()

self.access_token = tokens["access_token"]

return tokens

def make_auth_header(self):
""" Returns: Headers required to authenticate. """
headers = {"User-Agent": "CGWire Gazu " + __version__}

if self.access_token:
if self.access_token_has_expired and self.:
self.refresh_authentication_tokens()

headers["Authorization"] = "Bearer " + self.access_token

return headers


def create_client(
host,
Expand Down Expand Up @@ -158,14 +214,7 @@ def set_tokens(new_tokens, client=default_client):


def make_auth_header(client=default_client):
"""
Returns:
Headers required to authenticate.
"""
headers = {"User-Agent": "CGWire Gazu %s" % __version__}
if "access_token" in client.tokens:
headers["Authorization"] = "Bearer %s" % client.tokens["access_token"]
return headers
return client.make_auth_header()


def url_path_join(*items):
Expand Down Expand Up @@ -365,19 +414,13 @@ def check_status(request, path, client=None):
)
elif status_code in [401, 422]:
try:
if client is not None and client.automatic_refresh_token:
from . import refresh_token

refresh_token(client=client)

if client and client.automatic_refresh_token:
client.refresh_authentication_tokens()
return status_code, True
else:
raise NotAuthenticatedException(path)
except NotAuthenticatedException:
if (
client is not None
and client.callback_not_authenticated is not None
):
if client and client.callback_not_authenticated:
retry = client.callback_not_authenticated(client, path)
if retry:
return status_code, True
Expand Down
91 changes: 73 additions & 18 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import random
import string
import jwt

import unittest
import requests_mock
Expand Down Expand Up @@ -70,7 +71,8 @@ def test_set_tokens(self):
pass

def test_make_auth_header(self):
pass
self.assertEqual(first=raw.default_client.make_auth_header(),
second=raw.make_auth_header())

def test_url_path_join(self):
root = raw.get_host()
Expand Down Expand Up @@ -254,15 +256,72 @@ def test_version(self):
self.assertEqual(raw.get_api_version(), "0.2.0")

def test_make_auth_token(self):
tokens = {"access_token": "token_test"}
tokens = {"access_token": jwt.encode(
payload={'exp': (datetime.datetime.now() + datetime.timedelta(days=30)).timestamp()},
key='secretkey')}
raw.set_tokens(tokens)
self.assertEqual(
raw.make_auth_header(),
{
"Authorization": "Bearer token_test",
"User-Agent": "CGWire Gazu %s" % __version__,
},
)
self.assertEqual(raw.make_auth_header(),
{"Authorization": "Bearer " + tokens["access_token"],
"User-Agent": "CGWire Gazu " + __version__})

def test_access_token_has_expired(self):
client = raw.KitsuClient(host='http://localhost')
test_cases = {'fresh': (datetime.timedelta(days=30), False),
'expired': (datetime.timedelta(days=-1), True)}
for testcase in test_cases.items():
client.access_token = jwt.encode(
payload={'exp': (datetime.datetime.now() + testcase[-1][0]).timestamp()},
key='secretkey')

self.assertEqual(
first=client.access_token_has_expired, second=testcase[-1][-1],
msg=testcase[0] + ' Access Token correctly detected.')

client.access_token = None
self.assertEqual(first=client.access_token_has_expired, second=False,
msg='')

client.tokens["refresh_token"] = 'placeholder'
self.assertEqual(first=client.access_token_has_expired, second=True)

def test_automatic_token_refresh(self):
def encode(timestamp):
return jwt.encode(payload={'exp': timestamp}, key='secretkey')

expired_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=-30)).timestamp())
fresh_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=30)).timestamp())
new_access_token = encode((datetime.datetime.now() + datetime.timedelta(days=90)).timestamp())

client = raw.KitsuClient(host='http://localhost')
client.tokens["refresh_token"] = 'placeholder'

with requests_mock.Mocker() as mock:
mock_route(mock, "GET", "http://localhost/auth/refresh-token",
text={'access_token': new_access_token})
mock_route(mock, "GET", "http://localhost/test", text={})

client.automatic_refresh_token = False
client.access_token = expired_access_token
client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header())
# Expired tokens are not refreshed if automatic_refresh is False.
self.assertEqual(client.access_token, expired_access_token)

client.access_token = fresh_access_token
client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header())
# Fresh tokens are not changed.
self.assertEqual(client.access_token, fresh_access_token)

client.automatic_refresh_token = True
client.access_token = expired_access_token
client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header())
# Expired tokens are updated if automatic_refresh is True
self.assertEqual(client.access_token, new_access_token)

client.automatic_refresh_token = True
client.access_token = fresh_access_token
client.session.get(raw.get_full_url(path='test', client=client), headers=client.make_auth_header())
# Fresh tokens are not changed.
self.assertEqual(client.access_token, fresh_access_token)

def test_upload(self):
with open("./tests/fixtures/v1.png", "rb") as test_file:
Expand Down Expand Up @@ -429,18 +488,14 @@ def test_init_send_email_otp(self):
self.assertEqual(success, {"success": True})

def test_init_refresh_token(self):
access_token = jwt.encode(payload={'exp': datetime.datetime.now()}, key='secretkey')

with requests_mock.mock() as mock:
raw.default_client.tokens["refresh_token"] = "refresh_token1"
mock_route(
mock,
"GET",
"auth/refresh-token",
text={"access_token": "tokentest1"},
)
mock_route(mock, "GET", "auth/refresh-token", text={"access_token": access_token})
gazu.refresh_token()
self.assertEqual(
raw.default_client.tokens["access_token"], "tokentest1"
)

self.assertEqual(raw.default_client.access_token, access_token)

def test_init_log_in_fail(self):
with requests_mock.mock() as mock:
Expand Down

0 comments on commit 8e90738

Please sign in to comment.