Skip to content

Commit

Permalink
Replace call to deprecated torch.norm (#16758)
Browse files Browse the repository at this point in the history
### Description
torch.norm is deprecated as mentioned in issue #16751. This PR replaces
the call to torch.norm by the options suggested by torch documentation.
  • Loading branch information
xadupre authored Jul 21, 2023
1 parent b7176f9 commit b508c72
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion orttraining/orttraining/python/training/optim/_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ def param_is_not_tensor_parallel_duplicate(param):

else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
# torch.norm is deprecated and moved to torch.linalg.norm
# with a different signature
# see https://pytorch.org/docs/stable/generated/torch.norm.html
if norm_type in {"fro", "nuc"}:
grad_norm = torch.linalg.matrix_norm(grad, norm_type)
else:
grad_norm = torch.linalg.norm(grad, norm_type)
total_norm += grad_norm**norm_type

if horizontal_model_parallel_grad_norm_aggregation:
Expand Down

0 comments on commit b508c72

Please sign in to comment.