Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

autorunner params from config #7175

Merged
merged 9 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,11 @@ def __init__(
mlflow_tracking_uri: str | None = None,
**kwargs: Any,
):
logger.info(f"AutoRunner using work directory {work_dir}")
os.makedirs(work_dir, exist_ok=True)

self.work_dir = os.path.abspath(work_dir)
self.data_src_cfg = dict()
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
input = self.data_src_cfg_name
if input is None and os.path.isfile(os.path.join(os.path.abspath(work_dir), "input.yaml")):
input = os.path.join(os.path.abspath(work_dir), "input.yaml")
logger.info(f"Input config is not provided, using the default {input}")

self.data_src_cfg = dict()
if isinstance(input, dict):
self.data_src_cfg = input
elif isinstance(input, str) and os.path.isfile(input):
Expand All @@ -238,6 +227,51 @@ def __init__(
else:
raise ValueError(f"{input} is not a valid file or dict")

if "work_dir" in self.data_src_cfg: # override from config
work_dir = self.data_src_cfg["work_dir"]
self.work_dir = os.path.abspath(work_dir)

logger.info(f"AutoRunner using work directory {self.work_dir}")
os.makedirs(self.work_dir, exist_ok=True)
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")

self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.allow_skip = allow_skip

# cache.yaml
self.not_use_cache = not_use_cache
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
self.cache = self.read_cache()
self.export_cache()

# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.ensemble = ensemble # last step, no need to check
self.hpo = hpo and has_nni
self.hpo_backend = hpo_backend
self.mlflow_tracking_uri = mlflow_tracking_uri
self.kwargs = deepcopy(kwargs)

# parse input config for AutoRunner param overrides
for param in [
"analyze",
"algo_gen",
"train",
"hpo",
"ensemble",
"not_use_cache",
"allow_skip",
]: # override from config
if param in self.data_src_cfg and isinstance(self.data_src_cfg[param], bool):
setattr(self, param, self.data_src_cfg[param]) # e.g. self.analyze = self.data_src_cfg["analyze"]

for param in ["algos", "hpo_backend", "templates_path_or_url", "mlflow_tracking_uri"]: # override from config
if param in self.data_src_cfg:
setattr(self, param, self.data_src_cfg[param]) # e.g. self.algos = self.data_src_cfg["algos"]

missing_keys = {"dataroot", "datalist", "modality"}.difference(self.data_src_cfg.keys())
if len(missing_keys) > 0:
raise ValueError(f"Config keys are missing {missing_keys}")
Expand All @@ -256,6 +290,8 @@ def __init__(

# inspect and update folds
num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
if "num_fold" in self.data_src_cfg:
num_fold = int(self.data_src_cfg["num_fold"]) # override from config

self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
ConfigParser.export_config_file(
Expand All @@ -266,17 +302,6 @@ def __init__(
self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml")
self.datalist_filename = datalist_filename

self.not_use_cache = not_use_cache
self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
self.cache = self.read_cache()
self.export_cache()

# determine if we need to analyze, algo_gen or train from cache, unless manually provided
self.analyze = not self.cache["analyze"] if analyze is None else analyze
self.algo_gen = not self.cache["algo_gen"] if algo_gen is None else algo_gen
self.train = train
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
self.set_device_info()
self.set_prediction_params()
Expand All @@ -288,9 +313,9 @@ def __init__(
self.gpu_customization_specs: dict[str, Any] = {}

# hpo
if hpo_backend.lower() != "nni":
if self.hpo_backend.lower() != "nni":
raise NotImplementedError("HPOGen backend only supports NNI")
self.hpo = hpo and has_nni
self.hpo = self.hpo and has_nni
self.set_hpo_params()
self.search_space: dict[str, dict[str, Any]] = {}
self.hpo_tasks = 0
Expand Down
3 changes: 2 additions & 1 deletion tests/test_vis_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
from monai.visualize import GradCAM, GradCAMpp
from tests.utils import assert_allclose
from tests.utils import assert_allclose, skip_if_quick


class DenseNetAdjoint(DenseNet121):
Expand Down Expand Up @@ -147,6 +147,7 @@ def __call__(self, x, adjoint_info):
TESTS_ILL.append([cam])


@skip_if_quick
class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
def test_shape(self, cam_class, input_data, expected_shape):
Expand Down
Loading