Skip to content

Commit

Permalink
[core] Rename SpatialDerivativeKeys all() and unmixed() 'ndim' to 'sp…
Browse files Browse the repository at this point in the history
…atial_dims'
  • Loading branch information
aschuh-hf committed Nov 24, 2023
1 parent 64314a6 commit 24f82ad
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions src/deepali/core/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def is_mixed(key: SpatialDerivativeKey) -> bool:
return len(set(key)) > 1

@staticmethod
def all(ndim: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKey]:
def all(spatial_dims: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKey]:
r"""Unmixed spatial derivatives of specified order."""
if isinstance(order, int):
order = [order]
keys = []
dims = [str(SpatialDim(d)) for d in range(ndim)]
dims = [str(SpatialDim(d)) for d in range(spatial_dims)]
for n in order:
if n > 0:
codes = dims
Expand All @@ -221,11 +221,11 @@ def all(ndim: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKe
return keys

@staticmethod
def unmixed(ndim: int, order: int) -> List[SpatialDerivativeKey]:
def unmixed(spatial_dims: int, order: int) -> List[SpatialDerivativeKey]:
r"""Unmixed spatial derivatives of specified order."""
if order <= 0:
return []
return [SpatialDim(d).symbol() * order for d in range(ndim)]
return [SpatialDim(d).symbol() * order for d in range(spatial_dims)]

@classmethod
def unique(cls, keys: Iterable[SpatialDerivativeKey]) -> Set[SpatialDerivativeKey]:
Expand Down Expand Up @@ -466,7 +466,7 @@ def all(
if order == 0:
return []
channels = cls._channels(spatial_dims, channel)
derivs = SpatialDerivativeKeys.all(spatial_dims, order=order)
derivs = SpatialDerivativeKeys.all(spatial_dims=spatial_dims, order=order)
return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)]

@classmethod
Expand All @@ -482,7 +482,7 @@ def unmixed(
if order == 0:
return []
channels = cls._channels(spatial_dims, channel)
derivs = SpatialDerivativeKeys.unmixed(spatial_dims, order=order)
derivs = SpatialDerivativeKeys.unmixed(spatial_dims=spatial_dims, order=order)
return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)]

@classmethod
Expand All @@ -504,7 +504,7 @@ def divergence(cls, spatial_dims: int) -> List[FlowDerivativeKey]:
@classmethod
def curvature(cls, spatial_dims: int) -> List[FlowDerivativeKey]:
channels = range(spatial_dims)
derivs = SpatialDerivativeKeys.unmixed(spatial_dims, order=2)
derivs = SpatialDerivativeKeys.unmixed(spatial_dims=spatial_dims, order=2)
return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)]

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/deepali/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,7 +1536,7 @@ def spatial_derivatives(
if which is None:
if order is None:
order = 1
which = SpatialDerivativeKeys.all(ndim=D, order=order)
which = SpatialDerivativeKeys.all(spatial_dims=D, order=order)
elif order is not None:
which = [arg for arg in which if len(arg) == order]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_core_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_flow_derivative_keys_unique() -> None:
def test_flow_derivative_keys_all() -> None:
for d, order in itertools.product([2, 3], [0, 1, 2]):
channel_keys = ["u", "v", "w"][:d]
spatial_keys = SpatialDerivativeKeys.all(ndim=d, order=order)
spatial_keys = SpatialDerivativeKeys.all(spatial_dims=d, order=order)
expected = [f"d{a}/d{b}" for a, b in itertools.product(channel_keys, spatial_keys)]
assert FlowDerivativeKeys.all(spatial_dims=d, order=order) == expected

Expand Down

0 comments on commit 24f82ad

Please sign in to comment.