Skip to content

Commit

Permalink
Only have contigous calls after attention blocks (#7763)
Browse files Browse the repository at this point in the history
Towards #7227  .

### Description
There were lots of contigous calls in the DiffusionModelUnet. It turns
out these are necessary after attention blocks, as the einops operation
sometimes leads to non-contigous tensors that can cause errors. I've
tidied the code up so the .contiguous calls are only after attention
calls.

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
marksgraham and ericspod authored May 14, 2024
1 parent c54bf3c commit a052c44
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 20 deletions.
21 changes: 7 additions & 14 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch

class SpatialTransformer(nn.Module):
"""
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
use of this block as support is not guaranteed. For more information see:
https://github.com/Project-MONAI/MONAI/issues/7227
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.
Expand Down Expand Up @@ -396,14 +392,11 @@ def __init__(
)

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
h = x.contiguous()
h = x
h = self.norm1(h)
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -609,7 +602,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -726,7 +719,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -790,7 +783,7 @@ def forward(
) -> torch.Tensor:
del context
hidden_states = self.resnet_1(hidden_states, temb)
hidden_states = self.attention(hidden_states)
hidden_states = self.attention(hidden_states).contiguous()
hidden_states = self.resnet_2(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -1091,7 +1084,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -1669,7 +1662,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand All @@ -1682,7 +1675,7 @@ def forward(
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)

# 7. output block
output: torch.Tensor = self.out(h.contiguous())
output: torch.Tensor = self.out(h)

return output

Expand Down
9 changes: 3 additions & 6 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -430,7 +427,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -568,7 +565,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -919,7 +916,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand Down

0 comments on commit a052c44

Please sign in to comment.