diff --git a/test/TestPlans.jl b/test/TestPlans.jl index 6e10efe..e0d22f2 100644 --- a/test/TestPlans.jl +++ b/test/TestPlans.jl @@ -4,23 +4,25 @@ import AbstractFFTs import LinearAlgebra.mul! using AbstractFFTs: Plan -mutable struct TestPlan{T,N} <: Plan{T} +mutable struct TestPlan{T,N,inplace} <: Plan{T} region sz::NTuple{N,Int} pinv::Plan{T} - function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function TestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace} + return new{T,N,inplace}(region, sz) end end +TestPlan{T}(region, sz) where {T} = TestPlan{T,false}(region, sz) -mutable struct InverseTestPlan{T,N} <: Plan{T} +mutable struct InverseTestPlan{T,N,inplace} <: Plan{T} region sz::NTuple{N,Int} pinv::Plan{T} - function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N} - return new{T,N}(region, sz) + function InverseTestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace} + return new{T,N,inplace}(region, sz) end end +InverseTestPlan{T}(region, sz) where {T} = InverseTestPlan{T,false}(region, sz) Base.size(p::TestPlan) = p.sz Base.ndims(::TestPlan{T,N}) where {T,N} = N @@ -34,18 +36,25 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T return InverseTestPlan{T}(region, size(x)) end -function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} - unscaled_pinv = InverseTestPlan{T}(p.region, p.sz) - N = AbstractFFTs.normalization(T, p.sz, p.region) - unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N) - pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N) +function AbstractFFTs.plan_fft!(x::AbstractArray{T}, region; kwargs...) where {T} + return TestPlan{T,true}(region, size(x)) +end +function AbstractFFTs.plan_bfft!(x::AbstractArray{T}, region; kwargs...) where {T} + return InverseTestPlan{T,true}(region, size(x)) +end + +function AbstractFFTs.plan_inv(p::TestPlan{T,N,inplace}) where {T,N,inplace} + unscaled_pinv = InverseTestPlan{T,inplace}(p.region, p.sz) + _N = AbstractFFTs.normalization(T, p.sz, p.region) + unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N) + pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N) return pinv end -function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T} - unscaled_p = TestPlan{T}(pinv.region, pinv.sz) - N = AbstractFFTs.normalization(T, pinv.sz, pinv.region) - unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N) - p = AbstractFFTs.ScaledPlan(unscaled_p, N) +function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T,N,inplace}) where {T,N,inplace} + unscaled_p = TestPlan{T,inplace}(pinv.region, pinv.sz) + _N = AbstractFFTs.normalization(T, pinv.sz, pinv.region) + unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N) + p = AbstractFFTs.ScaledPlan(unscaled_p, _N) return p end @@ -80,20 +89,23 @@ function dft!( end function mul!( - y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N} -) where {N} + y::AbstractArray{<:Complex,N}, p::TestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N} +) where {T,N} size(y) == size(p) == size(x) || throw(DimensionMismatch()) dft!(y, x, p.region, -1) end function mul!( - y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N} -) where {N} + y::AbstractArray{<:Complex,N}, p::InverseTestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N} +) where {T,N} size(y) == size(p) == size(x) || throw(DimensionMismatch()) dft!(y, x, p.region, 1) end -Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) -Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::TestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x) +Base.:*(p::InverseTestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x) + +Base.:*(p::TestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, -1)) +Base.:*(p::InverseTestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, 1)) mutable struct TestRPlan{T,N} <: Plan{T} region