-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fwtw-overrides
- Loading branch information
Showing
14 changed files
with
1,321 additions
and
921 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
docs/demo/examples/test_customise_orchestration_example.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
orchestration: | ||
- Groundwater: infiltrate | ||
- Sewer: make_discharge | ||
|
||
nodes: | ||
Sewer: | ||
type_: Sewer | ||
name: my_sewer | ||
capacity: 0.04 | ||
|
||
Groundwater: | ||
type_: Groundwater | ||
name: my_groundwater | ||
capacity: 100 | ||
area: 100 | ||
|
||
River: | ||
type_: Node | ||
name: my_river | ||
|
||
Waste: | ||
type_: Waste | ||
name: my_outlet | ||
|
||
arcs: | ||
storm_outflow: | ||
type_: Arc | ||
name: storm_outflow | ||
in_port: my_sewer | ||
out_port: my_river | ||
|
||
baseflow: | ||
type_: Arc | ||
name: baseflow | ||
in_port: my_groundwater | ||
out_port: my_river | ||
|
||
catchment_outflow: | ||
type_: Arc | ||
name: catchment_outflow | ||
in_port: my_river | ||
out_port: my_outlet | ||
|
||
pollutants: | ||
- org-phosphorus | ||
- phosphate | ||
- ammonia | ||
- solids | ||
- temperature | ||
- nitrate | ||
- nitrite | ||
- org-nitrogen | ||
additive_pollutants: | ||
- org-phosphorus | ||
- phosphate | ||
- ammonia | ||
- solids | ||
- nitrate | ||
- nitrite | ||
- org-nitrogen | ||
non_additive_pollutants: | ||
- temperature | ||
float_accuracy: 1.0e-06 | ||
|
||
dates: | ||
- '2000-01-01' | ||
- '2000-01-02' | ||
- '2000-01-03' | ||
- '2000-01-04' | ||
- '2000-01-05' | ||
- '2000-01-06' | ||
- '2000-01-07' | ||
- '2000-01-08' | ||
- '2000-01-09' | ||
- '2000-01-10' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from typing import Optional | ||
|
||
import pytest | ||
|
||
|
||
@pytest.fixture | ||
def temp_extension_registry(): | ||
from wsimod.extensions import extensions_registry | ||
|
||
bkp = extensions_registry.copy() | ||
extensions_registry.clear() | ||
yield | ||
extensions_registry.clear() | ||
extensions_registry.update(bkp) | ||
|
||
|
||
def test_register_node_patch(temp_extension_registry): | ||
from wsimod.extensions import extensions_registry, register_node_patch | ||
|
||
# Define a dummy function to patch a node method | ||
@register_node_patch("node_name", "method_name") | ||
def dummy_patch(): | ||
print("Patched method") | ||
|
||
# Check if the patch is registered correctly | ||
assert extensions_registry[("node_name", "method_name", None, False)] == dummy_patch | ||
|
||
# Another function with other arguments | ||
@register_node_patch("node_name", "method_name", item="default", is_attr=True) | ||
def another_dummy_patch(): | ||
print("Another patched method") | ||
|
||
# Check if this other patch is registered correctly | ||
assert ( | ||
extensions_registry[("node_name", "method_name", "default", True)] | ||
== another_dummy_patch | ||
) | ||
|
||
|
||
def test_apply_patches(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import ( | ||
apply_patches, | ||
extensions_registry, | ||
register_node_patch, | ||
) | ||
from wsimod.nodes import Node | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Node("dummy_node") | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
# 1. Patch a method | ||
@register_node_patch("dummy_node", "apply_overrides") | ||
def dummy_patch(): | ||
pass | ||
|
||
# 2. Patch an attribute | ||
@register_node_patch("dummy_node", "t", is_attr=True) | ||
def another_dummy_patch(node): | ||
return f"A pathced attribute for {node.name}" | ||
|
||
# 3. Patch a method with an item | ||
@register_node_patch("dummy_node", "pull_set_handler", item="default") | ||
def yet_another_dummy_patch(): | ||
pass | ||
|
||
# 4. Path a method of an attribute | ||
@register_node_patch("dummy_node", "dummy_arc.arc_mass_balance") | ||
def arc_dummy_patch(): | ||
pass | ||
|
||
# Check if all patches are registered | ||
assert len(extensions_registry) == 4 | ||
|
||
# Apply the patches | ||
apply_patches(model) | ||
|
||
# Verify that the patches are applied correctly | ||
assert ( | ||
model.nodes[node.name].apply_overrides.__qualname__ == dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name]._patched_apply_overrides.__qualname__ | ||
== "Node.apply_overrides" | ||
) | ||
assert model.nodes[node.name].t == another_dummy_patch(node) | ||
assert model.nodes[node.name]._patched_t == None | ||
assert ( | ||
model.nodes[node.name].pull_set_handler["default"].__qualname__ | ||
== yet_another_dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name].dummy_arc.arc_mass_balance.__qualname__ | ||
== arc_dummy_patch.__qualname__ | ||
) | ||
assert ( | ||
model.nodes[node.name].dummy_arc._patched_arc_mass_balance.__qualname__ | ||
== "Arc.arc_mass_balance" | ||
) | ||
|
||
|
||
def assert_dict_almost_equal(d1: dict, d2: dict, tol: Optional[float] = None): | ||
"""Check if two dictionaries are almost equal. | ||
Args: | ||
d1 (dict): The first dictionary. | ||
d2 (dict): The second dictionary. | ||
tol (float | None, optional): Relative tolerance. Defaults to 1e-6, | ||
`pytest.approx` default. | ||
""" | ||
for key in d1.keys(): | ||
assert d1[key] == pytest.approx(d2[key], rel=tol) | ||
|
||
|
||
def test_path_method_with_reuse(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import apply_patches, register_node_patch | ||
from wsimod.nodes.storage import Reservoir | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Reservoir(name="dummy_node", initial_storage=10, capacity=10) | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
|
||
vq = node.pull_distributed({"volume": 5}) | ||
assert_dict_almost_equal(vq, node.v_change_vqip(node.empty_vqip(), 5)) | ||
|
||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
@register_node_patch("dummy_node", "pull_distributed") | ||
def new_pull_distributed(self, vqip, of_type=None, tag="default"): | ||
return self._patched_pull_distributed(vqip, of_type=["Node"], tag=tag) | ||
|
||
# Apply the patches | ||
apply_patches(model) | ||
|
||
# Check appropriate result | ||
assert node.tank.storage["volume"] == 5 | ||
vq = model.nodes[node.name].pull_distributed({"volume": 5}) | ||
assert_dict_almost_equal(vq, node.empty_vqip()) | ||
assert node.tank.storage["volume"] == 5 | ||
|
||
|
||
def test_handler_extensions(temp_extension_registry): | ||
from wsimod.arcs.arcs import Arc | ||
from wsimod.extensions import apply_patches, register_node_patch | ||
from wsimod.nodes import Node | ||
from wsimod.orchestration.model import Model | ||
|
||
# Create a dummy model | ||
node = Node("dummy_node") | ||
node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node) | ||
model = Model() | ||
model.nodes[node.name] = node | ||
|
||
# 1. Patch a handler | ||
@register_node_patch("dummy_node", "pull_check_handler", item="default") | ||
def dummy_patch(self, *args, **kwargs): | ||
return "dummy_patch" | ||
|
||
# 2. Patch a handler with access to self | ||
@register_node_patch("dummy_node", "pull_set_handler", item="default") | ||
def dummy_patch(self, vqip, *args, **kwargs): | ||
return f"{self.name} - {vqip['volume']}" | ||
|
||
apply_patches(model) | ||
|
||
assert node.pull_check() == "dummy_patch" | ||
assert node.pull_set({"volume": 1}) == "dummy_node - 1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.