Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding=same, support half-precision input #18

Merged
merged 3 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions fft_conv_pytorch/fft_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as f
from torch import Tensor, nn
from torch.fft import irfftn, rfftn
from math import ceil, floor


def complex_matmul(a: Tensor, b: Tensor, groups: int = 1) -> Tensor:
Expand Down Expand Up @@ -55,7 +56,7 @@ def fft_conv(
signal: Tensor,
kernel: Tensor,
bias: Tensor = None,
padding: Union[int, Iterable[int]] = 0,
padding: Union[int, Iterable[int], str] = 0,
padding_mode: str = "constant",
stride: Union[int, Iterable[int]] = 1,
dilation: Union[int, Iterable[int]] = 1,
Expand All @@ -69,19 +70,31 @@ def fft_conv(
signal: (Tensor) Input tensor to be convolved with the kernel.
kernel: (Tensor) Convolution kernel.
bias: (Tensor) Bias tensor to add to the output.
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
input on the last dimension.
padding: (Union[int, Iterable[int], str) If int, Number of zero samples to pad then
input on the last dimension. If str, "same" supported to pad input for size preservation.
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
reflection not available for 3d.
stride: (Union[int, Iterable[int]) Stride size for computing output values.
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
groups: (int) Number of groups for the convolution.

Returns:
(Tensor) Convolved tensor
"""

# Cast padding, stride & dilation to tuples.
n = signal.ndim - 2
padding_ = to_ntuple(padding, n=n)
stride_ = to_ntuple(stride, n=n)
dilation_ = to_ntuple(dilation, n=n)
if isinstance(padding, str):
if padding == "same":
if stride != 1 or dilation != 1:
raise ValueError("stride must be 1 for padding='same'.")
padding_ = [(k - 1) / 2 for k in kernel.shape[2:]]
else:
raise ValueError(f"Padding mode {padding} not supported.")
else:
padding_ = to_ntuple(padding, n=n)

# internal dilation offsets
offset = torch.zeros(1, 1, *dilation_, device=signal.device, dtype=signal.dtype)
Expand All @@ -93,36 +106,34 @@ def fft_conv(
# pad the kernel internally according to the dilation parameters
kernel = torch.kron(kernel, offset)[(slice(None), slice(None)) + cutoff]

# Pad the input signal & kernel tensors
signal_padding = [p for p in padding_[::-1] for _ in range(2)]
# Pad the input signal & kernel tensors (round to support even sized convolutions)
signal_padding = [r(p) for p in padding_[::-1] for r in (floor, ceil)]
signal = f.pad(signal, signal_padding, mode=padding_mode)

# Because PyTorch computes a *one-sided* FFT, we need the final dimension to
# have *even* length. Just pad with one more zero if the final dimension is odd.
signal_size = signal.size() # original signal size without padding to even
if signal.size(-1) % 2 != 0:
signal_ = f.pad(signal, [0, 1])
else:
signal_ = signal
signal = f.pad(signal, [0, 1])

kernel_padding = [
pad
for i in reversed(range(2, signal_.ndim))
for pad in [0, signal_.size(i) - kernel.size(i)]
for i in reversed(range(2, signal.ndim))
for pad in [0, signal.size(i) - kernel.size(i)]
]
padded_kernel = f.pad(kernel, kernel_padding)

# Perform fourier convolution -- FFT, matrix multiply, then IFFT
# signal_ = signal_.reshape(signal_.size(0), groups, -1, *signal_.shape[2:])
signal_fr = rfftn(signal_, dim=tuple(range(2, signal.ndim)))
kernel_fr = rfftn(padded_kernel, dim=tuple(range(2, signal.ndim)))
signal_fr = rfftn(signal.float(), dim=tuple(range(2, signal.ndim)))
kernel_fr = rfftn(padded_kernel.float(), dim=tuple(range(2, signal.ndim)))

kernel_fr.imag *= -1
output_fr = complex_matmul(signal_fr, kernel_fr, groups=groups)
output = irfftn(output_fr, dim=tuple(range(2, signal.ndim)))

# Remove extra padded values
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
slice(0, (signal.size(i) - kernel.size(i) + 1), stride_[i - 2])
crop_slices = [slice(None), slice(None)] + [
slice(0, (signal_size[i] - kernel.size(i) + 1), stride_[i - 2])
for i in range(2, signal.ndim)
]
output = output[crop_slices].contiguous()
Expand Down Expand Up @@ -157,9 +168,14 @@ def __init__(
out_channels: (int) Number of channels in output tensors
kernel_size: (Union[int, Iterable[int]) Square radius of the kernel
padding: (Union[int, Iterable[int]) Number of zero samples to pad the
input on the last dimension.
input on the last dimension. If str, "same" supported to pad input for size preservation.
padding_mode: (str) Padding mode to use from {constant, reflection, replication}.
reflection not available for 3d.
stride: (Union[int, Iterable[int]) Stride size for computing output values.
dilation: (Union[int, Iterable[int]) Dilation rate for the kernel.
groups: (int) Number of groups for the convolution.
bias: (bool) If True, includes bias, which is added after convolution
ndim: (int) Number of dimensions of the input tensor.
"""
super().__init__()
self.in_channels = in_channels
Expand Down
12 changes: 10 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.parametrize("out_channels", [2, 3])
@pytest.mark.parametrize("groups", [1, 2, 3])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("padding", [0, 1])
@pytest.mark.parametrize("padding", [0, 1, "same"])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("bias", [True])
Expand All @@ -30,6 +30,10 @@ def test_fft_conv_functional(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))

Expand Down Expand Up @@ -70,7 +74,7 @@ def test_fft_conv_functional(
@pytest.mark.parametrize("out_channels", [2, 3])
@pytest.mark.parametrize("groups", [1, 2, 3])
@pytest.mark.parametrize("kernel_size", [2, 3])
@pytest.mark.parametrize("padding", [0, 1])
@pytest.mark.parametrize("padding", [0, 1, "same"])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("bias", [True])
Expand All @@ -88,6 +92,10 @@ def test_fft_conv_backward_functional(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))

Expand Down
8 changes: 8 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_fft_conv_module(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))
fft_conv_layer = _FFTConv(
Expand Down Expand Up @@ -85,6 +89,10 @@ def test_fft_conv_backward_module(
ndim: int,
input_size: int,
):
if padding == "same" and (stride != 1 or dilation != 1):
# padding='same' is not compatible with strided convolutions
return

torch_conv = getattr(f, f"conv{ndim}d")
groups = _gcd(in_channels, _gcd(out_channels, groups))
fft_conv_layer = _FFTConv(
Expand Down