Skip to content

Commit

Permalink
Merge pull request #48 from eth-cscs/unit-test-fix
Browse files Browse the repository at this point in the history
Improve unit tests
  • Loading branch information
rsarm authored Nov 26, 2024
2 parents d6bed1a + 8902722 commit 4915ec2
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 67 deletions.
3 changes: 2 additions & 1 deletion firecrestspawner/spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ async def submit_batch_script(self):
subvars.update(self.user_options)

job_env = self.get_env()
job_env.pop("PATH")
if "PATH" in job_env:
job_env.pop("PATH")

# FIXME: These two variables may have quotes in their values.
# We encoded as base64 since quotes are not allowed
Expand Down
5 changes: 3 additions & 2 deletions tests/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from firecrestspawner.spawner import (
SlurmSpawner,
format_template
AuthorizationCodeFlowAuth,
format_template,
SlurmSpawner
)
39 changes: 35 additions & 4 deletions tests/fc_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,39 @@
from werkzeug.wrappers import Response


def keycloak_handler(request):
if "VALID_REFRESH_TOKEN" not in request.data.decode("utf-8"):
return Response(
json.dumps({"message": "Bad token; invalid JSON"}),
status=401,
content_type="application/json",
)


ret = {
'access_token': 'VALID_ACCESS_TOKEN',
'expires_in': 300,
'refresh_expires_in': 1800,
'refresh_token': 'VALID_REFRESH_TOKEN',
'token_type': 'Bearer',
'id_token': 'ID_TOKEN',
'not-before-policy': 0,
'session_state': 'fc347b39-8dd0-44f5-99d5-c3e5237eabac',
'scope': 'openid firecrest profile email'
}
extra_headers = None
status_code=200

return Response(
json.dumps(ret),
status=status_code,
headers=extra_headers,
content_type="application/json",
)


def whoami_handler(request):
if request.headers["Authorization"] != "Bearer VALID_TOKEN":
if request.headers["Authorization"] != "Bearer VALID_ACCESS_TOKEN":
return Response(
json.dumps({"message": "Bad token; invalid JSON"}),
status=401,
Expand Down Expand Up @@ -274,7 +305,7 @@ def tasks_handler(request):


def submit_upload_handler(request):
if request.headers["Authorization"] != "Bearer VALID_TOKEN":
if request.headers["Authorization"] != "Bearer VALID_ACCESS_TOKEN":
return Response(
json.dumps({"message": "Bad token; invalid JSON"}),
status=401,
Expand Down Expand Up @@ -331,7 +362,7 @@ def submit_upload_handler(request):


def systems_handler(request):
if request.headers["Authorization"] != "Bearer VALID_TOKEN":
if request.headers["Authorization"] != "Bearer VALID_ACCESS_TOKEN":
return Response(
json.dumps({"message": "Bad token; invalid JSON"}),
status=401,
Expand All @@ -358,7 +389,7 @@ def systems_handler(request):


def sacct_handler(request):
if request.headers["Authorization"] != "Bearer VALID_TOKEN":
if request.headers["Authorization"] != "Bearer VALID_ACCESS_TOKEN":
return Response(
json.dumps({"message": "Bad token; invalid JSON"}),
status=401,
Expand Down
157 changes: 97 additions & 60 deletions tests/test_spawner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import getpass
import pytest
from werkzeug.wrappers import Response
from context import SlurmSpawner, format_template
from context import (
AuthorizationCodeFlowAuth,
format_template,
SlurmSpawner
)
from fc_handlers import (
keycloak_handler,
tasks_handler,
submit_upload_handler,
systems_handler,
sacct_handler,
cancel_handler,
whoami_handler
whoami_handler,
)
from jupyterhub.tests.conftest import db
from jupyterhub.user import User
Expand All @@ -24,50 +29,26 @@
testport = random_port()


class DummyOAuthenticator(GenericOAuthenticator):
async def refresh_user(self, user, handler=None):
auth_state = {"access_token": "VALID_TOKEN"}
return {"auth_state": auth_state}


class FirecrestAccessTokenAuth:
"""Utility class to provide an object with the
`get_access_token()` attribute needed by PyFirecREST's
authenticator"""

_access_token: str = None

def __init__(self, access_token):
self._access_token = access_token

def get_access_token(self):
return self._access_token


async def get_firecrest_client(spawner):
auth_state_refreshed = await spawner.user.authenticator.refresh_user(spawner.user) # noqa E501
access_token = auth_state_refreshed['auth_state']['access_token']

client = firecrest.AsyncFirecrest(
firecrest_url=spawner.firecrest_url,
authorization=FirecrestAccessTokenAuth(access_token)
)

return client


# FIXME: Setup the auth state in the unit tests
# Since the auth state is not setup for the unit tests,
# the spawner's get_firecrest_client method will fail
# when trying to get a key from a none `auth_state`
SlurmSpawner.get_firecrest_client = get_firecrest_client
async def get_auth_state():
"""Function to monkey patch `user.authenticator.get_auth_state`
to simulate a hub where the user is already logged in
"""
auth_state = {
"access_token": "VALID_ACCESS_TOKEN",
"refresh_token": "VALID_REFRESH_TOKEN"
}
return auth_state


def new_spawner(db, spawner_class=SlurmSpawner, **kwargs):
user = db.query(orm.User).first()
hub = Hub()
user = User(user, {"authenticator": DummyOAuthenticator()})

user = db.query(orm.User).first()
user = User(user, {"authenticator": GenericOAuthenticator()})
# Monkey patch the `get_auth_state` function to return an
# auth state containing accesss tokens without having login
user.get_auth_state = get_auth_state
user.authenticator.client_id = "client-id"
user.authenticator.client_secret = "client-secret"
_spawner = user._new_spawner(
"",
spawner_class=spawner_class,
Expand All @@ -77,7 +58,7 @@ def new_spawner(db, spawner_class=SlurmSpawner, **kwargs):
req_host="cluster1",
port=testport,
node_name_template="{}.cluster1.ch",
enable_aux_fc_client=False
enable_aux_fc_client=False,
)
return _spawner

Expand All @@ -104,9 +85,12 @@ def fc_server(httpserver):
re.compile("^/compute/jobs.*"), method="DELETE"
).respond_with_handler(cancel_handler)

httpserver.expect_request("/utilities/whoami",
method="GET").respond_with_handler(whoami_handler)

httpserver.expect_request(
"/utilities/whoami", method="GET"
).respond_with_handler(whoami_handler)
"/auth/realms/kcrealm/protocol/openid-connect/token", method="POST"
).respond_with_handler(keycloak_handler)

return httpserver

Expand All @@ -121,9 +105,20 @@ def test_format_template():
assert templated == "value_1 and value_2"


def test_get_access_token():
auth = FirecrestAccessTokenAuth("access_token")
assert auth.get_access_token() == "access_token"
def test_get_access_token(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
auth = AuthorizationCodeFlowAuth(
client_id=spawner.user.authenticator.client_id,
client_secret=spawner.user.authenticator.client_secret,
refresh_token="VALID_REFRESH_TOKEN",
token_url=spawner.user.authenticator.token_url
)
assert auth.get_access_token() == "VALID_ACCESS_TOKEN"


@pytest.mark.asyncio
Expand Down Expand Up @@ -249,19 +244,15 @@ async def test_get_batch_script_subvars(db):
async def test_get_firecrest_client(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
client = await spawner.get_firecrest_client()
systems = await client.all_systems()
ref_systems = [
{
"description": "System ready",
"status": "available",
"system": "cluster1"
},
{
"description": "System ready",
"status": "available",
"system": "cluster2"
},
{"description": "System ready", "status": "available", "system": "cluster1"},
{"description": "System ready", "status": "available", "system": "cluster2"},
]
assert systems == ref_systems

Expand All @@ -270,6 +261,10 @@ async def test_get_firecrest_client(db, fc_server):
async def test_query_job_status_completed(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand All @@ -284,6 +279,10 @@ async def test_query_job_status_completed(db, fc_server):
async def test_query_job_status_running(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand All @@ -298,6 +297,10 @@ async def test_query_job_status_running(db, fc_server):
async def test_query_job_status_pending(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand All @@ -310,9 +313,13 @@ async def test_query_job_status_pending(db, fc_server):

@pytest.mark.asyncio
async def _test_query_job_status_fail(db, fc_server):
# TODO: Test the case where the job failed afte start
# TODO: Test the case where the job failed after start
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand All @@ -327,6 +334,10 @@ async def _test_query_job_status_fail(db, fc_server):
async def test_cancel_batch_job(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand Down Expand Up @@ -368,6 +379,10 @@ def test_load_state_nostate(db):
async def test_start_job_fail(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
spawner.set_trait("req_partition", "job_failed")
with pytest.raises(RuntimeError) as excinfo:
await spawner.start()
Expand All @@ -384,6 +399,10 @@ async def test_start_job_fail(db, fc_server):
async def test_start_no_jobid(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
spawner.set_trait("req_partition", "no_jobid")
with pytest.raises(RuntimeError) as excinfo:
await spawner.start()
Expand All @@ -398,13 +417,19 @@ async def test_start_no_jobid(db, fc_server):
async def test_start(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
ip, port = await spawner.start()
assert spawner.job_id == "353"
assert port == testport
assert ip == "nid02357.cluster1.ch"
assert spawner.job_status == "RUNNING nid02357"
env = spawner.get_env()
assert env["JUPYTERHUB_SERVICE_URL"] == f"http://nid02357.cluster1.ch:{testport}/" # noqa 505
assert (
env["JUPYTERHUB_SERVICE_URL"] == f"http://nid02357.cluster1.ch:{testport}/"
) # noqa 505

# Since the job 353 has status RUNNING, to stop the job,
# we have to trick the spawner into using a the job 352 that
Expand All @@ -418,6 +443,10 @@ async def test_start(db, fc_server):
async def test_submit_batch_script(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
await spawner.submit_batch_script()
assert spawner.job_id == "353"

Expand All @@ -426,6 +455,10 @@ async def test_submit_batch_script(db, fc_server):
async def test_state_gethost(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
# force setting `host` and `job_id` since they
# are set only set when calling `spawner.start()`
spawner.host = "cluster1"
Expand All @@ -438,6 +471,10 @@ async def test_state_gethost(db, fc_server):
async def test_stop_fail(db, fc_server):
spawner = new_spawner(db=db)
spawner.firecrest_url = fc_server.url_for("/")
spawner.user.authenticator.token_url = "".join([
fc_server.url_for("/") ,
"auth/realms/kcrealm/protocol/openid-connect/token"
])
spawner.host = "cluster1"
spawner.job_id = "353" # returns 'RUNNING'
# the spawner retries many poll calls, but
Expand Down

0 comments on commit 4915ec2

Please sign in to comment.