Skip to content

Commit

Permalink
Create WelchConfig object (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
haberdashPI authored Feb 24, 2024
1 parent 8f0e091 commit 10a7c1e
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 23 deletions.
2 changes: 2 additions & 0 deletions docs/src/periodograms.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Common procedures like computing the [short-time Fourier transform](@ref stft),
arraysplit
periodogram(s::AbstractVector{T}) where T <: Number
welch_pgram
welch_pgram!
spectrogram
stft
periodogram(s::AbstractMatrix{T}) where T <: Real
Expand Down Expand Up @@ -35,6 +36,7 @@ mt_coherence!
## Configuration objects

```@docs
WelchConfig
MTConfig
MTSpectrogramConfig
MTCrossSpectraConfig
Expand Down
153 changes: 130 additions & 23 deletions src/periodograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ module Periodograms
using LinearAlgebra: mul!
using ..Util, ..Windows
using Statistics: mean!
export arraysplit, nextfastfft, periodogram, welch_pgram,
export arraysplit, nextfastfft, periodogram,
WelchConfig, welch_pgram, welch_pgram!,
spectrogram, power, freq, stft,
MTConfig, mt_pgram, mt_pgram!,
MTSpectrogramConfig, mt_spectrogram, mt_spectrogram!,
Expand All @@ -24,15 +25,22 @@ struct ArraySplit{T<:AbstractVector,S,W} <: AbstractVector{Vector{S}}
window::W
k::Int

function ArraySplit{Ti,Si,Wi}(s, n, noverlap, nfft, window) where {Ti<:AbstractVector,Si,Wi}
function ArraySplit{Ti,Si,Wi}(s, n, noverlap, nfft, window;
buffer::Vector{Si}=zeros(Si, max(nfft, 0))) where {Ti<:AbstractVector,Si,Wi}

# n = noverlap is a problem - the algorithm will not terminate.
(0 noverlap < n) || error("noverlap must be between zero and n")
nfft >= n || error("nfft must be >= n")
new{Ti,Si,Wi}(s, zeros(Si, nfft), n, noverlap, window, length(s) >= n ? div((length(s) - n), n - noverlap)+1 : 0)
length(buffer) == nfft ||
throw(ArgumentError("buffer length ($(length(buffer))) must equal `nfft` ($nfft)"))

new{Ti,Si,Wi}(s, buffer, n, noverlap, window, length(s) >= n ? div((length(s) - n),
n - noverlap) + 1 : 0)
end

end
ArraySplit(s::AbstractVector, n, noverlap, nfft, window) =
ArraySplit{typeof(s),fftintype(eltype(s)),typeof(window)}(s, n, noverlap, nfft, window)
ArraySplit(s::AbstractVector, n, noverlap, nfft, window; kwargs...) =
ArraySplit{typeof(s),fftintype(eltype(s)),typeof(window)}(s, n, noverlap, nfft, window; kwargs...)

function Base.getindex(x::ArraySplit{T,S,Nothing}, i::Int) where {T,S}
(i >= 1 && i <= x.k) || throw(BoundsError())
Expand All @@ -54,13 +62,14 @@ end
Base.size(x::ArraySplit) = (x.k,)

"""
arraysplit(s, n, m)
arraysplit(s, n, m, nfft=n, window=nothing; buffer=zeros(eltype(s), nfft))
Split an array into arrays of length `n` with overlapping regions
of length `m`. Iterating or indexing the returned AbstractVector
always yields the same Vector with different contents.
Optionally provide a buffer of length `nfft`
"""
arraysplit(s, n, noverlap, nfft=n, window=nothing) = ArraySplit(s, n, noverlap, nfft, window)
arraysplit(s, n, noverlap, nfft=n, window=nothing; kwargs...) = ArraySplit(s, n, noverlap, nfft, window; kwargs...)

## Make collect() return the correct split arrays rather than repeats of the last computed copy
Base.collect(x::ArraySplit) = collect(copy(a) for a in x)
Expand Down Expand Up @@ -354,40 +363,138 @@ forward_plan(X::AbstractArray{T}, Y::AbstractArray{Complex{T}}) where {T<:Union{
forward_plan(X::AbstractArray{T}, Y::AbstractArray{T}) where {T<:Union{ComplexF32, ComplexF64}} =
plan_fft(X)

struct WelchConfig{F,Fr,W,P,T1,T2,R}
nsamples::Int
noverlap::Int
onesided::Bool
nfft::Int
fs::F
freq::Fr
window::W
plan::P
inbuf::T1
outbuf::T2
r::R # inverse normalization
end

"""
WelchConfig(data; n=size(signal, ndims(signal))>>3, noverlap=n>>1,
onesided=eltype(signal)<:Real, nfft=nextfastfft(n),
fs=1, window=nothing)
WelchConfig(nsamples, eltype; n=nsamples>>3, noverlap=n>>1,
onesided=eltype<:Real, nfft=nextfastfft(n),
fs=1, window=nothing)
Captures all configuration options for [`welch_pgram`](@ref) in a single struct (akin to
[`MTConfig`](@ref)). When passed on the second argument of [`welch_pgram`](@ref), computes the
periodogram based on segments with `n` samples with overlap of `noverlap` samples, and
returns a Periodogram object. For a Bartlett periodogram, set `noverlap=0`. See
[`periodogram`](@ref) for description of optional keyword arguments.
!!! note
WelchConfig precomputes an fft plan, and preallocates the necessary intermediate buffers.
Thus, repeated calls to `welch_pgram` that use the same `WelchConfig` object
will be more efficient than otherwise possible.
"""
function WelchConfig(nsamples, ::Type{T}; n::Int=nsamples >> 3, noverlap::Int=n >> 1,
onesided::Bool=T <: Real, nfft::Int=nextfastfft(n),
fs::Real=1, window::Union{Function,AbstractVector,Nothing}=nothing) where T

onesided && T <: Complex && throw(ArgumentError("cannot compute one-sided FFT of a complex signal"))
nfft >= n || throw(DomainError((; nfft, n), "nfft must be >= n"))

win, norm2 = compute_window(window, n)
r = fs * norm2
inbuf = zeros(float(T), nfft)
outbuf = Vector{fftouttype(T)}(undef, T<:Real ? (nfft >> 1)+1 : nfft)
plan = forward_plan(inbuf, outbuf)

freq = onesided ? rfftfreq(nfft, fs) : fftfreq(nfft, fs)

return WelchConfig(n, noverlap, onesided, nfft, fs, freq, win, plan, inbuf, outbuf, r)
end

function WelchConfig(data::AbstractArray; kwargs...)
return WelchConfig(size(data, ndims(data)), eltype(data); kwargs...)
end

# Compute an estimate of the power spectral density of a signal s via Welch's
# method. The resulting periodogram has length N and is computed with an overlap
# region of length M. The method is detailed in "The Use of Fast Fourier Transform
# for the Estimation of Power Spectra: A Method based on Time Averaging over Short,
# Modified Periodograms." P. Welch, IEEE Transactions on Audio and Electroacoustics,
# vol AU-15, pp 70-73, 1967.
"""
welch_pgram(s, n=div(length(s), 8), noverlap=div(n, 2); onesided=eltype(s)<:Real, nfft=nextfastfft(n), fs=1, window=nothing)
welch_pgram(s, n=div(length(s), 8), noverlap=div(n, 2); onesided=eltype(s)<:Real,
nfft=nextfastfft(n), fs=1, window=nothing)
Computes the Welch periodogram of a signal `s` based on segments with `n` samples
with overlap of `noverlap` samples, and returns a Periodogram
object. For a Bartlett periodogram, set `noverlap=0`. See
[`periodogram`](@ref) for description of optional keyword arguments.
"""
function welch_pgram(s::AbstractVector{T}, n::Int=length(s)>>3, noverlap::Int=n>>1;
onesided::Bool=eltype(s)<:Real,
nfft::Int=nextfastfft(n), fs::Real=1,
window::Union{Function,AbstractVector,Nothing}=nothing) where T<:Number
onesided && T <: Complex && error("cannot compute one-sided FFT of a complex signal")
nfft >= n || error("nfft must be >= n")
function welch_pgram(s::AbstractVector, n::Int=length(s)>>3, noverlap::Int=n>>1; kwargs...)
welch_pgram(s, WelchConfig(s; n, noverlap, kwargs...))
end

win, norm2 = compute_window(window, n)
sig_split = arraysplit(s, n, noverlap, nfft, win)
out = zeros(fftabs2type(T), onesided ? (nfft >> 1)+1 : nfft)
r = fs*norm2*length(sig_split)
"""
welch_pgram!(out::AbstractVector, in::AbstractVector, n=div(length(s), 8),
noverlap=div(n, 2); onesided=eltype(s)<:Real, nfft=nextfastfft(n),
fs=1, window=nothing)
Computes the Welch periodogram of a signal `s`, storing the result in `out`, based on
segments with `n` samples with overlap of `noverlap` samples, and returns a Periodogram
object. For a Bartlett periodogram, set `noverlap=0`. See [`periodogram`](@ref) for
description of optional keyword arguments.
"""
function welch_pgram!(output::AbstractVector, s::AbstractVector, n::Int=length(s)>>3, noverlap::Int=n>>1;
kwargs...)
welch_pgram!(output, s, WelchConfig(s; n, noverlap, kwargs...))
end

"""
welch_pgram(signal::AbstractVector, config::WelchConfig)
Computes the Welch periodogram of the given signal using a predefined [`WelchConfig`](@ref) object.
"""
function welch_pgram(s::AbstractVector{T}, config::WelchConfig) where T<:Number
out = Vector{fftabs2type(T)}(undef, config.onesided ? (config.nfft >> 1)+1 : config.nfft)
return welch_pgram_helper!(out, s, config)
end

"""
welch_pgram!(out::AbstractVector, in::AbstractVector, config::WelchConfig)
Computes the Welch periodogram of the given signal, storing the result in `out`,
using a predefined [`WelchConfig`](@ref) object.
"""
function welch_pgram!(out::AbstractVector, s::AbstractVector{T}, config::WelchConfig{T}) where T<:Number
if length(out) != length(config.freq)
throw(DimensionMismatch("""Expected `output` to be of length `length(config.freq)`;
got `length(output)` = $(length(out)) and `length(config.freq)` = $(length(config.freq))"""))
elseif eltype(out) != fftabs2type(T)
throw(ArgumentError("Eltype of output ($(eltype(out))) doesn't match the expected "*
"type: $(fftabs2type(T))."))
end
welch_pgram_helper!(out, s, config)
end

function welch_pgram_helper!(out, s, config)
fill!(out, 0)
sig_split = arraysplit(s, config.nsamples, config.noverlap, config.nfft, config.window;
buffer=config.inbuf)

r = length(sig_split) * config.r

tmp = Vector{fftouttype(T)}(undef, T<:Real ? (nfft >> 1)+1 : nfft)
plan = forward_plan(sig_split.buf, tmp)
for sig in sig_split
mul!(tmp, plan, sig)
fft2pow!(out, tmp, nfft, r, onesided)
mul!(config.outbuf, config.plan, sig)
fft2pow!(out, config.outbuf, config.nfft, r, config.onesided)
end

Periodogram(out, onesided ? rfftfreq(nfft, fs) : fftfreq(nfft, fs))
Periodogram(out, config.freq)
end

## SPECTROGRAM
Expand Down
12 changes: 12 additions & 0 deletions test/periodograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,18 @@ end
@test power(welch_pgram(data, length(data), 0; window=hamming, nfft=32)) expected
@test power(spectrogram(data, length(data), 0; window=hamming, nfft=32)) expected

# test welch_pgram configuration object
expected = power(welch_pgram(data, length(data), 0; window=hamming, nfft=32))
config = WelchConfig(data; n=length(data), noverlap=0, window=hamming, nfft=32)
@test power(welch_pgram(data, config)) == expected

# test welch_pgram!
out = similar(expected)
@test power(welch_pgram!(out, data, config)) == expected
@test power(welch_pgram!(out, data, length(data), 0; window=hamming, nfft=32)) == expected
@test_throws ArgumentError welch_pgram!(convert(Vector{Float32}, out), data, config)
@test_throws DimensionMismatch welch_pgram!(empty!(out), data, config)

# Test fftshift
p = periodogram(data)
@test power(p) == power(fftshift(p))
Expand Down

0 comments on commit 10a7c1e

Please sign in to comment.