diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 38717f27fd..884e9f4fd6 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -13,6 +13,7 @@ # have to explicitly bring these in here to resolve circular import issues from .aliases import alias, resolve_name +from .component_store import ComponentStore from .decorators import MethodReplacer, RestartGenerator from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py new file mode 100644 index 0000000000..6fd8e8884f --- /dev/null +++ b/monai/utils/component_store.py @@ -0,0 +1,117 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import namedtuple +from keyword import iskeyword +from textwrap import dedent, indent +from typing import Any, Callable, Iterable, TypeVar + +T = TypeVar("T") + + +def is_variable(name): + """Returns True if `name` is a valid Python variable name and also not a keyword.""" + return name.isidentifier() and not iskeyword(name) + + +class ComponentStore: + """ + Represents a storage object for other objects (specifically functions) keyed to a name with a description. + + These objects act as global named places for storing components for objects parameterised by component names. + Typically this is functions although other objects can be added. Printing a component store will produce a + list of members along with their docstring information if present. + + Example: + + .. code-block:: python + + TestStore = ComponentStore("Test Store", "A test store for demo purposes") + + @TestStore.add_def("my_func_name", "Some description of your function") + def _my_func(a, b): + '''A description of your function here.''' + return a * b + + print(TestStore) # will print out name, description, and 'my_func_name' with the docstring + + func = TestStore["my_func_name"] + result = func(7, 6) + + """ + + _Component = namedtuple("Component", ("description", "value")) # internal value pair + + def __init__(self, name: str, description: str) -> None: + self.components: dict[str, self._Component] = {} + self.name: str = name + self.description: str = description + + self.__doc__ = f"Component Store '{name}': {description}\n{self.__doc__ or ''}".strip() + + def add(self, name: str, desc: str, value: T) -> T: + """Store the object `value` under the name `name` with description `desc`.""" + if not is_variable(name): + raise ValueError("Name of component must be valid Python identifier") + + self.components[name] = self._Component(desc, value) + return value + + def add_def(self, name: str, desc: str) -> Callable: + """Returns a decorator which stores the decorated function under `name` with description `desc`.""" + + def deco(func): + """Decorator to add a function to a store.""" + return self.add(name, desc, func) + + return deco + + def __contains__(self, name: str) -> bool: + """Returns True if the given name is stored.""" + return name in self.components + + def __len__(self) -> int: + """Returns the number of stored components.""" + return len(self.components) + + def __iter__(self) -> Iterable: + """Yields name/component pairs.""" + for k, v in self.components.items(): + yield k, v.value + + def __str__(self): + result = f"Component Store '{self.name}': {self.description}\nAvailable components:" + for k, v in self.components.items(): + result += f"\n* {k}:" + + if hasattr(v.value, "__doc__"): + doc = indent(dedent(v.value.__doc__.lstrip("\n").rstrip()), " ") + result += f"\n{doc}\n" + else: + result += f" {v.description}" + + return result + + def __getattr__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + return self.__getattribute__(name) + + def __getitem__(self, name: str) -> Any: + """Returns the stored object under the given name.""" + if name in self.components: + return self.components[name].value + else: + raise ValueError(f"Component '{name}' not found") diff --git a/monai/utils/misc.py b/monai/utils/misc.py index 3b11af41b0..86abe591fd 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool: sqrt_num = [int(math.sqrt(_num)) for _num in num] ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)] return ensure_tuple(ret) == num + + +def unsqueeze_right(arr: T, ndim: int) -> T: + """Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(...,) + (None,) * (ndim - arr.ndim)] + + +def unsqueeze_left(arr: T, ndim: int) -> T: + """Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions.""" + return arr[(None,) * (ndim - arr.ndim)] diff --git a/tests/test_component_store.py b/tests/test_component_store.py new file mode 100644 index 0000000000..614f387754 --- /dev/null +++ b/tests/test_component_store.py @@ -0,0 +1,72 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.utils import ComponentStore + + +class TestComponentStore(unittest.TestCase): + def setUp(self): + self.cs = ComponentStore("TestStore", "I am a test store, please ignore") + + def test_empty(self): + self.assertEqual(len(self.cs), 0) + self.assertEqual(list(self.cs), []) + + def test_add(self): + test_obj = object() + + self.assertFalse("test_obj" in self.cs) + + self.cs.add("test_obj", "Test object", test_obj) + + self.assertTrue("test_obj" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_obj", test_obj)]) + + self.assertEqual(self.cs.test_obj, test_obj) + self.assertEqual(self.cs["test_obj"], test_obj) + + def test_add2(self): + test_obj1 = object() + test_obj2 = object() + + self.cs.add("test_obj1", "Test object", test_obj1) + self.cs.add("test_obj2", "Test object", test_obj2) + + self.assertEqual(len(self.cs), 2) + self.assertTrue("test_obj1" in self.cs) + self.assertTrue("test_obj2" in self.cs) + + def test_add_def(self): + self.assertFalse("test_func" in self.cs) + + @self.cs.add_def("test_func", "Test function") + def test_func(): + return 123 + + self.assertTrue("test_func" in self.cs) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(list(self.cs), [("test_func", test_func)]) + + self.assertEqual(self.cs.test_func, test_func) + self.assertEqual(self.cs["test_func"], test_func) + + # try adding the same function again + self.cs.add_def("test_func", "Test function but with new description")(test_func) + + self.assertEqual(len(self.cs), 1) + self.assertEqual(self.cs.test_func, test_func)