diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py
index e8ea9d62b0..dd556e9eb3 100644
--- a/monai/bundle/__init__.py
+++ b/monai/bundle/__init__.py
@@ -13,10 +13,11 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
-from .properties import InferProperties, TrainProperties
+from .properties import InferProperties, MetaProperties, TrainProperties
from .reference_resolver import ReferenceResolver
from .scripts import (
ckpt_export,
+ create_workflow,
download,
get_all_bundles_list,
get_bundle_info,
diff --git a/monai/bundle/properties.py b/monai/bundle/properties.py
index 16ecf77268..a75e862a84 100644
--- a/monai/bundle/properties.py
+++ b/monai/bundle/properties.py
@@ -13,7 +13,7 @@
to interact with the bundle workflow.
Some properties are required and some are optional, optional properties mean: if some component of the
bundle workflow refer to the property, the property must be defined, otherwise, the property can be None.
-Every item in this `TrainProperties` or `InferProperties` dictionary is a property,
+Every item in this `TrainProperties` or `InferProperties` or `MetaProperties` dictionary is a property,
the key is the property name and the values include:
1. description.
2. whether it's a required property.
@@ -48,6 +48,11 @@
BundleProperty.REQUIRED: True,
BundlePropertyConfig.ID: f"train{ID_SEP_KEY}trainer",
},
+ "network_def": {
+ BundleProperty.DESC: "network module for the training.",
+ BundleProperty.REQUIRED: False,
+ BundlePropertyConfig.ID: "network_def",
+ },
"max_epochs": {
BundleProperty.DESC: "max number of epochs to execute the training.",
BundleProperty.REQUIRED: True,
@@ -216,3 +221,42 @@
BundlePropertyConfig.REF_ID: f"evaluator{ID_SEP_KEY}key_val_metric",
},
}
+
+MetaProperties = {
+ "version": {
+ BundleProperty.DESC: "bundle version",
+ BundleProperty.REQUIRED: True,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}version",
+ },
+ "monai_version": {
+ BundleProperty.DESC: "required monai version used for bundle",
+ BundleProperty.REQUIRED: True,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}monai_version",
+ },
+ "pytorch_version": {
+ BundleProperty.DESC: "required pytorch version used for bundle",
+ BundleProperty.REQUIRED: True,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}pytorch_version",
+ },
+ "numpy_version": {
+ BundleProperty.DESC: "required numpy version used for bundle",
+ BundleProperty.REQUIRED: True,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}numpy_version",
+ },
+ "description": {
+ BundleProperty.DESC: "description for bundle",
+ BundleProperty.REQUIRED: False,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}description",
+ },
+ "spatial_shape": {
+ BundleProperty.DESC: "spatial shape for the inputs",
+ BundleProperty.REQUIRED: False,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}inputs{ID_SEP_KEY}image"
+ f"{ID_SEP_KEY}spatial_shape",
+ },
+ "channel_def": {
+ BundleProperty.DESC: "channel definition for the prediction",
+ BundleProperty.REQUIRED: False,
+ BundlePropertyConfig.ID: f"_meta_{ID_SEP_KEY}network_data_format{ID_SEP_KEY}outputs{ID_SEP_KEY}pred{ID_SEP_KEY}channel_def",
+ },
+}
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index 2b1d3cd6f7..be6c0caba6 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -28,7 +28,6 @@
from monai.apps.mmars.mmars import _get_all_ngc_models
from monai.apps.utils import _basename, download_url, extractall, get_logger
-from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
@@ -63,7 +62,7 @@
# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
# set BUNDLE_DOWNLOAD_SRC="monaihosting" to use monaihosting source in default for bundle download
-download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
+DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
PPRINT_CONFIG_N = 5
@@ -253,7 +252,7 @@ def download(
name: str | None = None,
version: str | None = None,
bundle_dir: PathLike | None = None,
- source: str = download_source,
+ source: str = DEFAULT_DOWNLOAD_SOURCE,
repo: str | None = None,
url: str | None = None,
remove_prefix: str | None = "monai_",
@@ -376,21 +375,28 @@ def download(
)
+@deprecated_arg("net_name", since="1.3", removed="1.4", msg_suffix="please use ``model`` instead.")
+@deprecated_arg("net_kwargs", since="1.3", removed="1.3", msg_suffix="please use ``model`` instead.")
def load(
name: str,
+ model: torch.nn.Module | None = None,
version: str | None = None,
+ workflow_type: str = "train",
model_file: str | None = None,
load_ts_module: bool = False,
bundle_dir: PathLike | None = None,
- source: str = download_source,
+ source: str = DEFAULT_DOWNLOAD_SOURCE,
repo: str | None = None,
remove_prefix: str | None = "monai_",
progress: bool = True,
device: str | None = None,
key_in_ckpt: str | None = None,
config_files: Sequence[str] = (),
+ workflow_name: str | BundleWorkflow | None = None,
+ args_file: str | None = None,
+ copy_model_args: dict | None = None,
net_name: str | None = None,
- **net_kwargs: Any,
+ **net_override: Any,
) -> object | tuple[torch.nn.Module, dict, dict] | Any:
"""
Load model weights or TorchScript module of a bundle.
@@ -402,8 +408,15 @@ def load(
https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
"monai_brats_mri_segmentation" in ngc:
https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=monai.
+ "mednist_gan" in monaihosting:
+ https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/mednist_gan/versions/0.2.0/files/mednist_gan_v0.2.0.zip
+ model: a pytorch module to be updated. Default to None, using the "network_def" in the bundle.
version: version name of the target bundle to download, like: "0.1.0". If `None`, will download
the latest version.
+ workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
+ or "infer", "inference", "eval", "evaluation" for a inference workflow,
+ other unsupported string will raise a ValueError.
+ default to `train` for training workflow.
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
@@ -417,7 +430,7 @@ def load(
If used, it should be in the form of "repo_owner/repo_name/release_tag".
remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
- maintain the consistency between these two sources, remove prefix is necessary.
+ maintain the consistency between these three sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
progress: whether to display a progress bar when downloading.
device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
@@ -425,13 +438,16 @@ def load(
weights. if not nested checkpoint, no need to set.
config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
see `_extra_files` in `torch.jit.load` for more details.
- net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
- This argument only works when loading weights.
- net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.
+ workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
+ args_file: a JSON or YAML file to provide default values for all the args in "download" function.
+ copy_model_args: other arguments for the `monai.networks.copy_model_state` function.
+ net_override: id-value pairs to override the parameters in the network of the bundle.
Returns:
- 1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights.
- 2. If `load_ts_module` is `False` and `net_name` is not `None`,
+ 1. If `load_ts_module` is `False` and `model` is `None`,
+ return model weights if can't find "network_def" in the bundle,
+ else return an instantiated network that loaded the weights.
+ 2. If `load_ts_module` is `False` and `model` is not `None`,
return an instantiated network that loaded the weights.
3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
the corresponding metadata dict, and extra files dict.
@@ -439,15 +455,14 @@ def load(
"""
bundle_dir_ = _process_bundle_dir(bundle_dir)
+ copy_model_args = {} if copy_model_args is None else copy_model_args
+ if device is None:
+ device = "cuda:0" if is_available() else "cpu"
if model_file is None:
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
- if source == "ngc":
- name = _add_ngc_prefix(name)
- if remove_prefix:
- name = _remove_ngc_prefix(name, prefix=remove_prefix)
full_path = os.path.join(bundle_dir_, name, model_file)
- if not os.path.exists(full_path):
+ if not os.path.exists(full_path) or model is None:
download(
name=name,
version=version,
@@ -456,10 +471,21 @@ def load(
repo=repo,
remove_prefix=remove_prefix,
progress=progress,
+ args_file=args_file,
)
+ train_config_file = bundle_dir_ / name / "configs" / f"{workflow_type}.json"
+ if train_config_file.is_file():
+ _net_override = {f"network_def#{key}": value for key, value in net_override.items()}
+ _workflow = create_workflow(
+ workflow_name=workflow_name,
+ args_file=args_file,
+ config_file=str(train_config_file),
+ workflow_type=workflow_type,
+ **_net_override,
+ )
+ else:
+ _workflow = None
- if device is None:
- device = "cuda:0" if is_available() else "cpu"
# loading with `torch.jit.load`
if load_ts_module is True:
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
@@ -469,13 +495,12 @@ def load(
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
model_dict = get_state_dict(model_dict)
- if net_name is None:
+ if model is None and _workflow is None:
return model_dict
- net_kwargs["_target_"] = net_name
- configer = ConfigComponent(config=net_kwargs)
- model = configer.instantiate()
- model.to(device) # type: ignore
- copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt]) # type: ignore
+ model = _workflow.network_def if model is None else model # type: ignore
+ model.to(device)
+
+ copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt], **copy_model_args)
return model
@@ -675,12 +700,12 @@ def run(
final_id: ID name of the expected config expression to finalize after running, default to "finalize".
it's optional for both configs and this `run` function.
meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
- Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo.
+ Default to None.
config_file: filepath of the config file, if `None`, must be provided in `args_file`.
if it is a list of file paths, the content of them will be merged.
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
- Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
+ Default to None.
tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible.
if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings,
if other string, treat it as file path to load the tracking settings.
@@ -695,46 +720,24 @@ def run(
"""
- _args = _update_args(
- args=args_file,
- run_id=run_id,
- init_id=init_id,
- final_id=final_id,
- meta_file=meta_file,
+ workflow = create_workflow(
config_file=config_file,
+ args_file=args_file,
+ meta_file=meta_file,
logging_file=logging_file,
+ init_id=init_id,
+ run_id=run_id,
+ final_id=final_id,
tracking=tracking,
**override,
)
- if "config_file" not in _args:
- warnings.warn("`config_file` not provided for 'monai.bundle run'.")
- _log_input_summary(tag="run", args=_args)
- config_file_, meta_file_, init_id_, run_id_, final_id_, logging_file_, tracking_ = _pop_args(
- _args,
- config_file=None,
- meta_file="configs/metadata.json",
- init_id="initialize",
- run_id="run",
- final_id="finalize",
- logging_file="configs/logging.conf",
- tracking=None,
- )
- workflow = ConfigWorkflow(
- config_file=config_file_,
- meta_file=meta_file_,
- logging_file=logging_file_,
- init_id=init_id_,
- run_id=run_id_,
- final_id=final_id_,
- tracking=tracking_,
- **_args,
- )
- workflow.initialize()
workflow.run()
workflow.finalize()
-def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any) -> None:
+def run_workflow(
+ workflow_name: str | BundleWorkflow | None = None, args_file: str | None = None, **kwargs: Any
+) -> None:
"""
Specify `bundle workflow` to run monai bundle components and workflows.
The workflow should be subclass of `BundleWorkflow` and be available to import.
@@ -748,35 +751,17 @@ def run_workflow(workflow: str | BundleWorkflow | None = None, args_file: str |
python -m monai.bundle run_workflow --meta_file --config_file
# Set the workflow to other customized BundleWorkflow subclass:
- python -m monai.bundle run_workflow --workflow CustomizedWorkflow ...
+ python -m monai.bundle run_workflow --workflow_name CustomizedWorkflow ...
Args:
- workflow: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
+ workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
args_file: a JSON or YAML file to provide default values for this API.
so that the command line inputs can be simplified.
kwargs: arguments to instantiate the workflow class.
"""
- _args = _update_args(args=args_file, workflow=workflow, **kwargs)
- _log_input_summary(tag="run", args=_args)
- (workflow_name,) = _pop_args(_args, workflow=ConfigWorkflow) # the default workflow name is "ConfigWorkflow"
- if isinstance(workflow_name, str):
- workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in
- if not has_built_in:
- workflow_class = locate(str(workflow_name)) # search dotted path
- if workflow_class is None:
- raise ValueError(f"cannot locate specified workflow class: {workflow_name}.")
- elif issubclass(workflow_name, BundleWorkflow):
- workflow_class = workflow_name
- else:
- raise ValueError(
- "Argument `workflow` must be a bundle workflow class name"
- f"or subclass of BundleWorkflow, got: {workflow_name}."
- )
-
- workflow_ = workflow_class(**_args)
- workflow_.initialize()
+ workflow_ = create_workflow(workflow_name=workflow_name, args_file=args_file, **kwargs)
workflow_.run()
workflow_.finalize()
@@ -1539,3 +1524,61 @@ def init_bundle(
copyfile(str(ckpt_file), str(models_dir / "model.pt"))
elif network is not None:
save_state(network, str(models_dir / "model.pt"))
+
+
+def create_workflow(
+ workflow_name: str | BundleWorkflow | None = None,
+ config_file: str | Sequence[str] | None = None,
+ args_file: str | None = None,
+ **kwargs: Any,
+) -> Any:
+ """
+ Specify `bundle workflow` to create monai bundle workflows.
+ The workflow should be subclass of `BundleWorkflow` and be available to import.
+ It can be MONAI existing bundle workflows or user customized workflows.
+
+ Typical usage examples:
+
+ .. code-block:: python
+
+ # Specify config_file path to create workflow:
+ workflow = create_workflow(config_file="/workspace/spleen_ct_segmentation/configs/train.json", workflow_type="train")
+
+ # Set the workflow to other customized BundleWorkflow subclass to create workflow:
+ workflow = create_workflow(workflow_name=CustomizedWorkflow)
+
+ Args:
+ workflow_name: specified bundle workflow name, should be a string or class, default to "ConfigWorkflow".
+ config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged.
+ args_file: a JSON or YAML file to provide default values for this API.
+ so that the command line inputs can be simplified.
+ kwargs: arguments to instantiate the workflow class.
+
+ """
+ _args = _update_args(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
+ _log_input_summary(tag="run", args=_args)
+ (workflow_name, config_file) = _pop_args(
+ _args, workflow_name=ConfigWorkflow, config_file=None
+ ) # the default workflow name is "ConfigWorkflow"
+ if isinstance(workflow_name, str):
+ workflow_class, has_built_in = optional_import("monai.bundle", name=str(workflow_name)) # search built-in
+ if not has_built_in:
+ workflow_class = locate(str(workflow_name)) # search dotted path
+ if workflow_class is None:
+ raise ValueError(f"cannot locate specified workflow class: {workflow_name}.")
+ elif issubclass(workflow_name, BundleWorkflow): # type: ignore
+ workflow_class = workflow_name
+ else:
+ raise ValueError(
+ "Argument `workflow_name` must be a bundle workflow class name"
+ f"or subclass of BundleWorkflow, got: {workflow_name}."
+ )
+
+ if config_file is not None:
+ workflow_ = workflow_class(config_file=config_file, **_args)
+ else:
+ workflow_ = workflow_class(**_args)
+
+ workflow_.initialize()
+
+ return workflow_
diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py
index 6bd966592e..3b349e1103 100644
--- a/monai/bundle/workflows.py
+++ b/monai/bundle/workflows.py
@@ -22,9 +22,9 @@
from monai.apps.utils import get_logger
from monai.bundle.config_parser import ConfigParser
-from monai.bundle.properties import InferProperties, TrainProperties
+from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
-from monai.utils import BundleProperty, BundlePropertyConfig
+from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple
__all__ = ["BundleWorkflow", "ConfigWorkflow"]
@@ -38,7 +38,7 @@ class BundleWorkflow(ABC):
And also provides the interface to get / set public properties to interact with a bundle workflow.
Args:
- workflow: specifies the workflow type: "train" or "training" for a training workflow,
+ workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
@@ -48,19 +48,26 @@ class BundleWorkflow(ABC):
supported_train_type: tuple = ("train", "training")
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
- def __init__(self, workflow: str | None = None):
- if workflow is None:
- self.properties = None
- self.workflow = None
+ @deprecated_arg(
+ "workflow",
+ since="1.3",
+ removed="1.5",
+ new_name="workflow_type",
+ msg_suffix="please use `workflow_type` instead.",
+ )
+ def __init__(self, workflow_type: str | None = None):
+ if workflow_type is None:
+ self.properties = copy(MetaProperties)
+ self.workflow_type = None
return
- if workflow.lower() in self.supported_train_type:
- self.properties = copy(TrainProperties)
- self.workflow = "train"
- elif workflow.lower() in self.supported_infer_type:
- self.properties = copy(InferProperties)
- self.workflow = "infer"
+ if workflow_type.lower() in self.supported_train_type:
+ self.properties = {**TrainProperties, **MetaProperties}
+ self.workflow_type = "train"
+ elif workflow_type.lower() in self.supported_infer_type:
+ self.properties = {**InferProperties, **MetaProperties}
+ self.workflow_type = "infer"
else:
- raise ValueError(f"Unsupported workflow type: '{workflow}'.")
+ raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
@abstractmethod
def initialize(self, *args: Any, **kwargs: Any) -> Any:
@@ -128,7 +135,7 @@ def get_workflow_type(self):
Get the workflow type, it can be `None`, "train", or "infer".
"""
- return self.workflow
+ return self.workflow_type
def add_property(self, name: str, required: str, desc: str | None = None) -> None:
"""
@@ -166,18 +173,18 @@ class ConfigWorkflow(BundleWorkflow):
For more information: https://docs.monai.io/en/latest/mb_specification.html.
Args:
- run_id: ID name of the expected config expression to run, default to "run".
- to run the config, the target config must contain this ID.
+ config_file: filepath of the config file, if this is a list of file paths, their contents will be merged in order.
+ meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
+ If None, default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo.
+ logging_file: config file for `logging` module in the program. for more details:
+ https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
+ If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
init_id: ID name of the expected config expression to initialize before running, default to "initialize".
allow a config to have no `initialize` logic and the ID.
+ run_id: ID name of the expected config expression to run, default to "run".
+ to run the config, the target config must contain this ID.
final_id: ID name of the expected config expression to finalize after running, default to "finalize".
allow a config to have no `finalize` logic and the ID.
- meta_file: filepath of the metadata file, if it is a list of file paths, the content of them will be merged.
- Default to "configs/metadata.json", which is commonly used for bundles in MONAI model zoo.
- config_file: filepath of the config file, if it is a list of file paths, the content of them will be merged.
- logging_file: config file for `logging` module in the program. for more details:
- https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
- Default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible.
if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings,
if other string, treat it as file path to load the tracking settings.
@@ -185,7 +192,7 @@ class ConfigWorkflow(BundleWorkflow):
will patch the target config content with `tracking handlers` and the top-level items of `configs`.
for detailed usage examples, please check the tutorial:
https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb.
- workflow: specifies the workflow type: "train" or "training" for a training workflow,
+ workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
@@ -194,23 +201,47 @@ class ConfigWorkflow(BundleWorkflow):
"""
+ @deprecated_arg(
+ "workflow",
+ since="1.3",
+ removed="1.5",
+ new_name="workflow_type",
+ msg_suffix="please use `workflow_type` instead.",
+ )
+ @deprecated_arg_default("workflow_type", None, "train", since="1.3", replaced="1.4")
def __init__(
self,
config_file: str | Sequence[str],
- meta_file: str | Sequence[str] | None = "configs/metadata.json",
- logging_file: str | None = "configs/logging.conf",
+ meta_file: str | Sequence[str] | None = None,
+ logging_file: str | None = None,
init_id: str = "initialize",
run_id: str = "run",
final_id: str = "finalize",
tracking: str | dict | None = None,
- workflow: str | None = None,
+ workflow_type: str | None = None,
**override: Any,
) -> None:
- super().__init__(workflow=workflow)
+ super().__init__(workflow_type=workflow_type)
+ if config_file is not None:
+ _config_files = ensure_tuple(config_file)
+ config_root_path = Path(_config_files[0]).parent
+ for _config_file in _config_files:
+ _config_file = Path(_config_file)
+ if _config_file.parent != config_root_path:
+ warnings.warn(
+ f"Not all config files are in {config_root_path}. If logging_file and meta_file are"
+ f"not specified, {config_root_path} will be used as the default config root directory."
+ )
+ if not _config_file.is_file():
+ raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
+ else:
+ config_root_path = Path("configs")
+
+ logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:
if not os.path.exists(logging_file):
- if logging_file == "configs/logging.conf":
- warnings.warn("Default logging file in 'configs/logging.conf' does not exist, skipping logging.")
+ if logging_file == str(config_root_path / "logging.conf"):
+ warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
@@ -219,14 +250,11 @@ def __init__(
self.parser = ConfigParser()
self.parser.read_config(f=config_file)
- if meta_file is not None:
- if isinstance(meta_file, str) and not os.path.exists(meta_file):
- if meta_file == "configs/metadata.json":
- warnings.warn("Default metadata file in 'configs/metadata.json' does not exist, skipping loading.")
- else:
- raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
- else:
- self.parser.read_meta(f=meta_file)
+ meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
+ if isinstance(meta_file, str) and not os.path.exists(meta_file):
+ raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
+ else:
+ self.parser.read_meta(f=meta_file)
# the rest key-values in the _args are to override config content
self.parser.update(pairs=override)
diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py
index 4838b784e7..626bc9651d 100644
--- a/monai/fl/client/monai_algo.py
+++ b/monai/fl/client/monai_algo.py
@@ -149,7 +149,7 @@ def initialize(self, extra=None):
if self.workflow is None:
config_train_files = self._add_config_files(self.config_train_filename)
self.workflow = ConfigWorkflow(
- config_file=config_train_files, meta_file=None, logging_file=None, workflow="train"
+ config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
)
self.workflow.initialize()
self.workflow.bundle_root = self.bundle_root
@@ -317,12 +317,12 @@ class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
config_train_filename: bundle training config path relative to bundle_root. can be a list of files.
defaults to "configs/train.json". only useful when `train_workflow` is None.
train_kwargs: other args of the `ConfigWorkflow` of train, except for `config_file`, `meta_file`,
- `logging_file`, `workflow`. only useful when `train_workflow` is None.
+ `logging_file`, `workflow_type`. only useful when `train_workflow` is None.
config_evaluate_filename: bundle evaluation config path relative to bundle_root. can be a list of files.
if "default", ["configs/train.json", "configs/evaluate.json"] will be used.
this arg is only useful when `eval_workflow` is None.
eval_kwargs: other args of the `ConfigWorkflow` of evaluation, except for `config_file`, `meta_file`,
- `logging_file`, `workflow`. only useful when `eval_workflow` is None.
+ `logging_file`, `workflow_type`. only useful when `eval_workflow` is None.
config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.
best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`.
@@ -431,7 +431,11 @@ def initialize(self, extra=None):
if "run_name" not in self.train_kwargs:
self.train_kwargs["run_name"] = f"{self.client_name}_{timestamp}"
self.train_workflow = ConfigWorkflow(
- config_file=config_train_files, meta_file=None, logging_file=None, workflow="train", **self.train_kwargs
+ config_file=config_train_files,
+ meta_file=None,
+ logging_file=None,
+ workflow_type="train",
+ **self.train_kwargs,
)
if self.train_workflow is not None:
self.train_workflow.initialize()
@@ -455,7 +459,7 @@ def initialize(self, extra=None):
config_file=config_eval_files,
meta_file=None,
logging_file=None,
- workflow=self.eval_workflow_name,
+ workflow_type=self.eval_workflow_name,
**self.eval_kwargs,
)
if self.eval_workflow is not None:
diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py
index 34f22aa565..7b5328bf72 100644
--- a/tests/nonconfig_workflow.py
+++ b/tests/nonconfig_workflow.py
@@ -37,7 +37,7 @@ class NonConfigWorkflow(BundleWorkflow):
"""
def __init__(self, filename, output_dir):
- super().__init__(workflow="inference")
+ super().__init__(workflow_type="inference")
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
@@ -50,9 +50,25 @@ def __init__(self, filename, output_dir):
self._preprocessing = None
self._postprocessing = None
self._evaluator = None
+ self._version = None
+ self._monai_version = None
+ self._pytorch_version = None
+ self._numpy_version = None
def initialize(self):
set_determinism(0)
+ if self._version is None:
+ self._version = "0.1.0"
+
+ if self._monai_version is None:
+ self._monai_version = "1.1.0"
+
+ if self._pytorch_version is None:
+ self._pytorch_version = "1.13.1"
+
+ if self._numpy_version is None:
+ self._numpy_version = "1.22.2"
+
if self._preprocessing is None:
self._preprocessing = Compose(
[LoadImaged(keys="image"), EnsureChannelFirstd(keys="image"), ScaleIntensityd(keys="image")]
@@ -118,6 +134,14 @@ def _get_property(self, name, property):
return self._preprocessing
if name == "postprocessing":
return self._postprocessing
+ if name == "version":
+ return self._version
+ if name == "monai_version":
+ return self._monai_version
+ if name == "pytorch_version":
+ return self._pytorch_version
+ if name == "numpy_version":
+ return self._numpy_version
if property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
@@ -142,5 +166,13 @@ def _set_property(self, name, property, value):
self._preprocessing = value
elif name == "postprocessing":
self._postprocessing = value
+ elif name == "version":
+ self._version = value
+ elif name == "monai_version":
+ self._monai_version = value
+ elif name == "pytorch_version":
+ self._pytorch_version = value
+ elif name == "numpy_version":
+ self._numpy_version = value
elif property[BundleProperty.REQUIRED]:
raise ValueError(f"unsupported property '{name}' is required in the bundle properties.")
diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py
index 36e935bf08..2457af3229 100644
--- a/tests/test_bundle_download.py
+++ b/tests/test_bundle_download.py
@@ -16,12 +16,13 @@
import tempfile
import unittest
+import numpy as np
import torch
from parameterized import parameterized
import monai.networks.nets as nets
from monai.apps import check_hash
-from monai.bundle import ConfigParser, load
+from monai.bundle import ConfigParser, create_workflow, load
from tests.utils import (
SkipIfBeforePyTorchVersion,
assert_allclose,
@@ -64,6 +65,12 @@
"https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip",
]
+TEST_CASE_7 = [
+ "spleen_ct_segmentation",
+ "cuda" if torch.cuda.is_available() else "cpu",
+ {"spatial_dims": 3, "out_channels": 5},
+]
+
class TestDownload(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@@ -146,7 +153,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
net_args = json.load(f)["network_def"]
model_name = net_args["_target_"]
del net_args["_target_"]
- model = nets.__dict__[model_name](**net_args)
+ model = getattr(nets, model_name)(**net_args)
model.to(device)
model.load_state_dict(weights)
model.eval()
@@ -159,20 +166,58 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
# load instantiated model directly and test, since the bundle has been downloaded,
# there is no need to input `repo`
+ _model_2 = getattr(nets, model_name)(**net_args)
model_2 = load(
name=bundle_name,
+ model=_model_2,
model_file=model_file,
bundle_dir=tempdir,
progress=False,
device=device,
net_name=model_name,
source="github",
- **net_args,
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
+ @parameterized.expand([TEST_CASE_7])
+ @skip_if_quick
+ def test_load_weights_with_net_override(self, bundle_name, device, net_override):
+ with skip_if_downloading_fails():
+ # download bundle, and load weights from the downloaded path
+ with tempfile.TemporaryDirectory() as tempdir:
+ # load weights
+ model = load(name=bundle_name, bundle_dir=tempdir, source="monaihosting", progress=False, device=device)
+
+ # prepare data and test
+ input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
+ output = model(input_tensor)
+ model_path = f"{tempdir}/spleen_ct_segmentation/models/model.pt"
+ workflow = create_workflow(
+ config_file=f"{tempdir}/spleen_ct_segmentation/configs/train.json", workflow_type="train"
+ )
+ expected_model = workflow.network_def.to(device)
+ expected_model.load_state_dict(torch.load(model_path))
+ expected_output = expected_model(input_tensor)
+ assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
+
+ # using net_override to override kwargs in network directly
+ model_2 = load(
+ name=bundle_name,
+ bundle_dir=tempdir,
+ source="monaihosting",
+ progress=False,
+ device=device,
+ **net_override,
+ )
+
+ # prepare data and test
+ input_tensor = torch.rand(1, 1, 96, 96, 96).to(device)
+ output = model_2(input_tensor)
+ expected_shape = (1, 5, 96, 96, 96)
+ np.testing.assert_equal(output.shape, expected_shape)
+
@parameterized.expand([TEST_CASE_5])
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 7, 1))
diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py
index d92f6e517f..391a56bc3c 100644
--- a/tests/test_bundle_utils.py
+++ b/tests/test_bundle_utils.py
@@ -100,7 +100,21 @@ def test_load_config_zip(self):
self.assertEqual(p["test_dict"]["b"], "c")
def test_run(self):
- command_line_tests(["python", "-m", "monai.bundle", "run", "test", "--test", "$print('hello world')"])
+ command_line_tests(
+ [
+ "python",
+ "-m",
+ "monai.bundle",
+ "run",
+ "test",
+ "--test",
+ "$print('hello world')",
+ "--config_file",
+ self.test_name,
+ "--meta_file",
+ self.metadata_name,
+ ]
+ )
def test_load_config_ts(self):
# create a Torchscript zip of the bundle
diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py
index 247ed5ecd4..4291eedf3f 100644
--- a/tests/test_bundle_workflow.py
+++ b/tests/test_bundle_workflow.py
@@ -95,7 +95,7 @@ def test_inference_config(self, config_file):
}
# test standard MONAI model-zoo config workflow
inferer = ConfigWorkflow(
- workflow="infer",
+ workflow_type="infer",
config_file=config_file,
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
**override,
@@ -106,7 +106,7 @@ def test_inference_config(self, config_file):
def test_train_config(self, config_file):
# test standard MONAI model-zoo config workflow
trainer = ConfigWorkflow(
- workflow="train",
+ workflow_type="train",
config_file=config_file,
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
init_id="initialize",
diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py
index 026f7ca8b8..ca781ff166 100644
--- a/tests/test_fl_monai_algo.py
+++ b/tests/test_fl_monai_algo.py
@@ -36,7 +36,9 @@
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
- config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
+ config_file=os.path.join(_data_dir, "config_fl_train.json"),
+ workflow_type="train",
+ logging_file=_logging_file,
),
"config_evaluate_filename": None,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
@@ -54,7 +56,9 @@
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
- config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
+ config_file=os.path.join(_data_dir, "config_fl_train.json"),
+ workflow_type="train",
+ logging_file=_logging_file,
),
"config_evaluate_filename": None,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
@@ -66,7 +70,7 @@
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"),
- workflow="train",
+ workflow_type="train",
logging_file=_logging_file,
tracking={
"handlers_id": DEFAULT_HANDLERS_ID,
@@ -95,7 +99,7 @@
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
- workflow="train",
+ workflow_type="train",
logging_file=_logging_file,
tracking="mlflow",
tracking_uri=path_to_uri(_data_dir) + "/mlflow_1",
@@ -130,7 +134,7 @@
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
- workflow="train",
+ workflow_type="train",
logging_file=_logging_file,
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
@@ -141,7 +145,9 @@
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
- config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
+ config_file=os.path.join(_data_dir, "config_fl_train.json"),
+ workflow_type="train",
+ logging_file=_logging_file,
),
"config_evaluate_filename": None,
"send_weight_diff": False,
@@ -161,7 +167,9 @@
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
- config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
+ config_file=os.path.join(_data_dir, "config_fl_train.json"),
+ workflow_type="train",
+ logging_file=_logging_file,
),
"config_evaluate_filename": None,
"send_weight_diff": True,
diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py
index f6dc626ad9..1302ab6618 100644
--- a/tests/test_fl_monai_algo_dist.py
+++ b/tests/test_fl_monai_algo_dist.py
@@ -41,15 +41,15 @@ def test_train(self):
pathjoin(_data_dir, "config_fl_evaluate.json"),
pathjoin(_data_dir, "multi_gpu_evaluate.json"),
]
- train_workflow = ConfigWorkflow(config_file=train_configs, workflow="train", logging_file=_logging_file)
+ train_workflow = ConfigWorkflow(config_file=train_configs, workflow_type="train", logging_file=_logging_file)
# simulate the case that this application has specific requirements for a bundle workflow
train_workflow.add_property(name="loader", required=True, config_id="train#training_transforms#0", desc="NA")
# initialize algo
algo = MonaiAlgo(
bundle_root=_data_dir,
- train_workflow=ConfigWorkflow(config_file=train_configs, workflow="train", logging_file=_logging_file),
- eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow="train", logging_file=_logging_file),
+ train_workflow=ConfigWorkflow(config_file=train_configs, workflow_type="train", logging_file=_logging_file),
+ eval_workflow=ConfigWorkflow(config_file=eval_configs, workflow_type="train", logging_file=_logging_file),
config_filters_filename=pathjoin(_root_dir, "testing_data", "config_fl_filters.json"),
)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
@@ -90,7 +90,7 @@ def test_evaluate(self):
algo = MonaiAlgo(
bundle_root=_data_dir,
config_train_filename=None,
- eval_workflow=ConfigWorkflow(config_file=config_file, workflow="train", logging_file=_logging_file),
+ eval_workflow=ConfigWorkflow(config_file=config_file, workflow_type="train", logging_file=_logging_file),
config_filters_filename=pathjoin(_data_dir, "config_fl_filters.json"),
)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py
index e46b6b899a..307b3f539c 100644
--- a/tests/test_fl_monai_algo_stats.py
+++ b/tests/test_fl_monai_algo_stats.py
@@ -30,7 +30,7 @@
{
"bundle_root": _data_dir,
"workflow": ConfigWorkflow(
- workflow="train",
+ workflow_type="train",
config_file=os.path.join(_data_dir, "config_fl_stats_1.json"),
logging_file=_logging_file,
meta_file=None,
@@ -49,7 +49,7 @@
{
"bundle_root": _data_dir,
"workflow": ConfigWorkflow(
- workflow="train",
+ workflow_type="train",
config_file=[
os.path.join(_data_dir, "config_fl_stats_1.json"),
os.path.join(_data_dir, "config_fl_stats_2.json"),
diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py
index d5578c01bc..92cf17eadb 100644
--- a/tests/test_handler_mlflow.py
+++ b/tests/test_handler_mlflow.py
@@ -255,7 +255,7 @@ def test_dataset_tracking(self):
meta_file = os.path.join(bundle_root, "configs/metadata.json")
logging_file = os.path.join(bundle_root, "configs/logging.conf")
workflow = ConfigWorkflow(
- workflow="infer",
+ workflow_type="infer",
config_file=config_file,
meta_file=meta_file,
logging_file=logging_file,
diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py
index 74ac93bc27..42abc1a5e0 100644
--- a/tests/test_integration_bundle_run.py
+++ b/tests/test_integration_bundle_run.py
@@ -53,6 +53,7 @@ def tearDown(self):
def test_tiny(self):
config_file = os.path.join(self.data_dir, "tiny_config.json")
+ meta_file = os.path.join(self.data_dir, "tiny_meta.json")
with open(config_file, "w") as f:
json.dump(
{
@@ -62,14 +63,25 @@ def test_tiny(self):
},
f,
)
+ with open(meta_file, "w") as f:
+ json.dump(
+ {"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"},
+ f,
+ )
cmd = ["coverage", "run", "-m", "monai.bundle"]
# test both CLI entry "run" and "run_workflow"
- command_line_tests(cmd + ["run", "training", "--config_file", config_file])
- command_line_tests(cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file])
+ command_line_tests(cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file])
+ command_line_tests(
+ cmd + ["run_workflow", "--run_id", "training", "--config_file", config_file, "--meta_file", meta_file]
+ )
with self.assertRaises(RuntimeError):
# test wrong run_id="run"
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
+ with self.assertRaises(RuntimeError):
+ # test missing meta file
+ command_line_tests(cmd + ["run", "training", "--config_file", config_file])
+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, config_file, expected_shape):
test_image = np.random.rand(*expected_shape)
@@ -147,7 +159,7 @@ def test_customized_workflow(self):
filename = os.path.join(self.data_dir, "image.nii")
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)
- cmd = "-m fire monai.bundle.scripts run_workflow --workflow tests.nonconfig_workflow.NonConfigWorkflow"
+ cmd = "-m fire monai.bundle.scripts run_workflow --workflow_name tests.nonconfig_workflow.NonConfigWorkflow"
cmd += f" --filename {filename} --output_dir {self.data_dir}"
command_line_tests(["coverage", "run"] + cmd.split(" "))
loader = LoadImage(image_only=True)