Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename ProjectionStyle -> AdjointStyle, _mul -> AdjointMul, and improve docs #109

Merged
merged 9 commits into from
Jul 27, 2023
19 changes: 18 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Public Interface

## FFT and FFT planning functions

```@docs
AbstractFFTs.fft
AbstractFFTs.fft!
Expand All @@ -20,11 +22,26 @@ AbstractFFTs.plan_rfft
AbstractFFTs.plan_brfft
AbstractFFTs.plan_irfft
AbstractFFTs.fftdims
Base.adjoint
AbstractFFTs.fftshift
AbstractFFTs.fftshift!
AbstractFFTs.ifftshift
AbstractFFTs.ifftshift!
AbstractFFTs.fftfreq
AbstractFFTs.rfftfreq
Base.size
```

## Adjoint functionality

The following API is supported by plans that support adjoint functionality.
It is also relevant to implementers of FFT plans that wish to support adjoints.
```@docs
Base.adjoint
AbstractFFTs.AdjointStyle
AbstractFFTs.output_size
AbstractFFTs.adjoint_mul
AbstractFFTs.FFTAdjointStyle
AbstractFFTs.RFFTAdjointStyle
AbstractFFTs.IRFFTAdjointStyle
AbstractFFTs.UnitaryAdjointStyle
```
10 changes: 5 additions & 5 deletions docs/src/implementations.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ To define a new FFT implementation in your own module, you should
inverse plan.

* Define a new method `AbstractFFTs.plan_fft(x, region; kws...)` that returns a `MyPlan` for at least some types of
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)` (which defaults to `p.region`).
`x` and some set of dimensions `region`. The `region` (or a copy thereof) should be accessible via `fftdims(p::MyPlan)`
(which defaults to `p.region`), and the input size `size(x)` should be accessible via `size(p::MyPlan)`.

* Define a method of `LinearAlgebra.mul!(y, p::MyPlan, x)` that computes the transform `p` of `x` and stores the result in `y`.

Expand All @@ -32,10 +33,9 @@ To define a new FFT implementation in your own module, you should

* You can also define similar methods of `plan_rfft` and `plan_brfft` for real-input FFTs.

* To enable automatic computation of adjoint plans via [`Base.adjoint`](@ref) (used in rules for reverse-mode differentiation), define the trait `AbstractFFTs.ProjectionStyle(::MyPlan)`, which can return:
* `AbstractFFTs.NoProjectionStyle()`,
* `AbstractFFTs.RealProjectionStyle()`, for plans that halve one of the output's dimensions analogously to [`rfft`](@ref),
* `AbstractFFTs.RealInverseProjectionStyle(d::Int)`, for plans that expect an input with a halved dimension analogously to [`irfft`](@ref), where `d` is the original length of the dimension.
* To support adjoints in a new plan, define the trait [`AbstractFFTs.AdjointStyle`](@ref).
`AbstractFFTs` implements the following adjoint styles: [`AbstractFFTs.FFTAdjointStyle`](@ref), [`AbstractFFTs.RFFTAdjointStyle`](@ref), [`AbstractFFTs.IRFFTAdjointStyle`](@ref), and [`AbstractFFTs.UnitaryAdjointStyle`](@ref).
To define a new adjoint style, define the methods [`AbstractFFTs.adjoint_mul`](@ref) and [`AbstractFFTs.output_size`](@ref).

The normalization convention for your FFT should be that it computes ``y_k = \sum_j x_j \exp(-2\pi i j k/n)`` for a transform of
length ``n``, and the "backwards" (unnormalized inverse) transform computes the same thing but with ``\exp(+2\pi i jk/n)``.
131 changes: 100 additions & 31 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ abstract type Plan{T} end

eltype(::Type{<:Plan{T}}) where {T} = T

# size(p) should return the size of the input array for p
size(p::Plan, d) = size(p)[d]
output_size(p::Plan, d) = output_size(p)[d]
"""
size(p::Plan, [dim])

Return the size of the input of a plan `p`, optionally at a specified dimenion `dim`.
"""
size(p::Plan, dim) = size(p)[dim]
ndims(p::Plan) = length(size(p))
length(p::Plan) = prod(size(p))::Int

Expand Down Expand Up @@ -583,17 +586,73 @@ plan_brfft

##############################################################################

struct NoProjectionStyle end
struct RealProjectionStyle end
struct RealInverseProjectionStyle
"""
AbstractFFTs.AdjointStyle(::Plan)

Return the adjoint style of a plan, enabling automatic computation of adjoint plans via
[`Base.adjoint`](@ref). Instructions for supporting adjoint styles are provided in the
[implementation instructions](implementations.md#Defining-a-new-implementation).
"""
abstract type AdjointStyle end

"""
FFTAdjointStyle()

Adjoint style for complex to complex discrete Fourier transforms that normalize
the output analogously to [`fft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint simply applies
the transform's inverse with an appropriate scaling.
"""
struct FFTAdjointStyle <: AdjointStyle end

"""
RFFTAdjointStyle()

Adjoint style for real to complex discrete Fourier transforms that halve one of
the output's dimensions and normalize the output analogously to [`rfft`](@ref).

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
output is projected to exploit its conjugate symmetry (see [`rfft`](@ref)).
"""
struct RFFTAdjointStyle <: AdjointStyle end

"""
IRFFTAdjointStyle(d::Dim)

Adjoint style for complex to real discrete Fourier transforms that expect an input
with a halved dimension and normalize the output analogously to [`irfft`](@ref),
where `d` is the original length of the dimension.

Since the Fourier transform is unitary up to a scaling, the adjoint applies the transform's
inverse, but with appropriate scaling and additional logic to handle the fact that the
input is projected to exploit its conjugate symmetry (see [`irfft`](@ref)).
"""
struct IRFFTAdjointStyle <: AdjointStyle
dim::Int
end
const ProjectionStyle = Union{NoProjectionStyle, RealProjectionStyle, RealInverseProjectionStyle}

output_size(p::Plan) = _output_size(p, ProjectionStyle(p))
_output_size(p::Plan, ::NoProjectionStyle) = size(p)
_output_size(p::Plan, ::RealProjectionStyle) = rfft_output_size(size(p), fftdims(p))
_output_size(p::Plan, s::RealInverseProjectionStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
"""
UnitaryAdjointStyle()

Adjoint style for unitary transforms, whose adjoint equals their inverse.
"""
struct UnitaryAdjointStyle <: AdjointStyle end

"""
output_size(p::Plan, [dim])

Return the size of the output of a plan `p`, optionally at a specified dimension `dim`.

Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define `output_size(::Plan, ::AS)`.
"""
output_size(p::Plan) = output_size(p, AdjointStyle(p))
output_size(p::Plan, dim) = output_size(p)[dim]
output_size(p::Plan, ::FFTAdjointStyle) = size(p)
output_size(p::Plan, ::RFFTAdjointStyle) = rfft_output_size(size(p), fftdims(p))
output_size(p::Plan, s::IRFFTAdjointStyle) = brfft_output_size(size(p), s.dim, fftdims(p))
output_size(p::Plan, ::UnitaryAdjointStyle) = size(p)

struct AdjointPlan{T,P<:Plan} <: Plan{T}
p::P
Expand All @@ -604,9 +663,7 @@ end
(p::Plan)'
adjoint(p::Plan)

Form the adjoint operator of an FFT plan. Returns a plan that performs the adjoint operation of
the original plan. Note that this differs from the corresponding backwards plan in the case of real
FFTs due to the halving of one of the dimensions of the FFT output, as described in [`rfft`](@ref).
Return a plan that performs the adjoint operation of the original plan.

!!! note
Adjoint plans do not currently support `LinearAlgebra.mul!`. Further, as a new addition to `AbstractFFTs`,
Expand All @@ -620,40 +677,52 @@ Base.adjoint(p::ScaledPlan) = ScaledPlan(p.p', p.scale)
size(p::AdjointPlan) = output_size(p.p)
output_size(p::AdjointPlan) = size(p.p)

Base.:*(p::AdjointPlan, x::AbstractArray) = _mul(p, x, ProjectionStyle(p.p))
Base.:*(p::AdjointPlan, x::AbstractArray) = adjoint_mul(p.p, x)

"""
adjoint_mul(p::Plan, x::AbstractArray)

Multiply an array `x` by the adjoint of a plan `p`. This is equivalent to `p' * x`.

Implementations of a new adjoint style `AS <: AbstractFFTs.AdjointStyle` should define
`adjoint_mul(::Plan, ::AbstractArray, ::AS)`.
"""
adjoint_mul(p::Plan, x::AbstractArray) = adjoint_mul(p, x, AdjointStyle(p))

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::NoProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
return (p.p \ x) / N
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::FFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(T, size(p), dims)
return (p \ x) / N
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealProjectionStyle) where {T<:Real}
dims = fftdims(p.p)
N = normalization(T, size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<:Real}
dims = fftdims(p)
N = normalization(T, size(p), dims)
halfdim = first(dims)
d = size(p.p, halfdim)
n = output_size(p.p, halfdim)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p.p \ (x ./ convert(typeof(x), scale))
return p \ (x ./ convert(typeof(x), scale))
end

function _mul(p::AdjointPlan{T}, x::AbstractArray, ::RealInverseProjectionStyle) where {T}
dims = fftdims(p.p)
N = normalization(real(T), output_size(p.p), dims)
function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
dims = fftdims(p)
N = normalization(real(T), output_size(p), dims)
halfdim = first(dims)
n = size(p.p, halfdim)
d = output_size(p.p, halfdim)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p.p \ x)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x

# Analogously to ScaledPlan, define both plan_inv (for no caching) and inv (caches inner plan only).
plan_inv(p::AdjointPlan) = adjoint(plan_inv(p.p))
inv(p::AdjointPlan) = adjoint(inv(p.p))
10 changes: 5 additions & 5 deletions test/testplans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Base.ndims(::TestPlan{T,N}) where {T,N} = N
Base.size(p::InverseTestPlan) = p.sz
Base.ndims(::InverseTestPlan{T,N}) where {T,N} = N

AbstractFFTs.ProjectionStyle(::TestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.ProjectionStyle(::InverseTestPlan) = AbstractFFTs.NoProjectionStyle()
AbstractFFTs.AdjointStyle(::TestPlan) = AbstractFFTs.FFTAdjointStyle()
AbstractFFTs.AdjointStyle(::InverseTestPlan) = AbstractFFTs.FFTAdjointStyle()

function AbstractFFTs.plan_fft(x::AbstractArray{T}, region; kwargs...) where {T}
return TestPlan{T}(region, size(x))
Expand Down Expand Up @@ -110,8 +110,8 @@ mutable struct InverseTestRPlan{T,N,G} <: Plan{Complex{T}}
end
end

AbstractFFTs.ProjectionStyle(::TestRPlan) = AbstractFFTs.RealProjectionStyle()
AbstractFFTs.ProjectionStyle(p::InverseTestRPlan) = AbstractFFTs.RealInverseProjectionStyle(p.d)
AbstractFFTs.AdjointStyle(::TestRPlan) = AbstractFFTs.RFFTAdjointStyle()
AbstractFFTs.AdjointStyle(p::InverseTestRPlan) = AbstractFFTs.IRFFTAdjointStyle(p.d)

function AbstractFFTs.plan_rfft(x::AbstractArray{T}, region; kwargs...) where {T<:Real}
return TestRPlan{T}(region, size(x))
Expand Down Expand Up @@ -241,7 +241,7 @@ end

Base.size(p::InplaceTestPlan) = size(p.plan)
Base.ndims(p::InplaceTestPlan) = ndims(p.plan)
AbstractFFTs.ProjectionStyle(p::InplaceTestPlan) = AbstractFFTs.ProjectionStyle(p.plan)
AbstractFFTs.AdjointStyle(p::InplaceTestPlan) = AbstractFFTs.AdjointStyle(p.plan)

function AbstractFFTs.plan_fft!(x::AbstractArray, region; kwargs...)
return InplaceTestPlan(plan_fft(x, region; kwargs...))
Expand Down