From 3537f76a4043ea20e88107da175c0a731bf13e67 Mon Sep 17 00:00:00 2001 From: Gaurav Arya Date: Sat, 8 Jul 2023 23:43:09 -0400 Subject: [PATCH] Apply code review suggestions and refactor TestUtils --- Project.toml | 2 +- docs/src/implementations.md | 10 +- ...TestUtilsExt.jl => AbstractFFTsTestExt.jl} | 111 ++++++++---------- src/AbstractFFTs.jl | 2 +- src/TestUtils.jl | 44 ++++--- test/runtests.jl | 4 +- 6 files changed, 87 insertions(+), 86 deletions(-) rename ext/{AbstractFFTsTestUtilsExt.jl => AbstractFFTsTestExt.jl} (69%) diff --git a/Project.toml b/Project.toml index f069e4f..6404c36 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" -AbstractFFTsTestUtilsExt = "Test" +AbstractFFTsTestExt = "Test" [compat] ChainRulesCore = "1" diff --git a/docs/src/implementations.md b/docs/src/implementations.md index 5c8303c..6db6ee0 100644 --- a/docs/src/implementations.md +++ b/docs/src/implementations.md @@ -43,9 +43,13 @@ length ``n``, and the "backwards" (unnormalized inverse) transform computes the ## Testing implementations `AbstractFFTs.jl` provides a `TestUtils` module to help with testing downstream implementations. - +The following functions test that all FFT functionality has been correctly implemented: +```@docs +AbstractFFTs.TestUtils.test_complex_ffts +AbstractFFTs.TestUtils.test_real_ffts +``` +`TestUtils` also exposes lower level functions for generically testing particular plans: ```@docs -AbstractFFTs.TestUtils.test_complex_fft -AbstractFFTs.TestUtils.test_real_fft +AbstractFFTs.TestUtils.test_plan AbstractFFTs.TestUtils.test_plan_adjoint ``` diff --git a/ext/AbstractFFTsTestUtilsExt.jl b/ext/AbstractFFTsTestExt.jl similarity index 69% rename from ext/AbstractFFTsTestUtilsExt.jl rename to ext/AbstractFFTsTestExt.jl index 1d6efc8..6d8fbec 100644 --- a/ext/AbstractFFTsTestUtilsExt.jl +++ b/ext/AbstractFFTsTestExt.jl @@ -1,13 +1,13 @@ # This file contains code that was formerly part of Julia. License is MIT: https://julialang.org/license -module AbstractFFTsTestUtilsExt +module AbstractFFTsTestExt using AbstractFFTs using AbstractFFTs: TestUtils using AbstractFFTs.LinearAlgebra using Test -# Ground truth _x_fft computed using FFTW library +# Ground truth x_fft computed using FFTW library const TEST_CASES = ( (; x = collect(1:7), dims = 1, x_fft = [28.0 + 0.0im, @@ -51,29 +51,47 @@ const TEST_CASES = ( dims=3)), ) -function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false) - y = rand(eltype(P * x), size(P * x)) + +function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transformed::AbstractArray; inplace_plan=false, copy_input=false) + _copy = copy_input ? copy : identity + if !inplace_plan + @test P * _copy(x) ≈ x_transformed + @test P \ (P * _copy(x)) ≈ x + _x_out = similar(P * _copy(x)) + @test mul!(_x_out, P, _copy(x)) ≈ x_transformed + @test _x_out ≈ x_transformed + else + _x = copy(x) + @test P * _copy(_x) ≈ x_transformed + @test _x ≈ x_transformed + @test P \ _copy(_x) ≈ x + @test _x ≈ x + end +end + +function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; real_plan=false, copy_input=false) + _copy = copy_input ? copy : identity + y = rand(eltype(P * _copy(x)), size(P * _copy(x))) # test basic properties - @test_broken eltype(P') === typeof(y) # (AbstractFFTs.jl#110) - @test fftdims(P') == fftdims(P) + @test_skip eltype(P') === typeof(y) # (AbstractFFTs.jl#110) @test (P')' === P # test adjoint of adjoint @test size(P') == AbstractFFTs.output_size(P) # test size of adjoint # test correctness of adjoint and its inverse via the dot test if !real_plan - @test dot(y, P * x) ≈ dot(P' * y, x) - @test dot(y, P \ x) ≈ dot(P' \ y, x) + @test dot(y, P * _copy(x)) ≈ dot(P' * _copy(y), x) + @test dot(y, P \ _copy(x)) ≈ dot(P' \ _copy(y), x) else _component_dot(x, y) = dot(real.(x), real.(y)) + dot(imag.(x), imag.(y)) - @test _component_dot(y, P * copy(x)) ≈ _component_dot(P' * copy(y), x) - @test _component_dot(x, P \ copy(y)) ≈ _component_dot(P' \ copy(x), y) + @test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x) + @test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y) end @test_throws MethodError mul!(x, P', y) end -function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adjoint=true) +function TestUtils.test_complex_ffts(ArrayType=Array; test_inplace=true, test_adjoint=true) @testset "correctness of fft, bfft, ifft" begin for test_case in TEST_CASES - _x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) x = convert(ArrayType, _x) # dummy array that will be passed to plans x_complexf = convert(ArrayType, complex.(float.(x))) # for testing mutating complex FFTs x_fft = convert(ArrayType, _x_fft) @@ -90,25 +108,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj for P in (plan_fft(similar(x_complexf), dims), inv(plan_ifft(similar(x_complexf), dims))) @test eltype(P) <: Complex @test fftdims(P) == dims - @test P * x ≈ x_fft - @test P \ (P * x) ≈ x - _x_out = similar(x_fft) - @test mul!(_x_out, P, x_complexf) ≈ x_fft - @test _x_out ≈ x_fft + TestUtils.test_plan(P, x_complexf, x_fft) if test_adjoint + @test fftdims(P') == fftdims(P) TestUtils.test_plan_adjoint(P, x_complexf) end end if test_inplace # test IIP plans for P in (plan_fft!(similar(x_complexf), dims), inv(plan_ifft!(similar(x_complexf), dims))) - @test eltype(P) <: Complex - @test fftdims(P) == dims - _x_complexf = copy(x_complexf) - @test P * _x_complexf ≈ x_fft - @test _x_complexf ≈ x_fft - @test P \ _x_complexf ≈ x - @test _x_complexf ≈ x + TestUtils.test_plan(P, x_complexf, x_fft; inplace_plan=true) end end @@ -124,24 +133,16 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj for P in (plan_bfft(similar(x_fft), dims),) @test eltype(P) <: Complex @test fftdims(P) == dims - @test P * x_fft ≈ x_scaled - @test P \ (P * x_fft) ≈ x_fft - _x_complexf = similar(x_complexf) - @test mul!(_x_complexf, P, x_fft) ≈ x_scaled - @test _x_complexf ≈ x_scaled + TestUtils.test_plan(P, x_fft, x_scaled) if test_adjoint - TestUtils.test_plan_adjoint(P, x_complexf) + TestUtils.test_plan_adjoint(P, x_fft) end end # test IIP plans for P in (plan_bfft!(similar(x_fft), dims),) @test eltype(P) <: Complex @test fftdims(P) == dims - _x_fft = copy(x_fft) - @test P * _x_fft ≈ x_scaled - @test _x_fft ≈ x_scaled - @test P \ _x_fft ≈ x_fft - @test _x_fft ≈ x_fft + TestUtils.test_plan(P, x_fft, x_scaled; inplace_plan=true) end # IFFT @@ -155,13 +156,9 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj for P in (plan_ifft(similar(x_complexf), dims), inv(plan_fft(similar(x_complexf), dims))) @test eltype(P) <: Complex @test fftdims(P) == dims - @test P * x_fft ≈ x - @test P \ (P * x_fft) ≈ x_fft - _x_complexf = similar(x_complexf) - @test mul!(_x_complexf, P, x_fft) ≈ x - @test _x_complexf ≈ x + TestUtils.test_plan(P, x_fft, x) if test_adjoint - TestUtils.test_plan_adjoint(P, x_complexf) + TestUtils.test_plan_adjoint(P, x_fft) end end # test IIP plans @@ -169,21 +166,17 @@ function TestUtils.test_complex_fft(ArrayType=Array; test_inplace=true, test_adj for P in (plan_ifft!(similar(x_complexf), dims), inv(plan_fft!(similar(x_complexf), dims))) @test eltype(P) <: Complex @test fftdims(P) == dims - _x_fft = copy(x_fft) - @test P * _x_fft ≈ x - @test _x_fft ≈ x - @test P \ _x_fft ≈ x_fft - @test _x_fft ≈ x_fft + TestUtils.test_plan(P, x_fft, x; inplace_plan=true) end end end end end -function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoint=true) +function TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false) @testset "correctness of rfft, brfft, irfft" begin for test_case in TEST_CASES - _x, dims, _x_fft = test_case.x, test_case.dims, test_case.x_fft + _x, dims, _x_fft = copy(test_case.x), test_case.dims, copy(test_case.x_fft) x = convert(ArrayType, _x) # dummy array that will be passed to plans x_real = float.(x) # for testing mutating real FFTs x_fft = convert(ArrayType, _x_fft) @@ -198,14 +191,9 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin for P in (plan_rfft(similar(x_real), dims), inv(plan_irfft(similar(x_rfft), size(x, first(dims)), dims))) @test eltype(P) <: Real @test fftdims(P) == dims - # Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101) - @test P * copy(x) ≈ x_rfft - @test P \ (P * copy(x)) ≈ x - _x_rfft = similar(x_rfft) - @test mul!(_x_rfft, P, copy(x_real)) ≈ x_rfft - @test _x_rfft ≈ x_rfft + TestUtils.test_plan(P, x_real, x_rfft; copy_input) if test_adjoint - TestUtils.test_plan_adjoint(P, x_real; real_plan=true) + TestUtils.test_plan_adjoint(P, x_real; real_plan=true, copy_input) end end @@ -215,11 +203,7 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin for P in (plan_brfft(similar(x_rfft), size(x, first(dims)), dims),) @test eltype(P) <: Complex @test fftdims(P) == dims - @test P * copy(x_rfft) ≈ x_scaled - @test P \ (P * copy(x_rfft)) ≈ x_rfft - _x_scaled = similar(x_real) - @test mul!(_x_scaled, P, copy(x_rfft)) ≈ x_scaled - @test _x_scaled ≈ x_scaled + TestUtils.test_plan(P, x_rfft, x_scaled; copy_input) end # IRFFT @@ -227,13 +211,10 @@ function TestUtils.test_real_fft(ArrayType=Array; test_inplace=true, test_adjoin for P in (plan_irfft(similar(x_rfft), size(x, first(dims)), dims), inv(plan_rfft(similar(x_real), dims))) @test eltype(P) <: Complex @test fftdims(P) == dims - @test P * copy(x_rfft) ≈ x - @test P \ (P * copy(x_rfft)) ≈ x_rfft - _x_real = similar(x_real) - @test mul!(_x_real, P, copy(x_rfft)) ≈ x_real + TestUtils.test_plan(P, x_rfft, x; copy_input) end end end end -end \ No newline at end of file +end diff --git a/src/AbstractFFTs.jl b/src/AbstractFFTs.jl index 6ace165..3225916 100644 --- a/src/AbstractFFTs.jl +++ b/src/AbstractFFTs.jl @@ -10,7 +10,7 @@ include("TestUtils.jl") if !isdefined(Base, :get_extension) include("../ext/AbstractFFTsChainRulesCoreExt.jl") - include("../ext/AbstractFFTsTestUtilsExt.jl") + include("../ext/AbstractFFTsTestExt.jl") end end # module diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 17ac76a..14f74d2 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -1,9 +1,9 @@ module TestUtils """ - TestUtils.test_complex_fft(ArrayType=Array; test_real=true, test_inplace=true) + TestUtils.test_complex_ffts(ArrayType=Array; test_adjoint=true, test_inplace=true) -Run tests to verify correctness of FFT/BFFT/IFFT functionality using a particular backend plan implementation. +Run tests to verify correctness of FFT, BFFT, and IFFT functionality using a particular backend plan implementation. The backend implementation is assumed to be loaded prior to calling this function. # Arguments @@ -14,12 +14,12 @@ The backend implementation is assumed to be loaded prior to calling this functio - `test_inplace=true`: whether to test in-place plans. - `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint). """ -function test_complex_fft end +function test_complex_ffts end """ - TestUtils.test_real_fft(ArrayType=Array; test_real=true, test_inplace=true) + TestUtils.test_real_ffts(ArrayType=Array; test_adjoint=true, copy_input=false) -Run tests to verify correctness of RFFT/BRFFT/IRFFT functionality using a particular backend plan implementation. +Run tests to verify correctness of RFFT, BRFFT, and IRFFT functionality using a particular backend plan implementation. The backend implementation is assumed to be loaded prior to calling this function. # Arguments @@ -27,30 +27,46 @@ The backend implementation is assumed to be loaded prior to calling this functio - `ArrayType`: determines the `AbstractArray` implementation for which the correctness tests are run. Arrays are constructed via `convert(ArrayType, ...)`. -- `test_inplace=true`: whether to test in-place plans. - `test_adjoint=true`: whether to test [plan adjoints](api.md#Base.adjoint). +- `copy_input=false`: whether to copy the input before applying the plan in tests, to accomodate for + [input-mutating behaviour of real FFTW plans](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101). """ -function test_real_fft end +function test_real_ffts end + # Always copy input before application due to FFTW real plans possibly mutating input (AbstractFFTs.jl#101) """ - TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false) + TestUtils.test_plan(P::Plan, x::AbstractArray, x_transformed::AbstractArray; + inplace_plan=false, copy_input=false) + +Test basic properties of a plan `P` given an input array `x` and expected output `x_transformed`. -Test basic properties of the adjoint `P'` of a particular plan given an input array to the plan `x`, -including its accuracy via the dot test. Real-to-complex and complex-to-real plans require -a slightly modified dot test, in which case `real_plan=true` should be provided. +Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101), +we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan. +""" +function test_plan end +""" + TestUtils.test_plan_adjoint(P::Plan, x::AbstractArray; real_plan=false, copy_input=false) + +Test basic properties of the [adjoint](api.md#Base.adjoint) `P'` of a particular plan given an input array `x`, +including its accuracy via the dot test. + +Real-to-complex and complex-to-real plans require a slightly modified dot test, in which case `real_plan=true` should be provided. +The plan is assumed out-of-place, as adjoints are not yet supported for in-place plans. +Because [real FFTW plans may mutate their input in some cases](https://github.com/JuliaMath/AbstractFFTs.jl/issues/101), +we allow specifying `copy_input=true` to allow for this behaviour in tests by copying the input before applying the plan. """ function test_plan_adjoint end function __init__() - if isdefined(Base, :Experimental) + if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint) # Better error message if users forget to load Test Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ - if exc.f in (test_real_fft, test_complex_fft) + if (exc.f === test_real_fft || exc.f === test_complex_fft) && Base.get_extension(AbstractFFTs, :AbstractFFTsTestExt) === nothing print(io, "\nDid you forget to load Test?") end end end end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index 3c5d4c6..fe74897 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,8 @@ Random.seed!(1234) include("TestPlans.jl") # Run interface tests for TestPlans -AbstractFFTs.TestUtils.test_complex_fft(Array) -AbstractFFTs.TestUtils.test_real_fft(Array) +AbstractFFTs.TestUtils.test_complex_ffts(Array) +AbstractFFTs.TestUtils.test_real_ffts(Array) @testset "rfft sizes" begin A = rand(11, 10)