diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 7f76f681..ec243583 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -172,3 +172,36 @@ def dummy_patch(self, vqip, *args, **kwargs): assert node.pull_check() == "dummy_patch" assert node.pull_set({"volume": 1}) == "dummy_node - 1" + + +def test_custom_class(): + """Test a custom class.""" + + import tempfile + + from wsimod.nodes.nodes import Node, NODES_REGISTRY + from wsimod.orchestration.model import Model, to_datetime + + class CustomNode(Node): + def __init__(self, name): + super().__init__(name) + self.custom_attr = 1 + + def end_timestep(self): + self.custom_attr += 1 + super().end_timestep() + + NODES_REGISTRY["CustomNode"] = CustomNode + + with tempfile.TemporaryDirectory() as temp_dir: + model = Model() + model.nodes["node_name"] = CustomNode("node_name") + model.save(temp_dir) + + del model + model = Model() + model.load(temp_dir) + model.river_dishcarge_order = [] + assert model.nodes["node_name"].custom_attr == 1 + model.run(dates=[to_datetime("2000-01-01")]) + assert model.nodes["node_name"].custom_attr == 2 diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 5b3347f1..df6b9858 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -193,6 +193,7 @@ def load(self, address, config_name="config.yml", overrides={}): FLAG: E.G. ADDITION FOR NEW ORCHESTRATION """ + load_extension_files(data.get("extensions", [])) if "orchestration" in data.keys(): # Update orchestration @@ -220,7 +221,6 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] - load_extension_files(data.get("extensions", [])) apply_patches(self) def save(self, address, config_name="config.yml", compress=False): @@ -497,6 +497,7 @@ def add_arcs(self, arclist): ]: river_arcs[name] = self.arcs[name] + self.river_discharge_order = [] if any(river_arcs): upstreamness = ( {x: 0 for x in self.nodes_type["Waste"].keys()} @@ -505,7 +506,6 @@ def add_arcs(self, arclist): ) upstreamness = self.assign_upstream(river_arcs, upstreamness) - self.river_discharge_order = [] if "River" in self.nodes_type: for node in sorted( upstreamness.items(), key=lambda item: item[1], reverse=True