Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing [] and {} as argument. #328

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
import os

# get the names kwargs for the PythonJob, which are the inputs before _wait
function_kwargs = {}
function_kwargs = kwargs.pop("function_kwargs", {})
# TODO better way to find the function_kwargs
input_names = [
name
Expand Down
11 changes: 9 additions & 2 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ def reset_task(
self.ctx._tasks[name]["execution_count"] = 0
for child_task in self.ctx._tasks[name]["children"]:
self.reset_task(child_task, reset_process=False, recursive=False)
elif self.ctx._tasks[name]["metadata"]["node_type"].upper() in ["IF", "ZONE"]:
for child_task in self.ctx._tasks[name]["children"]:
self.reset_task(child_task, reset_process=False, recursive=False)
if recursive:
# reset its child tasks
names = self.ctx._connectivity["child_node"][name]
Expand Down Expand Up @@ -816,8 +819,8 @@ def update_zone_task_state(self, name: str) -> None:
finished, _ = self.are_childen_finished(name)
if finished:
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)
self.report(f"Task: {name} finished.")
self.update_parent_task_state(name)

def should_run_while_task(self, name: str) -> tuple[bool, t.Any]:
"""Check if the while task should run."""
Expand Down Expand Up @@ -949,6 +952,7 @@ def check_while_conditions(self) -> bool:
task_name, socket_name = c.split(".")
if "task_name" != "context":
condition_tasks.append(task_name)
self.reset_task(task_name)
self.run_tasks(condition_tasks, continue_workgraph=False)
conditions = []
for c in self.ctx._workgraph["conditions"]:
Expand Down Expand Up @@ -1018,7 +1022,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
)
continue
# skip if the task is already executed
if name in self.ctx._executed_tasks:
# or if the task is in a skippped state
if name in self.ctx._executed_tasks or self.get_task_state_info(
name, "state"
) in ["SKIPPED"]:
continue
self.ctx._executed_tasks.append(name)
print("-" * 60)
Expand Down
51 changes: 37 additions & 14 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def get_nested_dict(d: Dict, name: str, **kwargs) -> Any:
return current


def merge_dicts(existing: Any, new: Any) -> Any:
"""Recursively merges two dictionaries."""
if isinstance(existing, dict) and isinstance(new, dict):
for k, v in new.items():
if k in existing and isinstance(existing[k], dict) and isinstance(v, dict):
merge_dicts(existing[k], v)
else:
existing[k] = v
else:
return new


def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> None:
"""
Update or create a nested dictionary structure based on a dotted key path.
Expand Down Expand Up @@ -178,11 +190,21 @@ def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> Non
If the resulting dictionary is empty after the update, it will be set to `None`.

"""

keys = key.split(".")
current = d if d is not None else {}
for k in keys[:-1]:
current = current.setdefault(k, {})
current[keys[-1]] = value
# Handle merging instead of overwriting
last_key = keys[-1]
if (
last_key in current
and isinstance(current[last_key], dict)
and isinstance(value, dict)
):
merge_dicts(current[last_key], value)
else:
current[last_key] = value
# if current is empty, set it to None
if not current:
current = None
Expand All @@ -200,26 +222,27 @@ def is_empty(value: Any) -> bool:
return False


def update_nested_dict_with_special_keys(d: Dict[str, Any]) -> Dict[str, Any]:
def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any]:
"""Remove None and empty value"""
d = {k: v for k, v in d.items() if v is not None and not is_empty(v)}
# data = {k: v for k, v in data.items() if v is not None and not is_empty(v)}
data = {k: v for k, v in data.items() if v is not None}
#
special_keys = [k for k in d.keys() if "." in k]
special_keys = [k for k in data.keys() if "." in k]
for key in special_keys:
value = d.pop(key)
update_nested_dict(d, key, value)
return d
value = data.pop(key)
update_nested_dict(data, key, value)
return data


def merge_properties(wgdata: Dict[str, Any]) -> None:
"""Merge sub properties to the root properties.
{
"base.pw.parameters": 2,
"base.pw.code": 1,
}
after merge:
{"base": {"pw": {"parameters": 2,
"code": 1}}
{
"base.pw.parameters": 2,
"base.pw.code": 1,
}
after merge:
{"base": {"pw": {"parameters": 2,
"code": 1}}
So that no "." in the key name.
"""
for _, task in wgdata["tasks"].items():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def add(x, y):
wg = WorkGraph("test_PythonJob_retrieve_files")
wg.add_task("PythonJob", function=add, name="add")
# ------------------------- Submit the calculation -------------------
wg.submit(
wg.run(
inputs={
"add": {
"x": 2,
Expand All @@ -450,7 +450,6 @@ def add(x, y):
},
},
},
wait=True,
)
assert (
"result.txt" in wg.tasks["add"].outputs["retrieved"].value.list_object_names()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": 1}
wg.max_iteration = 10
wg.max_iteration = 5
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
Expand Down
Loading