Skip to content

Commit

Permalink
Add tests for inplace plans
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Aug 31, 2022
1 parent 315b9ae commit ae78179
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 59 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ version = "1.2.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Expand Down
141 changes: 82 additions & 59 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ using LinearAlgebra
using Test

"""
test_fft_backend(array_constructor)
test_fft_backend(array_constructor; test_real=true, test_inplace=true)
Run tests to verify correctness of all FFT functions based on a particular
Run tests to verify correctness of FFT functions using a particular
backend plan implementation. The backend implementation is assumed to be loaded
prior to calling this function.
The input `array_constructor` determines the `AbstractArray` implementation for
# Arguments
- `array_constructor`: determines the `AbstractArray` implementation for
which the correctness tests are run. It is assumed to be a callable object that
takes in input arrays of type `Array` and return arrays of the desired type for
testing: this would most commonly be a constructor such as `Array` or `CuArray`.
- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs.
- `test_inplace=true`: whether to test in-place plans.
"""
function test_fft_backend(array_constructor)
function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
@testset "fft correctness" begin
# DFT along last dimension, results computed using FFTW
for (_x, _fftw_fft) in (
Expand Down Expand Up @@ -51,82 +55,101 @@ function test_fft_backend(array_constructor)
15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im;
18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]),
)
x = array_constructor(_x)
xcopy_float = array_constructor(copy(float.(x)))
xcopy_complex = array_constructor(copy(complex.(xcopy_float)))
x = array_constructor(_x) # dummy array that will be passed to plans
x_real = float.(x) # for testing real FFTs
x_complex = complex.(x_real) # for testing complex FFTs
fftw_fft = array_constructor(_fftw_fft)

dims = ndims(x) # TODO: this is a single dimension, should check multidimensional FFTs too

# FFT
dims = ndims(x)
y = AbstractFFTs.fft(x, dims)
ycopy = array_constructor(copy(y))
y = AbstractFFTs.fft(x_complex, dims)
@test y fftw_fft
test_inplace && (@test AbstractFFTs.fft!(copy(x_complex), dims) fftw_fft)
# test plan_fft and also inv and plan_inv of plan_ifft, which should all give
# functionally identical plans
for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)),
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
plans_to_test = [plan_fft(x, dims), inv(plan_ifft(x, dims)),
AbstractFFTs.plan_inv(plan_ifft(x, dims))]
for P in plans_to_test
@test mul!(similar(y), P, x_complex) fftw_fft
end
test_inplace && (plans_to_test = vcat(plans_to_test, plan_fft!(similar(x_complex), dims)))
for P in plans_to_test
@test eltype(P) <: Complex
@test P * x fftw_fft
@test mul!(ycopy, P, x) fftw_fft
@test P \ (P * x) x
@test P * copy(x_complex) fftw_fft
@test P \ (P * copy(x_complex)) x_complex
@test fftdims(P) == dims
end

# BFFT
fftw_bfft = complex.(size(x, dims) .* x)
fftw_bfft = size(x_complex, dims) .* x_complex
@test AbstractFFTs.bfft(y, dims) fftw_bfft
P = plan_bfft(x, dims)
@test P * y fftw_bfft
@test P \ (P * y) y
@test mul!(xcopy_complex, P, y) fftw_bfft
@test fftdims(P) == dims
test_inplace && (@test AbstractFFTs.bfft!(copy(y), dims) fftw_bfft)
plans_to_test = [plan_bfft(similar(y), dims)]
for P in plans_to_test
@test mul!(similar(x_complex), P, y) fftw_bfft
end
test_inplace && (plans_to_test = vcat(plans_to_test, plan_bfft!(similar(y), dims)))
for P in plans_to_test
@test eltype(P) <: Complex
@test P * copy(y) fftw_bfft
@test P \ (P * copy(y)) y
@test fftdims(P) == dims
end

# IFFT
fftw_ifft = complex.(x)
fftw_ifft = x_complex
@test AbstractFFTs.ifft(y, dims) fftw_ifft
for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)),
AbstractFFTs.plan_inv(plan_fft(x, dims))]
@test P * y fftw_ifft
@test mul!(xcopy_complex, P, y) fftw_ifft
@test P \ (P * y) y
@test fftdims(P) == dims
test_inplace && (@test AbstractFFTs.ifft!(copy(y), dims) fftw_ifft)
plans_to_test = [plan_ifft(x, dims), inv(plan_fft(x, dims)),
AbstractFFTs.plan_inv(plan_fft(x, dims))]
for P in plans_to_test
@test mul!(similar(x_complex), P, y) fftw_ifft
end

# RFFT
fftw_rfft = fftw_fft[
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
]
ry = AbstractFFTs.rfft(x, dims)
rycopy = array_constructor(copy(ry))
@test ry fftw_rfft
for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)),
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
@test eltype(P) <: Real
@test P * x fftw_rfft
@test mul!(rycopy, P, x) fftw_rfft
@test P \ (P * x) x
test_inplace && (plan_to_test = vcat(plans_to_test, plan_ifft!(similar(x_complex), dims)))
for P in plans_to_test
@test eltype(P) <: Complex
@test P * copy(y) fftw_ifft
@test P \ (P * copy(y)) y
@test fftdims(P) == dims
end

# BRFFT
fftw_brfft = complex.(size(x, dims) .* x)
@test AbstractFFTs.brfft(ry, size(x, dims), dims) fftw_brfft
P = plan_brfft(ry, size(x, dims), dims)
@test P * ry fftw_brfft
@test mul!(xcopy_float, P, ry) fftw_brfft
@test P \ (P * ry) ry
@test fftdims(P) == dims

# IRFFT
fftw_irfft = complex.(x)
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)),
AbstractFFTs.plan_inv(plan_rfft(x, dims))]
@test P * ry fftw_irfft
@test mul!(xcopy_float, P, ry) fftw_irfft
if test_real
# RFFT
fftw_rfft = fftw_fft[
(Colon() for _ in 1:(ndims(fftw_fft) - 1))...,
1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1)
]
ry = AbstractFFTs.rfft(x_real, dims)
@test ry fftw_rfft
for P in [plan_rfft(x_real, dims), inv(plan_irfft(ry, size(x, dims), dims)),
AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))]
@test eltype(P) <: Real
@test P * x_real fftw_rfft
@test mul!(similar(ry), P, x_real) fftw_rfft
@test P \ (P * x_real) x_real
@test fftdims(P) == dims
end

# BRFFT
fftw_brfft = complex.(size(x, dims) .* x_real)
@test AbstractFFTs.brfft(ry, size(x_real, dims), dims) fftw_brfft
P = plan_brfft(ry, size(x_real, dims), dims)
@test P * ry fftw_brfft
@test mul!(similar(x_real), P, ry) fftw_brfft
@test P \ (P * ry) ry
@test fftdims(P) == dims

# IRFFT
fftw_irfft = x_complex
@test AbstractFFTs.irfft(ry, size(x, dims), dims) fftw_irfft
for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x_real, dims)),
AbstractFFTs.plan_inv(plan_rfft(x_real, dims))]
@test P * ry fftw_irfft
@test mul!(similar(x_real), P, ry) fftw_irfft
@test P \ (P * ry) ry
@test fftdims(P) == dims
end
end
end
end
Expand Down

0 comments on commit ae78179

Please sign in to comment.