Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of dpnp.fft.fft2, dpnp.fft.ifft2, dpnp.fft.fftn, dpnp.fft.ifftn #1961

Merged
merged 17 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 72 additions & 49 deletions dpnp/fft/dpnp_utils_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,61 +242,78 @@ def _copy_array(x, complex_input):
return x, copy_flag


def _extract_axes_chunk(a, chunk_size=3):
def _extract_axes_chunk(a, s, chunk_size=3):
"""
Classify input into a list of list with each list containing
only unique values and its length is at most `chunk_size`.
Classify the first input into a list of lists with each list containing
only unique values in reverse order and its length is at most `chunk_size`.
The second input is also classified into a list of lists with each list
containing the corresponding values of the first input.

Parameters
----------
a : list, tuple
Input.
a : list or tuple of ints
The first input.
s : list or tuple of ints
The second input.
chunk_size : int
Maximum number of elements in each chunk.

Return
------
out : list of lists
List of lists with each list containing only unique values
and its length is at most `chunk_size`.
The final list is returned in reverse order.
out : a tuple of two lists
The first element of output is a list of lists with each list
containing only unique values in revere order and its length is
at most `chunk_size`.
The second element of output is a list of lists with each list
containing the corresponding values of the first input.

Examples
--------
>>> axes = (0, 1, 2, 3, 4)
>>> _extract_axes_chunk(axes, chunk_size=3)
[[2, 3, 4], [0, 1]]
>>> shape = (7, 8, 10, 9, 5)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4, 3], [2, 1, 0]], [[5, 9], [10, 8, 7]])

>>> axes = (0, 1, 2, 3, 4, 4)
>>> _extract_axes_chunk(axes, chunk_size=3)
[[4], [2, 3, 4], [0, 1]]
>>> axes = (1, 0, 3, 2, 4, 4)
>>> shape = (7, 8, 10, 5, 7, 6)
>>> _extract_axes_chunk(axes, shape, chunk_size=3)
([[4], [4, 2], [3, 0, 1]], [[6], [7, 5], [10, 8, 7]])

"""

chunks = []
current_chunk = []
a_chunks = []
a_current_chunk = []
seen_elements = set()

for elem in a:
if elem in seen_elements:
s_chunks = []
s_current_chunk = []

for a_elem, s_elem in zip(a, s):
if a_elem in seen_elements:
# If element is already seen, start a new chunk
chunks.append(current_chunk)
current_chunk = [elem]
seen_elements = {elem}
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = [a_elem]
s_current_chunk = [s_elem]
seen_elements = {a_elem}
else:
current_chunk.append(elem)
seen_elements.add(elem)

if len(current_chunk) == chunk_size:
chunks.append(current_chunk)
current_chunk = []
a_current_chunk.append(a_elem)
s_current_chunk.append(s_elem)
seen_elements.add(a_elem)

if len(a_current_chunk) == chunk_size:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])
a_current_chunk = []
s_current_chunk = []
seen_elements = set()

# Add the last chunk if it's not empty
if current_chunk:
chunks.append(current_chunk)
if a_current_chunk:
a_chunks.append(a_current_chunk[::-1])
s_chunks.append(s_current_chunk[::-1])

return chunks[::-1]
return a_chunks[::-1], s_chunks[::-1]


def _fft(a, norm, out, forward, in_place, c2c, axes=None):
Expand Down Expand Up @@ -392,7 +409,7 @@ def _truncate_or_pad(a, shape, axes):
return a


def _validate_out_keyword(a, out, axis, c2r, r2c):
def _validate_out_keyword(a, out, s, axes, c2r, r2c):
"""Validate out keyword argument."""
if out is not None:
dpnp.check_supported_arrays_type(out)
Expand All @@ -404,16 +421,18 @@ def _validate_out_keyword(a, out, axis, c2r, r2c):
"Input and output allocation queues are not compatible"
)

# validate out shape
expected_shape = a.shape
# validate out shape against the final shape,
# intermediate shapes may vary
expected_shape = list(a.shape)
for s_i, axis in zip(s[::-1], axes[::-1]):
expected_shape[axis] = s_i
if r2c:
expected_shape = list(a.shape)
expected_shape[axis] = a.shape[axis] // 2 + 1
expected_shape = tuple(expected_shape)
if out.shape != expected_shape:
expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1

if out.shape != tuple(expected_shape):
raise ValueError(
"output array has incorrect shape, expected "
f"{expected_shape}, got {out.shape}."
f"{tuple(expected_shape)}, got {out.shape}."
)

# validate out data type
Expand Down Expand Up @@ -477,7 +496,7 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None):

_check_norm(norm)
a = _truncate_or_pad(a, n, axis)
_validate_out_keyword(a, out, axis, c2r, r2c)
_validate_out_keyword(a, out, (n,), (axis,), c2r, r2c)
# if input array is copied, in-place FFT can be used
a, in_place = _copy_array(a, c2c or c2r)
if not in_place and out is not None:
Expand Down Expand Up @@ -519,36 +538,40 @@ def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None):

_validate_s_axes(a, s, axes)
s, axes = _cook_nd_args(a, s, axes)
a = _truncate_or_pad(a, s, axes)
# TODO: None, False, False are place holder for future development of
# TODO: False and False are place holder for future development of
# rfft2, irfft2, rfftn, irfftn
_validate_out_keyword(a, out, None, False, False)
_validate_out_keyword(a, out, s, axes, False, False)
# TODO: True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
a, in_place = _copy_array(a, True)

if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")

len_axes = len(axes)
# OneMKL supports up to 3-dimensional FFT on GPU
# repeated axis in OneMKL FFT is not allowed
if len_axes > 3 or len(set(axes)) < len_axes:
axes_chunk = _extract_axes_chunk(axes, chunk_size=3)
for chunk in axes_chunk:
axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3)
for s_chunk, a_chunk in zip(shape_chunk, axes_chunk):
a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk)
if out is not None and out.shape == a.shape:
tmp_out = out
else:
tmp_out = None
a = _fft(
a,
norm=norm,
out=out,
out=tmp_out,
forward=forward,
in_place=in_place,
# TODO: c2c=True is place holder for future development of
# rfft2, irfft2, rfftn, irfftn
c2c=True,
axes=chunk,
axes=a_chunk,
)
return a

a = _truncate_or_pad(a, s, axes)
if a.size == 0:
return dpnp.get_result_array(a, out=out, casting="same_kind")
if a.ndim == len_axes:
# non-batch FFT
axes = None
Expand Down
51 changes: 44 additions & 7 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def test_fftn_repeated_axes(self, axes):
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)

@pytest.mark.parametrize("axes", [(2, 3, 3, 2), (0, 0, 3, 3)])
@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 7)])
@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 9)])
def test_fftn_repeated_axes_with_s(self, axes, s):
x1 = numpy.random.uniform(-10, 10, 120)
x2 = numpy.random.uniform(-10, 10, 120)
Expand All @@ -495,19 +495,56 @@ def test_fftn_repeated_axes_with_s(self, axes, s):
)
a = dpnp.asarray(a_np)

result = dpnp.fft.fftn(a, axes=axes)
result = dpnp.fft.fftn(a, s=s, axes=axes)
# Intel® NumPy ignores repeated axes, handle it one by one
expected = a_np
for ii in axes:
expected = numpy.fft.fft(expected, axis=ii)
for jj, ii in zip(s[::-1], axes[::-1]):
expected = numpy.fft.fft(expected, n=jj, axis=ii)
assert_dtype_allclose(result, expected, check_only_type_kind=True)

iresult = dpnp.fft.ifftn(result, axes=axes)
iresult = dpnp.fft.ifftn(result, s=s, axes=axes)
iexpected = expected
for ii in axes:
iexpected = numpy.fft.ifft(iexpected, axis=ii)
for jj, ii in zip(s[::-1], axes[::-1]):
iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii)
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)

@pytest.mark.parametrize("axes", [(0, 1, 2, 3), (1, 2, 1, 2), (2, 2, 2, 3)])
@pytest.mark.parametrize("s", [(2, 3, 4, 5), (5, 4, 7, 8), (2, 5, 1, 2)])
def test_fftn_out(self, axes, s):
x1 = numpy.random.uniform(-10, 10, 120)
x2 = numpy.random.uniform(-10, 10, 120)
a_np = numpy.array(x1 + 1j * x2, dtype=numpy.complex64).reshape(
2, 3, 4, 5
)
a = dpnp.asarray(a_np)

out_shape = list(a.shape)
for s_i, axis in zip(s[::-1], axes[::-1]):
out_shape[axis] = s_i
result = dpnp.empty(out_shape, dtype=a.dtype)
dpnp.fft.fftn(a, out=result, s=s, axes=axes)
# Intel® NumPy ignores repeated axes, handle it one by one
expected = a_np
for jj, ii in zip(s[::-1], axes[::-1]):
expected = numpy.fft.fft(expected, n=jj, axis=ii)
assert_dtype_allclose(result, expected, check_only_type_kind=True)

iresult = dpnp.empty(out_shape, dtype=a.dtype)
dpnp.fft.ifftn(result, out=iresult, s=s, axes=axes)
iexpected = expected
for jj, ii in zip(s[::-1], axes[::-1]):
iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii)
assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True)

def test_negative_s(self):
# stock NumPy 2.0, if s is -1, the whole input is used (no padding/trimming).
a_np = numpy.empty((3, 4, 5), dtype=numpy.complex64)
a = dpnp.array(a_np)

result = dpnp.fft.fftn(a, s=(-1, -1), axes=(0, 2))
expected = numpy.fft.fftn(a_np, s=(3, 5), axes=(0, 2))
assert_dtype_allclose(result, expected, check_only_type_kind=True)

def test_fftn_empty_array(self):
a_np = numpy.empty((10, 0, 4), dtype=numpy.complex64)
a = dpnp.array(a_np)
Expand Down
2 changes: 1 addition & 1 deletion tests/third_party/cupy/fft_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_ifft2(self, xp, dtype, order):
{"shape": (3, 4), "s": (1, 5), "axes": (0, 1)},
{"shape": (3, 4), "s": None, "axes": (-2, -1)},
{"shape": (3, 4), "s": None, "axes": (-1, -2)},
{"shape": (3, 4), "s": None, "axes": [-1, -2]},
{"shape": (3, 4), "s": None, "axes": (-1, -2)},
vtavana marked this conversation as resolved.
Show resolved Hide resolved
# {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109
# {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108
{"shape": (3, 4), "s": None, "axes": None},
Expand Down
Loading