Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, and improve docs #109

Merged
merged 9 commits into from
Jul 27, 2023
4 changes: 4 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
Expand Down
9 changes: 5 additions & 4 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* We offer an experimental `AdjointStyle` trait to enable automatic computation of adjoint plans via [`Base.adjoint`](@ref).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the trait experimental? And what exactly is experimental? I guess downstream packages would be interested in what parts (if any) are considered stable, if every patch release may break things etc.

To support adjoints in a new plan, define the trait `AbstractFFTs.AdjointStyle(::MyPlan)`. This should return a subtype of `AS <: AbstractFFTs.AdjointStyle` supporting `AbstractFFTs.adjoint_mul(::Plan, ::AbstractArray, ::AS)` and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e., adjoint_mul is part of the API?

`AbstractFFTs._output_size(::Plan, ::AS)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume _output_size should be implemented as well? Could we just rename it to output_size? To me the underscore suggests that it is an internal method.


`AbstractFFTs` pre-implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is meant with pre-implements? AFAICT the package implements these styles? Or is there anything missing?


The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
93 changes: 68 additions & 25 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -583,17 +583,58 @@ plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
abstract type AdjointStyle end

"""
FFTAdjointStyle()

Projection style for complex to complex discrete Fourier transforms that normalize
the output analogously to [`fft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct FFTAdjointStyle <: AdjointStyle end

"""
RFFTAdjointStyle()

Projection style for real to complex discrete Fourier transforms that halve
one of the output's dimensions and normalize the output analogously to [`rfft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the output is projected
to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RFFTAdjointStyle <: AdjointStyle end

"""
IRFFTAdjointStyle(d::Dim)

Projection style for complex to real discrete Fourier transforms that expect
an input with a halved dimension and normalize the output analogously to [`irfft`](@ref),
where `d` is the original length of the dimension.

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with additional logic to handle the fact that the input is projected
to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()

Projection style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

output_size(p::Plan) = _output_size(p, AdjointStyle(p))
_output_size(p::Plan, ::FFTAdjointStyle) = size(p)
_output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
_output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -620,40 +661,42 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x, AdjointStyle(p.p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down