Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_fc and use_combined_linear argument in the SABlock #7996

Merged
merged 36 commits into from
Aug 9, 2024

Conversation

KumoLiu
Copy link
Contributor

@KumoLiu KumoLiu commented Aug 6, 2024

Fixes #7991
Fixes #7992

Description

Add include_fc and use_combined_linear argument in the SABlock.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 6, 2024

RelPosEmbedding.DECOMPOSED can not convert to torchscript, leave it in this PR. Error shows as below.

'__torch__.torch.nn.modules.container.ParameterList (of Python compilation unit at: 0x55c00fc6db20)' object has no attribute or method '__len__'. Did you forget to initialize an attribute in __init__()?:
  File "/workspace/Code/MONAI/monai/networks/blocks/attention_utils.py", line 98

@KumoLiu KumoLiu changed the title fix #7991 Add include_fc argument in the SABlock Aug 6, 2024
@KumoLiu KumoLiu changed the title Add include_fc argument in the SABlock Add include_fc and use_combined_linear argument in the SABlock Aug 6, 2024
@ericspod ericspod mentioned this pull request Aug 6, 2024
7 tasks
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 6, 2024

Remaining parts need to be addressed.
#7977 (comment) -- addressed.
#7977 (comment) -- addressed.
#7977 (comment) -- addressed.

@KumoLiu KumoLiu marked this pull request as draft August 6, 2024 16:01
KumoLiu added 2 commits August 8, 2024 01:05
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: YunLiu <[email protected]>
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 8, 2024

/build

KumoLiu added 2 commits August 8, 2024 11:59
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: YunLiu <[email protected]>
@KumoLiu KumoLiu mentioned this pull request Aug 8, 2024
7 tasks
@ericspod
Copy link
Member

ericspod commented Aug 8, 2024

For reference, I had attempted to test the similarity between the SABlock in core versus that in GenerativeModels and didn't have any luck narrowing down where the differences were. I have this at the end of test_selfattension.py which defines some failing tests that should pass if I haven't missed anything:

from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose
from monai.utils import set_determinism

generative, has_generative = optional_import("generative")
xops, has_xformers = optional_import("xformers.ops")


class TestComparison(unittest.TestCase):
    @parameterized.expand([["cuda:0", True], ["cuda:0", False], ["cpu", False]])
    @skipUnless(has_einops, "Requires einops")
    @skipUnless(has_generative, "Requires generative")
    @skipUnless(has_xformers, "Requires xformers")
    @SkipIfBeforePyTorchVersion((2, 0))
    def test_generative_vs_core(self, _device, use_flash_attention):
        device = torch.device(_device)
        input_shape = (2, 512, 360)
        input_param = {
            "hidden_size": 360,
            "num_heads": 4,
            "dropout_rate": 0,
            "use_flash_attention": use_flash_attention,
        }

        set_determinism(0)
        net_gen = generative.networks.blocks.SABlock(**input_param).to(device)
        set_determinism(0)
        net_monai = SABlock(**input_param).to(device)
        set_determinism(0)
        input = torch.randn(input_shape).to(device)

        net_monai.load_state_dict(net_gen.state_dict())

        with eval_mode(net_gen, net_monai):
            set_determinism(0)
            r1 = net_gen(input)
            set_determinism(0)
            r2 = net_monai(input)
            assert_allclose(r1.detach().cpu().numpy(), r2.detach().cpu().numpy())

I'm not suggesting we add this test but we should come back to where the differences are coming from.

@ericspod
Copy link
Member

ericspod commented Aug 8, 2024

Hi @KumoLiu I had a few comments about the tests that need to be addressed to test the actual cases you want. I'm otherwise good with things.

@guopengf
Copy link
Contributor

guopengf commented Aug 9, 2024

Hi @KumoLiu, shall we also add those parameters (use_flash_attention, include_fc and use_combined_linear) to monai.networks.nets.controlnet and monai.networks.nets.autoencoderkl? Those two networks also use attention layers.

KumoLiu and others added 3 commits August 9, 2024 11:26
Co-authored-by: Eric Kerfoot <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Signed-off-by: YunLiu <[email protected]>
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 9, 2024

For reference, I had attempted to test the similarity between the SABlock in core versus that in GenerativeModels and didn't have any luck narrowing down where the differences were.

I actually tested this locally and confirmed that they return the same result. It seems that the issue with your script might be due to differences in the initial weights. To address this, I copied and reused the state dict. I didn't include this test in this pr because we aim to avoid introducing generative.

For your reference, here is the test code I used:

import torch
from copy import deepcopy
from generative.networks.nets.diffusion_model_unet import CrossAttention, AttentionBlock
from monai.utils import set_determinism
from monai.networks.blocks import CrossAttentionBlock, SpatialAttentionBlock

set_determinism(0)
use_flash_attention = True

cross_block_gen = CrossAttention(query_dim=512, num_attention_heads=16, num_head_channels=32, use_flash_attention=use_flash_attention).to('cuda')
cross_block_core = CrossAttentionBlock(hidden_size=512, num_heads=16, dim_head=32, use_flash_attention=use_flash_attention).to('cuda')
state_dict_gen_cross = cross_block_gen.state_dict()
state_dict_core_cross = deepcopy(state_dict_gen_cross)
state_dict_core_cross["out_proj.weight"] = state_dict_gen_cross["to_out.0.weight"]
state_dict_core_cross["out_proj.bias"] = state_dict_gen_cross["to_out.0.bias"]
del state_dict_core_cross["to_out.0.weight"]
del state_dict_core_cross["to_out.0.bias"]

atten_block_gen = AttentionBlock(spatial_dims=2, num_channels=512, num_head_channels=32, use_flash_attention=use_flash_attention).to('cuda')
atten_block_core = SpatialAttentionBlock(spatial_dims=2, num_channels=512, num_head_channels=32, use_combined_linear=False, include_fc=False, use_flash_attention=use_flash_attention).to('cuda')
state_dict_gen_atten = atten_block_gen.state_dict()
state_dict_core_atten = atten_block_core.state_dict()
for key in state_dict_core_atten.keys():
    if 'attn' in key and "proj" not in key:
        state_dict_core_atten[key] = state_dict_gen_atten[key[5:]]
    if key in state_dict_gen_atten:
        state_dict_core_atten[key] = state_dict_gen_atten[key]
state_dict_core_atten["attn.out_proj.weight"] = state_dict_gen_atten["proj_attn.weight"]
state_dict_core_atten["attn.out_proj.bias"] = state_dict_gen_atten["proj_attn.bias"]


input_cross = torch.rand(1, 256, 512).to("cuda")
input_cross2 = deepcopy(input_cross)
cross_block_gen.load_state_dict(state_dict_gen_cross)

out_gen = cross_block_gen(input_cross)
cross_block_core.load_state_dict(state_dict_core_cross)
out_core = cross_block_core(input_cross2)

print('diff in cross', (out_gen-out_core).abs().sum())

input_atten = torch.rand(1, 512, 16, 16).to("cuda")
input_atten2 = deepcopy(input_atten)
atten_block_gen.load_state_dict(state_dict_gen_atten)

out_gen = atten_block_gen(input_atten)
atten_block_core.load_state_dict(state_dict_core_atten)
out_core = atten_block_core(input_atten2)

print('diff in atten', (out_gen-out_core).abs().sum())

Result:

diff in cross tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
diff in atten tensor(0., device='cuda:0', grad_fn=<SumBackward0>)

And for this one :

We may also want to test with and without CUDA, I've noticed some differences on CPU between using flash attention and not.

Yes, we also include these tests. The gpu pipeline have been removed to blossom, but it will stilled be tested.

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 9, 2024

Hi @KumoLiu, shall we also add those parameters (use_flash_attention, include_fc and use_combined_linear) to monai.networks.nets.controlnet and monai.networks.nets.autoencoderkl? Those two networks also use attention layers.

Hi @guopengf, added in the latest commit.

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Aug 9, 2024

/build

@KumoLiu KumoLiu merged commit 069519d into Project-MONAI:dev Aug 9, 2024
28 checks passed
@KumoLiu KumoLiu deleted the proj-atten branch August 9, 2024 08:00
@KumoLiu KumoLiu added this to the Refactor MAISI [P0 v1.4] milestone Aug 19, 2024
rcremese pushed a commit to rcremese/MONAI that referenced this pull request Sep 2, 2024
…roject-MONAI#7996)

Fixes Project-MONAI#7991
Fixes Project-MONAI#7992

### Description
Add `include_fc` and `use_combined_linear` argument in the `SABlock`.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Inconsistent usage of the linear layer in the SABlock Unused proj_attn in the attention block
5 participants