Skip to content

Commit

Permalink
Merge pull request #1043 from bghira/main
Browse files Browse the repository at this point in the history
regression
  • Loading branch information
bghira authored Oct 11, 2024
2 parents a1fe9ad + 66d288f commit f608ef9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
23 changes: 23 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,26 @@ def configure_env():
"Enter the path to your dataset. This should be a directory containing images and text files for their caption. For reliability, use an absolute (full) path, beginning with a '/'",
"/datasets/my-dataset",
)
dataset_caption_strategy = prompt_user(
(
"How should the dataloader handle captions?"
"\n-> 'filename' will use the names of your image files as the caption"
"\n-> 'textfile' requires a image.txt file to go next to your image.png file"
"\n-> 'instanceprompt' will just use one trigger phrase for all images"
"\n"
"\n(Options: filename, textfile, instanceprompt)"
),
"textfile",
)
if dataset_caption_strategy not in ["filename", "textfile", "instanceprompt"]:
print(f"Invalid caption strategy: {dataset_caption_strategy}")
dataset_caption_strategy = "textfile"
dataset_instance_prompt = None
if "instanceprompt" in dataset_caption_strategy:
dataset_instance_prompt = prompt_user(
"Enter the instance_prompt you want to use for all images in this dataset",
"CatchPhrase",
)
dataset_repeats = int(
prompt_user(
"How many times do you want to repeat each image in the dataset?", 10
Expand Down Expand Up @@ -818,6 +838,9 @@ def configure_env():
dataset["maximum_image_size"] = dataset["resolution"]
dataset["target_downsample_size"] = dataset["resolution"]
dataset["id"] = dataset["id"].replace("PLACEHOLDER", dataset_id)
if dataset_instance_prompt:
dataset["instance_prompt"] = dataset_instance_prompt
dataset["caption_strategy"] = dataset_caption_strategy

print("Dataloader configuration:")
print(default_local_configuration)
Expand Down
8 changes: 5 additions & 3 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,9 +1343,11 @@ def init_unload_vae(self):

def init_validations(self):
if (
self.accelerator.state.deepspeed_plugin.deepspeed_config[
"zero_optimization"
].get("stage")
hasattr(self.accelerator, "state")
and hasattr(self.accelerator.state, "deepspeed_plugin")
and getattr(self.accelerator.state.deepspeed_plugin, "deepspeed_config", {})
.get("zero_optimization", {})
.get("stage")
== 3
):
logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")
Expand Down

0 comments on commit f608ef9

Please sign in to comment.