Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing features, polish to SynchronousComputeService #98

Merged
merged 46 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
668c9ab
Add missing features, polish to `SynchronousComputeService`
dotsdl Mar 4, 2023
da5fb7a
ComputeServiceID additions
dotsdl Mar 10, 2023
d1c5255
Merge branch 'main' into synchronous-compute-service
dotsdl Mar 10, 2023
fbd7884
Added ComputeServiceID handling to state store
dotsdl Mar 11, 2023
f6b3c35
Finished converting claimant -> CLAIMS relationships
dotsdl Mar 14, 2023
a6c1e83
Black and bugfix
dotsdl Mar 14, 2023
cf90622
Added heartbeats for compute services
dotsdl Mar 14, 2023
56aa0f8
Black!
dotsdl Mar 14, 2023
a274647
Merge branch 'main' into synchronous-compute-service
dotsdl Mar 15, 2023
66066bf
Fixing broken tests
dotsdl Mar 15, 2023
ac5778a
Test suite appears fixed.
dotsdl Mar 16, 2023
05b4c19
Black!
dotsdl Mar 16, 2023
9d5a8e4
Unblackify versioneer.py
dotsdl Mar 16, 2023
b0e67de
Switching to WIP branch for gufe changes used here
dotsdl Mar 16, 2023
1800c6f
Added state store tests for computeserviceregistration
dotsdl Mar 16, 2023
e722c83
Build container with temporary gufe branch
dotsdl Mar 16, 2023
6ae1df5
Merge branch 'main' into synchronous-compute-service
dotsdl Mar 16, 2023
2c554b8
Black!
dotsdl Mar 16, 2023
d5e81b3
Added ComputeServiceRegistration tests to ComputeClient tests
dotsdl Mar 16, 2023
a8e6e00
Black!
dotsdl Mar 16, 2023
a097a1a
Merge branch 'main' into synchronous-compute-service
dotsdl Mar 17, 2023
1ecaa16
CLI entrypoint in place for SynchronousComputeService
dotsdl Mar 18, 2023
35a70cb
SynchronousComputeService now hits multiple taskhubs when claiming if…
dotsdl Mar 21, 2023
a1b668a
Added check for Transformation, Task in create_task
dotsdl Mar 22, 2023
4ab9d4c
Dropped use of sched in SynchronousComputeService; using thread for h…
dotsdl Mar 23, 2023
64d3b40
Thread.run -> Thread.start :P
dotsdl Mar 23, 2023
38d5bbc
Remove thread join; not necessary if daemon
dotsdl Mar 23, 2023
c2e94e0
Added expiry periodic to compute API
dotsdl Mar 23, 2023
bb55018
Merge branch 'main' into synchronous-compute-service
dotsdl Mar 23, 2023
7fe4c12
Expiration test in place for state store
dotsdl Mar 23, 2023
8b337eb
Attempt at periodic expiry failed; causing API service to hang on req…
dotsdl Mar 23, 2023
d1d8610
Black!
dotsdl Mar 23, 2023
15704e3
Updated heartbeat interval default, along with expiry
dotsdl Mar 24, 2023
2c4418b
Small convenience fix to user client set_task_status
dotsdl Mar 24, 2023
172b54e
Black!
dotsdl Mar 24, 2023
cf5ca0d
Update alchemiscale/cli.py
dotsdl Mar 24, 2023
b271549
Review fixes
dotsdl Mar 24, 2023
e180ca9
Added CLI test for synchronous compute service
dotsdl Mar 25, 2023
20770f3
Black!
dotsdl Mar 25, 2023
09fb527
Added additional tests for SynchronousComputeService
dotsdl Mar 27, 2023
bdbd223
Black!
dotsdl Mar 27, 2023
c2c850c
Fix tests
dotsdl Mar 28, 2023
e4b8559
Black!
dotsdl Mar 28, 2023
5fca226
Update alchemiscale/compute/service.py
dotsdl Mar 30, 2023
af7f7eb
Additions from @hmacdope review
dotsdl Mar 31, 2023
83a72bc
Black!
dotsdl Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 67 additions & 11 deletions alchemiscale/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
"""

import click
import yaml
import signal
import gunicorn.app.base
from typing import Type

from .models import Scope
from .security.auth import hash_key, authenticate, AuthenticationError
from .security.auth import hash_key
from .security.models import (
CredentialedEntity,
CredentialedUserIdentity,
Expand Down Expand Up @@ -228,7 +229,9 @@ def cli():
name="api",
help="Start the user-facing API service",
)
@api_starting_params("FA_API_HOST", "FA_API_PORT", "FA_API_LOGLEVEL")
@api_starting_params(
"ALCHEMISCALE_API_HOST", "ALCHEMISCALE_API_PORT", "ALCHEMISCALE_API_LOGLEVEL"
)
@db_params
@s3os_params
@jwt_params
Expand Down Expand Up @@ -268,23 +271,35 @@ def get_settings_override():

app.dependency_overrides[get_base_api_settings] = get_settings_override

start_api(app, workers, host["FA_API_HOST"], port["FA_API_PORT"])
start_api(
app, workers, host["ALCHEMISCALE_API_HOST"], port["ALCHEMISCALE_API_PORT"]
)


@cli.group(help="Subcommands for the compute service")
@cli.group(help="Subcommands for compute services")
def compute():
...


@compute.command(help="Start the compute API service.")
@api_starting_params(
"FA_COMPUTE_API_HOST", "FA_COMPUTE_API_PORT", "FA_COMPUTE_API_LOGLEVEL"
"ALCHEMISCALE_COMPUTE_API_HOST",
"ALCHEMISCALE_COMPUTE_API_PORT",
"ALCHEMISCALE_COMPUTE_API_LOGLEVEL",
)
@click.option(
"--registration-expire-seconds",
type=int,
default=1800,
help="number of seconds since last heartbeat at which to expire a compute service registration",
envvar="ALCHEMISCALE_COMPUTE_API_REGISTRATION_EXPIRE_SECONDS",
**SETTINGS_OPTION_KWARGS,
)
@db_params
@s3os_params
@jwt_params
def api(
workers, host, port, loglevel, # API
workers, host, port, loglevel, registration_expire_seconds, # API
url, user, password, dbname, # DB
jwt_secret, jwt_expire_seconds, jwt_algorithm, #JWT
access_key_id, secret_access_key, session_token, s3_bucket, s3_prefix, default_region # AWS
Expand All @@ -299,7 +314,7 @@ def api(

def get_settings_override():
# inject settings from CLI arguments
api_dict = host | port | loglevel
api_dict = host | port | loglevel | registration_expire_seconds
jwt_dict = jwt_secret | jwt_expire_seconds | jwt_algorithm
db_dict = url | user | password | dbname
s3_dict = (
Expand All @@ -316,12 +331,51 @@ def get_settings_override():

app.dependency_overrides[get_base_api_settings] = get_settings_override

start_api(app, workers, host["FA_COMPUTE_API_HOST"], port["FA_COMPUTE_API_PORT"])
start_api(
app,
workers,
host["ALCHEMISCALE_COMPUTE_API_HOST"],
port["ALCHEMISCALE_COMPUTE_API_PORT"],
)


@compute.command(help="Start the synchronous compute service.")
def synchronous():
...
@click.option(
"--config-file",
"-c",
type=click.File(),
help="YAML-based configuration file giving the settings for this service",
required=True,
)
def synchronous(config_file):
from alchemiscale.models import Scope
from alchemiscale.compute.service import SynchronousComputeService

params = yaml.safe_load(config_file)

params_init = params.get("init", {})
params_start = params.get("start", {})

if "scopes" in params_init:
params_init["scopes"] = [
Scope.from_str(scope) for scope in params_init["scopes"]
]

service = SynchronousComputeService(**params_init)

# add signal handling
for signame in {"SIGHUP", "SIGINT", "SIGTERM"}:

def stop(*args, **kwargs):
service.stop()
raise KeyboardInterrupt()

signal.signal(getattr(signal, signame), stop)

try:
service.start(**params_start)
except KeyboardInterrupt:
pass


@cli.group(help="Subcommands for the database")
Expand Down Expand Up @@ -491,6 +545,7 @@ def remove(url, user, password, dbname, identity_type, identifier):
@scope
def add_scope(url, user, password, dbname, identity_type, identifier, scope):
"""Add a scope for the given identity."""
from .models import Scope
from .storage.statestore import get_n4js
from .settings import Neo4jStoreSettings

Expand Down Expand Up @@ -532,6 +587,7 @@ def list_scope(url, user, password, dbname, identity_type, identifier):
@scope
def remove_scope(url, user, password, dbname, identity_type, identifier, scope):
"""Remove a scope for the given identity(s)."""
from .models import Scope
from .storage.statestore import get_n4js
from .settings import Neo4jStoreSettings

Expand Down
89 changes: 70 additions & 19 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

"""


import asyncio
from typing import Any, Dict, List
import os
import json
from datetime import datetime, timedelta

from fastapi import FastAPI, APIRouter, Body, Depends, HTTPException, status
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
Expand All @@ -26,10 +27,19 @@
_check_store_connectivity,
gufe_to_json,
)
from ..settings import get_base_api_settings, get_compute_api_settings
from ..storage.statestore import Neo4jStore
from ..settings import (
get_base_api_settings,
get_compute_api_settings,
ComputeAPISettings,
)
from ..storage.statestore import Neo4jStore, get_n4js
from ..storage.objectstore import S3ObjectStore
from ..storage.models import ProtocolDAGResultRef, TaskStatusEnum
from ..storage.models import (
ProtocolDAGResultRef,
ComputeServiceID,
ComputeServiceRegistration,
TaskStatusEnum,
)
from ..models import Scope, ScopedKey
from ..security.auth import get_token_data, oauth2_scheme
from ..security.models import (
Expand All @@ -39,11 +49,6 @@
)


# TODO:
# - add periodic removal of task claims from compute services that are no longer alive
# - can be done with an asyncio.sleeping task added to event loop: https://stackoverflow.com/questions/67154839/fastapi-best-way-to-run-continuous-get-requests-in-the-background
# - on startup,

app = FastAPI(title="AlchemiscaleComputeAPI")
app.dependency_overrides[get_base_api_settings] = get_compute_api_settings
app.include_router(base_router)
Expand Down Expand Up @@ -92,6 +97,53 @@ async def list_scopes(
return [str(scope) for scope in scopes]


@router.post("/computeservice/{compute_service_id}/register")
async def register_computeservice(
compute_service_id,
n4js: Neo4jStore = Depends(get_n4js_depends),
):
now = datetime.utcnow()
csreg = ComputeServiceRegistration(
identifier=compute_service_id, registered=now, heartbeat=now
)

compute_service_id_ = n4js.register_computeservice(csreg)

return compute_service_id_


@router.post("/computeservice/{compute_service_id}/deregister")
async def deregister_computeservice(
compute_service_id,
n4js: Neo4jStore = Depends(get_n4js_depends),
):
compute_service_id_ = n4js.deregister_computeservice(
ComputeServiceID(compute_service_id)
)

return compute_service_id_


@router.post("/computeservice/{compute_service_id}/heartbeat")
async def heartbeat_computeservice(
compute_service_id,
n4js: Neo4jStore = Depends(get_n4js_depends),
settings: ComputeAPISettings = Depends(get_base_api_settings),
):
now = datetime.utcnow()

# expire any stale registrations, along with their claims
expire_delta = timedelta(
seconds=settings.ALCHEMISCALE_COMPUTE_API_REGISTRATION_EXPIRE_SECONDS
)
expire_time = now - expire_delta
n4js.expire_registrations(expire_time)
dotsdl marked this conversation as resolved.
Show resolved Hide resolved

compute_service_id_ = n4js.heartbeat_computeservice(compute_service_id, now)

return compute_service_id_


@router.get("/taskhubs")
async def query_taskhubs(
*,
Expand All @@ -117,18 +169,11 @@ async def query_taskhubs(
return taskhubs_handler.format_return()


# @app.get("/taskhubs/{scoped_key}")
# async def get_taskhub(scoped_key: str,
# *,
# n4js: Neo4jStore = Depends(get_n4js_depends)):
# return


@router.post("/taskhubs/{taskhub_scoped_key}/claim")
async def claim_taskhub_tasks(
taskhub_scoped_key,
*,
claimant: str = Body(),
compute_service_id: str = Body(),
count: int = Body(),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
Expand All @@ -137,7 +182,9 @@ async def claim_taskhub_tasks(
validate_scopes(sk.scope, token)

tasks = n4js.claim_taskhub_tasks(
taskhub=taskhub_scoped_key, claimant=claimant, count=count
taskhub=taskhub_scoped_key,
compute_service_id=ComputeServiceID(compute_service_id),
count=count,
)

return [str(t) if t is not None else None for t in tasks]
Expand Down Expand Up @@ -197,8 +244,12 @@ def set_task_result(
task=task_sk, protocoldagresultref=protocoldagresultref
)

# TODO: if success, set task complete, remove from all hubs
# if success, set task complete, remove from all hubs
# otherwise, set as errored, leave in hubs
if protocoldagresultref.ok:
n4js.set_task_complete(tasks=[task_sk])
else:
n4js.set_task_error(tasks=[task_sk])

return result_sk

Expand Down
24 changes: 17 additions & 7 deletions alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
json_to_gufe,
)
from ..models import Scope, ScopedKey
from ..storage.models import TaskHub, Task, TaskStatusEnum
from ..storage.models import TaskHub, Task, ComputeServiceID, TaskStatusEnum


class AlchemiscaleComputeClientError(AlchemiscaleBaseClientError):
Expand All @@ -35,14 +35,26 @@ class AlchemiscaleComputeClient(AlchemiscaleBaseClient):

_exception = AlchemiscaleComputeClientError

def register(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/register", {})
return ComputeServiceID(res)

def deregister(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/deregister", {})
return ComputeServiceID(res)

def heartbeat(self, compute_service_id: ComputeServiceID):
res = self._post_resource(f"computeservice/{compute_service_id}/heartbeat", {})
return ComputeServiceID(res)

def list_scopes(self) -> List[Scope]:
scopes = self._get_resource(
f"/identities/{self.identifier}/scopes",
)
return [Scope.from_str(s) for s in scopes]

def query_taskhubs(
self, scopes: List[Scope], return_gufe=False, limit=None, skip=None
self, scopes: List[Scope], return_gufe=False
) -> Union[List[ScopedKey], Dict[ScopedKey, TaskHub]]:
"""Return all `TaskHub`s corresponding to given `Scope`."""
if return_gufe:
Expand All @@ -51,9 +63,7 @@ def query_taskhubs(
taskhubs = []

for scope in scopes:
params = dict(
return_gufe=return_gufe, limit=limit, skip=skip, **scope.dict()
)
params = dict(return_gufe=return_gufe, **scope.dict())
if return_gufe:
taskhubs.update(self._query_resource("/taskhubs", params=params))
else:
Expand All @@ -62,10 +72,10 @@ def query_taskhubs(
return taskhubs

def claim_taskhub_tasks(
self, taskhub: ScopedKey, claimant: str, count: int = 1
self, taskhub: ScopedKey, compute_service_id: ComputeServiceID, count: int = 1
) -> Task:
"""Claim a `Task` from the specified `TaskHub`"""
data = dict(claimant=claimant, count=count)
data = dict(compute_service_id=str(compute_service_id), count=count)
tasks = self._post_resource(f"taskhubs/{taskhub}/claim", data)

return [ScopedKey.from_str(t) if t is not None else None for t in tasks]
Expand Down
Loading