Skip to content

Commit

Permalink
Let filt/filt! return state if provided as input
Browse files Browse the repository at this point in the history
  • Loading branch information
martinholters committed Dec 3, 2024
1 parent 710ee46 commit 6d5e1e5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 39 deletions.
48 changes: 29 additions & 19 deletions src/Filters/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ Same as [`filt()`](@ref) but writes the result into the `out`
argument. Output array `out` may not be an alias of `x`, i.e. filtering may
not be done in place.
"""
filt!(out, f::PolynomialRatio{:z}, x::AbstractArray, si=_zerosi(f, x)) =
filt!(out, coefb(f), coefa(f), x, si)
filt!(out, f::PolynomialRatio{:z}, x::AbstractArray, si...) = filt!(out, coefb(f), coefa(f), x, si...)

"""
filt(f::FilterCoefficients{:z}, x::AbstractArray[, si])
Expand All @@ -32,7 +31,7 @@ to zeros). If `f` is a `PolynomialRatio`, `Biquad`, or
interpreted as an FIR filter, and a naïve or FFT-based algorithm is
selected based on the data and filter length.
"""
filt(f::PolynomialRatio{:z}, x, si=_zerosi(f, x)) = filt(coefb(f), coefa(f), x, si)
filt(f::PolynomialRatio{:z}, x, si...) = filt(coefb(f), coefa(f), x, si...)

## SecondOrderSections
_zerosi(f::SecondOrderSections{:z,T,G}, ::AbstractArray{S}) where {T,G,S} =
Expand All @@ -58,25 +57,29 @@ function _filt!(out::AbstractArray, si::AbstractArray{S,N}, f::SecondOrderSectio
si
end

filt!(out::AbstractArray, f::SecondOrderSections{:z}, x::AbstractArray) =
first(filt!(out, f, x, _zerosi(f, x)))
function filt!(out::AbstractArray, f::SecondOrderSections{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
si::AbstractArray{S,N}) where {S,N}
biquads = f.biquads

size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || size(si, 2) != length(biquads) || (N > 2 && size(si)[3:end] != size(x)[2:end])) &&
throw(ArgumentError("si must be 2 x nbiquads or 2 x nbiquads x nsignals"))

initial_si = si
si = similar(si, axes(si)[1:2])
if N > 2
si = copy(si)
else
si = repeat(si, outer=(1, 1, size(x)[2:end]...))
end
for col in CartesianIndices(axes(x)[2:end])
copyto!(si, view(initial_si, :, :, N > 2 ? col : CartesianIndex()))
_filt!(out, si, f, x, col)
_filt!(out, view(si, :, :, col), f, x, col)
end
out
return (out, si)
end

filt(f::SecondOrderSections{:z,T,G}, x::AbstractArray{S}, si=_zerosi(f, x)) where {T,G,S<:Number} =
filt!(similar(x, promote_type(T, G, S)), f, x, si)
filt(f::SecondOrderSections{:z,T,G}, x::AbstractArray{S}, si...) where {T,G,S<:Number} =
filt!(similar(x, promote_type(T, G, S)), f, x, si...)

## Biquad
_zerosi(::Biquad{:z,T}, ::AbstractArray{S}) where {T,S} =
Expand All @@ -95,25 +98,32 @@ function _filt!(out::AbstractArray, si1::Number, si2::Number, f::Biquad{:z},
(si1, si2)
end

filt!(out::AbstractArray, f::Biquad{:z}, x::AbstractArray) =
first(filt!(out, f, x, _zerosi(f, x)))
# filt! variant that preserves si
function filt!(out::AbstractArray, f::Biquad{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
si::AbstractArray{S,N}) where {S,N}
size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || (N > 1 && size(si)[2:end] != size(x)[2:end])) &&
throw(ArgumentError("si must have two rows and 1 or nsignals columns"))

if N > 1
si = copy(si)
else
si = repeat(si, outer=(1, size(x)[2:end]...))
end
for col in CartesianIndices(axes(x)[2:end])
_filt!(out, si[1, N > 1 ? col : CartesianIndex()], si[2, N > 1 ? col : CartesianIndex()], f, x, col)
si[:,col] .= _filt!(out, si[1, col], si[2, col], f, x, col)
end
out
return (out, si)
end

filt(f::Biquad{:z,T}, x::AbstractArray{S}, si=_zerosi(f, x)) where {T,S<:Number} =
filt!(similar(x, promote_type(T, S)), f, x, si)
filt(f::Biquad{:z,T}, x::AbstractArray{S}, si...) where {T,S<:Number} =
filt!(similar(x, promote_type(T, S)), f, x, si...)

## For arbitrary filters, convert to SecondOrderSections
filt(f::FilterCoefficients{:z}, x) = filt(convert(SecondOrderSections, f), x)
filt!(out, f::FilterCoefficients{:z}, x) = filt!(out, convert(SecondOrderSections, f), x)
filt(f::FilterCoefficients{:z}, x, si...) = filt(convert(SecondOrderSections, f), x, si...)
filt!(out, f::FilterCoefficients{:z}, x, si...) = filt!(out, convert(SecondOrderSections, f), x, si...)

"""
DF2TFilter(coef::FilterCoefficients{:z})
Expand Down Expand Up @@ -375,7 +385,7 @@ function filtfilt(f::SecondOrderSections{:z,T,G}, x::AbstractArray{S}) where {T,
istart = 1
for i = 1:Base.trailingsize(x, 2)
extrapolate_signal!(extrapolated, 1, x, istart, size(x, 1), pad_length)
reverse!(filt!(extrapolated, f, extrapolated, mul!(zitmp, zi, extrapolated[1])))
reverse!(first(filt!(extrapolated, f, extrapolated, mul!(zitmp, zi, extrapolated[1]))))
filt!(extrapolated, f, extrapolated, mul!(zitmp, zi, extrapolated[1]))
for j = 1:size(x, 1)
@inbounds out[j, i] = extrapolated[end-pad_length+1-j]
Expand Down
38 changes: 25 additions & 13 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ state vector `si` (defaults to zeros).
Inputs that are `Number`s are treated as one-element `Vector`s.
"""
function filt(b::Union{AbstractVector, Number}, a::Union{AbstractVector, Number},
x::AbstractArray{T}, si::AbstractArray{S} = _zerosi(b,a,T)) where {T,S}
x::AbstractArray{T}) where {T}
first(filt(b, a, x, _zerosi(b,a,T)))
end
function filt(b::Union{AbstractVector, Number}, a::Union{AbstractVector, Number},
x::AbstractArray{T}, si::AbstractArray{S}) where {T,S}
filt!(similar(x, promote_type(eltype(b), eltype(a), T, S)), b, a, x, si)
end

Expand All @@ -29,8 +33,11 @@ end
Same as [`filt`](@ref) but writes the result into the `out` argument, which may
alias the input `x` to modify it in-place.
"""
filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{AbstractVector, Number},
x::AbstractArray{T}) where {T} =
first(filt!(out, b, a, x, _zerosi(b,a,T)))
function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{AbstractVector, Number},
x::AbstractArray{T}, si::AbstractArray{S,N} = _zerosi(b,a,T)) where {T,S,N}
x::AbstractArray{T}, si::AbstractArray{S,N}) where {T,S,N}
isempty(b) && throw(ArgumentError("filter vector b must be non-empty"))
isempty(a) && throw(ArgumentError("filter vector a must be non-empty"))
a[1] == 0 && throw(ArgumentError("filter vector a[1] must be nonzero"))
Expand All @@ -49,8 +56,17 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
throw(ArgumentError("initial state si must be a vector or have the same number of columns as x"))
end

iszero(size(x, 1)) && return out
isone(sz) && return (k = b[1] / a[1]; @noinline mul!(out, x, k)) # Simple scaling without memory
if N > 1
si = copy(si)
else
si = repeat(si, outer=(1, size(x)[2:end]...))
end

iszero(size(x, 1)) && return (out, si)
if isone(sz)
k = b[1] / a[1]
return (@noinline mul!(out, x, k), si) # Simple scaling without memory
end

# Filter coefficient normalization
if !isone(a[1])
Expand All @@ -65,19 +81,15 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
if as == 1 && bs <= SMALL_FILT_CUTOFF
_small_filt_fir!(out, b, x, si, Val(bs))
else
initial_si = si
si = similar(si, axes(si, 1))
for col in CartesianIndices(axes(x)[2:end])
# Reset the filter state
copyto!(si, view(initial_si, :, N > 1 ? col : CartesianIndex()))
if as > 1
_filt_iir!(out, b, a, x, si, col)
_filt_iir!(out, b, a, x, view(si, :, col), col)
else
_filt_fir!(out, b, x, si, col)
_filt_fir!(out, b, x, view(si, :, col), col)
end
end
end
return out
return (out, si)
end

# Transposed direct form II
Expand Down Expand Up @@ -146,13 +158,13 @@ end
# Convert array filter tap input to tuple for small-filtering
function _small_filt_fir!(
out::AbstractArray, h::AbstractVector, x::AbstractArray,
si::AbstractArray{S,N}, ::Val{bs}) where {S,N,bs}
si::AbstractArray, ::Val{bs}) where {bs}

bs < 2 && throw(ArgumentError("invalid tuple size"))
length(h) != bs && throw(ArgumentError("length(h) does not match bs"))
b = ntuple(j -> h[j], Val(bs))
for col in CartesianIndices(axes(x)[2:end])
v_si = N > 1 ? view(si, :, col) : si
v_si = view(si, :, col)
_filt_fir!(out, b, x, v_si, col)
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/dsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ using DSP: filt, filt!, deconv, conv, xcorr,
@test filt(b, 1., [x 1.0:8.0]) == [filt(b, 1., x) filt(b, 1., 1.0:8.0)]
@test filt(b, [1., -0.5], [x 1.0:8.0]) == [filt(b, [1., -0.5], x) filt(b, [1., -0.5], 1.0:8.0)]
si = zeros(3)
@test filt(b, 1., [x 1.0:8.0], si) == [filt(b, 1., x, si) filt(b, 1., 1.0:8.0, si)]
@test first(filt(b, 1., [x 1.0:8.0], si)) == [first(filt(b, 1., x, si)) first(filt(b, 1., 1.0:8.0, si))]
@test si == zeros(3) # Will likely fail if/when arrayviews are implemented
si = [zeros(3) ones(3)]
@test filt(b, 1., [x 1.0:8.0], si) == [filt(b, 1., x, zeros(3)) filt(b, 1., 1.0:8.0, ones(3))]
@test first(filt(b, 1., [x 1.0:8.0], si)) == [first(filt(b, 1., x, zeros(3))) first(filt(b, 1., 1.0:8.0, ones(3)))]
# With initial conditions: a lowpass 5-pole butterworth filter with W_n = 0.25,
# and a stable initial filter condition matched to the initial value
b = [0.003279216306360201,0.016396081531801006,0.03279216306360201,0.03279216306360201,0.016396081531801006,0.003279216306360201]
a = [1.0,-2.4744161749781606,2.8110063119115782,-1.703772240915465,0.5444326948885326,-0.07231566910295834]
si = [0.9967207836936347,-1.4940914728163142,1.2841226760316475,-0.4524417279474106,0.07559488540931815]
@test filt(b, a, ones(10), si) ones(10) # Shouldn't affect DC offset
@test first(filt(b, a, ones(10), si)) ones(10) # Shouldn't affect DC offset

@test_throws ArgumentError filt!([1, 2], [1], [1], [1])
end
Expand Down
42 changes: 38 additions & 4 deletions test/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ end
@test all(col -> col y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x); dims=slicedims))
# with si given
@test all(col -> col y_ref, eachslice(filt(b, a, x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(PolynomialRatio(b, a), x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x, zeros(2, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x, zeros(2, 1, sz[2:end]...)); dims=slicedims))
@test all(col -> col y_ref, eachslice(first(filt(b, a, x, zeros(1, sz[2:end]...))); dims=slicedims))
@test all(col -> col y_ref, eachslice(first(filt(PolynomialRatio(b, a), x, zeros(1, sz[2:end]...))); dims=slicedims))
@test all(col -> col y_ref, eachslice(first(filt(Biquad(PolynomialRatio(b, a)), x, zeros(2, sz[2:end]...))); dims=slicedims))
@test all(col -> col y_ref, eachslice(first(filt(SecondOrderSections(PolynomialRatio(b, a)), x, zeros(2, 1, sz[2:end]...))); dims=slicedims))
# use _small_filt_fir!
b = [0.1, 0.1]
a = [1.0]
Expand Down Expand Up @@ -192,6 +192,40 @@ end
@test matlab_filt x
end

@testset "blockwise filt for $T" for T in [PolynomialRatio, ZeroPoleGain, SecondOrderSections], extra_dims in [(), (2,), (2, 3)]
x = rand(1000, extra_dims...)
H = T(PolynomialRatio([0.1, 0.1], [1, 0.8]))
y_ref = filt(H, x)
if T == PolynomialRatio
state = DSP.Filters._zerosi(H, x)
else
state = DSP.Filters._zerosi(SecondOrderSections(H), x)
end
y_test = similar(x)
all_cols = map(_ -> :, extra_dims)
for i in 1:100:size(x,1)
y_test[i:i+99, all_cols...], state = filt(H, x[i:i+99, all_cols...], state)
end
@test y_ref y_test
end

@testset "blockwise filt! for $T" for T in [PolynomialRatio, ZeroPoleGain, SecondOrderSections], extra_dims in [(), (2,), (2, 3)]
x = rand(1000, extra_dims...)
H = T(PolynomialRatio([0.1, 0.1], [1, 0.8]))
y_ref = filt(H, x)
if T == PolynomialRatio
state = DSP.Filters._zerosi(H, x)
else
state = DSP.Filters._zerosi(SecondOrderSections(H), x)
end
y_test = similar(x)
all_cols = map(_ -> :, extra_dims)
for i in 1:100:size(x,1)
_, state = filt!(view(y_test, i:i+99, all_cols...), H, x[i:i+99, all_cols...], state)
end
@test y_ref y_test
end

#######################################
#
# Test 1d filtfilt against matlab results
Expand Down

0 comments on commit 6d5e1e5

Please sign in to comment.