Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide constructors from generators #792

Merged
merged 11 commits into from
May 23, 2020
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
julia = "1"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools"]
56 changes: 56 additions & 0 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,62 @@ end
end
end


@noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex)
error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))")
end
@noinline function generator_too_long_error(inds::CartesianIndices)
error("Generator produced too many elements: Expected exactly $(shape_string(inds)) elements, but generator yields more")
end

shape_string(inds::CartesianIndices) = join(length.(inds.indices), '×')
shape_string(inds::CartesianIndex) = join(Tuple(inds), '×')

@inline throw_if_nothing(x, inds, i) =
(x === nothing && generator_too_short_error(inds, i); x)
eschnett marked this conversation as resolved.
Show resolved Hide resolved

@generated function sacollect(::Type{SA}, gen) where {SA <: StaticArray{S}} where {S <: Tuple}
stmts = [:(Base.@_inline_meta)]
args = []
iter = :(iterate(gen))
inds = CartesianIndices(size_to_tuple(S))
for i in inds
el = Symbol(:el, i)
push!(stmts, :(($el,st) = throw_if_nothing($iter, $inds, $i)))
push!(args, el)
iter = :(iterate(gen,st))
end
push!(stmts, :($iter === nothing || generator_too_long_error($inds)))
push!(stmts, :(SA($(args...))))
Expr(:block, stmts...)
end
"""
sacollect(SA, gen)

Construct a statically-sized vector of type `SA`.from a generator
`gen`. `SA` needs to have a size parameter since the length of `vec`
is unknown to the compiler. `SA` can optionally specify the element
type as well.

Example:

sacollect(SVector{3, Int}, 2i+1 for i in 1:3)
sacollect(SMatrix{2, 3}, i+j for i in 1:2, j in 1:3)
sacollect(SArray{2, 3}, i+j for i in 1:2, j in 1:3)

This creates the same statically-sized vector as if the generator were
collected in an array, but is more efficient since no array is
allocated.

Equivalent:

SVector{3, Int}([2i+1 for i in 1:3])
"""
sacollect

@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
sacollect(SA, gen)

@inline SArray(a::StaticArray) = SArray{size_tuple(Size(a))}(Tuple(a))

####################
Expand Down
5 changes: 5 additions & 0 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ end
end
end

@inline SMatrix{M, N, T}(gen::Base.Generator) where {M, N, T} =
sacollect(SMatrix{M, N, T}, gen)
@inline SMatrix{M, N}(gen::Base.Generator) where {M, N} =
sacollect(SMatrix{M, N}, gen)

@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
@inline SMatrix(a::StaticMatrix{S1, S2}) where {S1, S2} = SMatrix{S1, S2}(Tuple(a))

Expand Down
5 changes: 5 additions & 0 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ const SVector{S, T} = SArray{Tuple{S}, T, 1, S}
@inline SVector{S}(x::NTuple{S,T}) where {S, T} = SVector{S,T}(x)
@inline SVector{S}(x::T) where {S, T <: Tuple} = SVector{S,promote_tuple_eltype(T)}(x)

@inline SVector{N, T}(gen::Base.Generator) where {N, T} =
sacollect(SVector{N, T}, gen)
@inline SVector{N}(gen::Base.Generator) where {N} =
sacollect(SVector{N}, gen)

# conversion from AbstractVector / AbstractArray (better inference than default)
#@inline convert{S,T}(::Type{SVector{S}}, a::AbstractArray{T}) = SVector{S,T}((a...))

Expand Down
23 changes: 23 additions & 0 deletions test/MArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
v = MArray{Tuple{2}}(1,2)
@test MArray(v) !== v && MArray(v) == v

@test MArray{Tuple{}}(i for i in 1:1).data === (1,)
@test MArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
@test MArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test MArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
@test MArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
@test_throws Exception MArray{Tuple{}}(i for i in 1:0)
@test_throws Exception MArray{Tuple{}}(i for i in 1:2)
@test_throws Exception MArray{Tuple{3}}(i for i in 1:2)
@test_throws Exception MArray{Tuple{3}}(i for i in 1:4)
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)

@test StaticArrays.sacollect(MVector{6}, Iterators.product(1:2, 1:3)) ==
MVector{6}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(MVector{2}, Iterators.zip(1:2, 2:3)) ==
MVector{2}(collect(Iterators.zip(1:2, 2:3)))
@test StaticArrays.sacollect(MVector{3}, Iterators.take(1:10, 3)) ==
MVector{3}(collect(Iterators.take(1:10, 3)))
@test StaticArrays.sacollect(MMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
MMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(MArray{Tuple{2,3,4}}, 1:24) ==
MArray{Tuple{2,3,4}}(collect(1:24))

@test ((@MArray [1])::MArray{Tuple{1}}).data === (1,)
@test ((@MArray [1,2])::MArray{Tuple{2}}).data === (1,2)
@test ((@MArray Float64[1,2,3])::MArray{Tuple{3}}).data === (1.0, 2.0, 3.0)
Expand Down
23 changes: 23 additions & 0 deletions test/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,29 @@

@test SArray(SArray{Tuple{2}}(1,2)) === SArray{Tuple{2}}(1,2)

@test SArray{Tuple{}}(i for i in 1:1).data === (1,)
@test SArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
@test SArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test SArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
@test SArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
@test_throws Exception SArray{Tuple{}}(i for i in 1:0)
@test_throws Exception SArray{Tuple{}}(i for i in 1:2)
@test_throws Exception SArray{Tuple{3}}(i for i in 1:2)
@test_throws Exception SArray{Tuple{3}}(i for i in 1:4)
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)

@test StaticArrays.sacollect(SVector{6}, Iterators.product(1:2, 1:3)) ==
SVector{6}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(SVector{2}, Iterators.zip(1:2, 2:3)) ==
SVector{2}(collect(Iterators.zip(1:2, 2:3)))
@test StaticArrays.sacollect(SVector{3}, Iterators.take(1:10, 3)) ==
SVector{3}(collect(Iterators.take(1:10, 3)))
@test StaticArrays.sacollect(SMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
SMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(SArray{Tuple{2,3,4}}, 1:24) ==
SArray{Tuple{2,3,4}}(collect(1:24))

@test ((@SArray [1])::SArray{Tuple{1}}).data === (1,)
@test ((@SArray [1,2])::SArray{Tuple{2}}).data === (1,2)
@test ((@SArray Float64[1,2,3])::SArray{Tuple{3}}).data === (1.0, 2.0, 3.0)
Expand Down
16 changes: 16 additions & 0 deletions test/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
@test SMatrix{2}((1,2,3,4)).data === (1,2,3,4)
@test_throws DimensionMismatch SMatrix{2}((1,2,3,4,5))

@test (SMatrix{2,3}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11,12,21,22,31,32)
@test (SMatrix{2,3}(float(i+10j) for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11.0,12.0,21.0,22.0,31.0,32.0)
@test (SMatrix{0,0,Int}()::SMatrix{0,0}).data === ()
@test (SMatrix{0,3,Int}()::SMatrix{0,3}).data === ()
@test (SMatrix{2,0,Int}()::SMatrix{2,0}).data === ()
@test (SMatrix{2,3,Int}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11,12,21,22,31,32)
@test (SMatrix{2,3,Float64}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11.0,12.0,21.0,22.0,31.0,32.0)
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:1, j in 1:3)
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:3, j in 1:3)
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:1, j in 1:3)
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:3, j in 1:3)

@test ((@SMatrix [1.0])::SMatrix{1,1}).data === (1.0,)
@test ((@SMatrix [1 2])::SMatrix{1,2}).data === (1, 2)
@test ((@SMatrix [1 ; 2])::SMatrix{2,1}).data === (1, 2)
Expand Down
10 changes: 10 additions & 0 deletions test/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
@test SVector((1,)).data === (1,)
@test SVector((1.0,)).data === (1.0,)

@test SVector{3}(i for i in 1:3).data === (1,2,3)
@test SVector{3}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test SVector{0,Int}().data === ()
@test SVector{3,Int}(i for i in 1:3).data === (1,2,3)
@test SVector{3,Float64}(i for i in 1:3).data === (1.0,2.0,3.0)
@test_throws Exception SVector{3}(i for i in 1:2)
@test_throws Exception SVector{3}(i for i in 1:4)
@test_throws Exception SVector{3,Int}(i for i in 1:2)
@test_throws Exception SVector{3,Int}(i for i in 1:4)

@test SVector(1).data === (1,)
@test SVector(1,1.0).data === (1.0,1.0)

Expand Down