Skip to content

Commit

Permalink
auto updates (#6324)
Browse files Browse the repository at this point in the history
Signed-off-by: monai-bot <[email protected]>

(closing #6318 closing #6319 auto3dseg)
(closing #6314 closing #6315 dtype conversion)
(closing #6326 closing #6329 metatensor clone)
(including a workaround for #6311)

---------

Signed-off-by: monai-bot <[email protected]>
Signed-off-by: Mingxin Zheng <[email protected]>
Signed-off-by: Liam Chalcroft <[email protected]>
Signed-off-by: Wenqi Li <[email protected]>
Signed-off-by: KumoLiu <[email protected]>
Co-authored-by: Mingxin Zheng <[email protected]>
Co-authored-by: Liam Chalcroft <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
6 people authored Apr 11, 2023
1 parent e4b313d commit 9ef42ff
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 63 deletions.
20 changes: 10 additions & 10 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle import ConfigParser
from monai.transforms import SaveImage
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from monai.utils.module import look_up_option, optional_import

logger = get_logger(module_name=__name__)
Expand Down Expand Up @@ -636,11 +636,11 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
progress.yaml, accuracies in CSV and a pickle file of the Algo object.
"""
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
algo.train(self.train_params)
acc = algo.get_score()

algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}
algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)

def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
Expand Down Expand Up @@ -675,8 +675,8 @@ def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
mode_dry_run = self.hpo_params.pop("nni_dry_run", False)
for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
name = algo_dict[AlgoKeys.ID]
algo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
Expand Down Expand Up @@ -772,13 +772,13 @@ def run(self):
)

if auto_train_choice:
skip_algos = [h[AlgoEnsembleKeys.ID] for h in history if h["is_trained"]]
skip_algos = [h[AlgoKeys.ID] for h in history if h[AlgoKeys.IS_TRAINED]]
if len(skip_algos) > 0:
logger.info(
f"Skipping already trained algos {skip_algos}."
"Set option train=True to always retrain all algos."
)
history = [h for h in history if not h["is_trained"]]
history = [h for h in history if not h[AlgoKeys.IS_TRAINED]]

if len(history) > 0:
if not self.hpo:
Expand All @@ -794,13 +794,13 @@ def run(self):
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h["is_trained"]]
history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h["is_trained"]]
history = [h for h in history if h[AlgoKeys.IS_TRAINED]]

if len(history) == 0:
raise ValueError(
Expand All @@ -816,7 +816,7 @@ def run(self):
if len(preds) > 0:
logger.info("Auto3Dseg picked the following networks to ensemble:")
for algo in ensembler.get_algo_ensemble():
logger.info(algo[AlgoEnsembleKeys.ID])
logger.info(algo[AlgoKeys.ID])

for pred in preds:
self.save_image(pred)
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle.config_parser import ConfigParser
from monai.utils import ensure_tuple
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")
Expand Down Expand Up @@ -539,5 +539,5 @@ def generate(

algo_to_pickle(gen_algo, template_path=algo.template_path)
self.history.append(
{AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO: gen_algo}
{AlgoKeys.ID: name, AlgoKeys.ALGO: gen_algo}
) # track the previous, may create a persistent history
22 changes: 11 additions & 11 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from monai.auto3dseg.utils import datafold_read
from monai.bundle import ConfigParser
from monai.transforms import MeanEnsemble, VoteEnsemble
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys
from monai.utils.misc import prob2class
from monai.utils.module import look_up_option

Expand Down Expand Up @@ -59,7 +59,7 @@ def get_algo(self, identifier):
identifier: the name of the bundleAlgo
"""
for algo in self.algos:
if identifier == algo[AlgoEnsembleKeys.ID]:
if identifier == algo[AlgoKeys.ID]:
return algo

def get_algo_ensemble(self):
Expand Down Expand Up @@ -160,7 +160,7 @@ def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tenso
print(i)
preds = []
for algo in self.algo_ensemble:
infer_instance = algo[AlgoEnsembleKeys.ALGO]
infer_instance = algo[AlgoKeys.ALGO]
pred = infer_instance.predict(predict_files=[file], predict_params=param)
preds.append(pred[0])
outputs.append(self.ensemble_pred(preds, sigmoid=sigmoid))
Expand All @@ -187,7 +187,7 @@ def sort_score(self):
"""
Sort the best_metrics
"""
scores = concat_val_to_np(self.algos, [AlgoEnsembleKeys.SCORE])
scores = concat_val_to_np(self.algos, [AlgoKeys.SCORE])
return np.argsort(scores).tolist()

def collect_algos(self, n_best: int = -1) -> None:
Expand Down Expand Up @@ -238,14 +238,14 @@ def collect_algos(self) -> None:
best_model: BundleAlgo | None = None
for algo in self.algos:
# algorithm folder: {net}_{fold_index}_{other}
identifier = algo[AlgoEnsembleKeys.ID].split("_")[1]
identifier = algo[AlgoKeys.ID].split("_")[1]
try:
algo_id = int(identifier)
except ValueError as err:
raise ValueError(f"model identifier {identifier} is not number.") from err
if algo_id == f_idx and algo[AlgoEnsembleKeys.SCORE] > best_score:
if algo_id == f_idx and algo[AlgoKeys.SCORE] > best_score:
best_model = algo
best_score = algo[AlgoEnsembleKeys.SCORE]
best_score = algo[AlgoKeys.SCORE]
self.algo_ensemble.append(best_model)


Expand All @@ -268,7 +268,7 @@ class AlgoEnsembleBuilder:
"""

def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str | None = None):
self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = []
self.infer_algos: list[dict[AlgoKeys, Any]] = []
self.ensemble: AlgoEnsemble
self.data_src_cfg = ConfigParser(globals=False)

Expand All @@ -278,8 +278,8 @@ def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str
for algo_dict in history:
# load inference_config_paths

name = algo_dict[AlgoEnsembleKeys.ID]
gen_algo = algo_dict[AlgoEnsembleKeys.ALGO]
name = algo_dict[AlgoKeys.ID]
gen_algo = algo_dict[AlgoKeys.ALGO]

best_metric = gen_algo.get_score()
algo_path = gen_algo.output_path
Expand All @@ -306,7 +306,7 @@ def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: float
if best_metric is None:
raise ValueError("Feature to re-validate is to be implemented")

algo = {AlgoEnsembleKeys.ID: identifier, AlgoEnsembleKeys.ALGO: gen_algo, AlgoEnsembleKeys.SCORE: best_metric}
algo = {AlgoKeys.ID: identifier, AlgoKeys.ALGO: gen_algo, AlgoKeys.SCORE: best_metric}
self.infer_algos.append(algo)

def set_ensemble_method(self, ensemble: AlgoEnsemble, *args: Any, **kwargs: Any) -> None:
Expand Down
10 changes: 5 additions & 5 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys

nni, has_nni = optional_import("nni")
optuna, has_optuna = optional_import("optuna")
Expand Down Expand Up @@ -99,8 +99,8 @@ class NNIGen(HPOGen):
# Bundle Algorithms are already generated by BundleGen in work_dir
import_bundle_algo_history(work_dir, only_trained=False)
algo_dict = self.history[0] # pick the first algorithm
algo_name = algo_dict[AlgoEnsembleKeys.ID]
onealgo = algo_dict[AlgoEnsembleKeys.ALGO]
algo_name = algo_dict[AlgoKeys.ID]
onealgo = algo_dict[AlgoKeys.ALGO]
nni_gen = NNIGen(algo=onealgo)
nni_gen.print_bundle_algo_instruction()
Expand Down Expand Up @@ -238,7 +238,7 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}

if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
Expand Down Expand Up @@ -411,7 +411,7 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_meta_data = {str(AlgoKeys.SCORE): acc}
if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
Expand Down
13 changes: 4 additions & 9 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
from monai.utils.enums import AlgoEnsembleKeys
from monai.utils.enums import AlgoKeys


def import_bundle_algo_history(
Expand Down Expand Up @@ -49,17 +49,12 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metric = algo_meta_data.get(AlgoEnsembleKeys.SCORE, None)
best_metric = algo_meta_data.get(AlgoKeys.SCORE, None)
is_trained = best_metric is not None

if (only_trained and is_trained) or not only_trained:
history.append(
{
AlgoEnsembleKeys.ID: name,
AlgoEnsembleKeys.ALGO: algo,
AlgoEnsembleKeys.SCORE: best_metric,
"is_trained": is_trained,
}
{AlgoKeys.ID: name, AlgoKeys.ALGO: algo, AlgoKeys.SCORE: best_metric, AlgoKeys.IS_TRAINED: is_trained}
)

return history
Expand All @@ -73,5 +68,5 @@ def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
"""
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo = algo_dict[AlgoKeys.ALGO]
algo_to_pickle(algo, template_path=algo.template_path)
13 changes: 10 additions & 3 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,16 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False):
self.as_tensor().new_empty(size=size, dtype=dtype, device=device, requires_grad=requires_grad)
)

def clone(self):
"""returns a copy of the MetaTensor instance."""
new_inst = MetaTensor(self.as_tensor().clone())
def clone(self, **kwargs):
"""
Returns a copy of the MetaTensor instance.
Args:
kwargs: additional keyword arguments to `torch.clone`.
See also: https://pytorch.org/docs/stable/generated/torch.clone.html
"""
new_inst = MetaTensor(self.as_tensor().clone(**kwargs))
new_inst.__dict__ = deepcopy(self.__dict__)
return new_inst

Expand Down
12 changes: 7 additions & 5 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


def sliding_window_inference(
inputs: torch.Tensor,
inputs: torch.Tensor | MetaTensor,
roi_size: Sequence[int] | int,
sw_batch_size: int,
predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
Expand Down Expand Up @@ -307,9 +307,11 @@ def sliding_window_inference(
output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)]

final_output = _pack_struct(output_image_list, dict_keys)
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
if temp_meta is not None:
final_output = MetaTensor(final_output).copy_meta_from(temp_meta)
final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] # type: ignore
else:
final_output = convert_to_dst_type(final_output, inputs, device=device)[0]

return final_output # type: ignore


Expand All @@ -322,7 +324,7 @@ def _create_buffered_slices(slices, batch_size, sw_batch_size, buffer_dim, buffe

_, _, _b_lens = np.unique(slices_np[:, 0], return_counts=True, return_index=True)
b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries
x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))] # type: ignore
x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))]
if x[-1] < b_ends[-1]:
x.append(b_ends[-1])
n_per_batch = len(x) - 1
Expand Down Expand Up @@ -385,7 +387,7 @@ def _flatten_struct(seg_out):
dict_keys = sorted(seg_out.keys()) # track predictor's output keys
seg_probs = tuple(seg_out[k] for k in dict_keys)
else:
seg_probs = ensure_tuple(seg_out) # type: ignore
seg_probs = ensure_tuple(seg_out)
return dict_keys, seg_probs


Expand Down
2 changes: 1 addition & 1 deletion monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def convert_to_onnx(
set_determinism(seed=None)
# compare onnx/ort and PyTorch results
for r1, r2 in zip(torch_out, onnx_out):
torch.testing.assert_allclose(r1.cpu(), r2, rtol=rtol, atol=atol) # type: ignore
torch.testing.assert_allclose(r1.cpu(), r2, rtol=rtol, atol=atol)

return onnx_model

Expand Down
1 change: 1 addition & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
with self.trace_transform(False):
# we can't use `self.__call__` in case a child class calls this inverse.
out: torch.Tensor = SpatialResample.__call__(self, data, **kw_args)
kw_args["src_affine"] = kw_args.get("dst_affine")
return out


Expand Down
19 changes: 19 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import random
from enum import Enum

from monai.utils import deprecated

__all__ = [
"StrEnum",
"NumpyPadMode",
Expand Down Expand Up @@ -56,6 +58,7 @@
"LazyAttr",
"BundleProperty",
"BundlePropertyConfig",
"AlgoKeys",
]


Expand Down Expand Up @@ -592,6 +595,7 @@ class LabelStatsKeys(StrEnum):
LABEL_NCOMP = "ncomponents"


@deprecated(since="1.2", msg_suffix="please use `AlgoKeys` instead.")
class AlgoEnsembleKeys(StrEnum):
"""
Default keys for Mixed Ensemble
Expand Down Expand Up @@ -664,3 +668,18 @@ class BundlePropertyConfig(StrEnum):

ID = "id"
REF_ID = "refer_id"


class AlgoKeys(StrEnum):
"""
Default keys for templated Auto3DSeg Algo.
`ID` is the identifier of the algorithm. The string has the format of <name>_<idx>_<other>.
`ALGO` is the Auto3DSeg Algo instance.
`IS_TRAINED` is the status that shows if the Algo has been trained.
`SCORE` is the score the Algo has achieved after training.
"""

ID = "identifier"
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"
2 changes: 1 addition & 1 deletion monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

def get_numpy_dtype_from_string(dtype: str) -> np.dtype:
"""Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`)."""
return np.empty([], dtype=dtype).dtype
return np.empty([], dtype=str(dtype).split(".")[-1]).dtype


def get_torch_dtype_from_string(dtype: str) -> torch.dtype:
Expand Down
Loading

0 comments on commit 9ef42ff

Please sign in to comment.