Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAE training sample script #3801

Closed
wants to merge 29 commits into from
Closed

Conversation

aandyw
Copy link
Contributor

@aandyw aandyw commented Jun 15, 2023

PR for Issue #3726

Todos

  • implement training loop for VAE
  • KL loss implementation
  • evaluate performance of VAE training
  • fix script to work for mixed precision
  • integration with a1111

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@aandyw
Copy link
Contributor Author

aandyw commented Jun 24, 2023

[06/24/2023] VAE fine-tuning runs successfully but will need to test/evaluate image results.

@aandyw aandyw marked this pull request as ready for review June 24, 2023 19:13
--dataset_name="<DATASET_NAME>" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
--gradient_checkpointing
--gradient_checkpointing \

@aandyw aandyw changed the title [WIP] VAE training sample script VAE training sample script Jul 27, 2023
Comment on lines +390 to +395
with accelerator.main_process_first():
# Split into train/test
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)
Copy link

@zhuliyi0 zhuliyi0 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support loading test set from test_data_set folder

Suggested change
with accelerator.main_process_first():
# Split into train/test
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)
with accelerator.main_process_first():
# Load test data from test_data_dir
if(args.test_data_dir is not None and args.train_data_dir is not None):
logger.info(f"load test data from {args.test_data_dir}")
test_dir = os.path.join(args.test_data_dir, "**")
test_dataset = load_dataset(
"imagefolder",
data_files=test_dir,
cache_dir=args.cache_dir,
)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = test_dataset["train"].with_transform(preprocess)
# Split into train/test
else:
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)

type=int,
default=4,
help="Number of images to remove from training set to be used as validation.",
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new argument test_data_dir, for dedicated test data folder

Suggested change
)
)
parser.add_argument(
"--test_data_dir",
type=str,
default=None,
help=(
"If not None, will override test_samples arg and use data inside this dir as test dataset."
),
)

if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
"Original (left) / Reconstruction (right)", np_images, epoch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change file name to be compatible with Windows

Suggested change
"Original (left) / Reconstruction (right)", np_images, epoch
"Original (left)-Reconstruction (right)", np_images, epoch

examples/vae/train_vae.py Outdated Show resolved Hide resolved
progress_bar.set_description("Steps")

lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#initial validation as baseline
with torch.no_grad():
log_validation(test_dataloader, vae, accelerator, weight_dtype, 0)

one validation before training start as baseline for comparison.

pred = vae.decode(z).sample

kl_loss = posterior.kl().mean()
mse_loss = F.mse_loss(pred, target, reduction="mean")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In original stable-diffusion repo and SDXL repo, the vae loss is averaged over batch dim, which means they are summed in channelheightwidth dims. Is this the correct way to average reconstruction loss?
https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/losses/contperceptual.py#L58

@github-actions
Copy link

github-actions bot commented Sep 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 2, 2023
@github-actions github-actions bot closed this Sep 12, 2023
Copy link

@JunzheJosephZhu JunzheJosephZhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@lavinal712
Copy link

Is there any progress now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Issues that haven't received updates
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants