From 02af65f8238927f30306523e29f9dd3d9f5d0e70 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 4 Dec 2021 20:52:24 -0500 Subject: [PATCH] type stability --- base/abstractarray.jl | 17 +++++++++++++---- test/abstractarray.jl | 14 ++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index bebc955df939b2..fdf4d7c09d3549 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2611,7 +2611,7 @@ end function _vstack_plus(itr) z = iterate(itr) - isnothing(z) && throw(ArgumentError("cannot stack an empty collection")) + z === nothing && throw(ArgumentError("cannot stack an empty collection")) val, state = z val isa Union{AbstractArray, Tuple} || throw(ArgumentError("cannot stack elements of type $(typeof(val))")) @@ -2619,7 +2619,12 @@ function _vstack_plus(itr) len = length(val) n = haslength(itr) ? len*length(itr) : nothing - v = similar(val isa Tuple ? (1:0) : val, eltype(val), something(n, len)) + v = if val isa Tuple + T = mapreduce(typeof, promote_type, val) + similar(1:0, T, something(n, len)) + else + similar(val, something(n, len)) + end copyto!(v, 1, val, firstindex(val), len) w = _stack_rest!(v, 0, n, axe, itr, state) @@ -2630,12 +2635,16 @@ function _stack_rest!(v::AbstractVector, i, n, axe, itr, state) len = prod(length, axe; init=1) while true z = iterate(itr, state) - isnothing(z) && return v + z === nothing && return v val, state = z axes(val) == axe || throw(DimensionMismatch( "expected a consistent size, got axes $(UnitRange.(axes(val))) compared to $(UnitRange.(axe)) for the first")) i += 1 - T′ = promote_type(eltype(v), eltype(val)) + T′ = if val isa Tuple + promote_type(eltype(v), mapreduce(typeof, promote_type, val)) + else + promote_type(eltype(v), eltype(val)) + end if T′ <: eltype(v) if n isa Integer copyto!(v, i*len+1, val, firstindex(val), len) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index f8f83d5ece7025..b4960631a31b47 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1565,6 +1565,11 @@ end X3 = stack(x for x in args if true) @test X3 == Y @test typeof(X3) === typeof(Y) + + if isconcretetype(eltype(args)) + @inferred stack(args) + @inferred stack(x for x in args) + end end # Higher dims @test size(stack([rand(2,3) for _ in 1:4, _ in 1:5])) == (2,3,4,5) @@ -1574,12 +1579,13 @@ end # Tuples @test stack([(1,2), (3,4)]) == [1 3; 2 4] @test stack(((1,2), (3,4))) == [1 3; 2 4] - @test size(stack(Iterators.product(1:3, 1:4))) == (2,3,4) - @test stack([('a', 'b'), ('c', 'd')]) == ['a' 'c'; 'b' 'd'] + @test size(@inferred stack(Iterators.product(1:3, 1:4))) == (2,3,4) + @test @inferred(stack([('a', 'b'), ('c', 'd')])) == ['a' 'c'; 'b' 'd'] + @test @inferred(stack([(1,2+3im), (4, 5+6im)])) isa Matrix{Complex{Int}} # stack(f, iter) - @test stack(x -> [x, 2x], 3:5) == [3 4 5; 6 8 10] - @test stack(x -> x*x'/2, [1:2, 3:4]) == [0.5 1.0; 1.0 2.0;;; 4.5 6.0; 6.0 8.0] + @test @inferred(stack(x -> [x, 2x], 3:5)) == [3 4 5; 6 8 10] + @test @inferred(stack(x -> x*x'/2, [1:2, 3:4])) == [0.5 1.0; 1.0 2.0;;; 4.5 6.0; 6.0 8.0] # Mismatched sizes @test_throws DimensionMismatch stack([1:2, 1:3])