Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use calcfunction for getitem in For loop workgraph #283

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from aiida.common.extendeddicts import AttributeDict
from aiida.common.lang import override
from aiida import orm
from aiida.orm import load_node, Node, ProcessNode, WorkChainNode
from aiida.orm import load_node, Node, ProcessNode, WorkChainNode, to_aiida_type
from aiida.orm.utils.serialize import deserialize_unsafe, serialize

from aiida.engine.processes.exit_code import ExitCode
Expand All @@ -33,6 +33,7 @@
from aiida.engine import run_get_node
from aiida_workgraph.utils import create_and_pause_process
from aiida_workgraph.task import Task
from aiida_workgraph.decorator import task
from aiida_workgraph.utils import get_nested_dict, update_nested_dict
from aiida_workgraph.executors.monitors import monitor

Expand Down Expand Up @@ -920,8 +921,20 @@ def check_for_conditions(self) -> bool:
if should_run:
self.reset()
self.set_tasks_state(condition_tasks, "SKIPPED")
self.ctx["i"] = self.ctx._sequence[self.ctx._count]
@task.calcfunction()
def __getitem__(iter, key):
#value = kwargs['iter'][kwargs['key'].value]
value = iter[key.value]
if isinstance(value, orm.Data):
return value
else:
return orm.to_aiida_type(value)

key = self.ctx._sequence_keys[self.ctx._count]
self.ctx["i"] = __getitem__(iter=self.ctx._sequence, key=to_aiida_type(key))
self.ctx._count += 1


return should_run

def remove_executed_task(self, name: str) -> None:
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions aiida_workgraph/orm/data/iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from collections.abc import Iterator

from aiida.orm.nodes.data.base import to_aiida_type
from aiida.orm.nodes.data.data import Data
import copy
from aiida.orm.nodes.data import to_aiida_type

__all__ = ('Iterator',)


class AiidaIterator(Data, Iterator):
"""`Data` sub class to represent a iterator."""

_ITERATOR_KEY = 'iter'

def __init__(self, value=None, **kwargs):
"""Initialise a ``Iterator`` node instance.

:param value: iterator to initialise the ``Iterator`` node from
"""
data = value or kwargs.pop('iter', [])
super().__init__(**kwargs)
self.set_iterator(data)

def __next__(self):
iterator = self.get_iterator()
iterator.__next__()
if not self._using_reference():
self.set_iterator(iterator)

def get_iterator(self):
"""Return the contents of this node.

:return: a iterator
"""
return self.base.attributes.get(self._ITERATOR_KEY)

def set_iterator(self, data):
"""Set the contents of this node.

:param data: the iterator to set
"""
if not hasattr(data, "__iter__"):
raise TypeError('Must supply type that implements __iter__')
self.base.attributes.set(self._ITERATOR_KEY, copy.deepcopy(data))

def _using_reference(self):
"""This function tells the class if we are using a iterator reference. This
means that calls to self.get_iterator return a reference rather than a copy
of the underlying iterator and therefore self.set_iterator need not be called.
This knwoledge is essential to make sure this class is performant.

Currently the implementation assumes that if the node needs to be
stored then it is using the attributes cache which is a reference.

:return: True if using self.get_iterator returns a reference to the
underlying sequence. False otherwise.
:rtype: bool
"""
return not self.is_stored

from _collections_abc import list_iterator
@to_aiida_type.register(list_iterator)
def _(value):
return AiidaIterator(value)
34 changes: 32 additions & 2 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
link_deletion_hook,
)
from typing import Any, Dict, List, Optional, Union
from collections.abc import Sequence

if USE_WIDGET:
from aiida_workgraph.widget import NodeGraphWidget
Expand Down Expand Up @@ -48,7 +49,8 @@ def __init__(self, name: str = "WorkGraph", **kwargs) -> None:
super().__init__(name, **kwargs)
self.context = {}
self.workgraph_type = "NORMAL"
self.sequence = []
self._sequence = []
self._sequence_keys = range(0)
self.conditions = []
self.process = None
self.restart_process = None
Expand All @@ -68,6 +70,32 @@ def tasks(self) -> TaskCollection:
"""Add alias to `nodes` for WorkGraph"""
return self.nodes

@property
def sequence(self):
return self._sequence

@sequence.setter
def sequence(self, value):
# We need to store the keys for later use since iterators cannot be stored
# in the provenance as they have a mutable state (pointer to current element).
if isinstance(value, aiida.orm.Dict):
self._sequence = value
self._sequence_keys = value.keys()
elif isinstance(value, aiida.orm.List):
self._sequence = value
self._sequence_keys = range(len(value))
elif isinstance(value, dict) or isinstance(value, aiida.orm.Dict):
self._sequence = aiida.orm.Dict(value)
self._sequence_keys = value.keys()
elif isinstance(value, Sequence):
self._sequence = aiida.orm.List(list(value))
self._sequence_keys = range(len(value))
else:
raise TypeError(
f"Sequence of type {type(value)} is not "
"allowed. Please use a sequence."
)

def prepare_inputs(self, metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
from aiida_workgraph.utils import (
merge_properties,
Expand Down Expand Up @@ -180,7 +208,9 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:

wgdata = super().to_dict()
# save the sequence and context
self.context["_sequence"] = self.sequence
self.context["_sequence"] = self._sequence
self.context["_sequence_keys"] = self._sequence_keys

# only alphanumeric and underscores are allowed
wgdata["context"] = {
key.replace(".", "__"): value for key, value in self.context.items()
Expand Down
Loading