Skip to content

Commit

Permalink
And KA backend for fold/unfold
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jul 12, 2024
1 parent 52f22f9 commit 7e25ad6
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 175 deletions.
1 change: 0 additions & 1 deletion ext/NNlibCUDAExt/NNlibCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ include("activations.jl")
include("batchedadjtrans.jl")
include("batchedmul.jl")
include("ctc.jl")
include("fold.jl")
include("scatter.jl")
include("utils.jl")

Expand Down
111 changes: 0 additions & 111 deletions ext/NNlibCUDAExt/fold.jl

This file was deleted.

4 changes: 2 additions & 2 deletions ext/NNlibFFTWExt/stft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function NNlib.stft(x;
ids = [
row + hop_length * col
for row in 1:n_fft, col in 0:(n_frames - 1)]
x = x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
x = @inbounds x[ids, ntuple(_ -> Colon(), ndims(x) - 1)...]
end

region = 1
Expand Down Expand Up @@ -113,7 +113,7 @@ function NNlib.istft(y;
# In case of batched input, reshaped it (n_fft, n_frames, batch) -> (:, batch).
nd = ntuple(_ -> Colon(), ndims(x) - 2)
ndims(x) == 3 && (x = reshape(x, (:, size(x, 3)));)
x = x[ids, nd...]
x = @inbounds x[ids, nd...]

# Trim padding.
left = center ? (n_fft ÷ 2 + 1) : 1
Expand Down
2 changes: 1 addition & 1 deletion src/audio/spectrogram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function spectrogram(waveform;
window_normalized && (spec = spec .* inv(norm(window));)

if power > 0
p = real(eltype(spec)(power))
p = eltype(waveform)(power)
spec = abs.(spec).^p
end
return spec
Expand Down
135 changes: 119 additions & 16 deletions src/fold.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)
Expand All @@ -7,10 +6,10 @@ window_size, batchsize)`. The window size is determined by the `prod(spatial dim
of kernel)*input_channels`. The number of sliding windows will match those of
convolution (`conv`) with the same kernel_size and arguments. Note that
by default `conv` flips the spatial dimensions of its kernel (default
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
Uses `NNlib.im2col!` as backend.
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
Uses `NNlib.im2col!` as backend.
See also [`fold`](@ref), the adjoint/transpose operator
See also [`fold`](@ref), the adjoint/transpose operator
and a potential inverse of `unfold`.
# Example
Expand All @@ -23,7 +22,7 @@ julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3
julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold
julia> z = NNlib.unfold(x, size(w); kws...)
julia> z = NNlib.unfold(x, size(w); kws...)
4×3×1 Array{Int64, 3}:
[:, :, 1] =
0 100 2
Expand Down Expand Up @@ -61,8 +60,8 @@ end
The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
the output of `unfold` into a container tensor of size `output_size`. An inverse
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
with a divisor (see example). Uses `NNlib.col2im!` as backend.
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
with a divisor (see example). Uses `NNlib.col2im!` as backend.
See also [`unfold`](@ref).
Expand Down Expand Up @@ -101,7 +100,7 @@ julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3
2.0
1.0
julia> z ./ divisor
julia> z ./ divisor
7×1×1 Array{Float64, 3}:
[:, :, 1] =
100.0
Expand Down Expand Up @@ -133,30 +132,30 @@ function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
end

function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
x = similar(y, output_size)
x = similar(y, output_size)
return fold!(x, y, cdims)
end

# N < 5 -dimension in-place versions
# N < 5 -dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
unfold!(
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return y
end

function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
fold!(
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(x, 5-N),
y,
insert_singleton_spatial_dimension(cdims, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return x
end

# 5-dimension in-place versions
# 5-dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
Expand All @@ -173,6 +172,110 @@ function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseCon
return x
end

@kernel function unfold_kernel!(

Check warning on line 175 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L175

Added line #L175 was not covered by tests
col::AbstractArray{T}, x, col_size,
input_size, output_size, kernel_size,
flipkernel, stride, pad_lo, dilation, max_idx,
) where T
index = @index(Global)

Check warning on line 180 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L180

Added line #L180 was not covered by tests

@inbounds if index max_idx
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
w, h, d = CartesianIndices(output_size)[i].I # x indices

Check warning on line 184 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L182-L184

Added lines #L182 - L184 were not covered by tests

# project
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation

Check warning on line 187 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L187

Added line #L187 was not covered by tests

if !flipkernel
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1

Check warning on line 190 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L189-L190

Added lines #L189 - L190 were not covered by tests
end

# check out of bounds
if !all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
col[i, kw, kh, kd, c, b] = T(0)

Check warning on line 195 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L194-L195

Added lines #L194 - L195 were not covered by tests
else
xval::T = x[w, h, d, c, b]
col[i, kw, kh, kd, c, b] = xval

Check warning on line 198 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L197-L198

Added lines #L197 - L198 were not covered by tests
end
end
end

@kernel function fold_kernel!(

Check warning on line 203 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L203

Added line #L203 was not covered by tests
x::AbstractArray{T}, col, col_size,
input_size, output_size, kernel_size,
flipkernel, stride, pad_lo, dilation, max_idx,
) where T
index = @index(Global)

Check warning on line 208 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L208

Added line #L208 was not covered by tests

@inbounds if index max_idx
i, kw, kh, kd, c, b = CartesianIndices(col_size)[index].I # col indices
w, h, d = CartesianIndices(output_size)[i].I # x indices

Check warning on line 212 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L210-L212

Added lines #L210 - L212 were not covered by tests

# project
w, h, d = @. ((w, h, d) - 1) * stride - pad_lo + 1 + ((kw, kh, kd) - 1) * dilation

Check warning on line 215 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L215

Added line #L215 was not covered by tests

# check out of bounds
if all(checkindex.(Bool, UnitRange.(1, input_size), (w, h, d)))
if !flipkernel
kw, kh, kd = kernel_size .- (kw, kh, kd) .+ 1

Check warning on line 220 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L218-L220

Added lines #L218 - L220 were not covered by tests
end

cval::T = col[i, kw, kh, kd, c, b]
@atomic x[w, h, d, c, b] += cval

Check warning on line 224 in src/fold.jl

View check run for this annotation

Codecov / codecov/patch

src/fold.jl#L223-L224

Added lines #L223 - L224 were not covered by tests
end
end
end

function unfold!(
col::AnyGPUArray{cT,3}, x::AnyGPUArray{xT,5}, cdims::DenseConvDims,
) where {cT, xT}
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
"unfold!() only accepts 3d convoluitional inputs"))

C_in = channels_in(cdims)
ker_size = kernel_size(cdims)
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)

out_size = output_size(cdims)
col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))

max_idx = prod(size(col))
unfold_kernel!(get_backend(x))(
col_reshaped, x, size(col_reshaped),
input_size(cdims), out_size, ker_size,
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
ndrange=max_idx)
return col
end

function fold!(
x::AnyGPUArray{xT,5}, col::AnyGPUArray{cT,3}, cdims::DenseConvDims,
) where {xT, cT}
spatial_dims(cdims) != 3 && throw(DimensionMismatch(
"fold!() only accepts 3d convoluitional inputs"))

# going to accumulate into x
fill!(x, xT(0))

C_in = channels_in(cdims)
ker_size = kernel_size(cdims)
pad_w_lo, pad_w_hi, pad_h_lo, pad_h_hi, pad_d_lo, pad_d_hi = padding(cdims)
pad_lo = (pad_w_lo, pad_h_lo, pad_d_lo)
out_size = output_size(cdims)

col_reshaped = reshape(col, (prod(out_size), ker_size..., C_in, :))

max_idx = prod(size(col))
fold_kernel!(get_backend(x))(
x, col_reshaped, size(col_reshaped),
input_size(cdims), out_size, ker_size,
flipkernel(cdims), stride(cdims), pad_lo, dilation(cdims), max_idx;
ndrange=max_idx)

return x
end

# reverse diff rules
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
function unfold_pullback(Δ)
Expand Down
40 changes: 0 additions & 40 deletions test/fold.jl

This file was deleted.

Loading

0 comments on commit 7e25ad6

Please sign in to comment.