diff --git a/oommfc/drivers/hysteresisdriver.py b/oommfc/drivers/hysteresisdriver.py index 089c90b..69a2fc5 100644 --- a/oommfc/drivers/hysteresisdriver.py +++ b/oommfc/drivers/hysteresisdriver.py @@ -74,6 +74,11 @@ def _checkargs(self, **kwargs): msg = f"Cannot drive with {n=}." raise ValueError(msg) + def _check_system(self, system): + """Checks the system has energy in it""" + if len(system.energy) == 0: + raise RuntimeError("System's energy is not defined") + @property def _x(self): return "B_hysteresis" diff --git a/oommfc/drivers/mindriver.py b/oommfc/drivers/mindriver.py index c0aa3c8..413d275 100644 --- a/oommfc/drivers/mindriver.py +++ b/oommfc/drivers/mindriver.py @@ -59,6 +59,11 @@ class MinDriver(Driver): def _checkargs(self, **kwargs): pass # no kwargs should be checked + def _check_system(self, system): + """Checks the system has energy in it""" + if len(system.energy) == 0: + raise RuntimeError("System's energy is not defined") + @property def _x(self): return "iteration" diff --git a/oommfc/drivers/timedriver.py b/oommfc/drivers/timedriver.py index 6b4ec1d..328cdef 100644 --- a/oommfc/drivers/timedriver.py +++ b/oommfc/drivers/timedriver.py @@ -67,6 +67,13 @@ def _checkargs(self, **kwargs): msg = f"Cannot drive with {n=}." raise ValueError(msg) + def _check_system(self, system): + """Checks the system has dynamics in it""" + if len(system.dynamics) == 0: + raise RuntimeError("System's dynamics is not defined") + if len(system.energy) == 0: + raise RuntimeError("System's energy is not defined") + @property def _x(self): return "t" diff --git a/oommfc/tests/conftest.py b/oommfc/tests/conftest.py index bbd5033..d4447a3 100644 --- a/oommfc/tests/conftest.py +++ b/oommfc/tests/conftest.py @@ -2,7 +2,20 @@ import oommfc as oc +not_supported_by_oommf = ["test_relax_check_for_energy", "test_relaxdriver"] + @pytest.fixture(scope="module") def calculator(): return oc + + +@pytest.fixture(autouse=True) +def skip_unsupported_or_missing(request): + requesting_test_function = ( + f"{request.cls.__name__}.{request.function.__name__}" + if request.cls + else request.function.__name__ + ) + if requesting_test_function in not_supported_by_oommf: + pytest.skip("Not supported by OOMMF.")