diff --git a/alchemiscale/cli.py b/alchemiscale/cli.py index 1999ef57..c6d096ae 100644 --- a/alchemiscale/cli.py +++ b/alchemiscale/cli.py @@ -362,6 +362,7 @@ def get_settings_override(): def synchronous(config_file): from alchemiscale.models import Scope from alchemiscale.compute.service import SynchronousComputeService + from alchemiscale.compute.settings import ComputeServiceSettings params = yaml.safe_load(config_file) @@ -373,7 +374,7 @@ def synchronous(config_file): Scope.from_str(scope) for scope in params_init["scopes"] ] - service = SynchronousComputeService(**params_init) + service = SynchronousComputeService(ComputeServiceSettings(**params_init)) # add signal handling for signame in {"SIGHUP", "SIGINT", "SIGTERM"}: diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index db21d5b8..9337055b 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -8,6 +8,7 @@ import os import json from datetime import datetime, timedelta +import random from fastapi import FastAPI, APIRouter, Body, Depends from fastapi.middleware.gzip import GZipMiddleware @@ -23,6 +24,7 @@ get_cred_entity, validate_scopes, validate_scopes_query, + minimize_scope_space, _check_store_connectivity, gufe_to_json, GzipRoute, @@ -177,6 +179,7 @@ def claim_taskhub_tasks( *, compute_service_id: str = Body(), count: int = Body(), + protocols: Optional[List[str]] = Body(None, embed=True), n4js: Neo4jStore = Depends(get_n4js_depends), token: TokenData = Depends(get_token_data_depends), ): @@ -187,13 +190,91 @@ def claim_taskhub_tasks( taskhub=taskhub_scoped_key, compute_service_id=ComputeServiceID(compute_service_id), count=count, + protocols=protocols, ) return [str(t) if t is not None else None for t in tasks] +@router.post("/claim") +def claim_tasks( + scopes: List[Scope] = Body(), + compute_service_id: str = Body(), + count: int = Body(), + protocols: Optional[List[str]] = Body(None, embed=True), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + # intersect query scopes with accessible scopes in the token + scopes_reduced = minimize_scope_space(scopes) + query_scopes = [] + for scope in scopes_reduced: + query_scopes.extend(validate_scopes_query(scope, token)) + + taskhubs = dict() + # query each scope for available taskhubs + # loop might be more removable in the future with a Union like operator on scopes + for single_query_scope in set(query_scopes): + taskhubs.update(n4js.query_taskhubs(scope=single_query_scope, return_gufe=True)) + + # list of tasks to return + tasks = [] + + if len(taskhubs) == 0: + return [] + + # claim tasks from taskhubs based on weight; keep going till we hit our + # total desired task count, or we run out of taskhubs to draw from + while len(tasks) < count and len(taskhubs) > 0: + weights = [th.weight for th in taskhubs.values()] + + if sum(weights) == 0: + break + + # based on weights, choose taskhub to draw from + taskhub: ScopedKey = random.choices(list(taskhubs.keys()), weights=weights)[0] + + # claim tasks from the taskhub + claimed_tasks = n4js.claim_taskhub_tasks( + taskhub, + compute_service_id=ComputeServiceID(compute_service_id), + count=(count - len(tasks)), + protocols=protocols, + ) + + # gather up claimed tasks, if present + for t in claimed_tasks: + if t is not None: + tasks.append(t) + + # remove this taskhub from the options available; repeat + taskhubs.pop(taskhub) + + return [str(t) for t in tasks] + [None] * (count - len(tasks)) + + @router.get("/tasks/{task_scoped_key}/transformation") def get_task_transformation( + task_scoped_key, + *, + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + sk = ScopedKey.from_str(task_scoped_key) + validate_scopes(sk.scope, token) + + transformation: ScopedKey + + transformation, _ = n4js.get_task_transformation( + task=task_scoped_key, + return_gufe=False, + ) + + return str(transformation) + + +@router.get("/tasks/{task_scoped_key}/transformation/gufe") +def retrieve_task_transformation( task_scoped_key, *, n4js: Neo4jStore = Depends(get_n4js_depends), diff --git a/alchemiscale/compute/client.py b/alchemiscale/compute/client.py index fcc870ca..901a7516 100644 --- a/alchemiscale/compute/client.py +++ b/alchemiscale/compute/client.py @@ -35,15 +35,17 @@ class AlchemiscaleComputeClient(AlchemiscaleBaseClient): _exception = AlchemiscaleComputeClientError def register(self, compute_service_id: ComputeServiceID): - res = self._post_resource(f"computeservice/{compute_service_id}/register", {}) + res = self._post_resource(f"/computeservice/{compute_service_id}/register", {}) return ComputeServiceID(res) def deregister(self, compute_service_id: ComputeServiceID): - res = self._post_resource(f"computeservice/{compute_service_id}/deregister", {}) + res = self._post_resource( + f"/computeservice/{compute_service_id}/deregister", {} + ) return ComputeServiceID(res) def heartbeat(self, compute_service_id: ComputeServiceID): - res = self._post_resource(f"computeservice/{compute_service_id}/heartbeat", {}) + res = self._post_resource(f"/computeservice/{compute_service_id}/heartbeat", {}) return ComputeServiceID(res) def list_scopes(self) -> List[Scope]: @@ -71,19 +73,48 @@ def query_taskhubs( return taskhubs def claim_taskhub_tasks( - self, taskhub: ScopedKey, compute_service_id: ComputeServiceID, count: int = 1 + self, + taskhub: ScopedKey, + compute_service_id: ComputeServiceID, + count: int = 1, + protocols: Optional[List[str]] = None, ) -> Task: """Claim a `Task` from the specified `TaskHub`""" - data = dict(compute_service_id=str(compute_service_id), count=count) - tasks = self._post_resource(f"taskhubs/{taskhub}/claim", data) + data = dict( + compute_service_id=str(compute_service_id), count=count, protocols=protocols + ) + tasks = self._post_resource(f"/taskhubs/{taskhub}/claim", data) + + return [ScopedKey.from_str(t) if t is not None else None for t in tasks] + + def claim_tasks( + self, + scopes: List[Scope], + compute_service_id: ComputeServiceID, + count: int = 1, + protocols: Optional[List[str]] = None, + ): + """Claim Tasks from TaskHubs within a list of Scopes.""" + data = dict( + scopes=[scope.dict() for scope in scopes], + compute_service_id=str(compute_service_id), + count=count, + protocols=protocols, + ) + tasks = self._post_resource("/claim", data) return [ScopedKey.from_str(t) if t is not None else None for t in tasks] - def get_task_transformation( + def get_task_transformation(self, task: ScopedKey) -> ScopedKey: + """Get the Transformation associated with the given Task.""" + transformation = self._get_resource(f"/tasks/{task}/transformation") + return ScopedKey.from_str(transformation) + + def retrieve_task_transformation( self, task: ScopedKey ) -> Tuple[Transformation, Optional[ProtocolDAGResult]]: transformation, protocoldagresult = self._get_resource( - f"tasks/{task}/transformation" + f"/tasks/{task}/transformation/gufe" ) return ( @@ -104,6 +135,6 @@ def set_task_result( compute_service_id=str(compute_service_id), ) - pdr_sk = self._post_resource(f"tasks/{task}/results", data) + pdr_sk = self._post_resource(f"/tasks/{task}/results", data) return ScopedKey.from_dict(pdr_sk) diff --git a/alchemiscale/compute/service.py b/alchemiscale/compute/service.py index 50897ce1..2955555d 100644 --- a/alchemiscale/compute/service.py +++ b/alchemiscale/compute/service.py @@ -24,6 +24,7 @@ from gufe.protocols.protocoldag import execute_DAG, ProtocolDAG, ProtocolDAGResult from .client import AlchemiscaleComputeClient +from .settings import ComputeServiceSettings from ..storage.models import Task, TaskHub, ComputeServiceID from ..models import Scope, ScopedKey @@ -73,114 +74,38 @@ class SynchronousComputeService: """ - def __init__( - self, - api_url: str, - identifier: str, - key: str, - name: str, - shared_basedir: os.PathLike, - scratch_basedir: os.PathLike, - keep_shared: bool = False, - keep_scratch: bool = False, - n_retries: int = 3, - sleep_interval: int = 30, - heartbeat_interval: int = 300, - scopes: Optional[List[Scope]] = None, - claim_limit: int = 1, - loglevel="WARN", - logfile: Optional[Path] = None, - client_max_retries=5, - client_retry_base_seconds=2.0, - client_retry_max_seconds=60.0, - client_verify=True, - ): - """Create a `SynchronousComputeService` instance. + def __init__(self, settings: ComputeServiceSettings): + """Create a `SynchronousComputeService` instance.""" + self.settings = settings - Parameters - ---------- - api_url - URL of the compute API to execute Tasks for. - identifier - Identifier for the compute identity used for authentication. - key - Credential for the compute identity used for authentication. - name - The name to give this compute service; used for Task provenance, so - typically set to a distinct value to distinguish different compute - resources, e.g. different hosts or HPC clusters. - shared_basedir - Filesystem path to use for `ProtocolDAG` `shared` space. - scratch_basedir - Filesystem path to use for `ProtocolUnit` `scratch` space. - keep_shared - If True, don't remove shared directories for `ProtocolDAG`s after - completion. - keep_scratch - If True, don't remove scratch directories for `ProtocolUnit`s after - completion. - n_retries - Number of times to attempt a given Task on failure. - sleep_interval - Time in seconds to sleep if no Tasks claimed from compute API. - heartbeat_interval - Frequency at which to send heartbeats to compute API. - scopes - Scopes to limit Task claiming to; defaults to all Scopes accessible - by compute identity. - claim_limit - Maximum number of Tasks to claim at a time from a TaskHub. - loglevel - The loglevel at which to report; see the :mod:`logging` docs for - available levels. - logfile - Path to file for logging output; if not set, logging will only go - to STDOUT. - client_max_retries - Maximum number of times to retry a request. In the case the API - service is unresponsive an expoenential backoff is applied with - retries until this number is reached. If set to -1, retries will - continue indefinitely until success. - client_retry_base_seconds - The base number of seconds to use for exponential backoff. - Must be greater than 1.0. - client_retry_max_seconds - Maximum number of seconds to sleep between retries; avoids runaway - exponential backoff while allowing for many retries. - client_verify - Whether to verify SSL certificate presented by the API server. - - """ - self.api_url = api_url - self.name = name - self.sleep_interval = sleep_interval - self.heartbeat_interval = heartbeat_interval - self.claim_limit = claim_limit + self.api_url = self.settings.api_url + self.name = self.settings.name + self.sleep_interval = self.settings.sleep_interval + self.heartbeat_interval = self.settings.heartbeat_interval + self.claim_limit = self.settings.claim_limit self.client = AlchemiscaleComputeClient( - api_url, - identifier, - key, - max_retries=client_max_retries, - retry_base_seconds=client_retry_base_seconds, - retry_max_seconds=client_retry_max_seconds, - verify=client_verify, + self.settings.api_url, + self.settings.identifier, + self.settings.key, + max_retries=self.settings.client_max_retries, + retry_base_seconds=self.settings.client_retry_base_seconds, + retry_max_seconds=self.settings.client_retry_max_seconds, + verify=self.settings.client_verify, ) - if scopes is None: + if self.settings.scopes is None: self.scopes = [Scope()] else: - self.scopes = scopes + self.scopes = self.settings.scopes - self.shared_basedir = Path(shared_basedir).absolute() + self.shared_basedir = Path(self.settings.shared_basedir).absolute() self.shared_basedir.mkdir(exist_ok=True) - self.keep_shared = keep_shared + self.keep_shared = self.settings.keep_shared - self.scratch_basedir = Path(scratch_basedir).absolute() + self.scratch_basedir = Path(self.settings.scratch_basedir).absolute() self.scratch_basedir.mkdir(exist_ok=True) - self.keep_scratch = keep_scratch - - self.n_retries = n_retries + self.keep_scratch = self.settings.keep_scratch self.scheduler = sched.scheduler(time.monotonic, time.sleep) @@ -193,7 +118,7 @@ def __init__( # logging extra = {"compute_service_id": str(self.compute_service_id)} logger = logging.getLogger("AlchemiscaleSynchronousComputeService") - logger.setLevel(loglevel) + logger.setLevel(self.settings.loglevel) formatter = logging.Formatter( "[%(asctime)s] [%(compute_service_id)s] [%(levelname)s] %(message)s" @@ -204,8 +129,8 @@ def __init__( sh.setFormatter(formatter) logger.addHandler(sh) - if logfile is not None: - fh = logging.FileHandler(logfile) + if self.settings.logfile is not None: + fh = logging.FileHandler(self.settings.logfile) fh.setFormatter(formatter) logger.addHandler(fh) @@ -232,50 +157,30 @@ def heartbeat(self): self.beat() time.sleep(self.heartbeat_interval) - def claim_tasks(self, count=1) -> List[Optional[ScopedKey]]: + def claim_tasks( + self, count=1, protocols: Optional[List[str]] = None + ) -> List[Optional[ScopedKey]]: """Get a Task to execute from compute API. Returns `None` if no Task was available matching service configuration. + Parameters + ---------- + count + The maximum number of Tasks to claim. + protocols + Protocol names to restrict Task claiming to. `None` means no restriction. + Regex patterns are allowed. + """ - # list of tasks to return - tasks = [] - taskhubs: Dict[ScopedKey, TaskHub] = self.client.query_taskhubs( - scopes=self.scopes, return_gufe=True + tasks = self.client.claim_tasks( + scopes=self.scopes, + compute_service_id=self.compute_service_id, + count=count, + protocols=protocols, ) - if len(taskhubs) == 0: - return [] - - # claim tasks from taskhubs based on weight; keep going till we hit our - # total desired task count, or we run out of taskhubs to draw from - while len(tasks) < count and len(taskhubs) > 0: - weights = [th.weight for th in taskhubs.values()] - - if sum(weights) == 0: - break - - # based on weights, choose taskhub to draw from - taskhub: List[ScopedKey] = random.choices( - list(taskhubs.keys()), weights=weights - )[0] - - # claim tasks from the taskhub - claimed_tasks = self.client.claim_taskhub_tasks( - taskhub, - compute_service_id=self.compute_service_id, - count=(count - len(tasks)), - ) - - # gather up claimed tasks, if present - for t in claimed_tasks: - if t is not None: - tasks.append(t) - - # remove this taskhub from the options available; repeat - taskhubs.pop(taskhub) - return tasks def task_to_protocoldag( @@ -289,9 +194,10 @@ def task_to_protocoldag( """ - transformation, extends_protocoldagresult = self.client.get_task_transformation( - task - ) + ( + transformation, + extends_protocoldagresult, + ) = self.client.retrieve_task_transformation(task) protocoldag = transformation.create( extends=extends_protocoldagresult, @@ -346,7 +252,7 @@ def execute(self, task: ScopedKey) -> ScopedKey: scratch_basedir=scratch, keep_scratch=self.keep_scratch, raise_error=False, - n_retries=self.n_retries, + n_retries=self.settings.n_retries, ) finally: if not self.keep_shared: diff --git a/alchemiscale/compute/settings.py b/alchemiscale/compute/settings.py new file mode 100644 index 00000000..87c80f97 --- /dev/null +++ b/alchemiscale/compute/settings.py @@ -0,0 +1,94 @@ +from pathlib import Path +from typing import Union, Optional, List, Dict, Tuple +from pydantic import BaseModel, Field + +from ..models import Scope, ScopedKey + + +class ComputeServiceSettings(BaseModel): + """Core settings schema for a compute service.""" + + class Config: + arbitrary_types_allowed = True + + api_url: str = Field( + ..., description="URL of the compute API to execute Tasks for." + ) + identifier: str = Field( + ..., description="Identifier for the compute identity used for authentication." + ) + key: str = Field( + ..., description="Credential for the compute identity used for authentication." + ) + name: str = Field( + ..., + description=( + "The name to give this compute service; used for Task provenance, so " + "typically set to a distinct value to distinguish different compute " + "resources, e.g. different hosts or HPC clusters." + ), + ) + shared_basedir: Path = Field( + ..., description="Filesystem path to use for `ProtocolDAG` `shared` space." + ) + scratch_basedir: Path = Field( + ..., description="Filesystem path to use for `ProtocolUnit` `scratch` space." + ) + keep_shared: bool = Field( + False, + description="If True, don't remove shared directories for `ProtocolDAG`s after completion.", + ) + keep_scratch: bool = Field( + False, + description="If True, don't remove scratch directories for `ProtocolUnit`s after completion.", + ) + n_retries: int = Field( + 3, + description="Number of times to attempt a given Task on failure.", + ) + sleep_interval: int = Field( + 30, description="Time in seconds to sleep if no Tasks claimed from compute API." + ) + heartbeat_interval: int = Field( + 300, description="Frequency at which to send heartbeats to compute API." + ) + scopes: Optional[List[Scope]] = Field( + None, + description="Scopes to limit Task claiming to; defaults to all Scopes accessible by compute identity.", + ) + protocols: Optional[List[str]] = Field( + None, + description="Names of Protocols to run with this service; `None` means no restriction.", + ) + claim_limit: int = Field( + 1000, description="Maximum number of Tasks to claim at a time from a TaskHub." + ) + loglevel: str = Field( + "WARN", + description="The loglevel at which to report; see the :mod:`logging` docs for available levels.", + ) + logfile: Optional[Path] = Field( + None, + description="Path to file for logging output; if not set, logging will only go to STDOUT.", + ) + client_max_retries: int = Field( + 5, + description=( + "Maximum number of times to retry a request. " + "In the case the API service is unresponsive an expoenential backoff " + "is applied with retries until this number is reached. " + "If set to -1, retries will continue indefinitely until success." + ), + ) + client_retry_base_seconds: float = Field( + 2.0, + description="The base number of seconds to use for exponential backoff. Must be greater than 1.0.", + ) + client_retry_max_seconds: float = Field( + 60.0, + description="Maximum number of seconds to sleep between retries; avoids runaway exponential backoff while allowing for many retries.", + ) + client_verify: bool = Field( + True, + description="Whether to verify SSL certificate presented by the API server.", + ) diff --git a/alchemiscale/storage/cypher.py b/alchemiscale/storage/cypher.py index 5fda7b03..91d91152 100644 --- a/alchemiscale/storage/cypher.py +++ b/alchemiscale/storage/cypher.py @@ -24,3 +24,7 @@ def cypher_list_from_scoped_keys(scoped_keys: List[Optional[ScopedKey]]) -> str: if scoped_key: data.append('"' + str(scoped_key) + '"') return "[" + ", ".join(data) + "]" + + +def cypher_or(items): + return "|".join(items) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 1ffd4f4a..fc8b38b2 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -14,7 +14,13 @@ import numpy as np import networkx as nx -from gufe import AlchemicalNetwork, Transformation, NonTransformation, Settings +from gufe import ( + AlchemicalNetwork, + Transformation, + NonTransformation, + Settings, + Protocol, +) from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER from neo4j import Transaction, GraphDatabase, Driver @@ -31,7 +37,7 @@ ) from ..strategies import Strategy from ..models import Scope, ScopedKey -from .cypher import cypher_list_from_scoped_keys +from .cypher import cypher_list_from_scoped_keys, cypher_or from ..security.models import CredentialedEntity from ..settings import Neo4jStoreSettings @@ -1655,7 +1661,11 @@ def get_taskhub_unclaimed_tasks( return [ScopedKey.from_str(t["_scoped_key"]) for t in tasks] def claim_taskhub_tasks( - self, taskhub: ScopedKey, compute_service_id: ComputeServiceID, count: int = 1 + self, + taskhub: ScopedKey, + compute_service_id: ComputeServiceID, + count: int = 1, + protocols: Optional[List[Union[Protocol, str]]] = None, ) -> List[Union[ScopedKey, None]]: """Claim a TaskHub Task. @@ -1676,8 +1686,13 @@ def claim_taskhub_tasks( Unique identifier for the compute service claiming the Tasks for execution. count Claim the given number of Tasks in a single transaction. + protocols + Protocols to restrict Task claiming to. `None` means no restriction. + If an empty list, raises ValueError. """ + if protocols is not None and len(protocols) == 0: + raise ValueError("`protocols` must be either `None` or not empty") q = f""" MATCH (th:TaskHub {{`_scoped_key`: '{taskhub}'}})-[actions:ACTIONS]-(task:Task) @@ -1686,6 +1701,22 @@ def claim_taskhub_tasks( OPTIONAL MATCH (task)-[:EXTENDS]->(other_task:Task) WITH task, other_task, actions + """ + + # filter down to `protocols`, if specified + if protocols is not None: + # need to extract qualnames if given protocol classes + protocols = [ + protocol.__qualname__ if isinstance(protocol, Protocol) else protocol + for protocol in protocols + ] + + q += f""" + MATCH (task)-[:PERFORMS]->(:Transformation|NonTransformation)-[:DEPENDS_ON]->(protocol:{cypher_or(protocols)}) + WITH task, other_task, actions + """ + + q += f""" WHERE other_task.status = '{TaskStatusEnum.complete.value}' OR other_task IS NULL RETURN task.`_scoped_key`, task.priority, actions.weight diff --git a/alchemiscale/tests/__init__.py b/alchemiscale/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/__init__.py b/alchemiscale/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/compute/__init__.py b/alchemiscale/tests/integration/compute/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/compute/client/__init__.py b/alchemiscale/tests/integration/compute/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/compute/client/conftest.py b/alchemiscale/tests/integration/compute/client/conftest.py index f4c92f8a..aebc5fe1 100644 --- a/alchemiscale/tests/integration/compute/client/conftest.py +++ b/alchemiscale/tests/integration/compute/client/conftest.py @@ -3,7 +3,6 @@ from time import sleep import uvicorn -import requests from alchemiscale.settings import get_base_api_settings from alchemiscale.base.api import get_n4js_depends, get_s3os_depends diff --git a/alchemiscale/tests/integration/compute/client/test_compute_client.py b/alchemiscale/tests/integration/compute/client/test_compute_client.py index f324a9ec..99777c41 100644 --- a/alchemiscale/tests/integration/compute/client/test_compute_client.py +++ b/alchemiscale/tests/integration/compute/client/test_compute_client.py @@ -189,6 +189,36 @@ def test_claim_taskhub_task( assert task_sks2[0] in remaining_tasks assert task_sks2[1] in remaining_tasks + def test_claim_tasks( + self, + scope_test, + n4js_preloaded, + compute_client: client.AlchemiscaleComputeClient, + compute_service_id, + uvicorn_server, + ): + # register compute service id + compute_client.register(compute_service_id) + + # claim a single task; should get highest priority task + task_sks = compute_client.claim_tasks( + scopes=[scope_test], + compute_service_id=compute_service_id, + ) + all_tasks = n4js_preloaded.query_tasks(scope=scope_test) + priorities = { + task_sk: priority + for task_sk, priority in zip( + all_tasks, n4js_preloaded.get_task_priority(all_tasks) + ) + } + + assert len(task_sks) == 1 + assert task_sks[0] in all_tasks + assert [t.gufe_key for t in task_sks] == [ + t.gufe_key for t in all_tasks if priorities[t] == 1 + ] + def test_get_task_transformation( self, scope_test, @@ -215,7 +245,7 @@ def test_get_task_transformation( ( transformation_, extends_protocoldagresult, - ) = compute_client.get_task_transformation(task_sks[0]) + ) = compute_client.retrieve_task_transformation(task_sks[0]) assert transformation_ == transformation assert extends_protocoldagresult is None @@ -249,7 +279,7 @@ def test_set_task_result( ( transformation_, extends_protocoldagresult, - ) = compute_client.get_task_transformation(task_sks[0]) + ) = compute_client.retrieve_task_transformation(task_sks[0]) assert transformation_ == transformation assert extends_protocoldagresult is None @@ -265,7 +295,7 @@ def test_set_task_result( ( transformation2, extends_protocoldagresult2, - ) = compute_client.get_task_transformation(task_sk2) + ) = compute_client.retrieve_task_transformation(task_sk2) assert transformation2 == transformation_ assert extends_protocoldagresult2 == protocoldagresults[0] diff --git a/alchemiscale/tests/integration/compute/client/test_compute_service.py b/alchemiscale/tests/integration/compute/client/test_compute_service.py index 9ae4d738..bb097257 100644 --- a/alchemiscale/tests/integration/compute/client/test_compute_service.py +++ b/alchemiscale/tests/integration/compute/client/test_compute_service.py @@ -11,6 +11,7 @@ from alchemiscale.storage.statestore import Neo4jStore from alchemiscale.storage.objectstore import S3ObjectStore from alchemiscale.compute.service import SynchronousComputeService +from alchemiscale.compute.settings import ComputeServiceSettings class TestSynchronousComputeService: @@ -20,14 +21,16 @@ class TestSynchronousComputeService: def service(self, n4js_preloaded, compute_client, tmpdir): with tmpdir.as_cwd(): return SynchronousComputeService( - api_url=compute_client.api_url, - identifier=compute_client.identifier, - key=compute_client.key, - name="test_compute_service", - shared_basedir=Path("shared").absolute(), - scratch_basedir=Path("scratch").absolute(), - heartbeat_interval=1, - sleep_interval=1, + ComputeServiceSettings( + api_url=compute_client.api_url, + identifier=compute_client.identifier, + key=compute_client.key, + name="test_compute_service", + shared_basedir=Path("shared").absolute(), + scratch_basedir=Path("scratch").absolute(), + heartbeat_interval=1, + sleep_interval=1, + ) ) def test_heartbeat(self, n4js_preloaded, service): diff --git a/alchemiscale/tests/integration/compute/conftest.py b/alchemiscale/tests/integration/compute/conftest.py index e75a55ab..d66b20c5 100644 --- a/alchemiscale/tests/integration/compute/conftest.py +++ b/alchemiscale/tests/integration/compute/conftest.py @@ -140,7 +140,9 @@ def get_token_data_depends_override(): @pytest.fixture -def compute_api_no_auth(s3os, scope_consistent_token_data_depends_override): +def compute_api_no_auth( + n4js_preloaded, s3os, scope_consistent_token_data_depends_override +): def get_s3os_override(): return s3os diff --git a/alchemiscale/tests/integration/compute/test_compute_api.py b/alchemiscale/tests/integration/compute/test_compute_api.py index 8ab1c7a5..19cda547 100644 --- a/alchemiscale/tests/integration/compute/test_compute_api.py +++ b/alchemiscale/tests/integration/compute/test_compute_api.py @@ -63,13 +63,15 @@ def out_of_scoped_keys(self, n4js_preloaded, network_tyk2, multiple_scopes): assert len(task_sks) > 0 return {"network": network_sk, "taskhub": tq_sk, "tasks": task_sks} - def test_get_task_transformation( + def test_retrieve_task_transformation( self, n4js_preloaded, test_client, scoped_keys, ): - response = test_client.get(f"/tasks/{scoped_keys['tasks'][0]}/transformation") + response = test_client.get( + f"/tasks/{scoped_keys['tasks'][0]}/transformation/gufe" + ) assert response.status_code == 200 data = response.json() assert len(data) == 2 diff --git a/alchemiscale/tests/integration/conftest.py b/alchemiscale/tests/integration/conftest.py index 1875981e..c05a7374 100644 --- a/alchemiscale/tests/integration/conftest.py +++ b/alchemiscale/tests/integration/conftest.py @@ -221,6 +221,20 @@ def s3os(s3objectstore_settings): # test alchemical networks + +## define varying protocols to simulate protocol variety +class DummyProtocolA(DummyProtocol): + pass + + +class DummyProtocolB(DummyProtocol): + pass + + +class DummyProtocolC(DummyProtocol): + pass + + # TODO: add in atom mapping once `gufe`#35 is settled @@ -251,7 +265,7 @@ def network_tyk2(): Transformation( stateA=complexes[edge[0]], stateB=complexes[edge[1]], - protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + protocol=DummyProtocolA(settings=DummyProtocolA.default_settings()), name=f"{edge[0]}_to_{edge[1]}_complex", ) for edge in tyk2s.connections @@ -260,7 +274,7 @@ def network_tyk2(): Transformation( stateA=solvated[edge[0]], stateB=solvated[edge[1]], - protocol=DummyProtocol(settings=DummyProtocol.default_settings()), + protocol=DummyProtocolB(settings=DummyProtocolB.default_settings()), name=f"{edge[0]}_to_{edge[1]}_solvent", ) for edge in tyk2s.connections @@ -270,7 +284,7 @@ def network_tyk2(): for cs in list(solvated.values()) + list(complexes.values()): nt = NonTransformation( system=cs, - protocol=DummyProtocol(DummyProtocol.default_settings()), + protocol=DummyProtocolC(DummyProtocolC.default_settings()), name=f"f{cs.name}_nt", ) nontransformations.append(nt) diff --git a/alchemiscale/tests/integration/interface/__init__.py b/alchemiscale/tests/integration/interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/interface/client/__init__.py b/alchemiscale/tests/integration/interface/client/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/storage/__init__.py b/alchemiscale/tests/integration/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 2632524b..802eab94 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -3,6 +3,7 @@ from typing import List, Dict from pathlib import Path from itertools import chain +from functools import reduce import pytest from gufe import AlchemicalNetwork @@ -27,6 +28,8 @@ ) from alchemiscale.security.auth import hash_key +from ..conftest import DummyProtocolA, DummyProtocolB, DummyProtocolC + class TestStateStore: ... @@ -1324,6 +1327,52 @@ def test_claim_taskhub_tasks(self, n4js: Neo4jStore, network_tyk2, scope_test): claimed6 = n4js.claim_taskhub_tasks(taskhub_sk, csid, count=2) assert claimed6 == [None] * 2 + def test_claim_taskhub_tasks_protocol_split( + self, n4js: Neo4jStore, network_tyk2, scope_test + ): + an = network_tyk2 + network_sk, taskhub_sk, _ = n4js.assemble_network(an, scope_test) + + def reducer(collection, transformation): + protocol = transformation.protocol.__class__ + if len(collection[protocol]) >= 3: + return collection + sk = n4js.get_scoped_key(transformation, scope_test) + collection[transformation.protocol.__class__].append(sk) + return collection + + transformations = reduce( + reducer, + an.edges, + {DummyProtocolA: [], DummyProtocolB: [], DummyProtocolC: []}, + ) + + transformation_sks = [ + value for _, values in transformations.items() for value in values + ] + + task_sks = n4js.create_tasks(transformation_sks) + assert len(task_sks) == 9 + + # action the tasks + n4js.action_tasks(task_sks, taskhub_sk) + assert len(n4js.get_taskhub_unclaimed_tasks(taskhub_sk)) == 9 + + csid = ComputeServiceID("another task handler") + n4js.register_computeservice(ComputeServiceRegistration.from_now(csid)) + + claimedA = n4js.claim_taskhub_tasks( + taskhub_sk, csid, protocols=["DummyProtocolA"], count=9 + ) + + assert len([sk for sk in claimedA if sk]) == 3 + + claimedBC = n4js.claim_taskhub_tasks( + taskhub_sk, csid, protocols=["DummyProtocolB", "DummyProtocolC"], count=9 + ) + + assert len([sk for sk in claimedBC if sk]) == 6 + def test_claim_taskhub_tasks_deregister( self, n4js: Neo4jStore, network_tyk2, scope_test ): diff --git a/alchemiscale/tests/unit/__init__.py b/alchemiscale/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/devtools/configs/synchronous-compute-settings.yaml b/devtools/configs/synchronous-compute-settings.yaml index a9c29ab5..23a9d9f2 100644 --- a/devtools/configs/synchronous-compute-settings.yaml +++ b/devtools/configs/synchronous-compute-settings.yaml @@ -44,6 +44,9 @@ init: scopes: - '*-*-*' + # Names of Protocols to run with this service; `None` means no restriction + protocols: null + # Maximum number of Tasks to claim at a time from a TaskHub. claim_limit: 1