Skip to content

Commit

Permalink
[core] Convert to Tensor.float() only if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 7, 2023
1 parent 121999f commit 14629bb
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/deepali/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,7 +1306,8 @@ def clamp_fn(data: Tensor, a: float, b: float) -> Tensor:
return data.clamp_(a, b)

else:
data = data.float()
if not data.is_floating_point():
data = data.float()

def add_fn(data: Tensor, a: float) -> Tensor:
return data.add(a)
Expand Down Expand Up @@ -1546,7 +1547,8 @@ def spatial_derivatives(
unique_keys = SpatialDerivativeKeys.unique(which)
max_order = SpatialDerivativeKeys.max_order(which)

data = data.float()
if not data.is_floating_point():
data = data.float()

if mode is None:
mode = "forward_central_backward"
Expand Down Expand Up @@ -1699,7 +1701,8 @@ def finite_differences(
if step_size.ndim > 1 or step_size.shape[0] not in (1, N):
raise ValueError(f"finite_differences() 'spacing' must be scalar or sequence of length {N}")

data = data.float()
if not data.is_floating_point():
data = data.float()

def pad_spatial_dim(data: Tensor, left: int, right: int) -> Tensor:
pad = [(left, right) if d == spatial_dim else (0, 0) for d in range(data.ndim - 2)]
Expand Down

0 comments on commit 14629bb

Please sign in to comment.