Skip to content

Commit

Permalink
use Val(dims) and Iterators.peel for _dim_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 10, 2022
1 parent a9db4de commit fa7e79c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 26 deletions.
49 changes: 23 additions & 26 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2662,44 +2662,41 @@ _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S}
_vec_axis(A, ax=_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1))

function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S}
xit = iterate(A)
xit = Iterators.peel(A)
nothing === xit && return _empty_stack(dims, T, S, A)
x1, _ = xit
x1, xrest = xit
ax1 = _axes(x1)
N1 = length(ax1)+1
dims in 1:N1 || throw(ArgumentError("cannot stack slices ndims(x) = $(N1-1) along dims = $dims"))

newaxis = _vec_axis(A)
outax = ntuple(d -> d==dims ? newaxis : _axes(x1)[d - (d>dims)], N1)
outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1)
B = similar(_first_array(x1, A), T, outax...)

iit = iterate(newaxis)
while xit !== nothing
x, state = xit
i, istate = iit
_stack_size_check(x, ax1)
@inbounds if dims==1
inds1 = ntuple(d -> d==1 ? i : Colon(), N1)
if x isa AbstractArray
B[inds1...] = x
else
copyto!(view(B, inds1...), x)
end
else
inds = ntuple(d -> d==dims ? i : Colon(), N1)
if x isa AbstractArray
B[inds...] = x
else
# This is where the type-instability of inds hurts, but it is pretty exotic:
copyto!(view(B, inds...), x)
end
end
xit = iterate(A, state)
iit = iterate(newaxis, istate)
if dims == 1
_dim_stack!(Val(1), B, x1, xrest)
elseif dims == 2
_dim_stack!(Val(2), B, x1, xrest)
else
_dim_stack!(Val(dims), B, x1, xrest)
end
B
end

function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims}
before = ntuple(d -> Colon(), dims - 1)
after = ntuple(d -> Colon(), ndims(B) - dims)

i = firstindex(B, dims)
copyto!(view(B, before..., i, after...), x1)

for x in xrest
_stack_size_check(x, _axes(x1))
i += 1
@inbounds copyto!(view(B, before..., i, after...), x)
end
end

@inline function _stack_size_check(x, ax1::Tuple)
if _axes(x) != ax1
uax1 = UnitRange.(ax1)
Expand Down
14 changes: 14 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,20 @@ end
# Trivial, because numbers are iterable:
@test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3))

# Allocation tests
xv = [rand(10) for _ in 1:100]
xt = Tuple.(xv)
for dims in (1, 2, :)
@test stack(xv; dims) == stack(xt; dims)
@test 9000 > @allocated stack(xv; dims)
@test 9000 > @allocated stack(xt; dims)
end
xr = (reshape(1:1000,10,10,10) for _ = 1:1000)
for dims in (1, 2, 3, :)
stack(xr; dims)
@test 8.1e6 > @allocated stack(xr; dims)
end

# Mismatched sizes
@test_throws DimensionMismatch stack([1:2, 1:3])
@test_throws DimensionMismatch stack([1:2, 1:3]; dims=1)
Expand Down

0 comments on commit fa7e79c

Please sign in to comment.