diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 8a347676..19b5048a 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -4,7 +4,7 @@ using Base: tuple_type_cons, tuple_type_head, tuple_type_tail, tail using PooledArrays: PooledArray export StructArray, StructVector, LazyRow, LazyRows -export collect_structarray, fieldarrays +export collect_structarray, collect_to_structarray!, fieldarrays export replace_storage include("interface.jl") diff --git a/src/collect.jl b/src/collect.jl index c2e898ea..871942e3 100644 --- a/src/collect.jl +++ b/src/collect.jl @@ -128,3 +128,41 @@ function widenarray(dest::AbstractArray, i, ::Type{T}) where T copyto!(new, 1, dest, 1, i-1) new end + +""" +`collect_to_structarray!(dest, itr) -> dest′` + +Try to append `itr` into a vector `dest`. Widen element type of +`dest` if it cannot hold the elements of `itr`. That is to say, + +```julia +vcat(dest, StructVector(itr)) == collect_to_structarray!(dest, itr) +``` + +holds. Note that `dest′` may or may not be the same object as `dest`. +The state of `dest` is unpredictable after `collect_to_structarray!` +is called (e.g., it may contain just half of the elements from `itr`). +""" +collect_to_structarray!(dest::AbstractVector, itr) = + _collect_or_grow!(dest, itr, Base.IteratorSize(itr)) + +function _collect_or_grow!(dest::AbstractVector, itr, ::Union{Base.HasShape, Base.HasLength}) + n = length(itr) # itr may be stateful so do this first + fr = iterate(itr) + fr === nothing && return dest + el, st = fr + i = lastindex(dest) + 1 + if iscompatible(el, dest) + resize!(dest, length(dest) + n) + dest[i] = el + return _collect_to_structarray!(dest, itr, i + 1, st) + else + new = widenstructarray(dest, i, el) + resize!(new, length(dest) + n) + @inbounds new[i] = el + return _collect_to_structarray!(new, itr, i + 1, st) + end +end + +_collect_or_grow!(dest::AbstractVector, itr, ::Base.SizeUnknown) = + grow_to_structarray!(dest, itr) diff --git a/test/runtests.jl b/test/runtests.jl index 49f2ac41..aeb8a4bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -646,3 +646,24 @@ end str = String(take!(io)) @test str == "StructArray(::Array{Int64,1}, ::Array{Int64,1})" end + +@testset "collect_to_structarray!" begin + dest_examples = [ + ("mutate", StructVector(a = [1], b = [2])), + ("widen", StructVector(a = [1], b = [nothing])), + ] + itr = [(a = 1, b = 2), (a = 1, b = 2), (a = 1, b = 12)] + itr_examples = [ + ("HasLength", () -> itr), + ("SizeUnknown", () -> (x for x in itr if isodd(x.a))), + # Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/100: + # ("empty", (x for x in itr if false)), + # Broken due to https://github.com/JuliaArrays/StructArrays.jl/issues/99: + # ("stateful", () -> Iterators.Stateful(itr)), + ] + @testset "$destlabel $itrlabel" for (destlabel, dest) in dest_examples, + (itrlabel, makeitr) in itr_examples + + @test vcat(dest, StructVector(makeitr())) == collect_to_structarray!(copy(dest), makeitr()) + end +end