Skip to content

Commit

Permalink
[losses] Mostly cosmetic changes in grad_loss() functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Dec 14, 2023
1 parent 6170778 commit 838a990
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/deepali/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,6 @@ def grad_loss(
"grad_loss() not implemented for linear transformation and 'reduction'='none'"
)
return torch.tensor(0, dtype=u.dtype, device=u.device)
N = u.shape[0]
D = u.shape[1]
if u.ndim - 2 != D:
raise ValueError("grad_loss() 'u' must be tensor of shape (N, D, ..., X)")
Expand All @@ -1167,9 +1166,11 @@ def grad_loss(
deriv = {k: v.pow_(p) for k, v in deriv.items()}
else:
deriv = {k: v.abs_().pow_(p) for k, v in deriv.items()}
loss = torch.zeros((N, 1) + u.shape[2:], dtype=u.dtype, device=u.device)
loss: Optional[Tensor] = None
for value in deriv.values():
loss = loss.add_(value.sum(dim=1, keepdim=True))
value = value.sum(dim=1, keepdim=True)
loss = value if loss is None else loss.add_(value)
assert loss is not None
if q == 0:
loss.abs_()
elif q != 1:
Expand Down Expand Up @@ -1299,7 +1300,8 @@ def curvature_loss(
kwargs = dict(mode=mode or "sobel", sigma=sigma, spacing=spacing, stride=stride)
which = FlowDerivativeKeys.curvature(spatial_dims=D)
deriv = flow_derivatives(u, which=which, **kwargs)
loss = torch.zeros((N, D) + u.shape[2:], dtype=u.dtype, device=u.device)
shape = deriv["du/dxx"].shape
loss = torch.zeros((N, D) + shape[2:], dtype=u.dtype, device=u.device)
for i, j in itertools.product(range(D), repeat=2):
loss.narrow(1, i, 1).add_(deriv[FlowDerivativeKeys.symbol(i, j, j)])
loss = loss.square_().sum(dim=1, keepdim=True)
Expand Down

0 comments on commit 838a990

Please sign in to comment.