Skip to content

Commit

Permalink
Removed CPU randn() from schedulers (#8145)
Browse files Browse the repository at this point in the history
Fixes performance issued due to extra CPU/GPU sync:
https://nvbugswb.nvidia.com/NvBugs5/SWBug.aspx?bugid=4904446&cmtNo=

---------

Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: YunLiu <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
borisfom and KumoLiu authored Oct 15, 2024
1 parent 5002fd9 commit 4a4c251
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion monai/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def step(
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu")
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device)
variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise

pred_prev_sample = pred_prev_sample + variance
Expand Down
8 changes: 6 additions & 2 deletions monai/networks/schedulers/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,12 @@ def step(
variance = 0
if timestep > 0:
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
model_output.size(),
dtype=model_output.dtype,
layout=model_output.layout,
generator=generator,
device=model_output.device,
)
variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise

pred_prev_sample = pred_prev_sample + variance
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ isort>=5.1
ruff
pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
mypy>=1.5.0
mypy>=1.5.0, <1.12.0
ninja
torchvision
psutil
Expand Down

0 comments on commit 4a4c251

Please sign in to comment.