diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 932beb2a268fc4..42753a44ef98dc 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2662,44 +2662,41 @@ _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S} _vec_axis(A, ax=_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1)) function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} - xit = iterate(A) + xit = Iterators.peel(A) nothing === xit && return _empty_stack(dims, T, S, A) - x1, _ = xit + x1, xrest = xit ax1 = _axes(x1) N1 = length(ax1)+1 dims in 1:N1 || throw(ArgumentError("cannot stack slices ndims(x) = $(N1-1) along dims = $dims")) newaxis = _vec_axis(A) - outax = ntuple(d -> d==dims ? newaxis : _axes(x1)[d - (d>dims)], N1) + outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1) B = similar(_first_array(x1, A), T, outax...) - iit = iterate(newaxis) - while xit !== nothing - x, state = xit - i, istate = iit - _stack_size_check(x, ax1) - @inbounds if dims==1 - inds1 = ntuple(d -> d==1 ? i : Colon(), N1) - if x isa AbstractArray - B[inds1...] = x - else - copyto!(view(B, inds1...), x) - end - else - inds = ntuple(d -> d==dims ? i : Colon(), N1) - if x isa AbstractArray - B[inds...] = x - else - # This is where the type-instability of inds hurts, but it is pretty exotic: - copyto!(view(B, inds...), x) - end - end - xit = iterate(A, state) - iit = iterate(newaxis, istate) + if dims == 1 + _dim_stack!(Val(1), B, x1, xrest) + elseif dims == 2 + _dim_stack!(Val(2), B, x1, xrest) + else + _dim_stack!(Val(dims), B, x1, xrest) end B end +function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims} + before = ntuple(d -> Colon(), dims - 1) + after = ntuple(d -> Colon(), ndims(B) - dims) + + i = firstindex(B, dims) + copyto!(view(B, before..., i, after...), x1) + + for x in xrest + _stack_size_check(x, _axes(x1)) + i += 1 + @inbounds copyto!(view(B, before..., i, after...), x) + end +end + @inline function _stack_size_check(x, ax1::Tuple) if _axes(x) != ax1 uax1 = UnitRange.(ax1) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index fef44da4628ef5..b0dd7407feb299 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1606,6 +1606,20 @@ end # Trivial, because numbers are iterable: @test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3)) + # Allocation tests + xv = [rand(10) for _ in 1:100] + xt = Tuple.(xv) + for dims in (1, 2, :) + @test stack(xv; dims) == stack(xt; dims) + @test 9000 > @allocated stack(xv; dims) + @test 9000 > @allocated stack(xt; dims) + end + xr = (reshape(1:1000,10,10,10) for _ = 1:1000) + for dims in (1, 2, 3, :) + stack(xr; dims) + @test 8.1e6 > @allocated stack(xr; dims) + end + # Mismatched sizes @test_throws DimensionMismatch stack([1:2, 1:3]) @test_throws DimensionMismatch stack([1:2, 1:3]; dims=1)