Skip to content

Commit

Permalink
Merge pull request #101 from ImperialCollegeLondon/extensions_patch
Browse files Browse the repository at this point in the history
Adds the extensions patch functionality
  • Loading branch information
dalonsoa authored Sep 11, 2024
2 parents 05ead2f + 566976e commit e74c6e7
Show file tree
Hide file tree
Showing 3 changed files with 295 additions and 0 deletions.
174 changes: 174 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Optional

import pytest


@pytest.fixture
def temp_extension_registry():
from wsimod.extensions import extensions_registry

bkp = extensions_registry.copy()
extensions_registry.clear()
yield
extensions_registry.clear()
extensions_registry.update(bkp)


def test_register_node_patch(temp_extension_registry):
from wsimod.extensions import extensions_registry, register_node_patch

# Define a dummy function to patch a node method
@register_node_patch("node_name", "method_name")
def dummy_patch():
print("Patched method")

# Check if the patch is registered correctly
assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch

# Another function with other arguments
@register_node_patch("node_name", "method_name", item="default", is_attr=True)
def another_dummy_patch():
print("Another patched method")

# Check if this other patch is registered correctly
assert (
extensions_registry[("node_name", "method_name", "default", True)]
== another_dummy_patch
)


def test_apply_patches(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import (
apply_patches,
extensions_registry,
register_node_patch,
)
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

# Create a dummy model
node = Node("dummy_node")
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)
model = Model()
model.nodes[node.name] = node

# 1. Patch a method
@register_node_patch("dummy_node", "apply_overrides")
def dummy_patch():
pass

# 2. Patch an attribute
@register_node_patch("dummy_node", "t", is_attr=True)
def another_dummy_patch(node):
return f"A pathced attribute for {node.name}"

# 3. Patch a method with an item
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def yet_another_dummy_patch():
pass

# 4. Path a method of an attribute
@register_node_patch("dummy_node", "dummy_arc.arc_mass_balance")
def arc_dummy_patch():
pass

# Check if all patches are registered
assert len(extensions_registry) == 4

# Apply the patches
apply_patches(model)

# Verify that the patches are applied correctly
assert (
model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__
)
assert (
model.nodes[node.name]._patched_apply_overrides.__qualname__
== "Node.apply_overrides"
)
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name]._patched_t == None
assert (
model.nodes[node.name].pull_set_handler["default"].__qualname__
== yet_another_dummy_patch.__qualname__
)
assert (
model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__
== arc_dummy_patch.__qualname__
)
assert (
model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__
== "Arc.arc_mass_balance"
)


def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None):
"""Check if two dictionaries are almost equal.
Args:
d1 (dict): The first dictionary.
d2 (dict): The second dictionary.
tol (float | None, optional): Relative tolerance. Defaults to 1e-6,
`pytest.approx` default.
"""
for key in d1.keys():
assert d1[key] == pytest.approx(d2[key], rel=tol)


def test_path_method_with_reuse(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes.storage import Reservoir
from wsimod.orchestration.model import Model

# Create a dummy model
node = Reservoir(name="dummy_node", initial_storage=10, capacity=10)
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)

vq = node.pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5))

model = Model()
model.nodes[node.name] = node

@register_node_patch("dummy_node", "pull_distributed")
def new_pull_distributed(self, vqip, of_type=None, tag="default"):
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag)

# Apply the patches
apply_patches(model)

# Check appropriate result
assert node.tank.storage["volume"] == 5
vq = model.nodes[node.name].pull_distributed({"volume": 5})
assert_dict_almost_equal(vq, node.empty_vqip())
assert node.tank.storage["volume"] == 5


def test_handler_extensions(temp_extension_registry):
from wsimod.arcs.arcs import Arc
from wsimod.extensions import apply_patches, register_node_patch
from wsimod.nodes import Node
from wsimod.orchestration.model import Model

# Create a dummy model
node = Node("dummy_node")
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)
model = Model()
model.nodes[node.name] = node

# 1. Patch a handler
@register_node_patch("dummy_node", "pull_check_handler", item="default")
def dummy_patch(self, *args, **kwargs):
return "dummy_patch"

# 2. Patch a handler with access to self
@register_node_patch("dummy_node", "pull_set_handler", item="default")
def dummy_patch(self, vqip, *args, **kwargs):
return f"{self.name} - {vqip['volume']}"

apply_patches(model)

assert node.pull_check() == "dummy_patch"
assert node.pull_set({"volume": 1}) == "dummy_node - 1"
117 changes: 117 additions & 0 deletions wsimod/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""This module contains the utilities to extend WSMOD with new features.
The `register_node_patch` decorator is used to register a function that will be used
instead of a method or attribute of a node. The `apply_patches` function applies all
registered patches to a model.
Example of patching a method:
`empty_distributed` will be called instead of `pull_distributed` of "my_node":
>>> from wsimod.extensions import register_node_patch, apply_patches
>>> @register_node_patch("my_node", "pull_distributed")
>>> def empty_distributed(self, vqip):
>>> return {}
Attributes, methods of the node, and sub-attributes can be patched. Also, an item of a
list or a dictionary can be patched if the item argument is provided.
Example of patching an attribute:
`10` will be assigned to `t`:
>>> @register_node_patch("my_node", "t", is_attr=True)
>>> def patch_t(node):
>>> return 10
Example of patching an attribute item:
`patch_default_pull_set_handler` will be assigned to
`pull_set_handler["default"]`:
>>> @register_node_patch("my_node", "pull_set_handler", item="default")
>>> def patch_default_pull_set_handler(self, vqip):
>>> return {}
If patching a method of an attribute, the `is_attr` argument should be set to `True` and
the target should include the attribute name and the method name, all separated by
periods, eg. `attribute_name.method_name`.
It should be noted that the patched function should have the same signature as the
original method or attribute, and the return type should be the same as well, otherwise
there will be a runtime error. In particular, the first argument of the patched function
should be the node object itself, which will typically be named `self`.
The overridden method or attribute can be accessed within the patched function using the
`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`.
The exception to this is when patching an item, in which case the original item is no
available to be used within the overriding function.
Finally, the `apply_patches` is called within the `Model.load` method and will apply all
patches in the order they were registered. This means that users need to be careful with
the order of the patches in their extensions files, as they may have interdependencies.
TODO: Update documentation on extensions files.
"""
from typing import Callable, Hashable

from .orchestration.model import Model

extensions_registry: dict[tuple[str, Hashable, bool], Callable] = {}


def register_node_patch(
node_name: str, target: str, item: Hashable = None, is_attr: bool = False
) -> Callable:
"""Register a function to patch a node method or any of its attributes.
Args:
node_name (str): The name of the node to patch.
target (str): The target of the object to patch in the form of a string with the
attribute, sub-attribute, etc. and finally method (or attribute) to replace,
sepparated with period, eg. `make_discharge` or
`sewer_tank.pull_storage_exact`.
item (Hashable): Typically a string or an integer indicating the item to replace
in the selected attribue, which should be a list or a dictionary.
is_attr (bool): If True, the decorated function will be called when applying
the patch and the result assigned to the target, instead of assigning the
function itself. In this case, the only argument passed to the function is
the node object.
"""
target_id = (node_name, target, item, is_attr)
if target_id in extensions_registry:
raise ValueError(f"Patch for {target} already registered.")

def decorator(func):
extensions_registry[target_id] = func
return func

return decorator


def apply_patches(model: Model) -> None:
"""Apply all registered patches to the model.
TODO: Validate signature of the patched methods and type of patched attributes.
Args:
model (Model): The model to apply the patches to.
"""
for (node_name, target, item, is_attr), func in extensions_registry.items():
starget = target.split(".")
method = starget.pop()

# Get the member to patch
node = obj = model.nodes[node_name]
for attr in starget:
obj = getattr(obj, attr)

# Apply the patch
if item is not None:
obj = getattr(obj, method)
obj[item] = func(node) if is_attr else func.__get__(node, node.__class__)
else:
setattr(obj, f"_patched_{method}", getattr(obj, method))
setattr(
obj, method, func(node) if is_attr else func.__get__(obj, obj.__class__)
)
4 changes: 4 additions & 0 deletions wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def load(self, address, config_name="config.yml", overrides={}):
config_name:
overrides:
"""
from ..extensions import apply_patches

with open(os.path.join(address, config_name), "r") as file:
data = yaml.safe_load(file)

Expand Down Expand Up @@ -191,6 +193,8 @@ def load(self, address, config_name="config.yml", overrides={}):
if "dates" in data.keys():
self.dates = [to_datetime(x) for x in data["dates"]]

apply_patches(self)

def save(self, address, config_name="config.yml", compress=False):
"""Save the model object to a yaml file and input data to csv.gz format in the
directory specified.
Expand Down

0 comments on commit e74c6e7

Please sign in to comment.