Skip to content

Commit

Permalink
Apply code review suggestions and refactor TestUtils
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 9, 2023
1 parent 98fdcde commit 3537f76
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 86 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
AbstractFFTsTestUtilsExt = "Test"
AbstractFFTsTestExt = "Test"

[compat]
ChainRulesCore = "1"
Expand Down
10 changes: 7 additions & 3 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ length ``n``, and the "backwards" (unnormalized inverse) transform computes the
## Testing implementations

`AbstractFFTs.jl` provides a `TestUtils` module to help with testing downstream implementations.

The following functions test that all FFT functionality has been correctly implemented:
```@docs
AbstractFFTs.TestUtils.test_complex_ffts
AbstractFFTs.TestUtils.test_real_ffts
```
`TestUtils` also exposes lower level functions for generically testing particular plans:
```@docs
AbstractFFTs.TestUtils.test_complex_fft
AbstractFFTs.TestUtils.test_real_fft
AbstractFFTs.TestUtils.test_plan
AbstractFFTs.TestUtils.test_plan_adjoint
```
111 changes: 46 additions & 65 deletions ext/AbstractFFTsTestUtilsExt.jl → ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license

module AbstractFFTsTestUtilsExt
module AbstractFFTsTestExt

using AbstractFFTs
using AbstractFFTs: TestUtils
using AbstractFFTs.LinearAlgebra
using Test

# Ground truth _x_fft computed using FFTW library
# Ground truth x_fft computed using FFTW library
const TEST_CASES = (
(; x = collect(1:7), dims = 1,
x_fft = [28.0 + 0.0im,
Expand Down Expand Up @@ -51,29 +51,47 @@ const TEST_CASES = (
dims=3)),
)

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false)
y = rand(eltype(P * x), size(P * x))

function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
if !inplace_plan
@test P * _copy(x) x_transformed
@test P \ (P * _copy(x)) x
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) x_transformed
@test _x_out x_transformed
else
_x = copy(x)
@test P * _copy(_x) x_transformed
@test _x x_transformed
@test P \ _copy(_x) x
@test _x x
end
end

function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false)
_copy = copy_input ? copy : identity
y = rand(eltype(P * _copy(x)), size(P * _copy(x)))
# test basic properties
@test_broken eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test fftdims(P') == fftdims(P)
@test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110)
@test (P')' === P # test adjoint of adjoint
@test size(P') == AbstractFFTs.output_size(P) # test size of adjoint
# test correctness of adjoint and its inverse via the dot test
if !real_plan
@test dot(y, P * x) dot(P' * y, x)
@test dot(y, P \ x) dot(P' \ y, x)
@test dot(y, P * _copy(x)) dot(P' * _copy(y), x)
@test dot(y, P \ _copy(x)) dot(P' \ _copy(y), x)
else
_component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y))
@test _component_dot(y, P * copy(x)) _component_dot(P' * copy(y), x)
@test _component_dot(x, P \ copy(y)) _component_dot(P' \ copy(x), y)
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
end
@test_throws MethodError mul!(x, P', y)
end

function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true)
@testset "correctness of fft, bfft, ifft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs
x_fft = convert(ArrayType, _x_fft)
Expand All @@ -90,25 +108,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
for P in (plan_fft(similar(x_complexf), dims), inv(plan_ifft(similar(x_complexf), dims)))
@test eltype(P) <: Complex
@test fftdims(P) == dims
@test P * x x_fft
@test P \ (P * x) x
_x_out = similar(x_fft)
@test mul!(_x_out, P, x_complexf) x_fft
@test _x_out x_fft
TestUtils.test_plan(P, x_complexf, x_fft)
if test_adjoint
@test fftdims(P') == fftdims(P)
TestUtils.test_plan_adjoint(P, x_complexf)
end
end
if test_inplace
# test IIP plans
for P in (plan_fft!(similar(x_complexf), dims), inv(plan_ifft!(similar(x_complexf), dims)))
@test eltype(P) <: Complex
@test fftdims(P) == dims
_x_complexf = copy(x_complexf)
@test P * _x_complexf x_fft
@test _x_complexf x_fft
@test P \ _x_complexf x
@test _x_complexf x
TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true)
end
end

Expand All @@ -124,24 +133,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
for P in (plan_bfft(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
@test P * x_fft x_scaled
@test P \ (P * x_fft) x_fft
_x_complexf = similar(x_complexf)
@test mul!(_x_complexf, P, x_fft) x_scaled
@test _x_complexf x_scaled
TestUtils.test_plan(P, x_fft, x_scaled)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_complexf)
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
for P in (plan_bfft!(similar(x_fft), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
_x_fft = copy(x_fft)
@test P * _x_fft x_scaled
@test _x_fft x_scaled
@test P \ _x_fft x_fft
@test _x_fft x_fft
TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true)
end

# IFFT
Expand All @@ -155,35 +156,27 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj
for P in (plan_ifft(similar(x_complexf), dims), inv(plan_fft(similar(x_complexf), dims)))
@test eltype(P) <: Complex
@test fftdims(P) == dims
@test P * x_fft x
@test P \ (P * x_fft) x_fft
_x_complexf = similar(x_complexf)
@test mul!(_x_complexf, P, x_fft) x
@test _x_complexf x
TestUtils.test_plan(P, x_fft, x)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_complexf)
TestUtils.test_plan_adjoint(P, x_fft)
end
end
# test IIP plans
if test_inplace
for P in (plan_ifft!(similar(x_complexf), dims), inv(plan_fft!(similar(x_complexf), dims)))
@test eltype(P) <: Complex
@test fftdims(P) == dims
_x_fft = copy(x_fft)
@test P * _x_fft x
@test _x_fft x
@test P \ _x_fft x_fft
@test _x_fft x_fft
TestUtils.test_plan(P, x_fft, x; inplace_plan=true)
end
end
end
end
end

function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoint=true)
function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
@testset "correctness of rfft, brfft, irfft" begin
for test_case in TEST_CASES
_x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft
_x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft)
x = convert(ArrayType, _x) # dummy array that will be passed to plans
x_real = float.(x) # for testing mutating real FFTs
x_fft = convert(ArrayType, _x_fft)
Expand All @@ -198,14 +191,9 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin
for P in (plan_rfft(similar(x_real), dims), inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims)))
@test eltype(P) <: Real
@test fftdims(P) == dims
# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
@test P * copy(x) x_rfft
@test P \ (P * copy(x)) x
_x_rfft = similar(x_rfft)
@test mul!(_x_rfft, P, copy(x_real)) x_rfft
@test _x_rfft x_rfft
TestUtils.test_plan(P, x_real, x_rfft; copy_input)
if test_adjoint
TestUtils.test_plan_adjoint(P, x_real; real_plan=true)
TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input)
end
end

Expand All @@ -215,25 +203,18 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin
for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),)
@test eltype(P) <: Complex
@test fftdims(P) == dims
@test P * copy(x_rfft) x_scaled
@test P \ (P * copy(x_rfft)) x_rfft
_x_scaled = similar(x_real)
@test mul!(_x_scaled, P, copy(x_rfft)) x_scaled
@test _x_scaled x_scaled
TestUtils.test_plan(P, x_rfft, x_scaled; copy_input)
end

# IRFFT
@test irfft(x_rfft, size(x, first(dims)), dims) x
for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), inv(plan_rfft(similar(x_real), dims)))
@test eltype(P) <: Complex
@test fftdims(P) == dims
@test P * copy(x_rfft) x
@test P \ (P * copy(x_rfft)) x_rfft
_x_real = similar(x_real)
@test mul!(_x_real, P, copy(x_rfft)) x_real
TestUtils.test_plan(P, x_rfft, x; copy_input)
end
end
end
end

end
end
2 changes: 1 addition & 1 deletion src/AbstractFFTs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include("TestUtils.jl")

if !isdefined(Base, :get_extension)
include("../ext/AbstractFFTsChainRulesCoreExt.jl")
include("../ext/AbstractFFTsTestUtilsExt.jl")
include("../ext/AbstractFFTsTestExt.jl")
end

end # module
44 changes: 30 additions & 14 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module TestUtils

"""
TestUtils.test_complex_fft(ArrayType=Array; test_real=true, test_inplace=true)
TestUtils.test_complex_ffts(ArrayType=Array; test_adjoint=true, test_inplace=true)
Run tests to verify correctness of FFT/BFFT/IFFT functionality using a particular backend plan implementation.
Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
# Arguments
Expand All @@ -14,43 +14,59 @@ The backend implementation is assumed to be loaded prior to calling this functio
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
"""
function test_complex_fft end
function test_complex_ffts end

"""
TestUtils.test_real_fft(ArrayType=Array; test_real=true, test_inplace=true)
TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false)
Run tests to verify correctness of RFFT/BRFFT/IRFFT functionality using a particular backend plan implementation.
Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation.
The backend implementation is assumed to be loaded prior to calling this function.
# Arguments
- `ArrayType`: determines the `AbstractArray` implementation for
which the correctness tests are run. Arrays are constructed via
`convert(ArrayType, ...)`.
- `test_inplace=true`: whether to test in-place plans.
- `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint).
- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for
[input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101).
"""
function test_real_fft end
function test_real_ffts end

# Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101)
"""
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false)
TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray;
inplace_plan=false, copy_input=false)
Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`.
Test basic properties of the adjoint `P'` of a particular plan given an input array to the plan `x`,
including its accuracy via the dot test. Real-to-complex and complex-to-real plans require
a slightly modified dot test, in which case `real_plan=true` should be provided.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan end

"""
TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false)
Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`,
including its accuracy via the dot test.
Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided.
The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans.
Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101),
we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan.
"""
function test_plan_adjoint end

function __init__()
if isdefined(Base, :Experimental)
if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint)
# Better error message if users forget to load Test
Base.Experimental.register_error_hint(MethodError) do io, exc, _, _
if exc.f in (test_real_fft, test_complex_fft)
if (exc.f === test_real_fft || exc.f === test_complex_fft) && Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing
print(io, "\nDid you forget to load Test?")
end
end
end
end

end
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ Random.seed!(1234)
include("TestPlans.jl")

# Run interface tests for TestPlans
AbstractFFTs.TestUtils.test_complex_fft(Array)
AbstractFFTs.TestUtils.test_real_fft(Array)
AbstractFFTs.TestUtils.test_complex_ffts(Array)
AbstractFFTs.TestUtils.test_real_ffts(Array)

@testset "rfft sizes" begin
A = rand(11, 10)
Expand Down

0 comments on commit 3537f76

Please sign in to comment.