Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testing code #3

Open
ozzyou opened this issue Mar 30, 2021 · 14 comments
Open

Testing code #3

ozzyou opened this issue Mar 30, 2021 · 14 comments

Comments

@ozzyou
Copy link

ozzyou commented Mar 30, 2021

Hi, thank you very much for the really nice implementation! I have trained the model for 100 epochs and the evaluation results look nice. I was wondering if there's also testing code available. I implemented my own, but I get results such as the image below.

Thank you very much for your reply.

123

@ozzyou
Copy link
Author

ozzyou commented Mar 30, 2021

See my implementation below

from slot_attention.data import CLEVRDataModule
from slot_attention.method import SlotAttentionMethod
from slot_attention.model import SlotAttentionModel
from slot_attention.params import SlotAttentionParams
from slot_attention.utils import rescale
import torch
from torchvision import transforms
from PIL import Image
from slot_attention.utils import to_rgb_from_tensor
from torchvision import utils as vutils

params = SlotAttentionParams()
model = SlotAttentionModel(
    resolution=params.resolution,
    num_slots=params.num_slots,
    num_iterations=params.num_iterations,
    empty_cache=params.empty_cache,
)
clevr_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(rescale),  # rescale between -1 and 1
        transforms.Resize(params.resolution),
    ]
)

clevr_datamodule = CLEVRDataModule(
    data_root=params.data_root,
    max_n_objects=params.num_slots - 1,
    train_batch_size=params.batch_size,
    val_batch_size=params.val_batch_size,
    clevr_transforms=clevr_transforms,
    num_train_images=params.num_train_images,
    num_val_images=params.num_val_images,
    num_workers=params.num_workers,
)

root = "/home/ozzy/Projects/slot_attention/"
model = SlotAttentionMethod(model=model, datamodule=clevr_datamodule, params=params)
model.load_state_dict(torch.load(root + "wandb/offline-run-1/files/slot-attention-clevr6/3cy530ay/checkpoints/epoch=99-step=27298.ckpt"), strict=False)
model.eval()

img = Image.open(root + "data/CLEVR_v1.0/images/test/CLEVR_test_014999.png")
img = img.convert("RGB")
img = clevr_transforms(img).unsqueeze(0)

recon_combined, recons, masks, slots = model(img)
out = to_rgb_from_tensor(
    torch.cat(
        [
            img.unsqueeze(1),  # original images
            recon_combined.unsqueeze(1),  # reconstructions
            recons * masks + (1 - masks),  # each slot
        ],
        dim=1,
    )
)
print("RECON SHAPE", out.shape)

batch_size, num_slots, C, H, W = recons.shape
images = vutils.make_grid(
    out.view(batch_size * out.shape[1], C, H, W).cpu(), normalize=False, nrow=out.shape[1],
)

new = transforms.ToPILImage()(images)
new.save("123.jpg")

@brydenfogelman
Copy link
Contributor

Hmm not sure I quite understand what you're asking for here ... do you just want some code to test the model?

Also, the first image linked is that the results from the model or the test code you pasted?

@ozzyou
Copy link
Author

ozzyou commented Mar 31, 2021

Hi, thanks for your reply. Yes, that would be great. The image is the result of the test code.

@liuyvchi
Copy link

Hi, thanks for your reply. Yes, that would be great. The image is the result of the test code.

Hi, does your issue been resolved? I meet the same issue as yours.
image

@greeneggsandyaml
Copy link

greeneggsandyaml commented May 30, 2021

I also find that training diverges -- I get similar-looking results (i.e. terrible results) after 100 epochs.

In other words, we are saying that we cannot get the model to train properly with this code. @brydenfogelman have you been able to successfully train a model with this code?

@greeneggsandyaml
Copy link

Hello, I just wanted to follow up on this issue.

@brydenfogelman
Copy link
Contributor

I also find that training diverges -- I get similar-looking results (i.e. terrible results) after 100 epochs.

In other words, we are saying that we cannot get the model to train properly with this code. @brydenfogelman have you been able to successfully train a model with this code?

Hi! I was able to successfully train the model ... the resulting image in the README was from this model. I can try rerunning the model and seeing if I can replicate the issue you all are having.

I may have also introduced a bug in 2fdd396 by switching the LR scheduler to match the paper. I'll test this over the weekend.

In the meantime, @greeneggsandyaml you could try reverting the model back to the Exponential LR scheduler and see if that works?

@greeneggsandyaml
Copy link

greeneggsandyaml commented Jun 3, 2021 via email

@greeneggsandyaml
Copy link

As promised, here is an update. I ran the code from commit 603787ddebde0e19ff9419e6a4e4311ce362956d and everything worked well!

For those who are interested, here is my Weights and Biases log: https://wandb.ai/lukemelas2/public-experiments/runs/td2j9zcn?workspace=user-

Overall, this is great to see. I'll be doing more investigation into this as well.

@greeneggsandyaml
Copy link

Hello, I'm back with another update. Also, @brydenfogelman, did you manage to run the code again?

I'm finding that sometimes I get results that look good:
looks-good

and sometimes I get results that look bad:
looks-bad

Have you seen these sorts of "splotchy" results before? Is it just due random initialization? It feels to me like it is too much variation to be caused solely by random initialization.

@brydenfogelman
Copy link
Contributor

@greeneggsandyaml How long did you train for?

I think even the original authors found that results can vary on the network and slot initialization. I think this figure from the paper demonstrates this finding.
image

(Looking at this caption again I also realized that I didn't increase the number of slot iteration at test time, increasing this would probably make the results look better)

Here's an image of one of my earlier experiments where it did randomly learn to separate the background image.
image

They also trained their model for significantly longer than I trained it here (5 days wall clock time).

My best guess is that increasing the number of slot iterations at test time will improve the visualizations. What are your thoughts @greeneggsandyaml?

@ZiwenZhuang
Copy link

Hello, I'm back with another update. Also, @brydenfogelman, did you manage to run the code again?

I'm finding that sometimes I get results that look good:
looks-good

and sometimes I get results that look bad:
looks-bad

Have you seen these sorts of "splotchy" results before? Is it just due random initialization? It feels to me like it is too much variation to be caused solely by random initialization.

Hi, How did you solve the problem?

I used the test code from above but get the gray-scale output.
123

I checked the images in wandb log files, which looks acceptable.
images_13699_0

Is there anything wrong in the testing code above? How should I change it?

Thank you,

@ZiwenZhuang
Copy link

Hi, this is a follow-up, I know where the test code above might go wrong.

Depending on how you save your checkpoint file, the state_dict loaded might not match the model state_dict. Using strict = False might failed to load any weights into the model.

You should check the state_dict keys while loading the model.

:)

@brydenfogelman
Copy link
Contributor

@ZiwenZhuang mentioned above is that the LR scheduler changed with ended up breaking subsequent runs. I'll try and push the fixes to the repo later today

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants