Skip to content

Commit

Permalink
Add spectrogram
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jun 5, 2024
1 parent 01cc8a2 commit bd76e9f
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 4 deletions.
3 changes: 3 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,7 @@ hann_window
hamming_window
stft
istft
spectrogram
NNlib.power_to_db
NNlib.db_to_power
```
3 changes: 2 additions & 1 deletion src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ include("rotation.jl")
export imrotate, ∇imrotate

include("audio/stft.jl")
export stft, istft, hann_window, hamming_window
include("audio/spectrogram.jl")
export stft, istft, hann_window, hamming_window, spectrogram

end # module NNlib
89 changes: 89 additions & 0 deletions src/audio/spectrogram.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Create a spectrogram or a batch of spectrograms from a raw audio signal.
# Arguments
- `pad::Int`:
Then amount of padding to apply on both sides.
Default is `0`.
- `window_normalized::Bool`:
Whether to normalize the waveform by the window’s L2 energy.
Default is `false`.
- `power::Real`:
Exponent for the magnitude spectrogram (must be ≥ 0)
e.g., `1` for magnitude, `2` for power, etc.
If `0`, complex spectrum is returned instead.
See [`stft`](@ref) for other arguments.
# Returns
Spectrogram in the shape `(T, F, B)`, where
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
# Example
```julia
julia> waveform, sampling_rate = load("test.flac");
julia> spec = spectrogram(waveform;
n_fft=1024, hop_length=128, window=hann_window(1024));
julia> spec_db = NNlib.power_to_db(spec);
julia> Makie.heatmap(spec_db[:, :, 1])
```
"""
function spectrogram(waveform;
pad::Int = 0, n_fft::Int, hop_length::Int, window,
center::Bool = true, power::Real = 2.0,
normalized::Bool = false, window_normalized::Bool = false,
)
pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)

# Pack batch dimensions.
sz = size(waveform)
spec_ = stft(reshape(waveform, (sz[1], :));
n_fft, hop_length, window, center, normalized)
# Unpack batch dimensions.
spec = reshape(spec_, (size(spec_)[1:2]..., sz[2:end]...))
window_normalized && (spec .*= inv(norm(window));)

if power > 0
p = real(eltype(spec)(power))
spec = abs.(spec).^p
end
return spec
end

"""
power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
Convert a power spectrogram (amplitude squared) to decibel (dB) units.
# Arguments
- `s`: Input power.
- `ref`: Scalar w.r.t. which the input is scaled. Default is `1`.
- `amin`: Minimum threshold for `s`. Default is `1f-10`.
- `top_db`: Threshold the output at `top_db` below the peak:
`max.(s_db, maximum(s_db) - top_db)`. Default is `80`.
# Returns
`s_db ~= 10 * log10(s) - 10 * log10(ref)`
"""
function power_to_db(s; ref::Real = 1f0, amin::Real = 1f-10, top_db::Real = 80f0)
log_spec = 10f0 .* (log10.(max.(amin, s)) .- log10.(max.(amin, ref)))
return max.(log_spec, maximum(log_spec) - top_db)
end

"""
db_to_power(s_db; ref::Real = 1f0)
Inverse of [`power_to_db`](@ref).
"""
function db_to_power(s_db; ref::Real = 1f0)
return ref .* 10f0.^(s_db .* 0.1f0)
end
42 changes: 39 additions & 3 deletions test/testsuite/spectral.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using NNlib

function spectral_testsuite(Backend)
cpu(x) = adapt(CPU(), x)
device(x) = adapt(Backend(), x)

@testset "Window functions" begin
for window_fn in (hann_window, hamming_window)
@inferred window_fn(10, Float32)
Expand All @@ -18,9 +21,6 @@ function spectral_testsuite(Backend)
end

@testset "STFT" begin
cpu(x) = adapt(CPU(), x)
device(x) = adapt(Backend(), x)

for batch in ((), (3,))
@testset "Batch $batch" begin
x = device(ones(Float32, 16, batch...))
Expand Down Expand Up @@ -85,4 +85,40 @@ function spectral_testsuite(Backend)
end
end
end

@testset "Spectrogram" begin
x = device(rand(Float32, 1024))
window = device(hann_window(1024))

y = stft(x;
n_fft=1024, hop_length=128, window,
center=true, normalized=false)
spec = spectrogram(x;
n_fft=1024, hop_length=128, window,
center=true, normalized=false)

@test abs.(y).^2 spec

# Batched.
x = device(rand(Float32, 1024, 3))
spec = spectrogram(x;
n_fft=1024, hop_length=128, window,
center=true, normalized=false)
for i in 1:3
y = stft(x[:, i];
n_fft=1024, hop_length=128, window,
center=true, normalized=false)
@test abs.(y).^2 spec[:, :, i]
end
end

@testset "Power to dB" begin
x = device(rand(Float32, 1024))
window = device(hann_window(1024))
spec = spectrogram(x; pad=0, n_fft=1024, hop_length=128, window)

@test spec NNlib.db_to_power(NNlib.power_to_db(spec))
@inferred NNlib.power_to_db(spec)
@inferred NNlib.db_to_power(NNlib.power_to_db(spec))
end
end

0 comments on commit bd76e9f

Please sign in to comment.