Skip to content

Commit

Permalink
Added new simplelayer SavitskyGolayFilter()
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Baker <[email protected]>
  • Loading branch information
crnbaker committed Jan 8, 2021
1 parent a4ef691 commit d6ad24a
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
3 changes: 2 additions & 1 deletion monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
ChannelPad,
Flatten,
GaussianFilter,
SavitskyGolayFilter,
HilbertTransform,
Reshape,
SkipConnection,
separable_filtering,
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
102 changes: 96 additions & 6 deletions monai/networks/layers/simplelayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
"LLTM",
"Reshape",
"separable_filtering",
"SavitskyGolayFilter",
"HilbertTransform",
"ChannelPad",
"ChannelPad"
]


Expand Down Expand Up @@ -163,14 +164,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.reshape(shape)


def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Tensor:
def separable_filtering(
x: torch.Tensor, kernels: Union[Sequence[torch.Tensor], torch.Tensor], mode: str = "zeros"
) -> torch.Tensor:
"""
Apply 1-D convolutions along each spatial dimension of `x`.
Args:
x: the input image. must have shape (batch, channels, H[, W, ...]).
kernels: kernel along each spatial dimension.
could be a single kernel (duplicated for all dimension), or `spatial_dims` number of kernels.
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
Raises:
TypeError: When ``x`` is not a ``torch.Tensor``.
Expand All @@ -184,23 +189,108 @@ def separable_filtering(x: torch.Tensor, kernels: Union[Sequence[torch.Tensor],
for s in ensure_tuple_rep(kernels, spatial_dims)
]
_paddings = [cast(int, (same_padding(k.shape[0]))) for k in _kernels]
n_chns = x.shape[1]
n_chs = x.shape[1]

def _conv(input_: torch.Tensor, d: int) -> torch.Tensor:
if d < 0:
return input_
s = [1] * len(input_.shape)
s[d + 2] = -1
_kernel = kernels[d].reshape(s)
_kernel = _kernel.repeat([n_chns, 1] + [1] * spatial_dims)
# if filter kernel is unity, don't convolve
if torch.equal(_kernel.squeeze(), torch.ones(1, device=_kernel.device)):
return _conv(input_, d - 1)
_kernel = _kernel.repeat([n_chs, 1] + [1] * spatial_dims)
_padding = [0] * spatial_dims
_padding[d] = _paddings[d]
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chns)

if mode == "zeros": # if zero padding (default), can use functional convolution
conv_type = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
return conv_type(input=_conv(input_, d - 1), weight=_kernel, padding=_padding, groups=n_chs)
else:
conv_type = [
nn.Conv1d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode),
nn.Conv2d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode),
nn.Conv3d(n_chs, n_chs, _kernel.shape, padding=_padding, groups=n_chs, bias=False, padding_mode=mode),
][spatial_dims - 1]
conv_type.weight = torch.nn.Parameter(_kernel, requires_grad=_kernel.requires_grad)
return conv_type(input=_conv(input_, d - 1))

return _conv(x, spatial_dims - 1)


class SavitskyGolayFilter(nn.Module):
"""
Convolve a Tensor along a particular axis with a Savitsky-Golay kernel.
Args:
window_length: Length of the filter window, must be a positive odd integer.
order: Order of the polynomial to fit to each window, must be less than ``window_length``.
axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
"""

def __init__(self, window_length: int, order: int, axis: int = 2, mode: str = "zeros"):

super().__init__()
if order >= window_length:
raise ValueError("order must be less than window_length.")

self.axis = axis
self.mode = mode
self.coeffs = self._make_coeffs(window_length, order)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
have a device type of ``'cpu'``.
Returns:
torch.Tensor: ``x`` filtered by Savitsky-Golay kernel with window length ``self.window_length`` using
polynomials of order ``self.order``, along axis specified in ``self.axis``.
"""

# Make input a real tensor on the CPU
x = torch.as_tensor(x, device=x.device if torch.is_tensor(x) else None)
if torch.is_complex(x):
raise ValueError("x must be real.")
else:
x = x.to(dtype=torch.float)

if (self.axis < 0) or (self.axis > len(x.shape) - 1):
raise ValueError("Invalid axis for shape of x.")

# Create list of filter kernels (1 per spatial dimension). The kernel for self.axis will be the savgol coeffs,
# while the other kernels will be set to [1].
n_spatial_dims = len(x.shape) - 2
spatial_processing_axis = self.axis - 2
new_dims_before = spatial_processing_axis
new_dims_after = n_spatial_dims - spatial_processing_axis - 1
kernel_list = [self.coeffs.to(device=x.device, dtype=x.dtype)]
for _ in range(new_dims_before):
kernel_list.insert(0, torch.ones(1, device=x.device, dtype=x.dtype))
for _ in range(new_dims_after):
kernel_list.append(torch.ones(1, device=x.device, dtype=x.dtype))

return separable_filtering(x, kernel_list, mode=self.mode)

@staticmethod
def _make_coeffs(window_length, order):

half_length, rem = divmod(window_length, 2)
if rem == 0:
raise ValueError("window_length must be odd.")

idx = torch.arange(
window_length - half_length - 1, -half_length - 1, -1, dtype=torch.float, device="cpu"
)
a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
y[0] = 1.0
return torch.lstsq(y, a).solution.squeeze()


class HilbertTransform(nn.Module):
"""
Determine the analytical signal of a Tensor along a particular axis.
Expand Down

0 comments on commit d6ad24a

Please sign in to comment.