Skip to content

Commit

Permalink
Auto3DSeg continue training (skip trained algos) (#6310)
Browse files Browse the repository at this point in the history
Second PR for issue #6291

Since the previous PR #6290
was reverted #6295

Allows to skip the already trained algos, and continue training only for
the non-trained ones.

after this PR, the default option AutoRunner(train=None) will have this
behavior, whereas manually setting AutoRunner(train=True/False) will
always train all or skip all training. Previously we can only train all
or skip all (without any option to resume)

I changed  import_bundle_algo_history() to return a better algo_dict 

previously it returned "list[dict(name: algo)]" - a list of dict, but
each dict must have a single key name "name => algo". Not it returns a
list of dicts, each with several keys dict(AlgoEnsembleKeys.ID: name,
AlgoEnsembleKeys.ALGO, algo, "is_trained": bool, etc).
this allows to put additional metadata inside of each algo_dict, and
it's easier to read it back.

previously, to get a name we had to use "name = history[0].keys()[0]",
now it's more elegant "name = history[0][AlgoEnsembleKeys.ID]".

this however required to change many files, everywhere where
import_bundle_algo_history and export_bundle_algo_history was used.

All the tests have passed, except for "integration GPU utilization
tests" , but those errors seems unrelated



After this PR, tutorials need to be updated too
Project-MONAI/tutorials#1288

---------

Signed-off-by: myron <[email protected]>
  • Loading branch information
myron authored Apr 6, 2023
1 parent 06defb7 commit e4b313d
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 106 deletions.
36 changes: 26 additions & 10 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,33 @@ jobs:
python -m pip install --upgrade torch torchvision torchaudio
python -m pip install -r requirements-dev.txt
rm -rf /github/home/.cache/torch/hub/mmars/
- name: Run integration tests
- name: Clean directory
run: |
python -m pip list
git config --global --add safe.directory /__w/MONAI/MONAI
git clean -ffdx
nvidia-smi
export CUDA_VISIBLE_DEVICES=$(python -m tests.utils -c 1 | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na=[torch.zeros(1,device=f"cuda:{i}") for i in range(torch.cuda.device_count())];\nwhile True:print(a)' > /dev/null &
python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
python -c 'import torch; print(torch.rand(5,3, device=torch.device("cuda:0")))'
# test auto3dseg
echo "test tag algo"
- name: Auto3dseg tag algo
shell: bash
env:
BUILD_MONAI: 0
run: |
BUILD_MONAI=0 ./runtests.sh --build
python -m tests.test_auto3dseg_ensemble
python -m tests.test_auto3dseg_hpo
python -m tests.test_integration_autorunner
python -m tests.test_integration_gpu_customization
- name: Auto3dseg latest algo
shell: bash
env:
BUILD_MONAI: 0
run: |
# test latest template
echo "test latest algo"
cd ../
Expand All @@ -81,14 +87,24 @@ jobs:
python -m tests.test_integration_autorunner
python -m tests.test_integration_gpu_customization
# the other tests
echo "the other tests"
- name: Integration tests
shell: bash
env:
BUILD_MONAI: 1
run: |
pwd
ls -ll
BUILD_MONAI=1 ./runtests.sh --build --net
BUILD_MONAI=1 ./runtests.sh --build --unittests
if pgrep python; then pkill python; fi
./runtests.sh --build --net
- name: Unit tests
shell: bash
env:
BUILD_MONAI: 1
run: |
pwd
ls -ll
./runtests.sh --unittests
- name: Add reaction
uses: peter-evans/create-or-update-comment@v2
if: github.event.pull_request.number != ''
Expand Down
118 changes: 71 additions & 47 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(
# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = not self.cache["train"] if train is None else train
self.train = train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
Expand Down Expand Up @@ -635,13 +635,15 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
folders under the working directory. The results include the model checkpoints, a
progress.yaml, accuracies in CSV and a pickle file of the Algo object.
"""
for task in history:
for _, algo in task.items():
algo.train(self.train_params)
acc = algo.get_score()
algo_to_pickle(algo, template_path=algo.template_path, best_metrics=acc)
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo.train(self.train_params)
acc = algo.get_score()

def _train_algo_in_nni(self, history):
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
algo_to_pickle(algo, template_path=algo.template_path, **algo_meta_data)

def _train_algo_in_nni(self, history: list[dict[str, Any]]) -> None:
"""
Train the Algos using HPO.
Expand Down Expand Up @@ -672,40 +674,41 @@ def _train_algo_in_nni(self, history):

last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
mode_dry_run = self.hpo_params.pop("nni_dry_run", False)
for task in history:
for name, algo in task.items():
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
# override the default nni config with the same key in hpo_params
for key in self.hpo_params:
if key in nni_config:
nni_config[key] = self.hpo_params[key]
nni_config.update({"experimentName": name})
nni_config.update({"search_space": self.search_space})
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
nni_config.update({"trialCommand": trial_cmd})
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)

max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"

if mode_dry_run:
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
continue

subprocess.run(cmd.split(), check=True)

for algo_dict in history:
name = algo_dict[AlgoEnsembleKeys.ID]
algo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=algo, params=self.hpo_params)
obj_filename = nni_gen.get_obj_filename()
nni_config = deepcopy(default_nni_config)
# override the default nni config with the same key in hpo_params
for key in self.hpo_params:
if key in nni_config:
nni_config[key] = self.hpo_params[key]
nni_config.update({"experimentName": name})
nni_config.update({"search_space": self.search_space})
trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
nni_config.update({"trialCommand": trial_cmd})
nni_config_filename = os.path.abspath(os.path.join(self.work_dir, f"{name}_nni_config.yaml"))
ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)

max_trial = min(self.hpo_tasks, cast(int, default_nni_config["maxTrialNumber"]))
cmd = "nnictl create --config " + nni_config_filename + " --port 8088"

if mode_dry_run:
logger.info(f"AutoRunner HPO is in dry-run mode. Please manually launch: {cmd}")
continue

subprocess.run(cmd.split(), check=True)

n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
while n_trainings - last_total_tasks < max_trial:
sleep(1)
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
while n_trainings - last_total_tasks < max_trial:
sleep(1)
n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))

cmd = "nnictl stop --all"
subprocess.run(cmd.split(), check=True)
logger.info(f"NNI completes HPO on {name}")
last_total_tasks = n_trainings
cmd = "nnictl stop --all"
subprocess.run(cmd.split(), check=True)
logger.info(f"NNI completes HPO on {name}")
last_total_tasks = n_trainings

def run(self):
"""
Expand Down Expand Up @@ -758,7 +761,8 @@ def run(self):
logger.info("Skipping algorithm generation...")

# step 3: algo training
if self.train:
auto_train_choice = self.train is None
if self.train or (auto_train_choice and not self.cache["train"]):
history = import_bundle_algo_history(self.work_dir, only_trained=False)

if len(history) == 0:
Expand All @@ -767,20 +771,40 @@ def run(self):
"Possibly the required algorithms generation step was not completed."
)

if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)
if auto_train_choice:
skip_algos = [h[AlgoEnsembleKeys.ID] for h in history if h["is_trained"]]
if len(skip_algos) > 0:
logger.info(
f"Skipping already trained algos {skip_algos}."
"Set option train=True to always retrain all algos."
)
history = [h for h in history if not h["is_trained"]]

if len(history) > 0:
if not self.hpo:
self._train_algo_in_sequence(history)
else:
self._train_algo_in_nni(history)

self.export_cache(train=True)
else:
logger.info("Skipping algorithm training...")

# step 4: model ensemble and write the prediction to disks.
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=True)
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h["is_trained"]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h["is_trained"]]

if len(history) == 0:
raise ValueError(
f"Could not find the trained results in {self.work_dir}. "
f"Could not find any trained algos in {self.work_dir}. "
"Possibly the required training step was not completed."
)

Expand All @@ -798,4 +822,4 @@ def run(self):
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

logger.info("Auto3Dseg pipeline is complete successfully.")
logger.info("Auto3Dseg pipeline is completed successfully.")
5 changes: 4 additions & 1 deletion monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle.config_parser import ConfigParser
from monai.utils import ensure_tuple
from monai.utils.enums import AlgoEnsembleKeys

logger = get_logger(module_name=__name__)
ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "7758ad1")
Expand Down Expand Up @@ -537,4 +538,6 @@ def generate(
gen_algo.export_to_disk(output_folder, name, fold=f_id)

algo_to_pickle(gen_algo, template_path=algo.template_path)
self.history.append({name: gen_algo}) # track the previous, may create a persistent history
self.history.append(
{AlgoEnsembleKeys.ID: name, AlgoEnsembleKeys.ALGO: gen_algo}
) # track the previous, may create a persistent history
12 changes: 5 additions & 7 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,22 +267,20 @@ class AlgoEnsembleBuilder:
"""

def __init__(self, history: Sequence[dict], data_src_cfg_filename: str | None = None):
def __init__(self, history: Sequence[dict[str, Any]], data_src_cfg_filename: str | None = None):
self.infer_algos: list[dict[AlgoEnsembleKeys, Any]] = []
self.ensemble: AlgoEnsemble
self.data_src_cfg = ConfigParser(globals=False)

if data_src_cfg_filename is not None and os.path.exists(str(data_src_cfg_filename)):
self.data_src_cfg.read_config(data_src_cfg_filename)

for h in history:
for algo_dict in history:
# load inference_config_paths
# raise warning/error if not found
if len(h) > 1:
raise ValueError(f"{h} should only contain one set of genAlgo key-value")

name = list(h.keys())[0]
gen_algo = h[name]
name = algo_dict[AlgoEnsembleKeys.ID]
gen_algo = algo_dict[AlgoEnsembleKeys.ALGO]

best_metric = gen_algo.get_score()
algo_path = gen_algo.output_path
infer_path = os.path.join(algo_path, "scripts", "infer.py")
Expand Down
16 changes: 10 additions & 6 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import optional_import
from monai.utils.enums import AlgoEnsembleKeys

nni, has_nni = optional_import("nni")
optuna, has_optuna = optional_import("optuna")
Expand Down Expand Up @@ -98,8 +99,8 @@ class NNIGen(HPOGen):
# Bundle Algorithms are already generated by BundleGen in work_dir
import_bundle_algo_history(work_dir, only_trained=False)
algo_dict = self.history[0] # pick the first algorithm
algo_name = list(algo_dict.keys())[0]
onealgo = algo_dict[algo_name]
algo_name = algo_dict[AlgoEnsembleKeys.ID]
onealgo = algo_dict[AlgoEnsembleKeys.ALGO]
nni_gen = NNIGen(algo=onealgo)
nni_gen.print_bundle_algo_instruction()
Expand Down Expand Up @@ -237,10 +238,12 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}

if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, best_metrics=acc)
algo_to_pickle(self.algo, **algo_meta_data)
self.set_score(acc)


Expand Down Expand Up @@ -408,8 +411,9 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
self.algo.train(self.params)
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoEnsembleKeys.SCORE): acc}
if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, best_metrics=acc)
algo_to_pickle(self.algo, **algo_meta_data)
self.set_score(acc)
27 changes: 18 additions & 9 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.auto3dseg import algo_from_pickle, algo_to_pickle
from monai.utils.enums import AlgoEnsembleKeys


def import_bundle_algo_history(
output_folder: str = ".", template_path: str | None = None, only_trained: bool = True
) -> list:
"""
import the history of the bundleAlgo object with their names/identifiers
import the history of the bundleAlgo objects as a list of algo dicts.
each algo_dict has keys name (folder name), algo (bundleAlgo), is_trained (bool),
Args:
output_folder: the root path of the algorithms templates.
Expand All @@ -47,11 +49,18 @@ def import_bundle_algo_history(
if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

if only_trained:
if "best_metrics" in algo_meta_data:
history.append({name: algo})
else:
history.append({name: algo})
best_metric = algo_meta_data.get(AlgoEnsembleKeys.SCORE, None)
is_trained = best_metric is not None

if (only_trained and is_trained) or not only_trained:
history.append(
{
AlgoEnsembleKeys.ID: name,
AlgoEnsembleKeys.ALGO: algo,
AlgoEnsembleKeys.SCORE: best_metric,
"is_trained": is_trained,
}
)

return history

Expand All @@ -63,6 +72,6 @@ def export_bundle_algo_history(history: list[dict[str, BundleAlgo]]) -> None:
Args:
history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
"""
for task in history:
for _, algo in task.items():
algo_to_pickle(algo, template_path=algo.template_path)
for algo_dict in history:
algo = algo_dict[AlgoEnsembleKeys.ALGO]
algo_to_pickle(algo, template_path=algo.template_path)
Loading

0 comments on commit e4b313d

Please sign in to comment.