diff --git a/fine_tune.py b/fine_tune.py index 2ecb4ff36..4a3f49c7e 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -32,6 +32,7 @@ get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) @@ -339,7 +340,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -348,6 +349,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 677d1bf46..28b625d30 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1/torch.sqrt(snr_t) + loss = weight * loss + return loss # TODO train_utilと分散しているのでどちらかに寄せる @@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/sdxl_train.py b/sdxl_train.py index 7bde3cab7..55c11f9ce 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -34,6 +34,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -548,7 +549,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -559,6 +560,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0df61e848..7a141bb4e 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -44,6 +44,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train @@ -465,6 +466,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 79920a972..e256badca 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -40,6 +40,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite @@ -435,6 +436,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index a1b9cac8b..7316c27ee 100644 --- a/train_db.py +++ b/train_db.py @@ -35,6 +35,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) # perlin_noise, @@ -336,6 +337,8 @@ def train(args): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 2232a384a..9deb53313 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) @@ -528,6 +529,7 @@ def train(self, args): "ss_min_snr_gamma": args.min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, + "ss_debiased_estimation": bool(args.debiased_estimation_loss), } if use_user_config: @@ -811,6 +813,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 252add536..6b6e7f5a0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,6 +32,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) imagenet_templates_small = [ @@ -582,6 +583,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 525e612f1..8dd5c672f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -34,6 +34,7 @@ pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -471,6 +472,8 @@ def remove_model(old_ckpt_name): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし