diff --git a/Project.toml b/Project.toml index 7626cd4..22d23ba 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,8 @@ version = "1.2.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" diff --git a/src/TestUtils.jl b/src/TestUtils.jl index fecfb52..1ad6b77 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -11,18 +11,22 @@ using LinearAlgebra using Test """ - test_fft_backend(array_constructor) + test_fft_backend(array_constructor; test_real=true, test_inplace=true) -Run tests to verify correctness of all FFT functions based on a particular +Run tests to verify correctness of FFT functions using a particular backend plan implementation. The backend implementation is assumed to be loaded prior to calling this function. -The input `array_constructor` determines the `AbstractArray` implementation for +# Arguments + +- `array_constructor`: determines the `AbstractArray` implementation for which the correctness tests are run. It is assumed to be a callable object that takes in input arrays of type `Array` and return arrays of the desired type for testing: this would most commonly be a constructor such as `Array` or `CuArray`. +- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs. +- `test_inplace=true`: whether to test in-place plans. """ -function test_fft_backend(array_constructor) +function test_fft_backend(array_constructor; test_real=true, test_inplace=true) @testset "fft correctness" begin # DFT along last dimension, results computed using FFTW for (_x, _fftw_fft) in ( @@ -51,82 +55,101 @@ function test_fft_backend(array_constructor) 15.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im; 18.0+0.0im -4.5+2.598076211353316im -4.5-2.598076211353316im]), ) - x = array_constructor(_x) - xcopy_float = array_constructor(copy(float.(x))) - xcopy_complex = array_constructor(copy(complex.(xcopy_float))) + x = array_constructor(_x) # dummy array that will be passed to plans + x_real = float.(x) # for testing real FFTs + x_complex = complex.(x_real) # for testing complex FFTs fftw_fft = array_constructor(_fftw_fft) + dims = ndims(x) # TODO: this is a single dimension, should check multidimensional FFTs too + # FFT - dims = ndims(x) - y = AbstractFFTs.fft(x, dims) - ycopy = array_constructor(copy(y)) + y = AbstractFFTs.fft(x_complex, dims) @test y ≈ fftw_fft + test_inplace && (@test AbstractFFTs.fft!(copy(x_complex), dims) ≈ fftw_fft) # test plan_fft and also inv and plan_inv of plan_ifft, which should all give # functionally identical plans - for P in [plan_fft(x, dims), inv(plan_ifft(x, dims)), - AbstractFFTs.plan_inv(plan_ifft(x, dims))] + plans_to_test = [plan_fft(x, dims), inv(plan_ifft(x, dims)), + AbstractFFTs.plan_inv(plan_ifft(x, dims))] + for P in plans_to_test + @test mul!(similar(y), P, x_complex) ≈ fftw_fft + end + test_inplace && (plans_to_test = vcat(plans_to_test, plan_fft!(similar(x_complex), dims))) + for P in plans_to_test @test eltype(P) <: Complex - @test P * x ≈ fftw_fft - @test mul!(ycopy, P, x) ≈ fftw_fft - @test P \ (P * x) ≈ x + @test P * copy(x_complex) ≈ fftw_fft + @test P \ (P * copy(x_complex)) ≈ x_complex @test fftdims(P) == dims end # BFFT - fftw_bfft = complex.(size(x, dims) .* x) + fftw_bfft = size(x_complex, dims) .* x_complex @test AbstractFFTs.bfft(y, dims) ≈ fftw_bfft - P = plan_bfft(x, dims) - @test P * y ≈ fftw_bfft - @test P \ (P * y) ≈ y - @test mul!(xcopy_complex, P, y) ≈ fftw_bfft - @test fftdims(P) == dims + test_inplace && (@test AbstractFFTs.bfft!(copy(y), dims) ≈ fftw_bfft) + plans_to_test = [plan_bfft(similar(y), dims)] + for P in plans_to_test + @test mul!(similar(x_complex), P, y) ≈ fftw_bfft + end + test_inplace && (plans_to_test = vcat(plans_to_test, plan_bfft!(similar(y), dims))) + for P in plans_to_test + @test eltype(P) <: Complex + @test P * copy(y) ≈ fftw_bfft + @test P \ (P * copy(y)) ≈ y + @test fftdims(P) == dims + end # IFFT - fftw_ifft = complex.(x) + fftw_ifft = x_complex @test AbstractFFTs.ifft(y, dims) ≈ fftw_ifft - for P in [plan_ifft(x, dims), inv(plan_fft(x, dims)), - AbstractFFTs.plan_inv(plan_fft(x, dims))] - @test P * y ≈ fftw_ifft - @test mul!(xcopy_complex, P, y) ≈ fftw_ifft - @test P \ (P * y) ≈ y - @test fftdims(P) == dims + test_inplace && (@test AbstractFFTs.ifft!(copy(y), dims) ≈ fftw_ifft) + plans_to_test = [plan_ifft(x, dims), inv(plan_fft(x, dims)), + AbstractFFTs.plan_inv(plan_fft(x, dims))] + for P in plans_to_test + @test mul!(similar(x_complex), P, y) ≈ fftw_ifft end - - # RFFT - fftw_rfft = fftw_fft[ - (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., - 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) - ] - ry = AbstractFFTs.rfft(x, dims) - rycopy = array_constructor(copy(ry)) - @test ry ≈ fftw_rfft - for P in [plan_rfft(x, dims), inv(plan_irfft(ry, size(x, dims), dims)), - AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))] - @test eltype(P) <: Real - @test P * x ≈ fftw_rfft - @test mul!(rycopy, P, x) ≈ fftw_rfft - @test P \ (P * x) ≈ x + test_inplace && (plan_to_test = vcat(plans_to_test, plan_ifft!(similar(x_complex), dims))) + for P in plans_to_test + @test eltype(P) <: Complex + @test P * copy(y) ≈ fftw_ifft + @test P \ (P * copy(y)) ≈ y @test fftdims(P) == dims end - # BRFFT - fftw_brfft = complex.(size(x, dims) .* x) - @test AbstractFFTs.brfft(ry, size(x, dims), dims) ≈ fftw_brfft - P = plan_brfft(ry, size(x, dims), dims) - @test P * ry ≈ fftw_brfft - @test mul!(xcopy_float, P, ry) ≈ fftw_brfft - @test P \ (P * ry) ≈ ry - @test fftdims(P) == dims - - # IRFFT - fftw_irfft = complex.(x) - @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft - for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x, dims)), - AbstractFFTs.plan_inv(plan_rfft(x, dims))] - @test P * ry ≈ fftw_irfft - @test mul!(xcopy_float, P, ry) ≈ fftw_irfft + if test_real + # RFFT + fftw_rfft = fftw_fft[ + (Colon() for _ in 1:(ndims(fftw_fft) - 1))..., + 1:(size(fftw_fft, ndims(fftw_fft)) ÷ 2 + 1) + ] + ry = AbstractFFTs.rfft(x_real, dims) + @test ry ≈ fftw_rfft + for P in [plan_rfft(x_real, dims), inv(plan_irfft(ry, size(x, dims), dims)), + AbstractFFTs.plan_inv(plan_irfft(ry, size(x, dims), dims))] + @test eltype(P) <: Real + @test P * x_real ≈ fftw_rfft + @test mul!(similar(ry), P, x_real) ≈ fftw_rfft + @test P \ (P * x_real) ≈ x_real + @test fftdims(P) == dims + end + + # BRFFT + fftw_brfft = complex.(size(x, dims) .* x_real) + @test AbstractFFTs.brfft(ry, size(x_real, dims), dims) ≈ fftw_brfft + P = plan_brfft(ry, size(x_real, dims), dims) + @test P * ry ≈ fftw_brfft + @test mul!(similar(x_real), P, ry) ≈ fftw_brfft @test P \ (P * ry) ≈ ry @test fftdims(P) == dims + + # IRFFT + fftw_irfft = x_complex + @test AbstractFFTs.irfft(ry, size(x, dims), dims) ≈ fftw_irfft + for P in [plan_irfft(ry, size(x, dims), dims), inv(plan_rfft(x_real, dims)), + AbstractFFTs.plan_inv(plan_rfft(x_real, dims))] + @test P * ry ≈ fftw_irfft + @test mul!(similar(x_real), P, ry) ≈ fftw_irfft + @test P \ (P * ry) ≈ ry + @test fftdims(P) == dims + end end end end