Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The bad generations of generate.py. #120

Closed
ylsung opened this issue Mar 24, 2021 · 10 comments · Fixed by #129
Closed

The bad generations of generate.py. #120

ylsung opened this issue Mar 24, 2021 · 10 comments · Fixed by #129

Comments

@ylsung
Copy link
Contributor

ylsung commented Mar 24, 2021

Thanks for the repo.

I trained the DALLE on visual genome dataset. During the training, one of the generations is shown below,

image
image

But when I want to generate an image by generate.py, the generated images are non-sense, even though I use the text also appeared in the training phase.

The scripts I use

# separate the punctuation marks with characters
python generate.py --dalle_path ./dalle.pt --text "tire on bus . window on bus . window on bus . window on bus . window on bus . pole in grass . window on bus .

and

# don't separate the punctuation marks with characters
python generate.py --dalle_path ./dalle.pt --text "tire on bus. window on bus. window on bus. window on bus. window on bus. pole in grass. window on bus."

The results of both scripts are similar, and the generations are

image
image
image

I have checked the model weights are loaded normally. Any thoughts on this issue?

@robvanvolt
Copy link
Contributor

The training datasets have to be drastically increased to get decent results - have a look a few issues underneath on the results of afiaka87: #86 (comment)

@ylsung
Copy link
Contributor Author

ylsung commented Mar 25, 2021

I see. Maybe I didn't describe the issue clearly. My problem is more about the mismatched results between images generated during training and by generate.py. As you can see, the two outputs look very different even I use the same input text. Moreover, no matter what I use as the input text, the generated images look similar, which contains lots of black blocks.

@afiaka87
Copy link
Contributor

I see. Maybe I didn't describe the issue clearly. My problem is more about the mismatched results between images generated during training and by generate.py. As you can see, the two outputs look very different even I use the same input text. Moreover, no matter what I use as the input text, the generated images look similar, which contains lots of black blocks.

@louis2889184

I have noticed similar behavior with reconstructions (even ones directly from the training set) tend to be quite blurry and abstract. I'm out of my depth on that - but I assume it is due to the transformer being forced to recreate the phrase verbatim while allowing image output to deviate a bit. Hopefully someone more knowledgeable than I can chime in on why this happens.

As for the "mismatch" between images generated:

a.) it wont be deterministic. Run generation many times and see if you can find one or two that are similar.
b.) if your model hasn't generalized (and it needs a kind of absurd amount of data to do so) then even the slightest deviation in your phrasing compared to what is in the training will give you totally different output. Is the phrase identical?
c.) Most importantly: what are your parameters? Learning rate, batch_size, depth, heads? These numbers dramatically effect the output.

@ylsung
Copy link
Contributor Author

ylsung commented Mar 26, 2021

@afiaka87 Thanks for your help. I have three different texts corresponding to the same image, and will randomly pick one when loading the data. I'll try to make the image and text one to one, and see if the mismatch will go away.

@afiaka87
Copy link
Contributor

@lucidrains I haven't yet reproduced this exact issue, but I have discussed it extensively on the discord and am at least having trouble understanding why @louis2889184 generations are in fact decidedly worse than the reconstructions from training even though they're providing the exact same text. Can you clarify on why that might happen? To be clear - I've only used the generate code a few times.

@afiaka87
Copy link
Contributor

@louis2889184 let me know here once you have tried re-running with lower depth/heads and reversible turned off.

@ylsung
Copy link
Contributor Author

ylsung commented Mar 27, 2021

@afiaka87 It turns out it's because the approach we save images is different in train_dalle.py and generate.py.

In train_dalle.py we use wandb.Image to process the image, and it will automatically normalize and scale the image.
https://github.com/wandb/client/blob/9cc04578ebc6d593450e9dbbcae07452bf7bec35/wandb/sdk/data_types.py#L1676-L1679

However, in generate.py, we use torchvision.utils.save_image to do it. It won't normalize the image to (0, 1) unless we specify the argument normalize=True. Also, the VAE's output range is roughly within -1 and 1, if we don't normalize the image,
save_image will directly transform the float number to uint8 by

ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()

Hence, there will be lots of pixels, which are original smaller than 0, become 0. That's why there is a big part of the image is black.

There are some results, the config I use is

EPOCHS = 20
BATCH_SIZE = 8
LEARNING_RATE = 3e-4
GRAD_CLIP_NORM = 0.5

MODEL_DIM = 256 # 512
TEXT_SEQ_LEN = 64 # 256
DEPTH = 32
HEADS = 16
DIM_HEAD = 64
REVERSIBLE = False
ATTN_TYPES = None

And the outputs of generate.py, given text "frame on wall. lamp by bed. wall on building.", are
0
1
2
3
4

After I add normalize=True in save_image(image, outputs_dir / f'{i}.jpg', normalize=True), the outputs are
0
1
2
3
4

Looks much better.

BTW, the mask seems also important for generations. output = dalle.generate_images(text_chunk, mask = mask, filter_thres = args.top_k). The above results are generated by inputting mask. But I haven't dig into this too much.


Edit
Generations without masks (original implementation):
0
1
2
3
4

The results are quite weird, and I cannot even align the image with the text.

Generations with masks:
0
1
2
3
4

We can see some beds and lamps in the images, so the quality is higher than which without masks. It seems that the pad token influences the results a lot, so we need to use a mask to exclude them.

@afiaka87
Copy link
Contributor

This is great work and an obvious opportunity to submit a pull request if you'd like. @louis2889184

@afiaka87
Copy link
Contributor

A simple inline edit with the github ui should make this a fairly easy fix.

@ylsung
Copy link
Contributor Author

ylsung commented Mar 28, 2021

The PR is made! Thanks, @afiaka87.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants