Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding bulk task creation server side #250

Merged
merged 13 commits into from
Mar 19, 2024
35 changes: 30 additions & 5 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,36 @@ def create_tasks(
sk = ScopedKey.from_str(transformation_scoped_key)
validate_scopes(sk.scope, token)

task_sks = []
for i in range(count):
task_sks.append(
n4js.create_task(transformation=sk, extends=extends, creator=token.entity)
)
task_sks = n4js.create_tasks([sk] * count, [extends] * count)
return [str(sk) for sk in task_sks]
ianmkenney marked this conversation as resolved.
Show resolved Hide resolved


@router.post("/bulk/transformations/tasks/create")
def create_transformations_tasks(
*,
transformations: List[str] = Body(embed=True),
extends: Optional[List[Optional[str]]] = None,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
transformation_sks = [
ScopedKey.from_str(transformation_string)
for transformation_string in transformations
]

for transformation_sk in transformation_sks:
validate_scopes(transformation_sk.scope, token)

if extends is not None:
extends = [
None if not extends_str else ScopedKey.from_str(extends_str)
for extends_str in extends
]

try:
task_sks = n4js.create_tasks(transformation_sks, extends)
except ValueError as e:
raise HTTPException(status_code=http_status.HTTP_400_BAD_REQUEST, detail=str(e))

return [str(sk) for sk in task_sks]

Expand Down
58 changes: 58 additions & 0 deletions alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,64 @@ def create_tasks(
task_sks = self._post_resource(f"/transformations/{transformation}/tasks", data)
return [ScopedKey.from_str(i) for i in task_sks]

def create_transformations_tasks(
self,
transformations: List[ScopedKey],
extends: Optional[List[Optional[ScopedKey]]] = None,
) -> List[ScopedKey]:
"""Create Tasks for multiple Transformations.

Unlike `create_tasks`, this method can create Tasks for many
Transformations. This method should be used instead of `create_tasks`
whenever creating Tasks for more than one unique Transformation since it
minimizes the number of API requests to the alchemiscale server.

Parameters
----------
transformations
A list of ScopedKeys of Transformations to create Tasks for. The
same ScopedKey can be repeated to create multiple Tasks for the
same Transformation.
extends
A list of ScopedKeys for the Tasks to be extended. When not `None`,
`extends` must be a list of the same length as `transformations`. If
a transformation in `transformations` should not extend a Task, use
a `None` as a placeholder in the `extends` list.

Returns
-------
List[ScopedKey]
A list giving the ScopedKeys of the new Tasks created.

Examples
--------

Instead of looping over Transformations and calling `create_tasks`, make
one call to `create_transformations_tasks`.

>>> client.create_transformations_tasks([transformation_1_sk, transformation_2_sk])

The behavior of the `count` keyword argument from `create_tasks` can be
recreated by repeating the same transformation in the list while also
allowing the addition of other transformtions.

>>> client.create_transformations_tasks([transformation_1_sk] * 3 + [transformation_2_sk] * 2)

"""

data = dict(
transformations=[str(transformation) for transformation in transformations],
extends=(
None
if not extends
else [
str(task_sk) if task_sk is not None else None for task_sk in extends
]
),
)
task_sks = self._post_resource("/bulk/transformations/tasks/create", data)
return [ScopedKey.from_str(i) for i in task_sks]

def query_tasks(
self,
scope: Optional[Scope] = None,
Expand Down
216 changes: 159 additions & 57 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np

import networkx as nx
from gufe import AlchemicalNetwork, Transformation, Settings
from gufe import AlchemicalNetwork, Transformation, NonTransformation, Settings
from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER

from neo4j import Transaction, GraphDatabase, Driver
Expand Down Expand Up @@ -1544,88 +1544,190 @@ def task_count(task_dict: dict):

## tasks

def create_task(
def _validate_extends_tasks(self, task_list) -> Dict[str, Tuple[Node, str]]:

if not task_list:
return {}

q = f"""
UNWIND {cypher_list_from_scoped_keys(task_list)} as task
MATCH (t:Task {{`_scoped_key`: task}})-[PERFORMS]->(tf:Transformation)
return t, tf._scoped_key as tf_sk
"""

results = self.execute_query(q)

nodes = {}

for record in results.records:
node = record_data_to_node(record["t"])
transformation_sk = record["tf_sk"]

status = node.get("status")

if status in ("invalid", "deleted"):
invalid_task_scoped_key = node["_scoped_key"]
raise ValueError(
f"Cannot extend a `deleted` or `invalid` Task: {invalid_task_scoped_key}"
)

nodes[node["_scoped_key"]] = (node, transformation_sk)

return nodes

def create_tasks(
self,
transformation: ScopedKey,
extends: Optional[ScopedKey] = None,
transformations: List[ScopedKey],
extends: Optional[List[Optional[ScopedKey]]] = None,
creator: Optional[str] = None,
) -> ScopedKey:
"""Add a compute Task to a Transformation.
) -> List[ScopedKey]:
"""Create Tasks for the given Transformations.

Note: this creates a compute Task, but does not add it to any TaskHubs.
Note: this creates Tasks; it does not action them.

Parameters
----------
transformation
The Transformation to compute.
scope
The scope the Transformation is in; ignored if `transformation` is a ScopedKey.
transformations
The Transformations to create Tasks for.
One Task is created for each Transformation ScopedKey given; to
create multiple Tasks for a given Transformation, provide its
ScopedKey multiple times.
extends
The ScopedKey of the Task to use as a starting point for this Task.
The ScopedKeys of the Tasks to use as a starting point for the
created Tasks, in the same order as `transformations`. If ``None``
given for a given Task, it will not extend any other Task.
Will use the `ProtocolDAGResult` from the given Task as the
`extends` input for the Task's eventual call to `Protocol.create`.

creator (optional)
The creator of the Tasks.
"""
if transformation.qualname not in ["Transformation", "NonTransformation"]:
allowed_types = [Transformation.__qualname__, NonTransformation.__qualname__]

# reshape data to a standard form
if extends is None:
extends = [None] * len(transformations)
elif len(extends) != len(transformations):
raise ValueError(
"`transformation` ScopedKey does not correspond to a `Transformation`"
"`extends` must either be `None` or have the same length as `transformations`"
)

if extends is not None and extends.qualname != "Task":
raise ValueError("`extends` ScopedKey does not correspond to a `Task`")

scope = transformation.scope
transformation_node = self._get_node(transformation)
for i, _extends in enumerate(extends):
if _extends is not None:
if not (extended_task_qualname := getattr(_extends, "qualname", None)):
raise ValueError(
f"`extends` entry for `Task` {transformations[i]} is not valid"
)
if extended_task_qualname != "Task":
raise ValueError(
f"`extends` ScopedKey ({_extends}) does not correspond to a `Task`"
)

# create a new task for the supplied transformation
# use a PERFORMS relationship
task = Task(
creator=creator, extends=str(extends) if extends is not None else None
)
transformation_map = {
transformation_type: [[], []] for transformation_type in allowed_types
}
for i, transformation in enumerate(transformations):
if transformation.qualname not in allowed_types:
raise ValueError(
f"Got an unsupported `Transformation` type: {transformation.qualname}"
)
transformation_map[transformation.qualname][0].append(transformation)
transformation_map[transformation.qualname][1].append(extends[i])

_, task_node, scoped_key = self._gufe_to_subgraph(
task.to_shallow_dict(),
labels=["GufeTokenizable", task.__class__.__name__],
gufe_key=task.key,
scope=scope,
extends_nodes = self._validate_extends_tasks(
[_extends for _extends in extends if _extends is not None]
)

subgraph = Subgraph()

if extends is not None:
previous_task_node = self._get_node(extends)
stat = previous_task_node.get("status")
# do not allow creation of a task that extends an invalid or deleted task.
if (stat == "invalid") or (stat == "deleted"):
# py2neo Node doesn't like the neo4j datetime object
# manually cast since we're raising anyways
# and the results are ephemeral
previous_task_node["datetime_created"] = str(
previous_task_node["datetime_created"]
sks = []
# iterate over all allowed types, unpacking the transformations and extends subsets
for node_type, (
transformation_subset,
extends_subset,
) in transformation_map.items():

if not transformation_subset:
continue

q = f"""
UNWIND {cypher_list_from_scoped_keys(transformation_subset)} as sk
MATCH (n:{node_type} {{`_scoped_key`: sk}})
RETURN n
"""

results = self.execute_query(q)

transformation_nodes = {}
for record in results.records:
node = record_data_to_node(record["n"])
transformation_nodes[node["_scoped_key"]] = node

for _transformation, _extends in zip(transformation_subset, extends_subset):

scope = transformation.scope

_task = Task(
creator=creator,
extends=str(_extends) if _extends is not None else None,
)
raise ValueError(
f"Cannot extend a `deleted` or `invalid` Task: {previous_task_node}"
_, task_node, scoped_key = self._gufe_to_subgraph(
_task.to_shallow_dict(),
labels=["GufeTokenizable", _task.__class__.__name__],
gufe_key=_task.key,
scope=scope,
)
subgraph = subgraph | Relationship.type("EXTENDS")(
task_node,
previous_task_node,
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)

subgraph = subgraph | Relationship.type("PERFORMS")(
task_node,
transformation_node,
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)
sks.append(scoped_key)

if _extends is not None:
ianmkenney marked this conversation as resolved.
Show resolved Hide resolved

extends_task_node, extends_transformation_sk = extends_nodes[
str(_extends)
]

if extends_transformation_sk != str(_transformation):
raise ValueError(
f"{_extends} extends a Transformation other than {_transformation}"
)

subgraph |= Relationship.type("EXTENDS")(
task_node,
extends_task_node,
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)

subgraph |= Relationship.type("PERFORMS")(
task_node,
transformation_nodes[str(_transformation)],
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)

with self.transaction() as tx:
merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key")

return scoped_key
return sks

def create_task(
self,
transformation: ScopedKey,
extends: Optional[ScopedKey] = None,
creator: Optional[str] = None,
) -> ScopedKey:
"""Create a single Task for a Transformation.

This is a convenience method that wraps around the more general
`create_tasks` method.

"""
return self.create_tasks(
[transformation],
extends=[extends] if extends is not None else [None],
creator=creator,
)[0]

def query_tasks(self, *, status=None, key=None, scope: Scope = Scope()):
"""Query for `Task`\s matching given attributes."""
Expand Down
Loading
Loading