Combining transformer with diffusion models #375
Replies: 5 comments 6 replies
-
This is a very interesting line of work! I haven't yet grokked the architecture completely, but I'll dig into this when I have some more time. Thanks for spending the time; training guided diffusion (in the general case) is indeed an extremely slow endeavor. A fine usage for your 4x3090s! |
Beta Was this translation helpful? Give feedback.
-
This is really cool! Actually I wanted to try it out and asked about it recently (you should join DALLE-Pytorch discord). I had only thought of the simplest approach: "deblurring" VAE image output with the diffusion model, the VAE being frozen. The idea of using the diffusion model early is very interesting but shouldn't you put it a bit after to separate encoding from decoding (so you can have transformers task on encoded tokens). Also I'm not very familiar with diffusion models but do you have to add noise? For the simple model that would just "deblur" VAE output. I was thinking that the input would be encoded image and output would be the original but didn't think we would also need to add noise (the noise is already know and is the difference between the 2). If that's the case maybe pix2pix, stylegan, etc models are more appropriate? For the model where you add it at the start of decoder the noise would basically the one from quantized model (codebook - output of encoder). |
Beta Was this translation helpful? Give feedback.
-
Is it applicable to the speech compression ? |
Beta Was this translation helpful? Give feedback.
-
This is excellent |
Beta Was this translation helpful? Give feedback.
-
Hi guys, just wanted to show you something I've been working on. Here's a link to the code and models: https://github.com/Jack000/guided-diffusion
For a while now I've been thinking about how to combine transformer models with DDPM. The intuition is that diffusion models are great at generating low-level texture but poor at global composition, while transformers have the opposite problem. It would be great if they could be combined somehow.
Anyways after a lot of trial and error I've found an approach that works pretty well. The idea is really simple - just feed the image latents from the VQGAN encoder into a conditional DDPM. So after training a DALLE-pytorch model with the VQGAN VAE, you can feed the image latent code into this diffusion model and get potentially much better image quality.
the diffusion model architecture is also really simple - it's exactly the same as the super-resolution model in guided diffusion, except where they concatenate the low res image channel-wise with the noised input, we can skip the encoder layers entirely and inject the latents into the middle block. This has the added benefit of allowing us to re-use the encoder and decoder weights from the pretrained model, so only the middle block needs to be re-trained.
I also tried some other approaches with varying degrees of success:
This was the first and most obvious approach but the images always came out distorted, even with blurring + noise augmentation
Errors made by the low-resolution model seems to be amplified by the upscaling process. I think you really need the condition augmentation from google SR3 to make this work. Although the 128x128 -> 64x64 -> 256x256 upscaling pipeline seems to work ok
When I tried this the diffusion model quickly converged to produce nearly exactly the same image as the VAE, with errors and all. I think it gets too much information from the skip connections.
I thought that since most of the original openai models had a class embedding, it would help to give it class information (just replace the nn.embed with a nn.linear) after testing I found that it makes almost zero difference.
So after settling on the current approach, I tried different image embeddings to see how it would affect the diffusion model:
I was curious to see what would happen if the latents didn't contain information about reconstruction but semantic content of the image
here are some results from my early experiments:
a person with glasses is a particularly difficult case so I've been using this image from unsplash for testing
I tested on a lot more images, but overall I think the Gumbel f8 embeddings work better for reconstruction. The classifier approach is more like a complete re-interpretation and looks very GAN-like. More training might help but I imagine it would end up similar to regular clip-guided-diffusion.
I also tried training the embeddings together with the diffusion model, from scratch. The codebook collapsed and didn't move much - I think there's a conceptual issue with this, since the diffusion model is training for denoising and we want the embeddings to train for reconstruction, which are not totally aligned goals.
Some generations from the current model:
so I find it interesting that the initial noise matters a lot more than clip guidance. I haven't been tracking FID scores or anything so it's very possible that the model is still under-trained.
it's a bit of a wash but I think the 256x256 model does a bit better in the mouth and nose areas
There's one experiment that I really wanted to try but couldn't get working - the idea is to train a VQGAN model on an edge-detector version of the image (a Sobel filter would work I think), then feed these edge embeddings to the DDPM. This way the lighting and colors would be entirely generated by the diffusion model, and the transformer would be only responsible for the most salient features of the image. With only edges, the codebook could be a few hundred instead of 8192, possibly enabling 256x24x24 latents which would help a lot with the global structure.
I did try this but couldn't get the VAE to converge (either the DALLE-pytorch DVAE or the VQGAN). There's some kind of issue when the image is predominantly black.
anyways, try it out and let me know what you think. I'd be interested in any suggestions on other vector-quantized image embeddings that we can use. I'm still training the 256x256 model but it's going pretty slowly on 4x3090s
Beta Was this translation helpful? Give feedback.
All reactions