-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds the extensions patch functionality #101
Changes from all commits
809a712
cf5a5ae
165d42a
d457baf
e3c5776
80416c2
dc5c216
8082f1f
d10bbb1
6c7168d
a6636ce
566976e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def temp_extension_registry(): | ||
from wsimod.extensions import extensions_registry | ||
|
||
bkp = extensions_registry.copy() | ||
extensions_registry.clear() | ||
yield | ||
extensions_registry.clear() | ||
extensions_registry.update(bkp) | ||
|
||
|
||
def test_register_node_patch(temp_extension_registry): | ||
from wsimod.extensions import extensions_registry, register_node_patch | ||
|
||
# Define a dummy function to patch a node method | ||
@register_node_patch("node_name", "method_name") | ||
def dummy_patch(): | ||
print("Patched method") | ||
|
||
# Check if the patch is registered correctly | ||
assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch | ||
|
||
# Another function with other arguments | ||
@register_node_patch("node_name", "method_name", item="default", is_attr=True) | ||
def another_dummy_patch(): | ||
print("Another patched method") | ||
|
||
# Check if this other patch is registered correctly | ||
assert ( | ||
extensions_registry[("node_name", "method_name", "default", True)] | ||
== another_dummy_patch | ||
) | ||
|
||
|
||
def test_apply_patches(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import ( | ||
apply_patches, | ||
extensions_registry, | ||
register_node_patch, | ||
) | ||
from wsimod.nodes import Node | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Node("dummy_node") | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
# 1. Patch a method | ||
@register_node_patch("dummy_node", "apply_overrides") | ||
def dummy_patch(): | ||
pass | ||
|
||
# 2. Patch an attribute | ||
@register_node_patch("dummy_node", "t", is_attr=True) | ||
def another_dummy_patch(node): | ||
return f"A pathced attribute for {node.name}" | ||
|
||
# 3. Patch a method with an item | ||
@register_node_patch("dummy_node", "pull_set_handler", item="default") | ||
def yet_another_dummy_patch(): | ||
pass | ||
|
||
# 4. Path a method of an attribute | ||
@register_node_patch("dummy_node", "dummy_arc.arc_mass_balance") | ||
def arc_dummy_patch(): | ||
pass | ||
|
||
# Check if all patches are registered | ||
assert len(extensions_registry) == 4 | ||
|
||
# Apply the patches | ||
apply_patches(model) | ||
|
||
# Verify that the patches are applied correctly | ||
assert ( | ||
model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name]._patched_apply_overrides.__qualname__ | ||
== "Node.apply_overrides" | ||
) | ||
assert model.nodes[node.name].t == another_dummy_patch(node) | ||
assert model.nodes[node.name]._patched_t == None | ||
assert ( | ||
model.nodes[node.name].pull_set_handler["default"].__qualname__ | ||
== yet_another_dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ | ||
== arc_dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__ | ||
== "Arc.arc_mass_balance" | ||
) | ||
|
||
|
||
def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): | ||
"""Check if two dictionaries are almost equal. | ||
|
||
Args: | ||
d1 (dict): The first dictionary. | ||
d2 (dict): The second dictionary. | ||
tol (float | None, optional): Relative tolerance. Defaults to 1e-6, | ||
`pytest.approx` default. | ||
""" | ||
for key in d1.keys(): | ||
assert d1[key] == pytest.approx(d2[key], rel=tol) | ||
|
||
|
||
def test_path_method_with_reuse(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import apply_patches, register_node_patch | ||
from wsimod.nodes.storage import Reservoir | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Reservoir(name="dummy_node", initial_storage=10, capacity=10) | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
|
||
vq = node.pull_distributed({"volume": 5}) | ||
assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5)) | ||
|
||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
@register_node_patch("dummy_node", "pull_distributed") | ||
def new_pull_distributed(self, vqip, of_type=None, tag="default"): | ||
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) | ||
|
||
# Apply the patches | ||
apply_patches(model) | ||
|
||
# Check appropriate result | ||
assert node.tank.storage["volume"] == 5 | ||
vq = model.nodes[node.name].pull_distributed({"volume": 5}) | ||
assert_dict_almost_equal(vq, node.empty_vqip()) | ||
assert node.tank.storage["volume"] == 5 | ||
|
||
|
||
def test_handler_extensions(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import apply_patches, register_node_patch | ||
from wsimod.nodes import Node | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Node("dummy_node") | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
# 1. Patch a handler | ||
@register_node_patch("dummy_node", "pull_check_handler", item="default") | ||
def dummy_patch(self, *args, **kwargs): | ||
return "dummy_patch" | ||
|
||
# 2. Patch a handler with access to self | ||
@register_node_patch("dummy_node", "pull_set_handler", item="default") | ||
def dummy_patch(self, vqip, *args, **kwargs): | ||
return f"{self.name} - {vqip['volume']}" | ||
|
||
apply_patches(model) | ||
|
||
assert node.pull_check() == "dummy_patch" | ||
assert node.pull_set({"volume": 1}) == "dummy_node - 1" |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,117 @@ | ||||||
"""This module contains the utilities to extend WSMOD with new features. | ||||||
|
||||||
The `register_node_patch` decorator is used to register a function that will be used | ||||||
instead of a method or attribute of a node. The `apply_patches` function applies all | ||||||
registered patches to a model. | ||||||
|
||||||
Example of patching a method: | ||||||
|
||||||
`empty_distributed` will be called instead of `pull_distributed` of "my_node": | ||||||
|
||||||
>>> from wsimod.extensions import register_node_patch, apply_patches | ||||||
>>> @register_node_patch("my_node", "pull_distributed") | ||||||
>>> def empty_distributed(self, vqip): | ||||||
>>> return {} | ||||||
|
||||||
Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a | ||||||
list or a dictionary can be patched if the item argument is provided. | ||||||
|
||||||
Example of patching an attribute: | ||||||
|
||||||
`10` will be assigned to `t`: | ||||||
|
||||||
>>> @register_node_patch("my_node", "t", is_attr=True) | ||||||
>>> def patch_t(node): | ||||||
>>> return 10 | ||||||
|
||||||
Example of patching an attribute item: | ||||||
|
||||||
`patch_default_pull_set_handler` will be assigned to | ||||||
`pull_set_handler["default"]`: | ||||||
|
||||||
>>> @register_node_patch("my_node", "pull_set_handler", item="default") | ||||||
>>> def patch_default_pull_set_handler(self, vqip): | ||||||
>>> return {} | ||||||
|
||||||
If patching a method of an attribute, the `is_attr` argument should be set to `True` and | ||||||
the target should include the attribute name and the method name, all separated by | ||||||
periods, eg. `attribute_name.method_name`. | ||||||
|
||||||
It should be noted that the patched function should have the same signature as the | ||||||
original method or attribute, and the return type should be the same as well, otherwise | ||||||
there will be a runtime error. In particular, the first argument of the patched function | ||||||
should be the node object itself, which will typically be named `self`. | ||||||
|
||||||
The overridden method or attribute can be accessed within the patched function using the | ||||||
`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`. | ||||||
The exception to this is when patching an item, in which case the original item is no | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
available to be used within the overriding function. | ||||||
|
||||||
Finally, the `apply_patches` is called within the `Model.load` method and will apply all | ||||||
patches in the order they were registered. This means that users need to be careful with | ||||||
the order of the patches in their extensions files, as they may have interdependencies. | ||||||
|
||||||
TODO: Update documentation on extensions files. | ||||||
""" | ||||||
from typing import Callable, Hashable | ||||||
|
||||||
from .orchestration.model import Model | ||||||
|
||||||
extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {} | ||||||
|
||||||
|
||||||
def register_node_patch( | ||||||
node_name: str, target: str, item: Hashable = None, is_attr: bool = False | ||||||
) -> Callable: | ||||||
"""Register a function to patch a node method or any of its attributes. | ||||||
|
||||||
Args: | ||||||
node_name (str): The name of the node to patch. | ||||||
target (str): The target of the object to patch in the form of a string with the | ||||||
attribute, sub-attribute, etc. and finally method (or attribute) to replace, | ||||||
sepparated with period, eg. `make_discharge` or | ||||||
`sewer_tank.pull_storage_exact`. | ||||||
item (Hashable): Typically a string or an integer indicating the item to replace | ||||||
in the selected attribue, which should be a list or a dictionary. | ||||||
is_attr (bool): If True, the decorated function will be called when applying | ||||||
the patch and the result assigned to the target, instead of assigning the | ||||||
function itself. In this case, the only argument passed to the function is | ||||||
the node object. | ||||||
""" | ||||||
target_id = (node_name, target, item, is_attr) | ||||||
if target_id in extensions_registry: | ||||||
raise ValueError(f"Patch for {target} already registered.") | ||||||
|
||||||
def decorator(func): | ||||||
extensions_registry[target_id] = func | ||||||
return func | ||||||
|
||||||
return decorator | ||||||
|
||||||
|
||||||
def apply_patches(model: Model) -> None: | ||||||
"""Apply all registered patches to the model. | ||||||
|
||||||
TODO: Validate signature of the patched methods and type of patched attributes. | ||||||
|
||||||
Args: | ||||||
model (Model): The model to apply the patches to. | ||||||
""" | ||||||
for (node_name, target, item, is_attr), func in extensions_registry.items(): | ||||||
starget = target.split(".") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there isn't anything to stop nodes have a Perhaps could we have Not sure - what do you think? If it's too awful then at least we validate to ensure no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a problem. That's an easy fix. I just put it all together in a single line because I felt it was easier to understand and to cover more cases - in particular the sub-attributes one - in one, consistent approach. About users using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, you're right. There's no point on artificially restricting what a node name can be. What about changing the decorator signature to? def register_node_patch(
node_name: str, target: str, item: Hashable = None, is_attr: bool = False
) -> Callable: So There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. works for me! |
||||||
method = starget.pop() | ||||||
|
||||||
# Get the member to patch | ||||||
node = obj = model.nodes[node_name] | ||||||
for attr in starget: | ||||||
obj = getattr(obj, attr) | ||||||
|
||||||
# Apply the patch | ||||||
if item is not None: | ||||||
obj = getattr(obj, method) | ||||||
obj[item] = func(node) if is_attr else func.__get__(node, node.__class__) | ||||||
else: | ||||||
setattr(obj, f"_patched_{method}", getattr(obj, method)) | ||||||
setattr( | ||||||
obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__) | ||||||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dalonsoa I've added a test to update handlers in different ways - if you are happy that these are suitable then could you update the docstring accordingly please