From 7338be22f26a47ca0bb13ef7f24077964e28cc73 Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Fri, 31 May 2019 20:02:17 +0200 Subject: [PATCH] Save full configuration in output dir (#835) * Merge branch 'master' of /home/braincreator/projects/maskrcnn-benchmark with conflicts. * update Dockerfile * save config in output dir * replace string format with os.path.join --- maskrcnn_benchmark/utils/miscellaneous.py | 7 +++++++ tools/train_net.py | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/maskrcnn_benchmark/utils/miscellaneous.py b/maskrcnn_benchmark/utils/miscellaneous.py index db9a8b367..ecd3ef6a2 100644 --- a/maskrcnn_benchmark/utils/miscellaneous.py +++ b/maskrcnn_benchmark/utils/miscellaneous.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import errno import os +from .comm import is_main_process def mkdir(path): @@ -9,3 +10,9 @@ def mkdir(path): except OSError as e: if e.errno != errno.EEXIST: raise + + +def save_config(cfg, path): + if is_main_process(): + with open(path, 'w') as f: + f.write(cfg.dump()) diff --git a/tools/train_net.py b/tools/train_net.py index 9f4761b3f..3468fbb4a 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -23,7 +23,7 @@ from maskrcnn_benchmark.utils.comm import synchronize, get_rank from maskrcnn_benchmark.utils.imports import import_file from maskrcnn_benchmark.utils.logger import setup_logger -from maskrcnn_benchmark.utils.miscellaneous import mkdir +from maskrcnn_benchmark.utils.miscellaneous import mkdir, save_config # See if we can use apex.DistributedDataParallel instead of the torch default, # and enable mixed-precision via apex.amp @@ -176,6 +176,11 @@ def main(): logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) + output_config_path = os.path.join(cfg.OUTPUT_DIR, 'config.yml') + logger.info("Saving config into: {}".format(output_config_path)) + # save overloaded model config in the output directory + save_config(cfg, output_config_path) + model = train(cfg, args.local_rank, args.distributed) if not args.skip_test: