From c71db74efd6089b5ed48203d6cb4684ee7469cb8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 14 Jan 2023 22:06:12 +0000 Subject: [PATCH] reference object method as _target_ Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 2 +- monai/utils/module.py | 3 +-- tests/test_config_parser.py | 8 ++++++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index acd664e725c..10fb082ee78 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -236,7 +236,7 @@ def resolve_module_name(self): config = dict(self.get_config()) target = config.get("_target_") if not isinstance(target, str): - raise ValueError("must provide a string for the `_target_` of component to instantiate.") + return target module = self.locator.get_component_module_name(target) if module is None: diff --git a/monai/utils/module.py b/monai/utils/module.py index 1670ecadadc..e5ccdf99923 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -226,8 +226,7 @@ def instantiate(path: str, **kwargs): for `partial` function. """ - - component = locate(path) + component = locate(path) if isinstance(path, str) else path if component is None: raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.") try: diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index 00945a46ba7..a29780e1bce 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -248,6 +248,14 @@ def test_lambda_reference(self): result = trans(np.ones(64)) self.assertTupleEqual(result.shape, (1, 8, 8)) + def test_non_str_target(self): + configs = { + "forward": {"_target_": "$@model().forward", "x": "$torch.rand(1, 3, 256, 256)"}, + "model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2}, + } + breakpoint() + self.assertisTrue(callable(ConfigParser(config=configs).forward)) + def test_error_instance(self): config = {"transform": {"_target_": "Compose", "transforms_wrong_key": []}} parser = ConfigParser(config=config)