From 3f12bddc7c0643eb0cab165f6f2a1476d3629a94 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 10 Oct 2024 18:16:44 -0600 Subject: [PATCH 1/2] configure.py can configure caption strategy --- configure.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/configure.py b/configure.py index da3c3d94..7adf84da 100644 --- a/configure.py +++ b/configure.py @@ -789,6 +789,26 @@ def configure_env(): "Enter the path to your dataset. This should be a directory containing images and text files for their caption. For reliability, use an absolute (full) path, beginning with a '/'", "/datasets/my-dataset", ) + dataset_caption_strategy = prompt_user( + ( + "How should the dataloader handle captions?" + "\n-> 'filename' will use the names of your image files as the caption" + "\n-> 'textfile' requires a image.txt file to go next to your image.png file" + "\n-> 'instanceprompt' will just use one trigger phrase for all images" + "\n" + "\n(Options: filename, textfile, instanceprompt)" + ), + "textfile", + ) + if dataset_caption_strategy not in ["filename", "textfile", "instanceprompt"]: + print(f"Invalid caption strategy: {dataset_caption_strategy}") + dataset_caption_strategy = "textfile" + dataset_instance_prompt = None + if "instanceprompt" in dataset_caption_strategy: + dataset_instance_prompt = prompt_user( + "Enter the instance_prompt you want to use for all images in this dataset", + "CatchPhrase", + ) dataset_repeats = int( prompt_user( "How many times do you want to repeat each image in the dataset?", 10 @@ -818,6 +838,9 @@ def configure_env(): dataset["maximum_image_size"] = dataset["resolution"] dataset["target_downsample_size"] = dataset["resolution"] dataset["id"] = dataset["id"].replace("PLACEHOLDER", dataset_id) + if dataset_instance_prompt: + dataset["instance_prompt"] = dataset_instance_prompt + dataset["caption_strategy"] = dataset_caption_strategy print("Dataloader configuration:") print(default_local_configuration) From 66d288fb5c0fdf9b8d7cc13e04a45d7f40073fe2 Mon Sep 17 00:00:00 2001 From: bghira Date: Thu, 10 Oct 2024 19:44:02 -0600 Subject: [PATCH 2/2] regression: deepspeed check fix --- helpers/training/trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/helpers/training/trainer.py b/helpers/training/trainer.py index aa467587..6f321ffe 100644 --- a/helpers/training/trainer.py +++ b/helpers/training/trainer.py @@ -1343,9 +1343,11 @@ def init_unload_vae(self): def init_validations(self): if ( - self.accelerator.state.deepspeed_plugin.deepspeed_config[ - "zero_optimization" - ].get("stage") + hasattr(self.accelerator, "state") + and hasattr(self.accelerator.state, "deepspeed_plugin") + and getattr(self.accelerator.state.deepspeed_plugin, "deepspeed_config", {}) + .get("zero_optimization", {}) + .get("stage") == 3 ): logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")