From ab716302e475921424b110f13eda6a0b3bf17959 Mon Sep 17 00:00:00 2001 From: Nir Weingarten Date: Wed, 3 Jan 2024 11:52:38 +0200 Subject: [PATCH] Added cli argument for wandb session name --- library/train_util.py | 6 ++++++ train_network.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 3c850019e..2d954364d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2935,6 +2935,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) parser.add_argument( "--log_tracker_config", type=str, diff --git a/train_network.py b/train_network.py index 9cba78da0..a75299cda 100644 --- a/train_network.py +++ b/train_network.py @@ -684,6 +684,8 @@ def train(self, args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers(