-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VAE training sample script #3801
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
[06/24/2023] VAE fine-tuning runs successfully but will need to test/evaluate image results. |
--dataset_name="<DATASET_NAME>" \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--gradient_checkpointing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--gradient_checkpointing | |
--gradient_checkpointing \ |
with accelerator.main_process_first(): | ||
# Split into train/test | ||
dataset = dataset["train"].train_test_split(test_size=args.test_samples) | ||
# Set the training transforms | ||
train_dataset = dataset["train"].with_transform(preprocess) | ||
test_dataset = dataset["test"].with_transform(preprocess) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
support loading test set from test_data_set folder
with accelerator.main_process_first(): | |
# Split into train/test | |
dataset = dataset["train"].train_test_split(test_size=args.test_samples) | |
# Set the training transforms | |
train_dataset = dataset["train"].with_transform(preprocess) | |
test_dataset = dataset["test"].with_transform(preprocess) | |
with accelerator.main_process_first(): | |
# Load test data from test_data_dir | |
if(args.test_data_dir is not None and args.train_data_dir is not None): | |
logger.info(f"load test data from {args.test_data_dir}") | |
test_dir = os.path.join(args.test_data_dir, "**") | |
test_dataset = load_dataset( | |
"imagefolder", | |
data_files=test_dir, | |
cache_dir=args.cache_dir, | |
) | |
# Set the training transforms | |
train_dataset = dataset["train"].with_transform(preprocess) | |
test_dataset = test_dataset["train"].with_transform(preprocess) | |
# Split into train/test | |
else: | |
dataset = dataset["train"].train_test_split(test_size=args.test_samples) | |
# Set the training transforms | |
train_dataset = dataset["train"].with_transform(preprocess) | |
test_dataset = dataset["test"].with_transform(preprocess) |
type=int, | ||
default=4, | ||
help="Number of images to remove from training set to be used as validation.", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add new argument test_data_dir, for dedicated test data folder
) | |
) | |
parser.add_argument( | |
"--test_data_dir", | |
type=str, | |
default=None, | |
help=( | |
"If not None, will override test_samples arg and use data inside this dir as test dataset." | |
), | |
) |
if tracker.name == "tensorboard": | ||
np_images = np.stack([np.asarray(img) for img in images]) | ||
tracker.writer.add_images( | ||
"Original (left) / Reconstruction (right)", np_images, epoch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change file name to be compatible with Windows
"Original (left) / Reconstruction (right)", np_images, epoch | |
"Original (left)-Reconstruction (right)", np_images, epoch |
progress_bar.set_description("Steps") | ||
|
||
lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#initial validation as baseline | |
with torch.no_grad(): | |
log_validation(test_dataloader, vae, accelerator, weight_dtype, 0) |
one validation before training start as baseline for comparison.
Co-authored-by: zhuliyi0 <[email protected]>
pred = vae.decode(z).sample | ||
|
||
kl_loss = posterior.kl().mean() | ||
mse_loss = F.mse_loss(pred, target, reduction="mean") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In original stable-diffusion repo and SDXL repo, the vae loss is averaged over batch dim, which means they are summed in channelheightwidth dims. Is this the correct way to average reconstruction loss?
https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/losses/contperceptual.py#L58
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Is there any progress now? |
PR for Issue #3726
Todos