Skip to content

Commit

Permalink
reference object method as _target_
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 14, 2023
1 parent a5c4d29 commit c71db74
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
2 changes: 1 addition & 1 deletion monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def resolve_module_name(self):
config = dict(self.get_config())
target = config.get("_target_")
if not isinstance(target, str):
raise ValueError("must provide a string for the `_target_` of component to instantiate.")
return target

module = self.locator.get_component_module_name(target)
if module is None:
Expand Down
3 changes: 1 addition & 2 deletions monai/utils/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def instantiate(path: str, **kwargs):
for `partial` function.
"""

component = locate(path)
component = locate(path) if isinstance(path, str) else path
if component is None:
raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.")
try:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,14 @@ def test_lambda_reference(self):
result = trans(np.ones(64))
self.assertTupleEqual(result.shape, (1, 8, 8))

def test_non_str_target(self):
configs = {
"forward": {"_target_": "$@model().forward", "x": "$torch.rand(1, 3, 256, 256)"},
"model": {"_target_": "monai.networks.nets.resnet.resnet18", "pretrained": False, "spatial_dims": 2},
}
breakpoint()
self.assertisTrue(callable(ConfigParser(config=configs).forward))

def test_error_instance(self):
config = {"transform": {"_target_": "Compose", "transforms_wrong_key": []}}
parser = ConfigParser(config=config)
Expand Down

0 comments on commit c71db74

Please sign in to comment.