Skip to content

Commit

Permalink
add custom class from file
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobson committed Oct 3, 2024
1 parent 1e608d5 commit a0e485e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
12 changes: 12 additions & 0 deletions tests/custom_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from wsimod.nodes.nodes import Node

class CustomNode(Node):
"""A custom node."""
def __init__(self, name):
"""Initialise the node."""
super().__init__(name)
self.custom_attr = 1

def end_timestep(self):
self.custom_attr += 1
super().end_timestep()
35 changes: 32 additions & 3 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,45 @@ def dummy_patch(self, vqip, *args, **kwargs):
assert node.pull_set({"volume": 1}) == "dummy_node - 1"


def test_custom_class():
def test_custom_class_from_file():
"""Test a custom class."""
from pathlib import Path
import yaml
import tempfile

from wsimod.nodes.nodes import NODES_REGISTRY
from wsimod.orchestration.model import Model, to_datetime

# Remove in case it was in there from previous test
NODES_REGISTRY.pop("CustomNode", None)

with tempfile.TemporaryDirectory() as temp_dir:
config = {
"nodes": {"node_name": {"type_": "CustomNode", "name": "node_name"}},
"extensions": [str(Path(__file__).parent / "custom_class.py")],
}

with open(temp_dir + "/config.yml", "w") as f:
yaml.dump(config, f)

model = Model()
model.load(temp_dir)
assert model.nodes["node_name"].custom_attr == 1
model.run(dates=[to_datetime("2000-01-01")])
assert model.nodes["node_name"].custom_attr == 2


def test_custom_class_on_the_fly():
"""Test a custom class."""

import tempfile

from wsimod.nodes.nodes import Node, NODES_REGISTRY
from wsimod.orchestration.model import Model, to_datetime

# Remove in case it was in there from previous test
NODES_REGISTRY.pop("CustomNode", None)

class CustomNode(Node):
def __init__(self, name):
super().__init__(name)
Expand All @@ -191,7 +222,6 @@ def end_timestep(self):
self.custom_attr += 1
super().end_timestep()


with tempfile.TemporaryDirectory() as temp_dir:
model = Model()
model.nodes["node_name"] = CustomNode("node_name")
Expand All @@ -200,7 +230,6 @@ def end_timestep(self):
del model
model = Model()
model.load(temp_dir)
model.river_dishcarge_order = []
assert model.nodes["node_name"].custom_attr == 1
model.run(dates=[to_datetime("2000-01-01")])
assert model.nodes["node_name"].custom_attr == 2
19 changes: 14 additions & 5 deletions wsimod/orchestration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,16 @@ def load(self, address, config_name="config.yml", overrides={}):
for key, item in overrides.items():
data[key] = item

constants.POLLUTANTS = data["pollutants"]
constants.ADDITIVE_POLLUTANTS = data["additive_pollutants"]
constants.NON_ADDITIVE_POLLUTANTS = data["non_additive_pollutants"]
constants.FLOAT_ACCURACY = float(data["float_accuracy"])
constants.POLLUTANTS = data.get("pollutants", constants.POLLUTANTS)
constants.ADDITIVE_POLLUTANTS = data.get(
"additive_pollutants", constants.ADDITIVE_POLLUTANTS
)
constants.NON_ADDITIVE_POLLUTANTS = data.get(
"non_additive_pollutants", constants.NON_ADDITIVE_POLLUTANTS
)
constants.FLOAT_ACCURACY = float(
data.get("float_accuracy", constants.FLOAT_ACCURACY)
)
self.__dict__.update(Model().__dict__)

"""
Expand All @@ -199,6 +205,9 @@ def load(self, address, config_name="config.yml", overrides={}):
# Update orchestration
self.orchestration = data["orchestration"]

if "nodes" not in data.keys():
raise ValueError("No nodes found in the config")

nodes = data["nodes"]

for name, node in nodes.items():
Expand All @@ -215,7 +224,7 @@ def load(self, address, config_name="config.yml", overrides={}):
)
del surface["filename"]
node["surfaces"] = list(node["surfaces"].values())
arcs = data["arcs"]
arcs = data.get("arcs", {})
self.add_nodes(list(nodes.values()))
self.add_arcs(list(arcs.values()))
if "dates" in data.keys():
Expand Down

0 comments on commit a0e485e

Please sign in to comment.