diff --git a/tests/custom_class.py b/tests/custom_class.py new file mode 100644 index 00000000..6c510a9f --- /dev/null +++ b/tests/custom_class.py @@ -0,0 +1,12 @@ +from wsimod.nodes.nodes import Node + +class CustomNode(Node): + """A custom node.""" + def __init__(self, name): + """Initialise the node.""" + super().__init__(name) + self.custom_attr = 1 + + def end_timestep(self): + self.custom_attr += 1 + super().end_timestep() \ No newline at end of file diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 65e8f779..16bfbb90 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -174,7 +174,35 @@ def dummy_patch(self, vqip, *args, **kwargs): assert node.pull_set({"volume": 1}) == "dummy_node - 1" -def test_custom_class(): +def test_custom_class_from_file(): + """Test a custom class.""" + from pathlib import Path + import yaml + import tempfile + + from wsimod.nodes.nodes import NODES_REGISTRY + from wsimod.orchestration.model import Model, to_datetime + + # Remove in case it was in there from previous test + NODES_REGISTRY.pop("CustomNode", None) + + with tempfile.TemporaryDirectory() as temp_dir: + config = { + "nodes": {"node_name": {"type_": "CustomNode", "name": "node_name"}}, + "extensions": [str(Path(__file__).parent / "custom_class.py")], + } + + with open(temp_dir + "/config.yml", "w") as f: + yaml.dump(config, f) + + model = Model() + model.load(temp_dir) + assert model.nodes["node_name"].custom_attr == 1 + model.run(dates=[to_datetime("2000-01-01")]) + assert model.nodes["node_name"].custom_attr == 2 + + +def test_custom_class_on_the_fly(): """Test a custom class.""" import tempfile @@ -182,6 +210,9 @@ def test_custom_class(): from wsimod.nodes.nodes import Node, NODES_REGISTRY from wsimod.orchestration.model import Model, to_datetime + # Remove in case it was in there from previous test + NODES_REGISTRY.pop("CustomNode", None) + class CustomNode(Node): def __init__(self, name): super().__init__(name) @@ -191,7 +222,6 @@ def end_timestep(self): self.custom_attr += 1 super().end_timestep() - with tempfile.TemporaryDirectory() as temp_dir: model = Model() model.nodes["node_name"] = CustomNode("node_name") @@ -200,7 +230,6 @@ def end_timestep(self): del model model = Model() model.load(temp_dir) - model.river_dishcarge_order = [] assert model.nodes["node_name"].custom_attr == 1 model.run(dates=[to_datetime("2000-01-01")]) assert model.nodes["node_name"].custom_attr == 2 diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index bab31985..cb7f0282 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -183,10 +183,16 @@ def load(self, address, config_name="config.yml", overrides={}): for key, item in overrides.items(): data[key] = item - constants.POLLUTANTS = data["pollutants"] - constants.ADDITIVE_POLLUTANTS = data["additive_pollutants"] - constants.NON_ADDITIVE_POLLUTANTS = data["non_additive_pollutants"] - constants.FLOAT_ACCURACY = float(data["float_accuracy"]) + constants.POLLUTANTS = data.get("pollutants", constants.POLLUTANTS) + constants.ADDITIVE_POLLUTANTS = data.get( + "additive_pollutants", constants.ADDITIVE_POLLUTANTS + ) + constants.NON_ADDITIVE_POLLUTANTS = data.get( + "non_additive_pollutants", constants.NON_ADDITIVE_POLLUTANTS + ) + constants.FLOAT_ACCURACY = float( + data.get("float_accuracy", constants.FLOAT_ACCURACY) + ) self.__dict__.update(Model().__dict__) """ @@ -199,6 +205,9 @@ def load(self, address, config_name="config.yml", overrides={}): # Update orchestration self.orchestration = data["orchestration"] + if "nodes" not in data.keys(): + raise ValueError("No nodes found in the config") + nodes = data["nodes"] for name, node in nodes.items(): @@ -215,7 +224,7 @@ def load(self, address, config_name="config.yml", overrides={}): ) del surface["filename"] node["surfaces"] = list(node["surfaces"].values()) - arcs = data["arcs"] + arcs = data.get("arcs", {}) self.add_nodes(list(nodes.values())) self.add_arcs(list(arcs.values())) if "dates" in data.keys():