Skip to content

Commit

Permalink
initial broadcast based repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
awadell1 committed Jul 1, 2022
1 parent e8103f2 commit 7b82b02
Showing 1 changed file with 37 additions and 66 deletions.
103 changes: 37 additions & 66 deletions src/host/base.jl
Original file line number Diff line number Diff line change
@@ -1,82 +1,53 @@
# common Base functionality
import Base: _RepeatInnerOuter

function repeat_inner_kernel!(
ctx::AbstractKernelContext,
xs::AbstractArray{<:Any, N},
inner::NTuple{N, Int},
out::AbstractArray{<:Any, N}
) where {N}
# Get single element from src
idx = @cartesianidx xs
@inbounds val = xs[idx]

# Loop over "repeat" indices of inner
for rdx in CartesianIndices(inner)
# Get destination CartesianIndex
odx = ntuple(N) do i
@inbounds (idx[i]-1) * inner[i] + rdx[i]
end
@inbounds out[CartesianIndex(odx)] = val
end

return nothing
end

function repeat_inner(xs::AnyGPUArray, inner)
out = similar(xs, eltype(xs), inner .* size(xs))
any(==(0), size(out)) && return out # consistent with `Base.repeat`
gpu_call(repeat_inner_kernel!, xs, inner, out; elements=length(xs))
return out
# Overload methods used by `Base.repeat`.
# No need to implement `repeat_inner_outer` since this is implemented in `Base` as
# `repeat_outer(repeat_inner(arr, inner), outer)`.
function _RepeatInnerOuter.repeat_inner(xs::AnyGPUArray{<:Any, N}, inner) where {N}
return _repeat(xs; inner)
end

function repeat_outer_kernel!(
ctx::AbstractKernelContext,
xs::AbstractArray{<:Any, N},
xssize::NTuple{N},
outer::NTuple{N},
out::AbstractArray{<:Any, N}
) where {N}
# Get index to input element
idx = @cartesianidx xs
@inbounds val = xs[idx]

# Loop over repeat indices, copying val to out
for rdx in CartesianIndices(outer)
# Get destination CartesianIndex
odx = ntuple(N) do i
@inbounds idx[i] + xssize[i] * (rdx[i] -1)
end
@inbounds out[CartesianIndex(odx)] = val
end

return nothing
function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, N}, outer::NTuple{N}) where {N}
return _repeat(xs; outer)
end

function repeat_outer(xs::AnyGPUArray, outer)
out = similar(xs, eltype(xs), outer .* size(xs))
any(==(0), size(out)) && return out # consistent with `Base.repeat`
gpu_call(repeat_outer_kernel!, xs, size(xs), outer, out; elements=length(xs))
return out
function _RepeatInnerOuter.repeat_inner_outer(xs::AnyGPUArray{<:Any, 1}, inner, outer)
return _repeat(xs; inner, outer)
end

# Overload methods used by `Base.repeat`.
# No need to implement `repeat_inner_outer` since this is implemented in `Base` as
# `repeat_outer(repeat_inner(arr, inner), outer)`.
function _RepeatInnerOuter.repeat_inner(xs::AnyGPUArray{<:Any, N}, dims::NTuple{N}) where {N}
return repeat_inner(xs, dims)
end
function _repeat(x::AbstractArray, counts::Integer...)
N = max(ndims(x), length(counts))
size_y = ntuple(d -> size(x,d) * get(counts, d, 1), N)
size_x2 = ntuple(d -> isodd(d) ? size(x, 1+d÷2) : 1, 2*N)

function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, N}, dims::NTuple{N}) where {N}
return repeat_outer(xs, dims)
end
## version without mutation
# ignores = ntuple(d -> reshape(Base.OneTo(counts[d]), ntuple(_->1, 2d-1)..., :), length(counts))
# y = reshape(broadcast(first∘tuple, reshape(x, size_x2), ignores...), size_y)

function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, 1}, dims::Tuple{Any})
return repeat_outer(xs, dims)
# ## version with mutation
size_y2 = ntuple(d -> isodd(d) ? size(x, 1+d÷2) : get(counts, d÷2, 1), 2*N)
y = similar(x, size_y)
reshape(y, size_y2) .= reshape(x, size_x2)
y
end

function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, 2}, dims::NTuple{2, Any})
return repeat_outer(xs, dims)
function _repeat(x::AbstractArray; inner=1, outer=1)
N = max(ndims(x), length(inner), length(outer))
size_y = ntuple(d -> size(x, d) * get(inner, d, 1) * get(outer, d, 1), N)
size_y3 = ntuple(3*N) do d3
dim, class = divrem(d3+2, 3) # e.g. for x::Matrix, [divrem(n+2,3) for n in 1:3*2]
class == 0 && return get(inner, dim, 1)
class == 1 && return size(x, dim)
class == 2 && return get(outer, dim,1)
end
size_x3 = ntuple(3*N) do d3
dim, class = divrem(d3+2, 3)
class == 1 ? size(x, dim) : 1
end
y = similar(x, size_y)
reshape(y, size_y3) .= reshape(x, size_x3)
y
end

## PermutedDimsArrays
Expand Down

0 comments on commit 7b82b02

Please sign in to comment.