From 186636d3e3ba247c3096104385996c3d4fe6c545 Mon Sep 17 00:00:00 2001 From: Joao Felipe Santos Date: Fri, 5 Jun 2015 20:25:49 -0400 Subject: [PATCH] istft diff applied to current master --- doc/util.rst | 9 ++++++++ src/periodograms.jl | 19 +++------------- src/util.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++- test/util.jl | 35 +++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 17 deletions(-) diff --git a/doc/util.rst b/doc/util.rst index ce318ff0b..ae2039b42 100644 --- a/doc/util.rst +++ b/doc/util.rst @@ -22,6 +22,15 @@ \hat{x}`, where :math:`\hat{x}` is the Hilbert transform of x, along the first dimension of x. +.. function:: istft(S, wlen, overlap; nfft=nextfastfft(wlen), window=nothing) + + Computes the inverse short-time Fourier transform (STFT) of S (a complex + matrix, as computed by `stft`). `wlen` and `overlap` are respectively the + window length and overlap used for computing the STFT. `nfft` is the used + FFT size (defaults to `nextfastfft(wlen)`) and `window` can be either a + function or a vector with the window elements (defaults to a rectangular + window). + .. function:: fftfreq(n, fs=1) Return discrete fourier transform sample frequencies. The returned diff --git a/src/periodograms.jl b/src/periodograms.jl index 44fb9b364..b0ef8c2a2 100644 --- a/src/periodograms.jl +++ b/src/periodograms.jl @@ -162,19 +162,6 @@ function fft2oneortwosided!{T}(out::Array{Complex{T}}, s_fft::Vector{Complex{T}} out end -# Evaluate a window function at n points, returning both the window -# (or nothing if no window) and the squared L2 norm of the window -compute_window(::Nothing, n::Int) = (nothing, n) -function compute_window(window::Function, n::Int) - win = window(n)::Vector{Float64} - norm2 = sumabs2(win) - (win, norm2) -end -function compute_window(window::AbstractVector, n::Int) - length(window) == n || error("length of window must match input") - (window, sumabs2(window)) -end - ## PERIODOGRAMS abstract TFR{T} immutable Periodogram{T,F<:Union(Frequencies,Range)} <: TFR{T} @@ -207,7 +194,7 @@ function periodogram{T<:Number}(s::AbstractVector{T}; onesided::Bool=eltype(s)<: onesided && T <: Complex && error("cannot compute one-sided FFT of a complex signal") nfft >= length(s) || error("nfft must be >= n") - win, norm2 = compute_window(window, length(s)) + win, norm2 = Util.compute_window(window, length(s)) if nfft == length(s) && win == nothing && isa(s, StridedArray) input = s # no need to pad else @@ -286,7 +273,7 @@ function welch_pgram{T<:Number}(s::AbstractVector{T}, n::Int=length(s)>>3, nover onesided && T <: Complex && error("cannot compute one-sided FFT of a complex signal") nfft >= n || error("nfft must be >= n") - win, norm2 = compute_window(window, n) + win, norm2 = Util.compute_window(window, n) sig_split = arraysplit(s, n, noverlap, nfft, win) out = zeros(fftabs2type(T), onesided ? (nfft >> 1)+1 : nfft) r = fs*norm2*length(sig_split) @@ -369,7 +356,7 @@ function stft{T}(s::AbstractVector{T}, n::Int=length(s)>>3, noverlap::Int=n>>1, window::Union(Function,AbstractVector,Nothing)=nothing) onesided && T <: Complex && error("cannot compute one-sided FFT of a complex signal") - win, norm2 = compute_window(window, n) + win, norm2 = Util.compute_window(window, n) sig_split = arraysplit(s, n, noverlap, nfft, win) nout = onesided ? (nfft >> 1)+1 : nfft out = zeros(stfttype(T, psdonly), nout, length(sig_split)) diff --git a/src/util.jl b/src/util.jl index a085cb5c5..a09c777e0 100644 --- a/src/util.jl +++ b/src/util.jl @@ -21,7 +21,8 @@ export unwrap!, rmsfft, unsafe_dot, polyfit, - shiftin! + shiftin!, + istft function unwrap!{T <: FloatingPoint}(m::Array{T}, dim::Integer=ndims(m); range::Number=2pi) @@ -98,6 +99,57 @@ function hilbert{T<:Real}(x::AbstractArray{T}) out end +# Evaluate a window function at n points, returning both the window +# (or nothing if no window) and the squared L2 norm of the window +compute_window(::Nothing, n::Int) = (nothing, n) +function compute_window(window::Function, n::Int) + win = window(n)::Vector{Float64} + norm2 = sumabs2(win) + (win, norm2) +end +function compute_window(window::AbstractVector, n::Int) + length(window) == n || error("length of window must match input") + (window, sumabs2(window)) +end + +backward_plan{T<:Union(Float32, Float64)}(X::AbstractArray{Complex{T}}, Y::AbstractArray{T}) = + FFTW.Plan(X, Y, 1, FFTW.ESTIMATE, FFTW.NO_TIMELIMIT).plan + +function istft{T<:Union(Float32, Float64)}(S::AbstractMatrix{Complex{T}}, wlen::Int, overlap::Int; nfft=nextfastfft(wlen), window::Union(Function,AbstractVector,Nothing)=nothing) + winc = wlen-overlap + win, norm2 = compute_window(window, wlen) + if win != nothing + win² = win.^2 + end + nframes = size(S,2)-1 + outlen = nfft + nframes*winc + out = zeros(T, outlen) + tmp1 = Array(eltype(S), size(S, 1)) + tmp2 = zeros(T, nfft) + p = backward_plan(tmp1, tmp2) + wsum = zeros(outlen) + for k = 1:size(S,2) + copy!(tmp1, 1, S, 1+(k-1)*size(S,1), length(tmp1)) + FFTW.execute(p, tmp1, tmp2) + scale!(tmp2, FFTW.normalization(tmp2)) + if win != nothing + ix = (k-1)*winc + for n=1:nfft + @inbounds out[ix+n] += tmp2[n]*win[n] + @inbounds wsum[ix+n] += win²[n] + end + else + copy!(out, 1+(k-1)*winc, tmp2, 1, nfft) + end + end + if win != nothing + for i=1:length(wsum) + @inbounds wsum[i] != 0 && (out[i] /= wsum[i]) + end + end + out +end + ## FFT TYPES # Get the input element type of FFT for a given type diff --git a/test/util.jl b/test/util.jl index 9ccfea40b..cf3fd300d 100644 --- a/test/util.jl +++ b/test/util.jl @@ -74,6 +74,41 @@ r = round(Int, rand(128)*20) # Test hilbert with 2D input @test_approx_eq h hilbert(a) +## ISTFT + +# rectangular window with 50% overlap and regular sizes +x1 = rand(128) +X1 = stft(x1, 16, 8) +y1 = istft(X1, 16, 8) +@test_approx_eq x1 y1 + +# rectangular window with 50% overlap and irregular size: y will have less +# elements than x unless x is zero-padded to a FFT-friendly size +x2 = rand(171) +X2 = stft(x2, 16, 8) +y2 = istft(X2, 16, 8) +@test_approx_eq x2[1:length(y2)] y2 + +# Hanning window with 25% overlap +# First sample will be wrong since the first element in the window is zero +function hann(M, sym=true) + odd = mod(M,2) == 1 + if !sym && !odd + M = M+1 + end + w = [0.5-0.5*cos(2*pi*n/(M-1)) for n=0:(M-1)] + if !sym && !odd + return w[1:end-1] + else + return w + end +end + +hann_periodic(M::Int) = hann(M, false) +X1w = stft(x1, 16, 12; window=hann_periodic) +y1w = istft(X1w, 16, 12; window=hann_periodic) +@test_approx_eq x1[2:end] y1w[2:end] + ## FFTFREQ @test_approx_eq fftfreq(1) [0.]