Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Auto3DSeg skip trained algos" #6295

Merged
merged 16 commits into from
Apr 5, 2023
20 changes: 7 additions & 13 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.train = not self.cache["train"] if train is None else train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
Expand Down Expand Up @@ -758,8 +758,7 @@ def run(self):
logger.info("Skipping algorithm generation...")

# step 3: algo training
auto_train_choice = self.train is None
if self.train or (auto_train_choice and not self.cache["train"]):
if self.train:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

if len(history) == 0:
Expand All @@ -768,15 +767,10 @@ def run(self):
"Possibly the required algorithms generation step was not completed."
)

if auto_train_choice:
history = [h for h in history if not h["is_trained"]] # skip trained

if len(history) > 0:
if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)

if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)
self.export_cache(train=True)
else:
logger.info("Skipping algorithm training...")
Expand Down Expand Up @@ -804,4 +798,4 @@ def run(self):
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

logger.info("Auto3Dseg pipeline is completed successfully.")
logger.info("Auto3Dseg pipeline is complete successfully.")
9 changes: 3 additions & 6 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metrics = "best_metrics"
is_trained = best_metrics in algo_meta_data

if only_trained:
if is_trained:
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data[best_metrics]})
if "best_metrics" in algo_meta_data:
history.append({name: algo})
else:
history.append({name: algo, "is_trained": is_trained, best_metrics: algo_meta_data.get(best_metrics, None)})
history.append({name: algo})

return history

Expand Down