Skip to content

Commit

Permalink
Make broadcast extendable to user array wrapper. (#1001)
Browse files Browse the repository at this point in the history
* Reviewed

1. Make `broadcast` extendable
2. Fix empty case.

* Add more test

* Bump release.

* Test fix
  • Loading branch information
N5N3 authored Mar 2, 2022
1 parent df49828 commit ca50465
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.4.1"
version = "1.4.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
28 changes: 18 additions & 10 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,28 @@ end
scalar_getindex(x) = x
scalar_getindex(x::Ref) = x[]

@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray}}, a)]
isstatic(::StaticArray) = true
isstatic(::Transpose{<:Any, <:StaticArray}) = true
isstatic(::Adjoint{<:Any, <:StaticArray}) = true
isstatic(::Diagonal{<:Any, <:StaticArray}) = true
isstatic(_) = false

@inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...)
first_statictype() = error("unresolved dest type")

@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
first_staticarray = first_statictype(a...)
if prod(newsize) == 0
# Use inference to get eltype in empty case (see also comments in _map)
eltys = [:(eltype(a[$i])) for i 1:length(a)]
return quote
@_inline_meta
T = Core.Compiler.return_type(f, Tuple{$(eltys...)})
@inbounds return similar_type($first_staticarray, T, Size(newsize))()
end
eltys = Tuple{map(eltype, a)...}
T = Core.Compiler.return_type(f, eltys)
@inbounds return similar_type(first_staticarray, T, Size(newsize))()
end
elements = __broadcast(f, sz, s, a...)
@inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements)
end

@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
sizes = [sz.parameters[1] for sz s.parameters]
indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
Expand All @@ -123,8 +132,7 @@ scalar_getindex(x::Ref) = x[]

return quote
@_inline_meta
@inbounds elements = tuple($(exprs...))
@inbounds return similar_type($first_staticarray, eltype(elements), Size(newsize))(elements)
@inbounds return elements = tuple($(exprs...))
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function _precompile_()
end

# Some expensive generators
@assert precompile(Tuple{typeof(which(_broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
@assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any})
@assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any})
@assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any})
@assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any})
Expand Down
33 changes: 33 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,36 @@ end
end

end

# A help struct to test style-based broadcast dispatch with unknown array wrapper.
# `WrapArray(A)` behaves like `A` during broadcast. But its not a `StaticArray`.
struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
data::P
end
Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...]
Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...)
Base.size(A::WrapArray) = size(A.data)
Base.axes(A::WrapArray) = axes(A.data)
Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P)
StaticArrays.isstatic(A::WrapArray) = StaticArrays.isstatic(A.data)
StaticArrays.Size(::Type{WrapArray{T,N,P}}) where {T,N,P} = StaticArrays.Size(P)
function StaticArrays.similar_type(::Type{WrapArray{T,N,P}}, ::Type{t}, s::Size{S}) where {T,N,P,t,S}
return StaticArrays.similar_type(P, t, s)
end

@testset "Broadcast with unknown wrapper" begin
data = (1, 2)
for T in (SVector{2}, MVector{2})
destT = T <: SArray ? SArray : MArray
a = T(data)
for b in (WrapArray(a), WrapArray(a'))
@test @inferred(b .+ a) isa destT
@test @inferred(b .+ b) isa destT
@test @inferred(b .+ (1, 2)) isa destT
@test @inferred(b .+ a') isa destT
@test @inferred(a' .+ b) isa destT
# @test @inferred(b' .+ a') isa StaticMatrix # Adjoint doesn't propagate style
@test b .+ b.data == b .+ b == b.data .+ b.data
end
end
end

0 comments on commit ca50465

Please sign in to comment.