Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mnt: nodes context settings are no longer a regular input #547

Merged
merged 2 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions docs/nodes/nodes_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
"id": "62e9a61b",
"metadata": {},
"source": [
"You can also set a node's special `automatic_recalc` input to `True`. In this way, each time the input is changed, the node will recompute."
"A node's context defines how it behaves. One of the context keys is `lazy`, which determines whether the node should be recomputed each time its inputs change. By default it is `True`, which means it waits for its output to be needed. However, it can be set to `False`."
]
},
{
Expand All @@ -169,7 +169,9 @@
"metadata": {},
"outputs": [],
"source": [
"auto_result = my_sum(2, 5, automatic_recalc=True)\n",
"auto_result = my_sum(2, 5)\n",
"\n",
"auto_result.context.update(lazy=False)\n",
" \n",
"auto_result.get()\n",
"auto_result.update_inputs(a=8)"
Expand Down Expand Up @@ -261,7 +263,8 @@
"first_val = my_sum(2, 5)\n",
"# Use the first value to compute our final value, which we want to\n",
"# automatically recompute when there are changes.\n",
"final_val = my_sum(first_val, 5, automatic_recalc=True)\n",
"final_val = my_sum(first_val, 5)\n",
"final_val.context.update(lazy=False)\n",
"\n",
"# Get the value\n",
"final_val.get()"
Expand Down Expand Up @@ -516,14 +519,13 @@
"metadata": {},
"outputs": [],
"source": [
"@Node.from_func\n",
"def alert_change(val: int, automatic_recalc=True):\n",
"@Node.from_func(context={\"lazy\": False})\n",
"def alert_change(val: int):\n",
" print(f\"VALUE CHANGED, it now is {val}\")\n",
" \n",
"# We feed the node that produces the intermediate value into our alert node \n",
"my_alert = alert_change(result.nodes.first_val)\n",
" \n",
" \n",
"# Now when we update the inputs of the workflow, the node will propagate the information through\n",
"# our new node.\n",
"result.update_inputs(a=10)"
Expand All @@ -534,7 +536,7 @@
"id": "7b266dbf",
"metadata": {},
"source": [
"It sometimes might be useful to provide methods for a workflow. For that case, workflows can also be defined with class syntax, passing the workflow as a static method in the `_workflow` method."
"It sometimes might be useful to provide methods for a workflow. For that case, workflows can also be defined with class syntax, passing the workflow as a static method in the `function` method."
]
},
{
Expand All @@ -548,7 +550,7 @@
" \n",
" # Define the function that runs the workflow, exactly as we did before.\n",
" @staticmethod\n",
" def _workflow(a: int, b: int, c: int):\n",
" def function(a: int, b: int, c: int):\n",
" first_val = my_sum(a, b)\n",
" return my_sum(first_val, c)\n",
" \n",
Expand Down Expand Up @@ -624,16 +626,6 @@
"sum_triple.network.visualize(notebook=True, )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b554e051",
"metadata": {},
"outputs": [],
"source": [
"sum_triple.network.to_pyvis(notebook=True)"
]
},
{
"cell_type": "markdown",
"id": "21ccbd4c",
Expand Down Expand Up @@ -663,7 +655,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "9a5b336b",
"id": "39c74b3c",
"metadata": {},
"outputs": [],
"source": []
Expand All @@ -685,7 +677,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion sisl/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .node import Node
from .workflow import Workflow
from .context import lazy_context, set_lazy_computation
from .context import NodeContext, SISL_NODES_CONTEXT, temporal_context
115 changes: 91 additions & 24 deletions sisl/nodes/context.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,103 @@
import contextlib
from collections import ChainMap
from typing import Any, Union

def set_lazy_computation(nodes: bool = True, workflows: bool = True):
"""Set the lazy computation mode for nodes and workflows.
# The main sisl nodes context that all nodes will use by default as their base.
SISL_NODES_CONTEXT = dict(
# Whether the nodes should compute lazily or immediately when inputs are updated.
lazy=True,
# On initialization, should the node compute? If None, defaults to `lazy`.
lazy_init=None,
# Debugging options
debug=False,
debug_show_inputs=False
)

# Temporal contexts stack. It should not be used directly by users, the aim of this
# stack is to populate it when context managers are used. This is a chainmap and
# not a simple dict because we might have nested context managers.
_TEMPORAL_CONTEXTS = ChainMap()

class NodeContext(ChainMap):
"""Extension of Chainmap that always checks on the temporal context first.

Parameters
----------
nodes: bool, optional
Whether lazy computation is turned on for nodes.
workflows: bool, optional
Whether lazy computation is turned on for workflows.
Using this class is equivalent to forcing users to have the temporal context
always in the first position of the chainmap. Since this is not a very nice
thing to force on users, we use this class instead.

Keys:
lazy: bool
If `False`, nodes will automatically recompute if any of their inputs
have changed, even if no other node needs their output yet.
lazy_init: bool or None
Whether the node should compute on initialization. If None, defaults to
`lazy`.
debug: bool
Whether to print debugging information.
debug_show_inputs:
Whether to print the inputs of the node when debugging.
"""
from .node import Node
from .workflow import Workflow

Node._lazy_computation = nodes
Workflow._lazy_computation = workflows
def __getitem__(self, key: str):
if key in _TEMPORAL_CONTEXTS:
return _TEMPORAL_CONTEXTS[key]
else:
return super().__getitem__(key)

@contextlib.contextmanager
def lazy_context(nodes: bool = True, workflows: bool = True):
from .node import Node
from .workflow import Workflow
def temporal_context(context: Union[dict, ChainMap, None] = None, **context_keys: Any):
"""Sets a context temporarily (until the context manager is exited).

Parameters
----------
context: dict or ChainMap, optional
The context that should be updated temporarily. This could for example be
sisl's main context or the context of a specific node class.

If None, the keys and values are forced on all nodes.
**context_keys: Any
The keys and values that should be used for the nodes context.

old_lazy = {
"nodes": Node._lazy_computation,
"workflows": Workflow._lazy_computation,
}
Examples
-------
Forcing a certain context on all nodes:

set_lazy_computation(nodes, workflows)
>>> from sisl.nodes import temporal_context
>>> with temporal_context(lazy=False):
>>> # If a node class is called here, the computation will be performed
>>> # immediately and the result returned.

Switching off lazy behavior for workflows:

>>> from sisl.nodes import Workflow, temporal_context
>>> with temporal_context(context=Workflow.context, lazy=False):
>>> # If a workflow is called here, the computation will be performed
>>> # immediately and the result returned, unless that specific workflow
>>> # class overwrites the lazy behavior.

"""
if context is not None:
# We have to temporally update a context dictionary. We keep a copy of the
# original context so that we can restore it later.
old_context = {k: context[k] for k in context_keys}
context.update(context_keys)

def _restore():
# Restore the original context.
context.update(old_context)
else:
# Add this temporal context on top of the temporal contexts stack.
_TEMPORAL_CONTEXTS.maps.insert(0, context_keys)

def _restore():
# Remove the temporal context from the stack.
del _TEMPORAL_CONTEXTS.maps[0]

# We have entered the context, execute whatever code is inside the "with" block.
try:
yield
_restore()
except Exception as e:
set_lazy_computation(**old_lazy)
raise e

set_lazy_computation(**old_lazy)
# The block has raised an exception, restore the context and re-raise.
_restore()
raise e
Loading