diff --git a/test/plugin_test.py b/test/plugin_test.py new file mode 100644 index 00000000..2fd80755 --- /dev/null +++ b/test/plugin_test.py @@ -0,0 +1,71 @@ +import pytest + +from veros.plugins import load_plugin +from veros.routines import veros_routine +from veros.state import get_default_state +from veros.variables import Variable +from veros.settings import Setting + + +@pytest.fixture +def fake_plugin(): + class FakePlugin: + pass + + def run_setup(state): + plugin._setup_ran = True + + def run_main(state): + plugin._main_ran = True + + plugin = FakePlugin() + plugin.__name__ = "foobar" + plugin._setup_ran = False + plugin._main_ran = False + plugin.__VEROS_INTERFACE__ = { + "name": "foo", + "setup_entrypoint": run_setup, + "run_entrypoint": run_main, + "settings": dict(mydimsetting=Setting(15, int, "bar")), + "variables": dict(myvar=Variable("myvar", ("xt", "yt", "mydim"))), + "dimensions": dict(mydim="mydimsetting"), + "diagnostics": [], + } + yield plugin + + +def test_load_plugin(fake_plugin): + plugin_interface = load_plugin(fake_plugin) + assert plugin_interface.name == "foo" + + +def test_state_plugin(fake_plugin): + plugin_interface = load_plugin(fake_plugin) + state = get_default_state(plugin_interfaces=plugin_interface) + assert "mydimsetting" in state.settings + assert "mydim" in state.dimensions + assert state.dimensions["mydim"] == state.settings.mydimsetting + state.initialize_variables() + assert "myvar" in state.variables + assert state.variables.myvar.shape == (4, 4, state.settings.mydimsetting) + + +def test_run_plugin(fake_plugin): + from veros.setups.acc_basic import ACCBasicSetup + + class FakeSetup(ACCBasicSetup): + __veros_plugins__ = (fake_plugin,) + + @veros_routine + def set_diagnostics(self, state): + pass + + setup = FakeSetup(override=dict(dt_tracer=100, runlen=100)) + + assert not fake_plugin._setup_ran + setup.setup() + assert fake_plugin._setup_ran + + assert not fake_plugin._main_ran + setup.run() + assert fake_plugin._main_ran diff --git a/test/setup_test.py b/test/setup_test.py index 64de8029..0ed78c72 100644 --- a/test/setup_test.py +++ b/test/setup_test.py @@ -6,6 +6,10 @@ def set_options(): from veros import runtime_settings object.__setattr__(runtime_settings, "diskless_mode", True) + try: + yield + finally: + object.__setattr__(runtime_settings, "diskless_mode", False) @pytest.mark.parametrize("float_type", ("float32", "float64")) diff --git a/veros/plugins.py b/veros/plugins.py index 55ecec10..b16a4d48 100644 --- a/veros/plugins.py +++ b/veros/plugins.py @@ -12,6 +12,7 @@ "run_entrypoint", "settings", "variables", + "dimensions", "diagnostics", ], ) @@ -37,18 +38,23 @@ def load_plugin(module): if not callable(run_entrypoint): raise RuntimeError(f"module {modname} is missing a valid run entrypoint") - name = interface.get("name", module.__name__) + name = interface.get("name", modname) - settings = interface.get("settings", []) + settings = interface.get("settings", {}) for setting, val in settings.items(): if not isinstance(val, Setting): raise TypeError(f"got unexpected type {type(val)} for setting {setting}") - variables = interface.get("variables", []) + variables = interface.get("variables", {}) for variable, val in variables.items(): if not isinstance(val, Variable): raise TypeError(f"got unexpected type {type(val)} for variable {variable}") + dimensions = interface.get("dimensions", {}) + for dim, val in dimensions.items(): + if not isinstance(val, (str, int)): + raise TypeError(f"got unexpected type {type(val)} for dimension {dim}") + diagnostics = interface.get("diagnostics", []) for diagnostic in diagnostics: if not issubclass(diagnostic, VerosDiagnostic): @@ -61,5 +67,6 @@ def load_plugin(module): run_entrypoint=run_entrypoint, settings=settings, variables=variables, + dimensions=dimensions, diagnostics=diagnostics, ) diff --git a/veros/state.py b/veros/state.py index 703af9a5..a806fe5e 100644 --- a/veros/state.py +++ b/veros/state.py @@ -438,24 +438,24 @@ def to_xarray(self): return xr.Dataset(data_vars, coords=coords, attrs=attrs) -def get_default_state(use_plugins=None): - if use_plugins is not None: - plugin_interfaces = tuple(plugins.load_plugin(p) for p in use_plugins) - else: - plugin_interfaces = tuple() - - default_settings = deepcopy(settings_mod.SETTINGS) +def get_default_state(plugin_interfaces=()): + if isinstance(plugin_interfaces, plugins.VerosPlugin): + plugin_interfaces = [plugin_interfaces] for plugin in plugin_interfaces: - default_settings.update(plugin.settings) + if not isinstance(plugin, plugins.VerosPlugin): + raise TypeError(f"Got unexpected type {type(plugin)}") - default_dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR) + settings = deepcopy(settings_mod.SETTINGS) + dimensions = deepcopy(var_mod.DIM_TO_SHAPE_VAR) var_meta = deepcopy(var_mod.VARIABLES) for plugin in plugin_interfaces: + settings.update(plugin.settings) var_meta.update(plugin.variables) + dimensions.update(plugin.dimensions) - return VerosState(var_meta, default_settings, default_dimensions, plugin_interfaces=plugin_interfaces) + return VerosState(var_meta, settings, dimensions, plugin_interfaces=plugin_interfaces) def veros_state_pytree_flatten(state): diff --git a/veros/variables.py b/veros/variables.py index 55d3bb6a..3284bf32 100644 --- a/veros/variables.py +++ b/veros/variables.py @@ -36,7 +36,7 @@ def __init__( self.get_mask = mask - elif dims is not None: + elif isinstance(dims, tuple): if dims[:3] in DEFAULT_MASKS: self.get_mask = DEFAULT_MASKS[dims[:3]] elif dims[:2] in DEFAULT_MASKS: diff --git a/veros/veros.py b/veros/veros.py index fea8c6d7..c8858265 100644 --- a/veros/veros.py +++ b/veros/veros.py @@ -15,7 +15,7 @@ class VerosSetup(metaclass=abc.ABCMeta): This class is meant to be subclassed. Subclasses need to implement the methods :meth:`set_parameter`, :meth:`set_topography`, :meth:`set_grid`, :meth:`set_coriolis`, :meth:`set_initial_conditions`, :meth:`set_forcing`, - and :meth:`set_diagnostics`. + :meth:`set_diagnostics`, and :meth:`after_timestep`. Example: >>> import matplotlib.pyplot as plt @@ -42,7 +42,7 @@ def __init__(self, override=None): self._plugin_interfaces = tuple(load_plugin(p) for p in self.__veros_plugins__) self._setup_done = False - self.state = get_default_state(use_plugins=self.__veros_plugins__) + self.state = get_default_state(plugin_interfaces=self._plugin_interfaces) @abc.abstractmethod def set_parameter(self, state): @@ -142,7 +142,7 @@ def set_forcing(self, state): pass @abc.abstractmethod - def set_diagnostics(self, vs): + def set_diagnostics(self, state): """To be implemented by subclass. Called before setting up the :ref:`diagnostics `. Use this method e.g. to @@ -204,7 +204,7 @@ def setup(self): self.state.diagnostics.update(diagnostics.create_default_diagnostics(self.state)) - for plugin in self.state.plugin_interfaces: + for plugin in self._plugin_interfaces: for diagnostic in plugin.diagnostics: self.state.diagnostics[diagnostic.name] = diagnostic() @@ -413,7 +413,7 @@ def _timing_summary(self): timing_summary.extend( [ " {:<22} = {:.2f}s".format(plugin.name, self.state.timers[plugin.name].total_time) - for plugin in self.state._plugin_interfaces + for plugin in self._plugin_interfaces ] )