Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixups to #600 #601

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/Filters/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,15 @@ end
function filt!(out::AbstractArray, f::SecondOrderSections{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
biquads = f.biquads
ncols = Base.trailingsize(x, 2)

size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || size(si, 2) != length(biquads) || (N > 2 && Base.trailingsize(si, 3) != ncols)) &&
(size(si, 1) != 2 || size(si, 2) != length(biquads) || (N > 2 && size(si)[3:end] != size(x)[2:end])) &&
throw(ArgumentError("si must be 2 x nbiquads or 2 x nbiquads x nsignals"))

initial_si = si
si = similar(si, axes(si)[1:2])
for col in CartesianIndices(axes(x)[2:end])
copyto!(si, view(initial_si, :, :, N > 2 ? col : 1))
copyto!(si, view(initial_si, :, :, N > 2 ? col : CartesianIndex()))
_filt!(out, si, f, x, col)
end
out
Expand Down Expand Up @@ -99,14 +98,12 @@ end
# filt! variant that preserves si
function filt!(out::AbstractArray, f::Biquad{:z}, x::AbstractArray,
si::AbstractArray{S,N}=_zerosi(f, x)) where {S,N}
ncols = Base.trailingsize(x, 2)

size(x) != size(out) && throw(DimensionMismatch("out size must match x"))
(size(si, 1) != 2 || (N > 1 && Base.trailingsize(si, 2) != ncols)) &&
(size(si, 1) != 2 || (N > 1 && size(si)[2:end] != size(x)[2:end])) &&
Comment on lines -105 to +102
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks base on trailingsize were more permissive, but the indexing below would then fail if the sizes didn't match, so test up here. (Also, trailingsize was accessing a Base internal.)

throw(ArgumentError("si must have two rows and 1 or nsignals columns"))

for col in CartesianIndices(axes(x)[2:end])
_filt!(out, si[1, N > 1 ? col : 1], si[2, N > 1 ? col : 1], f, x, col)
_filt!(out, si[1, N > 1 ? col : CartesianIndex()], si[2, N > 1 ? col : CartesianIndex()], f, x, col)
end
out
end
Expand Down
9 changes: 4 additions & 5 deletions src/dspbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
bs = length(b)
sz = max(as, bs)
silen = sz - 1
ncols = size(x, 2)

if size(si, 1) != silen
throw(ArgumentError("initial state vector si must have max(length(a),length(b))-1 rows"))
elseif N > 1 && size(si, 2) != ncols
elseif N > 1 && size(si)[2:end] != size(x)[2:end]
throw(ArgumentError("initial state si must be a vector or have the same number of columns as x"))
end

Expand All @@ -70,7 +69,7 @@ function filt!(out::AbstractArray, b::Union{AbstractVector, Number}, a::Union{Ab
si = similar(si, axes(si, 1))
for col in CartesianIndices(axes(x)[2:end])
# Reset the filter state
copyto!(si, view(initial_si, :, N > 1 ? col : 1))
copyto!(si, view(initial_si, :, N > 1 ? col : CartesianIndex()))
if as > 1
_filt_iir!(out, b, a, x, si, col)
else
Expand Down Expand Up @@ -124,7 +123,7 @@ const SMALL_FILT_VECT_CUTOFF = 19
si_end = Symbol(:si_, silen)

quote
col = colv isa Val{:DF2} ? 1 : colv
col = colv isa Val{:DF2} ? CartesianIndex() : colv
N <= SMALL_FILT_VECT_CUTOFF && checkbounds(siarr, $silen)
Base.@nextract $silen si siarr
for i in axes(x, 1)
Expand Down Expand Up @@ -152,7 +151,7 @@ function _small_filt_fir!(
bs < 2 && throw(ArgumentError("invalid tuple size"))
length(h) != bs && throw(ArgumentError("length(h) does not match bs"))
b = ntuple(j -> h[j], Val(bs))
for col in axes(x, 2)
for col in CartesianIndices(axes(x)[2:end])
v_si = N > 1 ? view(si, :, col) : si
_filt_fir!(out, b, x, v_si, col)
end
Expand Down
20 changes: 16 additions & 4 deletions test/filt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,22 @@ end
sz = (10, ntuple(n -> n+1, Val(D))...)
y_ref = filt(b, a, ones(sz[1]))
x = ones(sz)
@test all(col -> col ≈ y_ref, eachslice(filt(b, a, x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col ≈ y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col ≈ y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x); dims=ntuple(n -> n+1, Val(D))))
@test all(col -> col ≈ y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x); dims=ntuple(n -> n+1, Val(D))))
slicedims = ntuple(n -> n+1, Val(D))
@test all(col -> col ≈ y_ref, eachslice(filt(b, a, x); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x); dims=slicedims))
# with si given
@test all(col -> col ≈ y_ref, eachslice(filt(b, a, x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(PolynomialRatio(b, a), x, zeros(1, sz[2:end]...)); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(Biquad(PolynomialRatio(b, a)), x, zeros(2, sz[2:end]...)); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(SecondOrderSections(PolynomialRatio(b, a)), x, zeros(2, 1, sz[2:end]...)); dims=slicedims))
# use _small_filt_fir!
b = [0.1, 0.1]
a = [1.0]
y_ref = filt(b, a, ones(sz[1]))
@test all(col -> col ≈ y_ref, eachslice(filt(b, a, x); dims=slicedims))
@test all(col -> col ≈ y_ref, eachslice(filt(PolynomialRatio(b, a), x); dims=slicedims))
end

#
Expand Down
Loading