Skip to content

Commit

Permalink
Use ONNX / Core ML compatible method to broadcast (open-mmlab#310)
Browse files Browse the repository at this point in the history
* Use ONNX / Core ML compatible method to broadcast.

Unfortunately `tile` could not be used either, it's still not compatible
with ONNX.

See open-mmlab#284.

* Add comment about why broadcast_to is not used.

Also, apply style to changed files.

* Make sure broadcast remains in same device.
  • Loading branch information
pcuenca authored Sep 2, 2022
1 parent 7b628a2 commit e49dd03
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
5 changes: 2 additions & 3 deletions src/diffusers/models/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __init__(
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand All @@ -132,8 +131,8 @@ def forward(
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
Expand Down
7 changes: 2 additions & 5 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def forward(
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand All @@ -133,8 +132,8 @@ def forward(
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)

# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
Expand All @@ -145,7 +144,6 @@ def forward(
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:

if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
Expand All @@ -160,7 +158,6 @@ def forward(

# 5. up
for upsample_block in self.up_blocks:

res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

Expand Down

0 comments on commit e49dd03

Please sign in to comment.