diff --git a/tests/test_model.py b/tests/test_model.py index e5e404c..3ffdbf7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,10 +4,10 @@ @author: Barney """ - -# import pytest +import os +import pytest import unittest -from unittest import TestCase +from unittest import TestCase, mock from wsimod.arcs.arcs import Arc from wsimod.nodes.land import Land @@ -305,5 +305,47 @@ def test_customise_orchestration(self): self.assertListEqual(my_model.orchestration, revised_orchestration) +class TestLoadExtensionFiles: + def test_load_extension_files_valid(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("def test_func(): pass") + + load_extension_files([valid_file]) + + def test_load_extension_files_invalid_extension(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + invalid_file = os.path.join(tempdir, "invalid_extension.txt") + with open(invalid_file, "w") as f: + f.write("This is a text file") + + with pytest.raises(ValueError, match="Only .py files are supported"): + load_extension_files([invalid_file]) + + def test_load_extension_files_nonexistent_file(self): + from wsimod.orchestration.model import load_extension_files + + with pytest.raises( + FileNotFoundError, match="File nonexistent_file.py does not exist" + ): + load_extension_files(["nonexistent_file.py"]) + + def test_load_extension_files_import_error(self, tmp_path_factory): + from wsimod.orchestration.model import load_extension_files + + with tmp_path_factory.mktemp("extensions") as tempdir: + valid_file = os.path.join(tempdir, "valid_extension.py") + with open(valid_file, "w") as f: + f.write("raise ImportError") + + with pytest.raises(ImportError): + load_extension_files([valid_file]) + + if __name__ == "__main__": unittest.main() diff --git a/wsimod/orchestration/model.py b/wsimod/orchestration/model.py index 4026a32..5b3347f 100644 --- a/wsimod/orchestration/model.py +++ b/wsimod/orchestration/model.py @@ -178,7 +178,7 @@ def load(self, address, config_name="config.yml", overrides={}): from ..extensions import apply_patches with open(os.path.join(address, config_name), "r") as file: - data = yaml.safe_load(file) + data: dict = yaml.safe_load(file) for key, item in overrides.items(): data[key] = item @@ -220,6 +220,7 @@ def load(self, address, config_name="config.yml", overrides={}): if "dates" in data.keys(): self.dates = [to_datetime(x) for x in data["dates"]] + load_extension_files(data.get("extensions", [])) apply_patches(self) def save(self, address, config_name="config.yml", compress=False): @@ -1269,3 +1270,27 @@ def yaml2csv(address, config_name="config.yml", csv_folder_name="csv"): writer.writerow( [str(value_[x]) if x in value_.keys() else None for x in fields] ) + + +def load_extension_files(files: list[str]) -> None: + """Load extension files from a list of files. + + Args: + files (list[str]): List of file paths to load + + Raises: + ValueError: If file is not a .py file + FileNotFoundError: If file does not exist + """ + import importlib + from pathlib import Path + + for file in files: + if not file.endswith(".py"): + raise ValueError(f"Only .py files are supported. Invalid file: {file}") + if not Path(file).exists(): + raise FileNotFoundError(f"File {file} does not exist") + + spec = importlib.util.spec_from_file_location("module.name", file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module)