diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index daa5abdd56..bdecf63168 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -59,13 +59,12 @@ def __init__( causal (bool, optional): whether to use causal attention. sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length. rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only - "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. + "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported. input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional - parameter size. + parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ super().__init__() @@ -109,7 +108,7 @@ def __init__( self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias) self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -152,31 +151,20 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # calculate query, key, values for all heads in batch and move head forward to be the batch dim b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size) - q = self.to_q(x) + q = self.input_rearrange(self.to_q(x)) kv = context if context is not None else x _, kv_t, _ = kv.size() - k = self.to_k(kv) - v = self.to_v(kv) + k = self.input_rearrange(self.to_k(kv)) + v = self.input_rearrange(self.to_v(kv)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) k = k.to(self.attention_dtype) - q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs) # - k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs) - if self.use_flash_attention: x = torch.nn.functional.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose( - 1, 2 - ) # Back to (b, nh, t, hs) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale # apply relative positional embedding if defined @@ -195,6 +183,7 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 124c00acc6..ac96b077bd 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Optional, Tuple +from typing import Tuple, Union import torch import torch.nn as nn @@ -40,9 +40,11 @@ def __init__( hidden_input_size: int | None = None, causal: bool = False, sequence_length: int | None = None, - rel_pos_embedding: Optional[str] = None, - input_size: Optional[Tuple] = None, - attention_dtype: Optional[torch.dtype] = None, + rel_pos_embedding: str | None = None, + input_size: Tuple | None = None, + attention_dtype: torch.dtype | None = None, + include_fc: bool = True, + use_combined_linear: bool = True, use_flash_attention: bool = False, ) -> None: """ @@ -61,9 +63,10 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt - flash attention for a memory efficient attention mechanism (see - https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -105,9 +108,22 @@ def __init__( self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size) - self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) - self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) - self.out_rearrange = Rearrange("b h l d -> b l (h d)") + self.qkv: Union[nn.Linear, nn.Identity] + self.to_q: Union[nn.Linear, nn.Identity] + self.to_k: Union[nn.Linear, nn.Identity] + self.to_v: Union[nn.Linear, nn.Identity] + + if use_combined_linear: + self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias) + self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) + else: + self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias) + self.qkv = nn.Identity() # add to enable torchscript + self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads) + self.out_rearrange = Rearrange("b l h d -> b h (l d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.dropout_rate = dropout_rate @@ -117,6 +133,8 @@ def __init__( self.attention_dtype = attention_dtype self.causal = causal self.sequence_length = sequence_length + self.include_fc = include_fc + self.use_combined_linear = use_combined_linear self.use_flash_attention = use_flash_attention if causal and sequence_length is not None: @@ -144,8 +162,13 @@ def forward(self, x): Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C """ - output = self.input_rearrange(self.qkv(x)) - q, k, v = output[0], output[1], output[2] + if self.use_combined_linear: + output = self.input_rearrange(self.qkv(x)) + q, k, v = output[0], output[1], output[2] + else: + q = self.input_rearrange(self.to_q(x)) + k = self.input_rearrange(self.to_k(x)) + v = self.input_rearrange(self.to_v(x)) if self.attention_dtype is not None: q = q.to(self.attention_dtype) @@ -153,13 +176,8 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q.transpose(1, 2), - key=k.transpose(1, 2), - value=v.transpose(1, 2), - scale=self.scale, - dropout_p=self.dropout_rate, - is_causal=self.causal, - ).transpose(1, 2) + query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -179,7 +197,9 @@ def forward(self, x): att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) + x = self.out_rearrange(x) - x = self.out_proj(x) + if self.include_fc: + x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 1cfafb1585..665442b55e 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -32,8 +32,13 @@ class SpatialAttentionBlock(nn.Module): spatial_dims: number of spatial dimensions, could be 1, 2, or 3. num_channels: number of input channels. Must be divisible by num_head_channels. num_head_channels: number of channels per head. + norm_num_groups: Number of groups for the group norm layer. + norm_eps: Epsilon for the normalization. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ @@ -45,6 +50,8 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, attention_dtype: Optional[torch.dtype] = None, + include_fc: bool = True, + use_combined_linear: bool = False, use_flash_attention: bool = False, ) -> None: super().__init__() @@ -60,6 +67,8 @@ def __init__( num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 28d9c563ac..05eb3b07ab 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -37,6 +37,8 @@ def __init__( sequence_length: int | None = None, with_cross_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = True, ) -> None: """ Args: @@ -47,7 +49,9 @@ def __init__( qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. """ @@ -69,6 +73,8 @@ def __init__( save_attn=save_attn, causal=causal, sequence_length=sequence_length, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.norm2 = nn.LayerNorm(hidden_size) diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py index 35d80e0565..836027796f 100644 --- a/monai/networks/nets/autoencoderkl.py +++ b/monai/networks/nets/autoencoderkl.py @@ -157,6 +157,10 @@ class Encoder(nn.Module): norm_eps: epsilon for the normalization. attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -170,6 +174,9 @@ def __init__( norm_eps: float, attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -220,6 +227,9 @@ def __init__( num_channels=input_channel, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -243,6 +253,9 @@ def __init__( num_channels=channels[-1], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -291,6 +304,10 @@ class Decoder(nn.Module): attention_levels: indicate which level from num_channels contain an attention block. with_nonlocal_attn: if True use non-local attention block. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -305,6 +322,9 @@ def __init__( attention_levels: Sequence[bool], with_nonlocal_attn: bool = True, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -350,6 +370,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -389,6 +412,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -463,6 +489,10 @@ class AutoencoderKL(nn.Module): with_decoder_nonlocal_attn: if True use non-local attention block in the decoder. use_checkpoint: if True, use activation checkpoint to save memory. use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + include_fc: whether to include the final linear layer in the attention block. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -480,6 +510,9 @@ def __init__( with_decoder_nonlocal_attn: bool = True, use_checkpoint: bool = False, use_convtranspose: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -509,6 +542,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = Decoder( spatial_dims=spatial_dims, @@ -521,6 +557,9 @@ def __init__( attention_levels=attention_levels, with_nonlocal_attn=with_decoder_nonlocal_attn, use_convtranspose=use_convtranspose, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, @@ -665,27 +704,18 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] + attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.weight"], - old_state_dict[f"{block}.to_k.weight"], - old_state_dict[f"{block}.to_v.weight"], - ], - dim=0, - ) - new_state_dict[f"{block}.attn.qkv.bias"] = torch.cat( - [ - old_state_dict[f"{block}.to_q.bias"], - old_state_dict[f"{block}.to_k.bias"], - old_state_dict[f"{block}.to_v.bias"], - ], - dim=0, - ) + new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight") + new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight") + new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight") + new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias") + new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias") + new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias") + # old version did not have a projection so set these to the identity new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye( new_state_dict[f"{block}.attn.out_proj.weight"].shape[0] @@ -698,5 +728,8 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] - self.load_state_dict(new_state_dict) + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) + self.load_state_dict(new_state_dict, strict=True) diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py index ed3654733d..8b08eaae10 100644 --- a/monai/networks/nets/controlnet.py +++ b/monai/networks/nets/controlnet.py @@ -143,6 +143,10 @@ class ControlNet(nn.Module): upcast_attention: if True, upcast attention operations to full precision. conditioning_embedding_in_channels: number of input channels for the conditioning embedding. conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -163,6 +167,9 @@ def __init__( upcast_attention: bool = False, conditioning_embedding_in_channels: int = 1, conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256), + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -282,6 +289,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -326,6 +336,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) controlnet_block = Convolution( @@ -441,25 +454,16 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] - - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index a885339d0d..f57fe251d2 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -67,7 +67,9 @@ class DiffusionUNetTransformerBlock(nn.Module): cross_attention_dim: size of the context vector for cross attention. upcast_attention: if True, upcast attention operations to full precision. use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. """ @@ -80,6 +82,8 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, use_flash_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, ) -> None: super().__init__() self.attn1 = SABlock( @@ -89,6 +93,8 @@ def __init__( dim_head=num_head_channels, dropout_rate=dropout, attention_dtype=torch.float if upcast_attention else None, + include_fc=include_fc, + use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) @@ -134,6 +140,11 @@ class SpatialTransformer(nn.Module): norm_eps: epsilon for the normalization. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + """ def __init__( @@ -148,6 +159,9 @@ def __init__( norm_eps: float = 1e-6, cross_attention_dim: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -175,6 +189,9 @@ def __init__( dropout=dropout, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(num_layers) ] @@ -529,6 +546,10 @@ class AttnDownBlock(nn.Module): resblock_updown: if True use residual blocks for downsampling. downsample_padding: padding used in the downsampling block. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -544,6 +565,9 @@ def __init__( resblock_updown: bool = False, downsample_padding: int = 1, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -570,6 +594,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -636,7 +663,11 @@ class CrossAttnDownBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -656,6 +687,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -688,6 +722,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -745,6 +782,10 @@ class AttnMidBlock(nn.Module): norm_num_groups: number of groups for the group normalization. norm_eps: epsilon for the group normalization. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -755,6 +796,9 @@ def __init__( norm_num_groups: int = 32, norm_eps: float = 1e-6, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -772,6 +816,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( @@ -808,6 +855,10 @@ class CrossAttnMidBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -822,6 +873,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -844,6 +898,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.resnet_2 = DiffusionUNetResnetBlock( spatial_dims=spatial_dims, @@ -989,6 +1046,10 @@ class AttnUpBlock(nn.Module): add_upsample: if True add downsample block. resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1004,6 +1065,9 @@ def __init__( add_upsample: bool = True, resblock_updown: bool = False, num_head_channels: int = 1, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1032,6 +1096,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1116,7 +1183,11 @@ class CrossAttnUpBlock(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1136,6 +1207,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -1169,6 +1243,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -1250,6 +1327,9 @@ def get_down_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnDownBlock( @@ -1263,6 +1343,9 @@ def get_down_block( add_downsample=add_downsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnDownBlock( @@ -1280,6 +1363,9 @@ def get_down_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return DownBlock( @@ -1307,6 +1393,9 @@ def get_mid_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_conditioning: return CrossAttnMidBlock( @@ -1320,6 +1409,9 @@ def get_mid_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return AttnMidBlock( @@ -1329,6 +1421,9 @@ def get_mid_block( norm_num_groups=norm_num_groups, norm_eps=norm_eps, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) @@ -1350,6 +1445,9 @@ def get_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return AttnUpBlock( @@ -1364,6 +1462,9 @@ def get_up_block( add_upsample=add_upsample, resblock_updown=resblock_updown, num_head_channels=num_head_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return CrossAttnUpBlock( @@ -1382,6 +1483,9 @@ def get_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) else: return UpBlock( @@ -1419,9 +1523,13 @@ class DiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -1442,6 +1550,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, dropout_cattn: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1536,6 +1647,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -1553,6 +1667,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -1587,6 +1704,9 @@ def __init__( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, dropout_cattn=dropout_cattn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) @@ -1714,31 +1834,23 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] + new_state_dict[k] = old_state_dict.pop(k) # fix the attention blocks - attention_blocks = [k.replace(".attn1.qkv.weight", "") for k in new_state_dict if "attn1.qkv.weight" in k] + attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k] for block in attention_blocks: - new_state_dict[f"{block}.attn1.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn1.to_q.weight"], - old_state_dict[f"{block}.attn1.to_k.weight"], - old_state_dict[f"{block}.attn1.to_v.weight"], - ], - dim=0, - ) - # projection - new_state_dict[f"{block}.attn1.out_proj.weight"] = old_state_dict[f"{block}.attn1.to_out.0.weight"] - new_state_dict[f"{block}.attn1.out_proj.bias"] = old_state_dict[f"{block}.attn1.to_out.0.bias"] + new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight") + new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias") - new_state_dict[f"{block}.attn2.out_proj.weight"] = old_state_dict[f"{block}.attn2.to_out.0.weight"] - new_state_dict[f"{block}.attn2.out_proj.bias"] = old_state_dict[f"{block}.attn2.to_out.0.bias"] # fix the upsample conv blocks which were renamed postconv for k in new_state_dict: if "postconv" in k: old_name = k.replace("postconv", "conv") - new_state_dict[k] = old_state_dict[old_name] + new_state_dict[k] = old_state_dict.pop(old_name) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) @@ -1782,6 +1894,9 @@ def __init__( cross_attention_dim: int | None = None, num_class_embeds: int | None = None, upcast_attention: bool = False, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -1866,6 +1981,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py index d5794a9227..cc8909194a 100644 --- a/monai/networks/nets/spade_autoencoderkl.py +++ b/monai/networks/nets/spade_autoencoderkl.py @@ -137,6 +137,10 @@ class SPADEDecoder(nn.Module): label_nc: number of semantic channels for SPADE normalisation. with_nonlocal_attn: if True use non-local attention block. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -152,6 +156,9 @@ def __init__( label_nc: int, with_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.spatial_dims = spatial_dims @@ -200,6 +207,9 @@ def __init__( num_channels=reversed_block_out_channels[0], norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) blocks.append( @@ -243,6 +253,9 @@ def __init__( num_channels=block_in_ch, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -331,6 +344,9 @@ def __init__( with_encoder_nonlocal_attn: bool = True, with_decoder_nonlocal_attn: bool = True, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() @@ -360,6 +376,9 @@ def __init__( norm_eps=norm_eps, attention_levels=attention_levels, with_nonlocal_attn=with_encoder_nonlocal_attn, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.decoder = SPADEDecoder( spatial_dims=spatial_dims, @@ -373,6 +392,9 @@ def __init__( label_nc=label_nc, with_nonlocal_attn=with_decoder_nonlocal_attn, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.quant_conv_mu = Convolution( spatial_dims=spatial_dims, diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index 75d1687df3..a9609b1d39 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -325,6 +325,10 @@ class SPADEAttnUpBlock(nn.Module): resblock_updown: if True use residual blocks for upsampling. num_head_channels: number of channels in each attention head. spade_intermediate_channels: number of intermediate channels for SPADE block layer + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to False. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -342,6 +346,9 @@ def __init__( resblock_updown: bool = False, num_head_channels: int = 1, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -371,6 +378,9 @@ def __init__( num_head_channels=num_head_channels, norm_num_groups=norm_num_groups, norm_eps=norm_eps, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -457,6 +467,8 @@ class SPADECrossAttnUpBlock(nn.Module): cross_attention_dim: number of context dimensions to use. upcast_attention: if True, upcast attention operations to full precision. spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism. + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -477,6 +489,9 @@ def __init__( cross_attention_dim: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.resblock_updown = resblock_updown @@ -510,6 +525,9 @@ def __init__( num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) ) @@ -592,6 +610,9 @@ def get_spade_up_block( cross_attention_dim: int | None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> nn.Module: if with_attn: return SPADEAttnUpBlock( @@ -608,6 +629,9 @@ def get_spade_up_block( resblock_updown=resblock_updown, num_head_channels=num_head_channels, spade_intermediate_channels=spade_intermediate_channels, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) elif with_cross_attn: return SPADECrossAttnUpBlock( @@ -627,6 +651,7 @@ def get_spade_up_block( cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) else: return SPADEUpBlock( @@ -667,9 +692,11 @@ class SPADEDiffusionModelUNet(nn.Module): transformer_num_layers: number of layers of Transformer blocks to use. cross_attention_dim: number of context dimensions to use. num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` - classes. + classes. upcast_attention: if True, upcast attention operations to full precision. - spade_intermediate_channels: number of intermediate channels for SPADE block layer + spade_intermediate_channels: number of intermediate channels for SPADE block layer. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -691,6 +718,9 @@ def __init__( num_class_embeds: int | None = None, upcast_attention: bool = False, spade_intermediate_channels: int = 128, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() if with_conditioning is True and cross_attention_dim is None: @@ -783,6 +813,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) self.down_blocks.append(down_block) @@ -799,6 +832,9 @@ def __init__( transformer_num_layers=transformer_num_layers, cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) # up @@ -834,6 +870,7 @@ def __init__( upcast_attention=upcast_attention, label_nc=label_nc, spade_intermediate_channels=spade_intermediate_channels, + use_flash_attention=use_flash_attention, ) self.up_blocks.append(up_block) diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py index 1af725abda..3a278c112a 100644 --- a/monai/networks/nets/transformer.py +++ b/monai/networks/nets/transformer.py @@ -51,6 +51,10 @@ class DecoderOnlyTransformer(nn.Module): attn_layers_heads: Number of attention heads. with_cross_attention: Whether to use cross attention for conditioning. embedding_dropout_rate: Dropout rate for the embedding. + include_fc: whether to include the final linear layer. Default to True. + use_combined_linear: whether to use a single linear layer for qkv projection, default to True. + use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism + (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). """ def __init__( @@ -62,6 +66,9 @@ def __init__( attn_layers_heads: int, with_cross_attention: bool = False, embedding_dropout_rate: float = 0.0, + include_fc: bool = True, + use_combined_linear: bool = False, + use_flash_attention: bool = False, ) -> None: super().__init__() self.num_tokens = num_tokens @@ -86,6 +93,9 @@ def __init__( causal=True, sequence_length=max_seq_len, with_cross_attention=with_cross_attention, + include_fc=include_fc, + use_combined_linear=use_combined_linear, + use_flash_attention=use_flash_attention, ) for _ in range(attn_layers_depth) ] @@ -133,25 +143,15 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None: # copy over all matching keys for k in new_state_dict: if k in old_state_dict: - new_state_dict[k] = old_state_dict[k] - - # fix the attention blocks - attention_blocks = [k.replace(".attn.qkv.weight", "") for k in new_state_dict if "attn.qkv.weight" in k] - for block in attention_blocks: - new_state_dict[f"{block}.attn.qkv.weight"] = torch.cat( - [ - old_state_dict[f"{block}.attn.to_q.weight"], - old_state_dict[f"{block}.attn.to_k.weight"], - old_state_dict[f"{block}.attn.to_v.weight"], - ], - dim=0, - ) + new_state_dict[k] = old_state_dict.pop(k) # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2 - for k in old_state_dict: + for k in list(old_state_dict.keys()): if "norm2" in k: - new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict[k] + new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k) if "norm3" in k: - new_state_dict[k.replace("norm3", "norm2")] = old_state_dict[k] - + new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k) + if verbose: + # print all remaining keys in old_state_dict + print("remaining keys in old_state_dict:", old_state_dict.keys()) self.load_state_dict(new_state_dict) diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py index 44458147d6..e034e42290 100644 --- a/tests/test_crossattention.py +++ b/tests/test_crossattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.crossattention import CrossAttentionBlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -166,6 +166,21 @@ def test_access_attn_matrix(self): matrix_acess_blk(torch.randn(input_shape)) assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1]) + @parameterized.expand([[True], [False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self, causal): + input_param = {"hidden_size": 128, "num_heads": 1, "causal": causal, "sequence_length": 16 if causal else None} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(1, 16, 128).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 3e98f4c5c4..88919fd8b1 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -22,7 +22,7 @@ from monai.networks.blocks.selfattention import SABlock from monai.networks.layers.factories import RelPosEmbedding from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save einops, has_einops = optional_import("einops") @@ -32,20 +32,23 @@ for num_heads in [4, 6, 8, 12]: for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: for input_size in [(16, 32), (8, 8, 8)]: - for flash_attn in [True, False]: - test_case = [ - { - "hidden_size": hidden_size, - "num_heads": num_heads, - "dropout_rate": dropout_rate, - "rel_pos_embedding": rel_pos_embedding if not flash_attn else None, - "input_size": input_size, - "use_flash_attention": flash_attn, - }, - (2, 512, hidden_size), - (2, 512, hidden_size), - ] - TEST_CASE_SABLOCK.append(test_case) + for include_fc in [True, False]: + for use_combined_linear in [True, False]: + test_case = [ + { + "hidden_size": hidden_size, + "num_heads": num_heads, + "dropout_rate": dropout_rate, + "rel_pos_embedding": rel_pos_embedding, + "input_size": input_size, + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + "use_flash_attention": True if rel_pos_embedding is None else False, + }, + (2, 512, hidden_size), + (2, 512, hidden_size), + ] + TEST_CASE_SABLOCK.append(test_case) class TestResBlock(unittest.TestCase): @@ -175,6 +178,39 @@ def count_sablock_params(*args, **kwargs): nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2) self.assertEqual(nparams_default, nparams_default_more_heads) + @parameterized.expand([[True, False], [True, True], [False, True], [False, False]]) + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_script(self, include_fc, use_combined_linear): + input_param = { + "hidden_size": 360, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": None, + "input_size": (16, 32), + "include_fc": include_fc, + "use_combined_linear": use_combined_linear, + } + net = SABlock(**input_param) + input_shape = (2, 512, 360) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + + @skipUnless(has_einops, "Requires einops") + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self): + for causal in [True, False]: + input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal} + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device) + block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device) + block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict()) + test_data = torch.randn(2, 512, 360).to(device) + + out_1 = block_w_flash_attention(test_data) + out_2 = block_wo_flash_attention(test_data) + assert_allclose(out_1, out_2, atol=1e-4) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_unetr.py b/tests/test_unetr.py index 46018d2bc0..1217c9d85f 100644 --- a/tests/test_unetr.py +++ b/tests/test_unetr.py @@ -123,7 +123,7 @@ def test_ill_arg(self): ) @parameterized.expand(TEST_CASE_UNETR) - @SkipIfBeforePyTorchVersion((1, 9)) + @SkipIfBeforePyTorchVersion((2, 0)) def test_script(self, input_param, input_shape, _): net = UNETR(**(input_param)) net.eval()