Skip to content

Commit

Permalink
Disallow mixing offset and non-offset axes in conv input (#586)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters authored Nov 20, 2024
1 parent 56773aa commit 3d87980
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
2 changes: 1 addition & 1 deletion ext/OffsetArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ module OffsetArraysExt
import DSP
import OffsetArrays

DSP.conv_with_offset(::OffsetArrays.IdOffsetRange) = true
DSP.conv_axis_with_offset(::OffsetArrays.IdOffsetRange) = true

end
34 changes: 21 additions & 13 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,16 @@ function _conv_td!(out, output_indices, u::AbstractArray{<:Number, N}, v::Abstra
end

# whether the given axis are to be considered to carry an offset for `conv!` and `conv`
conv_with_offset(::Base.OneTo) = false
conv_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))
conv_axis_with_offset(::Base.OneTo) = false
conv_axis_with_offset(a::Any) = throw(ArgumentError("unsupported axis type $(typeof(a))"))

function conv_axes_with_offset(as::Tuple...)
with_offset = ((map(a -> map(conv_axis_with_offset, a), as)...)...,)
if !allequal(with_offset)
throw(ArgumentError("cannot mix offset and non-offset axes"))
end
return !isempty(with_offset) && first(with_offset)
end

const FFTTypes = Union{Float32, Float64, ComplexF32, ComplexF64}

Expand All @@ -677,7 +685,7 @@ offsets. If none of them has offset axes,
`size(out,d) ≥ size(u,d) + size(v,d) - 1` must hold. If both input and output
have offset axes, `firstindex(out,d) ≤ firstindex(u,d) + firstindex(v,d)` and
`lastindex(out,d) ≥ lastindex(u,d) + lastindex(v,d)` must hold (for d = 1,...,N).
A mix of offset and non-offset axes between input and output is not permitted.
A mix of offset and non-offset axes is not permitted.
The `algorithm` keyword allows choosing the algorithm to use:
* `:direct`: Evaluates the convolution sum in time domain.
Expand All @@ -704,12 +712,8 @@ function conv!(
v::AbstractArray{<:Number, N};
algorithm=:auto
) where {T<:Number, N}
offset = conv_axes_with_offset(axes(out), axes(u), axes(v)) ? 0 : 1
output_indices = CartesianIndices(map(axes(out), axes(u), axes(v)) do ao, au, av
input_has_offset = conv_with_offset(au) || conv_with_offset(av)
if input_has_offset !== conv_with_offset(ao)
throw(ArgumentError("output must have offset axes if and only if the input has"))
end
offset = input_has_offset ? 0 : 1
return (first(au)+first(av) : last(au)+last(av)) .- offset
end)

Expand Down Expand Up @@ -752,9 +756,13 @@ function conv!(
end
end

conv_output_axis(au, av) =
conv_with_offset(au) || conv_with_offset(av) ?
(first(au)+first(av):last(au)+last(av)) : Base.OneTo(last(au) + last(av) - 1)
function conv_output_axes(au::Tuple, av::Tuple)
if conv_axes_with_offset(au, av)
return map((au, av) -> first(au)+first(av):last(au)+last(av), au, av)
else
return map((au, av) -> Base.OneTo(last(au) + last(av) - 1), au, av)
end
end

"""
conv(u, v; algorithm)
Expand All @@ -768,7 +776,7 @@ function conv(
u::AbstractArray{Tu, N}, v::AbstractArray{Tv, N}; kwargs...
) where {Tu<:Number, Tv<:Number, N}
T = promote_type(Tu, Tv)
out_axes = map(conv_output_axis, axes(u), axes(v))
out_axes = conv_output_axes(axes(u), axes(v))
out = similar(u, T, out_axes)
return conv!(out, u, v; kwargs...)
end
Expand All @@ -792,7 +800,7 @@ Uses 2-D FFT algorithm.
"""
function conv(u::AbstractVector{T}, v::Transpose{T,<:AbstractVector}, A::AbstractMatrix{T}) where T
# Arbitrary indexing offsets not implemented
if any(conv_with_offset, (axes(u)..., axes(v)..., axes(A)...))
if any(conv_axis_with_offset, (axes(u)..., axes(v)..., axes(A)...))
throw(ArgumentError("offset axes not supported"))
end
m = length(u)+size(A,1)-1
Expand Down
13 changes: 10 additions & 3 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,19 @@ end

offset_arr = OffsetVector{Int}(undef, -1:2)
offset_arr[:] = a
@test conv(offset_arr, 1:3) == OffsetVector(expectation, 0:5)
@test_throws ArgumentError conv(offset_arr, 1:3)
@test conv(offset_arr, OffsetArray(1:3)) == OffsetVector(expectation, 0:5)
offset_arr_f = OffsetVector{Float64}(undef, -1:2)
offset_arr_f[:] = fa
@test conv(offset_arr_f, 1:3) OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv(offset_arr_f, 1:3)
@test conv(offset_arr_f, OffsetArray(1:3)) OffsetVector(fexp, 0:5)
@test_throws ArgumentError conv!(zeros(6), offset_arr, 1:3) # output needs to be OA, too
@test_throws ArgumentError conv!(OffsetVector{Int}(undef, 1:6), 1:4, 1:3) # output mustn't be OA

@test conv(fa, fill(true)) == conv(fill(true), fa) == fa
@test_broken conv(offset_arr_f, fill(true)) == conv(fill(true), offset_arr_f) == offset_arr_f
@test conv(fill(true), fill(true)) == fill(true)

for M in [10, 200], N in [10, 200], T in [Float64, ComplexF64]
u = rand(T, M)
v = rand(T, N)
Expand Down Expand Up @@ -156,7 +162,8 @@ end

offset_arr = OffsetMatrix{Int}(undef, -1:1, -1:1)
offset_arr[:] = a
@test conv(offset_arr, b) == OffsetArray(expectation, 0:3, 0:3)
@test_throws ArgumentError conv(offset_arr, b)
@test conv(offset_arr, OffsetArray(b)) == OffsetArray(expectation, 0:3, 0:3)

for (M1, M2) in [(10, 20), (190, 200)], (N1, N2) in [(20, 10), (210, 200)], T in [Float64, ComplexF64]
u = rand(T, M1, M2)
Expand Down

0 comments on commit 3d87980

Please sign in to comment.