Skip to content

Commit

Permalink
Implement in-place test plans
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Aug 31, 2022
1 parent ae78179 commit 388ba18
Showing 1 changed file with 34 additions and 22 deletions.
56 changes: 34 additions & 22 deletions test/TestPlans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,25 @@ import AbstractFFTs
import LinearAlgebra.mul!
using AbstractFFTs: Plan

mutable struct TestPlan{T,N} <: Plan{T}
mutable struct TestPlan{T,N,inplace} <: Plan{T}
region
sz::NTuple{N,Int}
pinv::Plan{T}
function TestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
return new{T,N}(region, sz)
function TestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace}
return new{T,N,inplace}(region, sz)
end
end
TestPlan{T}(region, sz) where {T} = TestPlan{T,false}(region, sz)

mutable struct InverseTestPlan{T,N} <: Plan{T}
mutable struct InverseTestPlan{T,N,inplace} <: Plan{T}
region
sz::NTuple{N,Int}
pinv::Plan{T}
function InverseTestPlan{T}(region, sz::NTuple{N,Int}) where {T,N}
return new{T,N}(region, sz)
function InverseTestPlan{T,inplace}(region, sz::NTuple{N,Int}) where {T,N,inplace}
return new{T,N,inplace}(region, sz)
end
end
InverseTestPlan{T}(region, sz) where {T} = InverseTestPlan{T,false}(region, sz)

Base.size(p::TestPlan) = p.sz
Base.ndims(::TestPlan{T,N}) where {T,N} = N
Expand All @@ -34,18 +36,25 @@ function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T
return InverseTestPlan{T}(region, size(x))
end

function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T}
unscaled_pinv = InverseTestPlan{T}(p.region, p.sz)
N = AbstractFFTs.normalization(T, p.sz, p.region)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N)
function AbstractFFTs.plan_fft!(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T,true}(region, size(x))
end
function AbstractFFTs.plan_bfft!(x::AbstractArray{T}, region; kwargs...) where {T}
return InverseTestPlan{T,true}(region, size(x))
end

function AbstractFFTs.plan_inv(p::TestPlan{T,N,inplace}) where {T,N,inplace}
unscaled_pinv = InverseTestPlan{T,inplace}(p.region, p.sz)
_N = AbstractFFTs.normalization(T, p.sz, p.region)
unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N)
pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N)
return pinv
end
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T}
unscaled_p = TestPlan{T}(pinv.region, pinv.sz)
N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N)
p = AbstractFFTs.ScaledPlan(unscaled_p, N)
function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T,N,inplace}) where {T,N,inplace}
unscaled_p = TestPlan{T,inplace}(pinv.region, pinv.sz)
_N = AbstractFFTs.normalization(T, pinv.sz, pinv.region)
unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N)
p = AbstractFFTs.ScaledPlan(unscaled_p, _N)
return p
end

Expand Down Expand Up @@ -80,20 +89,23 @@ function dft!(
end

function mul!(
y::AbstractArray{<:Complex,N}, p::TestPlan, x::AbstractArray{<:Union{Complex,Real},N}
) where {N}
y::AbstractArray{<:Complex,N}, p::TestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N}
) where {T,N}
size(y) == size(p) == size(x) || throw(DimensionMismatch())
dft!(y, x, p.region, -1)
end
function mul!(
y::AbstractArray{<:Complex,N}, p::InverseTestPlan, x::AbstractArray{<:Union{Complex,Real},N}
) where {N}
y::AbstractArray{<:Complex,N}, p::InverseTestPlan{T,N,false}, x::AbstractArray{<:Union{Complex,Real},N}
) where {T,N}
size(y) == size(p) == size(x) || throw(DimensionMismatch())
dft!(y, x, p.region, 1)
end

Base.:*(p::TestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
Base.:*(p::InverseTestPlan, x::AbstractArray) = mul!(similar(x, complex(float(eltype(x)))), p, x)
Base.:*(p::TestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x)
Base.:*(p::InverseTestPlan{T,N,false}, x::AbstractArray) where {T,N} = mul!(similar(x, complex(float(eltype(x)))), p, x)

Base.:*(p::TestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, -1))
Base.:*(p::InverseTestPlan{T,N,true}, x::AbstractArray) where {T,N} = copy!(x, dft!(similar(x), x, p.region, 1))

mutable struct TestRPlan{T,N} <: Plan{T}
region
Expand Down

0 comments on commit 388ba18

Please sign in to comment.