diff --git a/src/host/base.jl b/src/host/base.jl index d084b487..21e09ccb 100644 --- a/src/host/base.jl +++ b/src/host/base.jl @@ -1,82 +1,53 @@ # common Base functionality import Base: _RepeatInnerOuter -function repeat_inner_kernel!( - ctx::AbstractKernelContext, - xs::AbstractArray{<:Any, N}, - inner::NTuple{N, Int}, - out::AbstractArray{<:Any, N} -) where {N} - # Get single element from src - idx = @cartesianidx xs - @inbounds val = xs[idx] - - # Loop over "repeat" indices of inner - for rdx in CartesianIndices(inner) - # Get destination CartesianIndex - odx = ntuple(N) do i - @inbounds (idx[i]-1) * inner[i] + rdx[i] - end - @inbounds out[CartesianIndex(odx)] = val - end - - return nothing -end - -function repeat_inner(xs::AnyGPUArray, inner) - out = similar(xs, eltype(xs), inner .* size(xs)) - any(==(0), size(out)) && return out # consistent with `Base.repeat` - gpu_call(repeat_inner_kernel!, xs, inner, out; elements=length(xs)) - return out +# Overload methods used by `Base.repeat`. +# No need to implement `repeat_inner_outer` since this is implemented in `Base` as +# `repeat_outer(repeat_inner(arr, inner), outer)`. +function _RepeatInnerOuter.repeat_inner(xs::AnyGPUArray{<:Any, N}, inner) where {N} + return _repeat(xs; inner) end -function repeat_outer_kernel!( - ctx::AbstractKernelContext, - xs::AbstractArray{<:Any, N}, - xssize::NTuple{N}, - outer::NTuple{N}, - out::AbstractArray{<:Any, N} -) where {N} - # Get index to input element - idx = @cartesianidx xs - @inbounds val = xs[idx] - - # Loop over repeat indices, copying val to out - for rdx in CartesianIndices(outer) - # Get destination CartesianIndex - odx = ntuple(N) do i - @inbounds idx[i] + xssize[i] * (rdx[i] -1) - end - @inbounds out[CartesianIndex(odx)] = val - end - - return nothing +function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, N}, outer::NTuple{N}) where {N} + return _repeat(xs; outer) end -function repeat_outer(xs::AnyGPUArray, outer) - out = similar(xs, eltype(xs), outer .* size(xs)) - any(==(0), size(out)) && return out # consistent with `Base.repeat` - gpu_call(repeat_outer_kernel!, xs, size(xs), outer, out; elements=length(xs)) - return out +function _RepeatInnerOuter.repeat_inner_outer(xs::AnyGPUArray{<:Any, 1}, inner, outer) + return _repeat(xs; inner, outer) end -# Overload methods used by `Base.repeat`. -# No need to implement `repeat_inner_outer` since this is implemented in `Base` as -# `repeat_outer(repeat_inner(arr, inner), outer)`. -function _RepeatInnerOuter.repeat_inner(xs::AnyGPUArray{<:Any, N}, dims::NTuple{N}) where {N} - return repeat_inner(xs, dims) -end +function _repeat(x::AbstractArray, counts::Integer...) + N = max(ndims(x), length(counts)) + size_y = ntuple(d -> size(x,d) * get(counts, d, 1), N) + size_x2 = ntuple(d -> isodd(d) ? size(x, 1+d÷2) : 1, 2*N) -function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, N}, dims::NTuple{N}) where {N} - return repeat_outer(xs, dims) -end + ## version without mutation + # ignores = ntuple(d -> reshape(Base.OneTo(counts[d]), ntuple(_->1, 2d-1)..., :), length(counts)) + # y = reshape(broadcast(first∘tuple, reshape(x, size_x2), ignores...), size_y) -function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, 1}, dims::Tuple{Any}) - return repeat_outer(xs, dims) + # ## version with mutation + size_y2 = ntuple(d -> isodd(d) ? size(x, 1+d÷2) : get(counts, d÷2, 1), 2*N) + y = similar(x, size_y) + reshape(y, size_y2) .= reshape(x, size_x2) + y end -function _RepeatInnerOuter.repeat_outer(xs::AnyGPUArray{<:Any, 2}, dims::NTuple{2, Any}) - return repeat_outer(xs, dims) +function _repeat(x::AbstractArray; inner=1, outer=1) + N = max(ndims(x), length(inner), length(outer)) + size_y = ntuple(d -> size(x, d) * get(inner, d, 1) * get(outer, d, 1), N) + size_y3 = ntuple(3*N) do d3 + dim, class = divrem(d3+2, 3) # e.g. for x::Matrix, [divrem(n+2,3) for n in 1:3*2] + class == 0 && return get(inner, dim, 1) + class == 1 && return size(x, dim) + class == 2 && return get(outer, dim,1) + end + size_x3 = ntuple(3*N) do d3 + dim, class = divrem(d3+2, 3) + class == 1 ? size(x, dim) : 1 + end + y = similar(x, size_y) + reshape(y, size_y3) .= reshape(x, size_x3) + y end ## PermutedDimsArrays