Skip to content

Commit

Permalink
Provide constructors from generators (#792)
Browse files Browse the repository at this point in the history
* Add SVector constructors from generators

* Add SMatrix constructors from generators

* Add SArray constructor from generators

* Implement sacollect. Add error messages.

This also much simplifies the code.

* Simplify sacollect doc string

* Make test case backward compatible with Julia <1.5

* Correct test case

* Update src/SArray.jl

Co-authored-by: Takafumi Arakaki <[email protected]>

* Define generator constructors for StaticArray instead of just SArray

* Do not export sacollect

* Update sacollect tests

Co-authored-by: Takafumi Arakaki <[email protected]>
  • Loading branch information
eschnett and tkf authored May 23, 2020
1 parent c75c664 commit 64c64b2
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
julia = "1"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"

[targets]
test = ["InteractiveUtils", "Test", "BenchmarkTools"]
56 changes: 56 additions & 0 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,62 @@ end
end
end


@noinline function generator_too_short_error(inds::CartesianIndices, i::CartesianIndex)
error("Generator produced too few elements: Expected exactly $(shape_string(inds)) elements, but generator stopped at $(shape_string(i))")
end
@noinline function generator_too_long_error(inds::CartesianIndices)
error("Generator produced too many elements: Expected exactly $(shape_string(inds)) elements, but generator yields more")
end

shape_string(inds::CartesianIndices) = join(length.(inds.indices), '×')
shape_string(inds::CartesianIndex) = join(Tuple(inds), '×')

@inline throw_if_nothing(x, inds, i) =
(x === nothing && generator_too_short_error(inds, i); x)

@generated function sacollect(::Type{SA}, gen) where {SA <: StaticArray{S}} where {S <: Tuple}
stmts = [:(Base.@_inline_meta)]
args = []
iter = :(iterate(gen))
inds = CartesianIndices(size_to_tuple(S))
for i in inds
el = Symbol(:el, i)
push!(stmts, :(($el,st) = throw_if_nothing($iter, $inds, $i)))
push!(args, el)
iter = :(iterate(gen,st))
end
push!(stmts, :($iter === nothing || generator_too_long_error($inds)))
push!(stmts, :(SA($(args...))))
Expr(:block, stmts...)
end
"""
sacollect(SA, gen)
Construct a statically-sized vector of type `SA`.from a generator
`gen`. `SA` needs to have a size parameter since the length of `vec`
is unknown to the compiler. `SA` can optionally specify the element
type as well.
Example:
sacollect(SVector{3, Int}, 2i+1 for i in 1:3)
sacollect(SMatrix{2, 3}, i+j for i in 1:2, j in 1:3)
sacollect(SArray{2, 3}, i+j for i in 1:2, j in 1:3)
This creates the same statically-sized vector as if the generator were
collected in an array, but is more efficient since no array is
allocated.
Equivalent:
SVector{3, Int}([2i+1 for i in 1:3])
"""
sacollect

@inline (::Type{SA})(gen::Base.Generator) where {SA <: StaticArray} =
sacollect(SA, gen)

@inline SArray(a::StaticArray) = SArray{size_tuple(Size(a))}(Tuple(a))

####################
Expand Down
5 changes: 5 additions & 0 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ end
end
end

@inline SMatrix{M, N, T}(gen::Base.Generator) where {M, N, T} =
sacollect(SMatrix{M, N, T}, gen)
@inline SMatrix{M, N}(gen::Base.Generator) where {M, N} =
sacollect(SMatrix{M, N}, gen)

@inline convert(::Type{SMatrix{S1,S2}}, a::StaticArray{<:Tuple, T}) where {S1,S2,T} = SMatrix{S1,S2,T}(Tuple(a))
@inline SMatrix(a::StaticMatrix{S1, S2}) where {S1, S2} = SMatrix{S1, S2}(Tuple(a))

Expand Down
5 changes: 5 additions & 0 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ const SVector{S, T} = SArray{Tuple{S}, T, 1, S}
@inline SVector{S}(x::NTuple{S,T}) where {S, T} = SVector{S,T}(x)
@inline SVector{S}(x::T) where {S, T <: Tuple} = SVector{S,promote_tuple_eltype(T)}(x)

@inline SVector{N, T}(gen::Base.Generator) where {N, T} =
sacollect(SVector{N, T}, gen)
@inline SVector{N}(gen::Base.Generator) where {N} =
sacollect(SVector{N}, gen)

# conversion from AbstractVector / AbstractArray (better inference than default)
#@inline convert{S,T}(::Type{SVector{S}}, a::AbstractArray{T}) = SVector{S,T}((a...))

Expand Down
23 changes: 23 additions & 0 deletions test/MArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@
v = MArray{Tuple{2}}(1,2)
@test MArray(v) !== v && MArray(v) == v

@test MArray{Tuple{}}(i for i in 1:1).data === (1,)
@test MArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
@test MArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test MArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
@test MArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
@test_throws Exception MArray{Tuple{}}(i for i in 1:0)
@test_throws Exception MArray{Tuple{}}(i for i in 1:2)
@test_throws Exception MArray{Tuple{3}}(i for i in 1:2)
@test_throws Exception MArray{Tuple{3}}(i for i in 1:4)
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
@test_throws Exception MArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)

@test StaticArrays.sacollect(MVector{6}, Iterators.product(1:2, 1:3)) ==
MVector{6}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(MVector{2}, Iterators.zip(1:2, 2:3)) ==
MVector{2}(collect(Iterators.zip(1:2, 2:3)))
@test StaticArrays.sacollect(MVector{3}, Iterators.take(1:10, 3)) ==
MVector{3}(collect(Iterators.take(1:10, 3)))
@test StaticArrays.sacollect(MMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
MMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(MArray{Tuple{2,3,4}}, 1:24) ==
MArray{Tuple{2,3,4}}(collect(1:24))

@test ((@MArray [1])::MArray{Tuple{1}}).data === (1,)
@test ((@MArray [1,2])::MArray{Tuple{2}}).data === (1,2)
@test ((@MArray Float64[1,2,3])::MArray{Tuple{3}}).data === (1.0, 2.0, 3.0)
Expand Down
23 changes: 23 additions & 0 deletions test/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,29 @@

@test SArray(SArray{Tuple{2}}(1,2)) === SArray{Tuple{2}}(1,2)

@test SArray{Tuple{}}(i for i in 1:1).data === (1,)
@test SArray{Tuple{3}}(i for i in 1:3).data === (1,2,3)
@test SArray{Tuple{3}}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test SArray{Tuple{2,3}}(i+10j for i in 1:2, j in 1:3).data === (11,12,21,22,31,32)
@test SArray{Tuple{1,2,3}}(i+10j+100k for i in 1:1, j in 1:2, k in 1:3).data === (111,121,211,221,311,321)
@test_throws Exception SArray{Tuple{}}(i for i in 1:0)
@test_throws Exception SArray{Tuple{}}(i for i in 1:2)
@test_throws Exception SArray{Tuple{3}}(i for i in 1:2)
@test_throws Exception SArray{Tuple{3}}(i for i in 1:4)
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:1, j in 1:3)
@test_throws Exception SArray{Tuple{2,3}}(10i+j for i in 1:3, j in 1:3)

@test StaticArrays.sacollect(SVector{6}, Iterators.product(1:2, 1:3)) ==
SVector{6}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(SVector{2}, Iterators.zip(1:2, 2:3)) ==
SVector{2}(collect(Iterators.zip(1:2, 2:3)))
@test StaticArrays.sacollect(SVector{3}, Iterators.take(1:10, 3)) ==
SVector{3}(collect(Iterators.take(1:10, 3)))
@test StaticArrays.sacollect(SMatrix{2,3}, Iterators.product(1:2, 1:3)) ==
SMatrix{2,3}(collect(Iterators.product(1:2, 1:3)))
@test StaticArrays.sacollect(SArray{Tuple{2,3,4}}, 1:24) ==
SArray{Tuple{2,3,4}}(collect(1:24))

@test ((@SArray [1])::SArray{Tuple{1}}).data === (1,)
@test ((@SArray [1,2])::SArray{Tuple{2}}).data === (1,2)
@test ((@SArray Float64[1,2,3])::SArray{Tuple{3}}).data === (1.0, 2.0, 3.0)
Expand Down
16 changes: 16 additions & 0 deletions test/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@
@test SMatrix{2}((1,2,3,4)).data === (1,2,3,4)
@test_throws DimensionMismatch SMatrix{2}((1,2,3,4,5))

@test (SMatrix{2,3}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11,12,21,22,31,32)
@test (SMatrix{2,3}(float(i+10j) for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11.0,12.0,21.0,22.0,31.0,32.0)
@test (SMatrix{0,0,Int}()::SMatrix{0,0}).data === ()
@test (SMatrix{0,3,Int}()::SMatrix{0,3}).data === ()
@test (SMatrix{2,0,Int}()::SMatrix{2,0}).data === ()
@test (SMatrix{2,3,Int}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11,12,21,22,31,32)
@test (SMatrix{2,3,Float64}(i+10j for i in 1:2, j in 1:3)::SMatrix{2,3}).data ===
(11.0,12.0,21.0,22.0,31.0,32.0)
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:1, j in 1:3)
@test_throws Exception SMatrix{2,3}(i+10j for i in 1:3, j in 1:3)
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:1, j in 1:3)
@test_throws Exception SMatrix{2,3,Int}(i+10j for i in 1:3, j in 1:3)

@test ((@SMatrix [1.0])::SMatrix{1,1}).data === (1.0,)
@test ((@SMatrix [1 2])::SMatrix{1,2}).data === (1, 2)
@test ((@SMatrix [1 ; 2])::SMatrix{2,1}).data === (1, 2)
Expand Down
10 changes: 10 additions & 0 deletions test/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
@test SVector((1,)).data === (1,)
@test SVector((1.0,)).data === (1.0,)

@test SVector{3}(i for i in 1:3).data === (1,2,3)
@test SVector{3}(float(i) for i in 1:3).data === (1.0,2.0,3.0)
@test SVector{0,Int}().data === ()
@test SVector{3,Int}(i for i in 1:3).data === (1,2,3)
@test SVector{3,Float64}(i for i in 1:3).data === (1.0,2.0,3.0)
@test_throws Exception SVector{3}(i for i in 1:2)
@test_throws Exception SVector{3}(i for i in 1:4)
@test_throws Exception SVector{3,Int}(i for i in 1:2)
@test_throws Exception SVector{3,Int}(i for i in 1:4)

@test SVector(1).data === (1,)
@test SVector(1,1.0).data === (1.0,1.0)

Expand Down

0 comments on commit 64c64b2

Please sign in to comment.