Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VAE training sample script #3726

Closed
zhuliyi0 opened this issue Jun 9, 2023 · 96 comments
Closed

VAE training sample script #3726

zhuliyi0 opened this issue Jun 9, 2023 · 96 comments

Comments

@zhuliyi0
Copy link

zhuliyi0 commented Jun 9, 2023

I believe the current lack of easy access to VAE training is stopping diffusion models from disrupting even more industries.

I'm talking about consistent details on things that are less represented in the original training data. 64x64 res can only carry so much detail. Very often I get good result from latent space (by checking the low-res intermedia image) before the final image is ruined by bad details. No prompting or finetuning or controlnet could solve this issue, I tried, and l know lots of other people tried, and most of them are trying without realising that the problem cannot be solved unless the thing that produces the final details can be trained with their domain data.

Right now VAE cannot be easily trained, at least not by someone like me who is not very good at math and python, so there is definitly a demand here. May I hope there will be a sample script based on diffusors to start with? I tried mess with the ones in compvis repo but to no avail. Thanks in advance!

@zhuliyi0 zhuliyi0 changed the title sample VAE training script VAE training sample script Jun 9, 2023
@patrickvonplaten
Copy link
Contributor

Currently I don't have the bandwidth to dive deeper into this, but I agree an easy training script for VAEs would make sense :-)

Let's see if the community has time for it!

@aandyw
Copy link
Contributor

aandyw commented Jun 13, 2023

Definitely would love to dive deeper into this but would love some guidance if possible.

@aandyw aandyw mentioned this issue Jun 15, 2023
5 tasks
@aandyw
Copy link
Contributor

aandyw commented Jun 24, 2023

Update: VAE training script runs successfully but I'll need to test on a full dataset and evaluate the results.

@zhuliyi0 Is there a dataset you would like me to try fine-tuning on? Preferably one hosted on hugging face?

@zhuliyi0
Copy link
Author

wow super cool! I was planning to train VAE to re-create certain architecture styles with consistent details, so I found this dataset on HF:

https://huggingface.co/datasets/Xpitfire/cmp_facade

Not a big dataset though, not sure if it works for you. Also there are images of extreme aspect ratio. Let me know if there are more specific requirement on the dataset and I will try to find/assemble a better one.

@aandyw
Copy link
Contributor

aandyw commented Jun 26, 2023

@zhuliyi0 No worries and thanks for responding. Might be a little busy this week but I'll try out with the new dataset and see if the VAE is improving in terms of learning the new data.

@zhuliyi0
Copy link
Author

I got the script to run, but looks like my 12G VRAM is far from enough. I assume vram will go down once adam8bit and other optimizations is in place?

@aandyw
Copy link
Contributor

aandyw commented Jun 29, 2023

@zhuliyi0 Perhaps but I can't really confirm anything at the moment. I'm basing hardware requirements on the docs (https://huggingface.co/docs/diffusers/training/text2image):

Using gradient_checkpointing and mixed_precision, it should be possible to finetune the model on a single 24GB GPU. For higher batch_size’s and faster training, it’s better to use GPUs with more than 30GB of GPU memory.

But this is obviously for training the Stable Diffusion model so the requirements will be different for sure.

At this time, I'm trying to confirm that the AutoencoderKL is indeed being fine-tuned with reasonable performance before actually implementing further techniques like EMA weights, MSE focused loss reconstruction + EMA weights, etc. (details are here: https://huggingface.co/stabilityai/sd-vae-ft-mse-original).

If you would like to work on this PR together I would appreciate the help since I maybe a little MIA for the next 2 weeks at most.

@zhuliyi0
Copy link
Author

zhuliyi0 commented Jul 8, 2023

I am a total newbie on python and ML. I am still trying to run the script on my local GPU, right now the OOM is gone after I stick to the arguments you provided in the script, vram and training speed is fine, but there is an error when saving validation image, basicly says an image file inside a wandb temp folder cannot be found. I checked and there is no such folder. Don't know how to use wandb to debug this one.

Colab seems to be running without error, but the speed is a bit slow compared to my local GPU, probably normal for T4. From validation images, I see signs of improvement of image details I was talking about, will validate with inferencing after a reasonable sized training has finished.

@zhuliyi0
Copy link
Author

I got training to run on my local GPU on Windows. The directory error was due to path naming convention in Windows. Again from validation images I can see it was learning. The loss was also going down.

I noticed there is a vram leak in log_validation function when the number of test image is 5 or above. I also failed to use the trained vae inside a1111 for inferencing, giving error "Missing key(s) in state_dict“.

@aandyw
Copy link
Contributor

aandyw commented Jul 11, 2023

Hey @zhuliyi0 , thanks for taking the time to test things. The script is definitely not perfect yet but I'll work on the things you mentioned. In terms of transferring the VAE over to a1111 I'm not quite sure about that. I haven't played around with a1111 so I would need some time.

My current focus will be to clean up the script and implement the memory saving techniques to improve training. Then I'll see how we can make the VAE transferrable to a1111.

@zhuliyi0
Copy link
Author

Totally understand that the script wouldn't be perfect at this point. I am glad to help whenever I can. I will try using pipeline to test inference performance. @pie31415

@zhuliyi0
Copy link
Author

zhuliyi0 commented Jul 12, 2023

here is a training test run:

https://wandb.ai//zhuliyi0/goa_5e5/reports/VAE-training-test--Vmlldzo0ODYzMzcx

Also did a quick inference test using a finetuned model that was trained on the same dataset, compare results with the default and trained VAE. I can confirm VAE is adding details, making the image better.

Another issue: the output from trained VAE looks white-washed. This happens on both sd15 and the finetuned model. I had to do some brightness and contrast change to the image. The validation images during training do not have this issue.

@aandyw
Copy link
Contributor

aandyw commented Jul 12, 2023

here is a training test run:

https://wandb.ai//zhuliyi0/goa_5e5/reports/VAE-training-test--Vmlldzo0ODYzMzcx

Your wandb experiment seems to be private/locked.

I can confirm VAE is adding details, making the image better.

Are you referring to the default VAE or custom trained one? If it is a custom trained one can you provide a link to the weights? It'll be extremely beneficial to have some results to compare to when I'm fixing up experiments for the script.

Another issue: the output from trained VAE looks white-washed. This happens on both sd15 and the finetuned model. I had to do some brightness and contrast change to the image. The validation images during training do not have this issue.

Hmm yeah, it may be how we're training the VAE. I'll take a look over the weekend. Most likely the substantial changes will have to be done this weekend since I'm a little preoccupied before then.

Thanks a lot for your patience though. 🤗

@zhuliyi0
Copy link
Author

zhuliyi0 commented Jul 13, 2023

I made the project public. And the weight file:

https://drive.google.com/file/d/1gTQqWuVA7m7GYIStVbulYS-tN_CMY-PM/view?usp=sharing

Some inference image that shows the white-wash issue, using VAE at step 4k - 40k, gradually getting worse:

https://drive.google.com/drive/folders/16ivRLiLgb7dDixfFbNIL7vf_wNe9BaRO?usp=sharing

@ThibaultCastells
Copy link

Hello,
This project is really cool, thank you!
I noticed a potential mistake in the code: the kl loss is applied on the output, but I think it should be applied on the latent space if I understood correctly (I may be wrong, I am not an expert of VAE training).
However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient).

The lpips loss gives great results however (without it, the image tends to become too 'smooth'). I used this library.
I hope this helps!

    lpips_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device)

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

                kl_loss = posterior.kl().mean()
                mse_loss = F.mse_loss(pred, target, reduction="mean")
                lpips_loss = lpips_loss_fn(pred, target).mean()

                logger.info(f'mse:{mse_loss.item()}, lpips:{lpips_loss.item()}, kl:{kl_loss.item()}')

                loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

@aandyw
Copy link
Contributor

aandyw commented Jul 22, 2023

Hello, This project is really cool, thank you! I noticed a potential mistake in the code: the kl loss is applied on the output, but I think it should be applied on the latent space if I understood correctly (I may be wrong, I am not an expert of VAE training). However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient).

The lpips loss gives great results however (without it, the image tends to become too 'smooth'). I used this library. I hope this helps!

    lpips_loss_fn = lpips.LPIPS(net='alex').to(accelerator.device)

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

                kl_loss = posterior.kl().mean()
                mse_loss = F.mse_loss(pred, target, reduction="mean")
                lpips_loss = lpips_loss_fn(pred, target).mean()

                logger.info(f'mse:{mse_loss.item()}, lpips:{lpips_loss.item()}, kl:{kl_loss.item()}')

                loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

Thanks for the feedback, that definitely might be the case. I'll take a look and make the necessary changes. Thanks again.

@aandyw
Copy link
Contributor

aandyw commented Jul 23, 2023

@zhuliyi0 I updated the PR with @ThibaultCastells 's code. Can you give your training another try and let us know the results? (e,.g, is the white-washing issue improved)

Also, I took a look at the VRAM issue you mentioned with test_images >= 5. I can't seem to reproduce the issue can you give more details on this if you're still experiencing this issue?

@ThibaultCastells I've credited the recent commit to you and I plan to mention your contribution in the PR as well.

@aandyw
Copy link
Contributor

aandyw commented Jul 23, 2023

@patrickvonplaten Do you mind giving the PR a look over when you're free?

@ThibaultCastells
Copy link

@pie31415 thank you very much! I will let you know if I have other improvement suggestions

@ThibaultCastells
Copy link

ThibaultCastells commented Jul 24, 2023

By the way:

However using it gives me bad results, I think it is because it changes too much the latent space organization (in the end I use it with a really small coefficient)

With a scale coefficient around $1e^{-7}$ and a training long enough (using my own dataset), the image quality first got much worst and then came back to normal, so my assumption about 'latent space reorganization' was good I think.
The kl loss went from >20,000 to ~100 when it converges.

@zhuliyi0
Copy link
Author

@pie31415 I re-run a training with new script, the result was conceivably no different. The white wash issue still exist, the same as previous. Seems like the training gradually makes the contrast lower and brightness higher, but not by much.

@ThibaultCastells do you mean "learning rate" when you say "coefficient"?

@ThibaultCastells
Copy link

ThibaultCastells commented Jul 24, 2023

No I meant the coefficient that multiplies the loss term (kl_scale):
loss = mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss

Note that by default kl_scale and lpips_scale are 0, so if you didn't change it you won't see any difference (I suggest to use lpips_scale = 0.1, as this is the value used to finetune the vae of SD).

@NoahRe1
Copy link

NoahRe1 commented Jul 26, 2023

I noticed that there is no transforms.Normalize([0.5], [0.5]) applied to the images in the training script, and the output images seem to be correct. However, in other model training scripts, normalization is performed before using VAE. Is it an error in other scripts?

@aandyw
Copy link
Contributor

aandyw commented Jul 26, 2023

@ThibaultCastells Do you have any thoughts about why the VAE might be outputting white washed reconstructions? I seem to have seen some Civitai models that had a similar issue. Not sure how it was resolved though.

@aandyw
Copy link
Contributor

aandyw commented Jul 26, 2023

I noticed that there is no transforms.Normalize([0.5], [0.5]) applied to the images in the training script, and the output images seem to be correct. However, in other model training scripts, normalization is performed before using VAE. Is it an error in other scripts?

You're right. A blunder on my part. I guess it must have been removed when I was playing around with things and forgot to put it back. Thanks for the catch

@ThibaultCastells
Copy link

@pie31415 I am not too surprised that this issue happens when using only the mse loss, because this is a very different training configuration than in the paper, so we don't know what to expect in this case. Therefore I would like to confirm that @zhuliyi0 changed the default value of the scale coefficients of the loss when he checked the new code. And if so, what value was used?

Note that when they finetune the vae for SD they only finetune the decoder, that's probably why they do not use kl loss (they do not need it since the decoder does not affect the latent space).

Also, not related but is it normal that there is no .eval() when evaluating the model (and therefore another .train() after evaluation)? Is it handled by the accelerator.unwrap_model function?

@bghira
Copy link
Contributor

bghira commented Aug 24, 2023

see #4636 for what happens with that normalisation range :D

@yeonsikch
Copy link

I compared this vae train script and latent-diffusion original code.

As a result, I'm sure that we need to fix our loss term.
In my experience, latent-diffusion original code is better.

I recommend using latent-diffusion original code.
If u wanna use latent-diffusion original code, u should write ur custom dataset class.

@aleksmirosh
Copy link

I compared this vae train script and latent-diffusion original code.

As a result, I'm sure that we need to fix our loss term. In my experience, latent-diffusion original code is better.

I recommend using latent-diffusion original code. If u wanna use latent-diffusion original code, u should write ur custom dataset class.

are they use the same Autoencoder model?
why do you think it is better?
i tried both, did not get result with any

@aandyw
Copy link
Contributor

aandyw commented Sep 26, 2023

I fixed some issues.

  1. fp16
  2. multi-gpu
  3. ema
  4. train script
  5. huggingface upload

But, I have to fix about Loss Function. this loss function is not right.

if anyone need to fix code about fp16 issue. try this:

vae, vae.encoder, vae.decoder = accelerator.prepare(vae, vae.encoder, vae.decoder)

Can you make a commit for these changes? I'll work on the VAE loss and hopefully try to get it matching with LDM.

@FrsECM
Copy link

FrsECM commented Oct 30, 2023

Hi Everyone !
Thanks a lot for your thread ! It helped me a lot.
I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ?
Thanks !

@ThibaultCastells
Copy link

Hello @FrsECM !
Thank you for sharing your results and this interesting Medium post!
Could you share the implementation you tried?

i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

It's hard to tell without trying, but I think we also need to keep in mind that the Stable Diffusion performance is bounded by the VAE performance: if the VAE can only generate blurry images then Stable Diffusion will produce blurry images, no matter how well the unet is trained.

@trouble-maker007
Copy link

@FrsECM Would you like you share your vae training implementation

@FrsECM
Copy link

FrsECM commented Nov 22, 2023

Hi @trouble-maker007 @ThibaultCastells ,
I did a try to train a LDM model based on both VAE, the one with sampling, the one without.
Bellow the result for the same amount of iterations :

image

image

For me it confirms that it's better to train with uncertainty.
Anyway, the issues i face to make the VAE converge remains on generated images.

You can find my training script there :
https://github.com/FrsECM/diffusers/blob/add-semantic-diffusion/examples/community/semantic_image_synthesis/train_vae_ldm.py

@bigcornflake
Copy link

Hi Everyone ! Thanks a lot for your thread ! It helped me a lot. I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ? Thanks !

Hi Everyone ! Thanks a lot for your thread ! It helped me a lot. I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ? Thanks !

I think your method is correct. Although using posterior.mode() looks better, it actually abandons randomness.

@sapkun
Copy link

sapkun commented Mar 18, 2024

@yeonsikch

Hello, I have tried what you said for fp16 issue, but i still got an error :

    (
        vae,
        vae.encode,
        vae.decode,
        optimizer,
        train_dataloader,
        test_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        vae, vae.encode, vae.decode, optimizer, train_dataloader, test_dataloader, lr_scheduler
    )

    Traceback (most recent call last):
  File "train_vae.py", line 551, in <module>
    main()
  File "train_vae.py", line 509, in main
    optimizer.step()
  File "/abdB5045/sd_model/lib/python3.8/site-packages/accelerate/optimizer.py", line 132, in step
    self.scaler.step(self.optimizer, closure)
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 446, in step
    self.unscale_(optimizer)
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 258, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

can you share your full script? thanks!

- `Accelerate` version: 0.27.0
- Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.10
- Python version: 3.8.12
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 1007.35 GB
- GPU type: NVIDIA A100-PCIE-40GB
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: NO
        - mixed_precision: fp16
        - use_cpu: False
        - debug: False
        - num_processes: 1
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: 0
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
        ```

@yeonsikch
Copy link

@sapkun
I've been trying to learn with the diffusers VAE code here, but the problem is that don't have lpips as a loss function. Of course, the loss term is something you can customize, but if you want to replicate the original, use the latent-diffusion code. I was able to finish learning VAE using that code.

However, that VAE code isn't perfect too. I've modified it to train with my own data and will share it again once it's up on git.

@sapkun
Copy link

sapkun commented Mar 20, 2024

thanks for your reply, my question is you mentioned that you used the accelerator.prepare method to prepare various components (vae,vae.encode,vae.decode,optimizer,train_dataloader,test_dataloader, lr_scheduler, ) = accelerator.prepare(vae, vae.encode, vae.decode, optimizer, train_dataloader, test_dataloader, lr_scheduler) but i was unable to resolve the issue related to mixed-precision training (fp16) and encountered difficulties in running the code across multiple GPUs. this is my question is.
FYI, I used MSE + 0.1 * LPIPS term to obtain good result, i removed kl term from the loss.

@yeonsikch
Copy link

@sapkun
It's been a while since I've used that code, so it took me a while to find it.
Post the full code here.

"""
TODO: fix training mixed precision -- issue with AdamW optimizer
"""

import argparse
import logging
import math
import os
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm

import diffusers
from diffusers import AutoencoderKL
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import is_wandb_available

import lpips
from PIL import Image

if is_wandb_available():
    import wandb

logger = get_logger(__name__, log_level="INFO")


@torch.no_grad()
def log_validation(args, repo_id, test_dataloader, vae, accelerator, weight_dtype, epoch):
    logger.info("Running validation... ")

    vae_model = accelerator.unwrap_model(vae)
    images = []
    
    for _, sample in enumerate(test_dataloader):
        x = sample["pixel_values"].to(weight_dtype)
        reconstructions = vae_model(x).sample
        images.append(
            torch.cat([sample["pixel_values"].cpu(), reconstructions.cpu()], axis=0)
        )

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images(
                "Original (left), Reconstruction (right)", np_images, epoch
            )
        elif tracker.name == "wandb":
            tracker.log(
                {
                    "Original (left), Reconstruction (right)": [
                        wandb.Image(torchvision.utils.make_grid(image))
                        for _, image in enumerate(images)
                    ]
                }
            )
        else:
            logger.warn(f"image logging not implemented for {tracker.gen_images}")

    if args.push_to_hub:
        try:
            save_model_card(args, repo_id, images, repo_folder=args.output_dir)
            upload_folder(
                repo_id=repo_id,
                folder_path=args.output_dir,
                commit_message="End of training",
                ignore_patterns=["step_*", "epoch_*"],
            )
        except:
            logger.info(f"UserWarning: Your huggingface's memory is limited. The weights will be saved only local path : {args.output_dir}")
    
    del vae_model
    torch.cuda.empty_cache()

def make_image_grid(imgs, rows, cols):

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def save_model_card(
    args,
    repo_id: str,
    images=None,
    repo_folder=None,
):
    # img_str = ""
    # if len(images) > 0:
    #     image_grid = make_image_grid(images, 1, "example")
    #     image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
    #     img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
    """
    model_card = f"""
# Text-to-image finetuning - {repo_id}

This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: Nothing: \n

## Training info

These are the key hyperparameters used during training:

* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}

"""
    wandb_info = ""
    if is_wandb_available():
        wandb_run_url = None
        if wandb.run is not None:
            wandb_run_url = wandb.run.url

    if wandb_run_url is not None:
        wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""

    model_card += wandb_info

    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)

def parse_args():
    parser = argparse.ArgumentParser(
        description="Simple example of a VAE training script."
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--test_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the validation data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--image_column",
        type=str,
        default="image",
        help="The column of the dataset containing an image.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--huggingface_repo",
        type=str,
        default="vae-model-finetuned",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    parser.add_argument(
        "--seed", type=int, default=21, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,#512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=1,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=2,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1.5e-7, # Reference : Waifu-diffusion-v1-4 config
        # default=4.5e-8,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=True,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=500,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=5000,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=(
            "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
            " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
            " for more docs"
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--test_samples",
        type=int,
        default=20,
        help="Number of images to remove from training set to be used as validation.",
    )
    parser.add_argument(
        "--validation_epochs",
        type=int,
        default=5,
        help="Run validation every X epochs.",
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default="vae-fine-tune",
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
    parser.add_argument(
        "--kl_scale",
        type=float,
        default=1e-6,
        help="Scaling factor for the Kullback-Leibler divergence penalty term.",
    )
    parser.add_argument(
        "--lpips_scale",
        type=float,
        default=5e-1,
        help="Scaling factor for the LPIPS metric",
    )
    parser.add_argument(
        "--lpips_start",
        type=int,
        default=50001,
        help="Start for the LPIPS metric",
    )
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
    
    args = parser.parse_args()

    # args.mixed_precision='fp16'
    # args.pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"
    # args.dataset_name="yeonsikc/sample_repeat"
    # args.seed=21
    # args.train_batch_size=1
    # args.num_train_epochs=100
    # args.learning_rate=1e-07
    # args.output_dir="/app/output_vae"
    # args.report_to='wandb'
    # args.push_to_hub=True
    # args.validation_epochs=1
    # args.resolution=128
    # args.use_8bit_adam=False

    # Sanity checks
    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError("Need either a dataset name or a training folder.")

    return args

# train_transforms = transforms.Compose(
#     [
#         transforms.Resize(
#             (128,128), interpolation=transforms.InterpolationMode.BILINEAR
#         ),
#         # transforms.RandomCrop(128),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5], [0.5]),
#     ]
# )

# def preprocess(examples):
#     images = [image.convert("RGB") for image in examples["image"]]
#     examples["pixel_values"] = [train_transforms(image) for image in images]
#     return examples

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
#     return {"pixel_values": pixel_values}

def main():
    args = parse_args()

    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(
        total_limit=args.checkpoints_total_limit,
        project_dir=args.output_dir,
        logging_dir=logging_dir,
    )

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
            
        if args.push_to_hub:
            repo_id = create_repo(
                repo_id = Path(args.huggingface_repo).name, exist_ok=True, token=args.hub_token
            ).repo_id

    # Load vae
    try:
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, weight_dtype=torch.float32
        )
    except:
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, revision=args.revision, weight_dtype=torch.float32
        )
    if args.use_ema:
        try:
            ema_vae = AutoencoderKL.from_pretrained(
                args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
        except:
            ema_vae = AutoencoderKL.from_pretrained(
                args.pretrained_model_name_or_path, revision=args.revision, weight_dtype=torch.float32)
        ema_vae = EMAModel(ema_vae.parameters(), model_cls=AutoencoderKL, model_config=ema_vae.config)
        
    vae.requires_grad_(True)
    vae_params = vae.parameters()

    # `accelerate` 0.16.0 will have better support for customized saving
    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(vae, weights, output_dir):
            if args.use_ema:
                ema_vae.save_pretrained(os.path.join(output_dir, "vae_ema"))

            logger.info(f"{vae = }")
            vae = vae[0]
            vae.save_pretrained(os.path.join(output_dir, "vae"))

        def load_model_hook(vae, input_dir):
            if args.use_ema:
                load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vae_ema"), AutoencoderKL)
                ema_vae.load_state_dict(load_model.state_dict())
                ema_vae.to(accelerator.device)
                del load_model

            # load diffusers style into model
            load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="vae")
            vae.register_to_config(**load_model.config)

            vae.load_state_dict(load_model.state_dict())
            del load_model

        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)

    if args.gradient_checkpointing:
        vae.enable_gradient_checkpointing()

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate
            * args.gradient_accumulation_steps
            * args.train_batch_size
            * accelerator.num_processes
        )
        
    # Initialize the optimizer
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes` or `pip install bitsandbytes-windows` for Windows"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        vae.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        data_files = {}
        if args.train_data_dir is not None:
            data_files["train"] = os.path.join(args.train_data_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=args.cache_dir,
        )

    column_names = dataset["train"].column_names
    if args.image_column is None:
        image_column = column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )

    train_transforms = transforms.Compose(
        [
            transforms.Resize(
                args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.RandomCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )
    
    # test_transforms = transforms.Compose(
    #     [
    #         transforms.Resize(
    #             args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
    #         ),
    #         transforms.CenterCrop(args.resolution),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.5], [0.5]),
    #     ]
    # )

    def preprocess(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        return examples

    # def test_preprocess(examples):
    #     images = [image.convert("RGB") for image in examples[image_column]]
    #     examples["pixel_values"] = [test_transforms(image) for image in images]
    #     return examples

    with accelerator.main_process_first():
        # Load test data from test_data_dir
        if(args.test_data_dir is not None and args.train_data_dir is not None):
            logger.info(f"load test data from {args.test_data_dir}")
            test_dir = os.path.join(args.test_data_dir, "**")        
            test_dataset = load_dataset(
                "imagefolder",
                data_files=test_dir,
                cache_dir=args.cache_dir,
            )
            # Set the training transforms
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = test_dataset["train"].with_transform(preprocess)
        # Load train/test data from train_data_dir
        elif "test" in dataset.keys():
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = dataset["test"].with_transform(preprocess)
        # Split into train/test
        else:
            dataset = dataset["train"].train_test_split(test_size=args.test_samples)        
            # Set the training transforms
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = dataset["test"].with_transform(preprocess)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        return {"pixel_values": pixel_values}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.train_batch_size*accelerator.num_processes,
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, shuffle=False, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=1,#args.train_batch_size*accelerator.num_processes,
    )

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.num_train_epochs * args.gradient_accumulation_steps,
    )

    # Prepare everything with our `accelerator`.
    
    (vae,
        vae.encoder,
        vae.decoder,
        optimizer,
        train_dataloader,
        test_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        vae,vae.encoder, vae.decoder, optimizer, train_dataloader, test_dataloader, lr_scheduler
    )

    if args.use_ema:
        ema_vae.to(accelerator.device)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    
    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        accelerator.init_trackers(args.tracker_project_name, tracker_config)

    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    # ------------------------------ TRAIN ------------------------------ #
    total_batch_size = (
        args.train_batch_size
        * accelerator.num_processes
        * args.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num test samples = {len(test_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            # dirs = os.listdir(args.output_dir)
            dirs = os.listdir(args.resume_from_checkpoint)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            print(f"Resuming from checkpoint {path}")
            # accelerator.load_state(os.path.join(args.output_dir, path))
            accelerator.load_state(os.path.join(path)) #kiml
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

    progress_bar = tqdm(
        range(global_step, args.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device, dtype=weight_dtype)
    lpips_loss_fn.requires_grad_(False)

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        logger.info(f"{epoch = }")

        for step, batch in enumerate(train_dataloader):
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                if accelerator.num_processes > 1:
                    posterior = vae.module.encode(target).latent_dist
                else:
                    posterior = vae.encode(target).latent_dist
                
                # z = mean                      if posterior.mode()
                # z = mean + variable*epsilon   if posterior.sample()
                z = posterior.sample() # Not mode()
                if accelerator.num_processes > 1:
                    pred = vae.module.decode(z).sample
                else:
                    pred = vae.decode(z).sample

                kl_loss = posterior.kl().mean()
                
                # if global_step > args.mse_start:
                #     pixel_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")
                # else:
                #     pixel_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")
                
                mse_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")
                
                with torch.no_grad():
                    lpips_loss = lpips_loss_fn(pred.to(dtype=weight_dtype), target).mean()
                    if not torch.isfinite(lpips_loss):
                        lpips_loss = torch.tensor(0)

                loss = (
                    mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss
                )

                if not torch.isfinite(loss):
                    pred_mean = pred.mean()
                    target_mean = target.mean()
                    logger.info("\nWARNING: non-finite loss, ending training ")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps
                    
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(vae.parameters(), args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                if args.use_ema:
                    ema_vae.step(vae.parameters())
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                if global_step % args.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(
                            args.output_dir, f"checkpoint-{global_step}"
                        )
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {
                "step_loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "mse": mse_loss.detach().item(),
                "lpips": lpips_loss.detach().item(),
                "kl": kl_loss.detach().item(),
            }
            accelerator.log(logs)
            progress_bar.set_postfix(**logs)

        if accelerator.is_main_process:
            if epoch % args.validation_epochs == 0:
                with torch.no_grad():
                    log_validation(args, repo_id, test_dataloader, vae, accelerator, weight_dtype, epoch)

    # Create the pipeline using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        vae = accelerator.unwrap_model(vae)
        if args.use_ema:
            ema_vae.copy_to(vae.parameters())
        vae.save_pretrained(args.output_dir)

    accelerator.end_training()


if __name__ == "__main__":
    # torch.autograd.set_detect_anomaly(True)
    main()
    

@yeonsikch
Copy link

@sapkun

And if you train VAE and use it with stable diffusion, you should definitely learn only the decoder part.

@jiangyuhangcn
Copy link

Reference

hi!can u finetune vae with fp16?? thanks!

@linnanwang
Copy link

@yeonsikch thanks for the great work, however I met the problem of creating a negative tensor in running your example code above, see below:
image

Any potential solutions? Thanks.

@humanely
Copy link

I am a bit confused with HF AutoencoderKL after training CompVis AutoEncoder.
There is no tokenizer in HF, whereas the CompVis uses a tokenizer. Basically, I am training a new language, so I need to use a custom tokenizer. Why is the difference in Autoencoder?

@yeonsikch
Copy link

@linnanwang Im sorry. I don't know too. I think that version issue (torch version)

@yeonsikch
Copy link

@humanely AutoencoderKL works that image encde latent space. Therefore, you don't need a tokenizer when you train autoencoderKL(=vae).

@GiilDe
Copy link

GiilDe commented Jul 28, 2024

Hi @trouble-maker007 @ThibaultCastells , I did a try to train a LDM model based on both VAE, the one with sampling, the one without. Bellow the result for the same amount of iterations :

image

image

For me it confirms that it's better to train with uncertainty. Anyway, the issues i face to make the VAE converge remains on generated images.

You can find my training script there : https://github.com/FrsECM/diffusers/blob/add-semantic-diffusion/examples/community/semantic_image_synthesis/train_vae_ldm.py

Can you elaborate about this? This is the result of left image: diffusiom model sampling and then feeding to decoder trained with noise and right image: diffusiom model sampling and then feeding to decoder trained without noise? and scale = 7.5 is the CFG scale?

@kukaiN
Copy link

kukaiN commented Aug 16, 2024

@linnanwang @jiangyuhangcn I made a "fork" of the code here: https://github.com/kukaiN/vae_finetune/tree/main

I also had the same issue with mixed precision (I wanted to use bf16) and negative dimension (caused by mismatching precisions), so I made some modification. I also added xformers to the code. The changes are listed in the readme, but tldr force initializing the trainable weights and using autocast in the training loop fixes the code to run mixed precision

@KimbingNg
Copy link
Contributor

@kukaiN Thanks! Good fix. Can you elaborate more about the cause (mismatching precisions)?
In my code, I didn't use accelerator. Both the model weights and the inputs are in float32. The forward are called under with torch.cuda.amp.autocast(), but the same issue still raises.
I also found that removing the attention blocks in the mid_block can make my code run perfectly. So I believe it is the attention operations in the mid_block that leads to this error (only during mixed precision training). Can you help me with that?

@kukaiN
Copy link

kukaiN commented Aug 20, 2024

@KimbingNg I just want to confirm if you reloaded the weights to float32 after the initial loading and the autocast scope contains the forward up to the backpropagation, like the snippet below.

My suspicion is that the mixed precision error happens because a part of the model is not properly casted. I made the changes based on the linked question/discussion, but I didn't pinpoint which layer is causing the problem.

# line 413 ~ 432:
# we load it with float32 here, but we cast it again right after
vae = AutoencoderKL.from_pretrained(model_path, ..., torch_dtype=torch.float32)

vae.requires_grad_(True)

# https://stackoverflow.com/questions/75802877/issues-when-using-huggingface-accelerate-with-fp16
# load params with fp32, which is auto casted later to mixed precision, may be needed for ema
#
# from stackoverflow's answer, it links to diffuser's sdxl training script example and in that code there's another link
# which points to https://github.com/huggingface/diffusers/pull/6514#discussion_r1447020705
# which may suggest we need to do all this casting before passing the learnable params to the optimizer

  for param in vae.parameters():
          if param.requires_grad:
              param.data = param.to(torch.float32)

...

vae.to(dtype=weight_dtype) #weight_dtype is fp16 or bf16

...

# training loop:
#  with autocast(): 
#     forward process
#      ...
#     backpropagate the loss

@sayakpaul
Copy link
Member

We have a VQGAN VAE: https://github.com/huggingface/diffusers/tree/main/examples/vqgan

@dill-shower
Copy link

We have a VQGAN VAE: https://github.com/huggingface/diffusers/tree/main/examples/vqgan

Can we use this for finetune sdxl vae or sd3 vae?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests