Skip to content

Commit

Permalink
Merge branch 'main' of ssh://github.com/bghira/SimpleTuner into featu…
Browse files Browse the repository at this point in the history
…re/fastapi-endpoints
  • Loading branch information
bghira committed Sep 15, 2024
2 parents 03ba7c5 + 4dd81a2 commit cc82914
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 18 deletions.
2 changes: 1 addition & 1 deletion helpers/configuration/env_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def load_env():

print(f"[CONFIG.ENV] Loaded environment variables from {config_env_path}")
else:
raise ValueError(f"Cannot find config file: {config_env_path}")
logger.error(f"Cannot find config file: {config_env_path}")

return config_file_contents

Expand Down
8 changes: 7 additions & 1 deletion helpers/metadata/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,13 @@ def compute_aspect_ratio_bucket_indices(self, ignore_existing_cache: bool = Fals
if self.should_abort:
logger.info("Aborting aspect bucket update.")
return
while any(worker.is_alive() for worker in workers):
while (
any(worker.is_alive() for worker in workers)
or not tqdm_queue.empty()
or not aspect_ratio_bucket_indices_queue.empty()
or not metadata_updates_queue.empty()
or not written_files_queue.empty()
):
current_time = time.time()
while not tqdm_queue.empty():
pbar.update(tqdm_queue.get())
Expand Down
4 changes: 2 additions & 2 deletions helpers/metadata/backends/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ def reload_cache(self, set_config: bool = True):
dict: The cache data.
"""
# Query our DataBackend to see whether the cache file exists.
logger.info(f"Checking for cache file: {self.cache_file}")
logger.debug(f"Checking for cache file: {self.cache_file}")
if self.data_backend.exists(self.cache_file):
try:
# Use our DataBackend to actually read the cache file.
logger.info("Pulling cache file from storage")
logger.debug("Pulling cache file from storage")
cache_data_raw = self.data_backend.read(self.cache_file)
cache_data = json.loads(cache_data_raw)
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions helpers/training/save_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def __init__(
if args.controlnet:
self.denoiser_class = ControlNetModel
self.denoiser_subdir = "controlnet"
logger.info(f"Denoiser class set to: {self.denoiser_class.__name__}.")
logger.info(f"Pipeline class set to: {self.pipeline_class.__name__}.")
logger.debug(f"Denoiser class set to: {self.denoiser_class.__name__}.")
logger.debug(f"Pipeline class set to: {self.pipeline_class.__name__}.")

self.ema_model_cls = None
self.ema_model_subdir = None
Expand Down
30 changes: 22 additions & 8 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(self, config: dict = None):
self.lycoris_wrapped_network = None
self.lycoris_config = None
self.lr_scheduler = None
self.webhook_handler = None
self.should_abort = False

def _config_to_obj(self, config):
Expand Down Expand Up @@ -257,7 +258,6 @@ def run(self):

raise e


def _initialize_components_with_signal_check(self, initializers):
"""
Runs a list of initializer functions with signal checks after each.
Expand Down Expand Up @@ -921,26 +921,40 @@ def init_post_load_freeze(self):

if self.unet is not None:
logger.info("Applying BitFit freezing strategy to the U-net.")
self.unet = apply_bitfit_freezing(self.unet, self.config)
self.unet = apply_bitfit_freezing(
unwrap_model(self.accelerator, self.unet), self.config
)
if self.transformer is not None:
logger.warning(
"Training DiT models with BitFit is not yet tested, and unexpected results may occur."
)
self.transformer = apply_bitfit_freezing(self.transformer, self.config)
self.transformer = apply_bitfit_freezing(
unwrap_model(self.accelerator, self.transformer), self.config
)

if self.config.gradient_checkpointing:
if self.unet is not None:
self.unet.enable_gradient_checkpointing()
unwrap_model(
self.accelerator, self.unet
).enable_gradient_checkpointing()
if self.transformer is not None and self.config.model_family != "smoldit":
self.transformer.enable_gradient_checkpointing()
unwrap_model(
self.accelerator, self.transformer
).enable_gradient_checkpointing()
if self.config.controlnet:
self.controlnet.enable_gradient_checkpointing()
unwrap_model(
self.accelerator, self.controlnet
).enable_gradient_checkpointing()
if (
hasattr(self.config, "train_text_encoder")
and self.config.train_text_encoder
):
self.text_encoder_1.gradient_checkpointing_enable()
self.text_encoder_2.gradient_checkpointing_enable()
unwrap_model(
self.accelerator, self.text_encoder_1
).gradient_checkpointing_enable()
unwrap_model(
self.accelerator, self.text_encoder_2
).gradient_checkpointing_enable()

def _recalculate_training_steps(self):
# Scheduler and math around the number of training steps.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metadata_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
self.image_path_str = "test_image.jpg"

self.instance_data_dir = "/some/fake/path"
self.cache_file = "/some/fake/cache.json"
self.cache_file = "/some/fake/cache"
self.metadata_file = "/some/fake/metadata.json"
StateTracker.set_args(MagicMock())
# Overload cache file with json:
Expand Down
Loading

0 comments on commit cc82914

Please sign in to comment.