From ca504651afbba9f3d94dfa37c53c74e174acdd1c Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 2 Mar 2022 19:35:16 +0800 Subject: [PATCH] Make `broadcast` extendable to user array wrapper. (#1001) * Reviewed 1. Make `broadcast` extendable 2. Fix empty case. * Add more test * Bump release. * Test fix --- Project.toml | 2 +- src/broadcast.jl | 28 ++++++++++++++++++---------- src/precompile.jl | 2 +- test/broadcast.jl | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 6d0f8b42..ae3692a1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.4.1" +version = "1.4.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/broadcast.jl b/src/broadcast.jl index b134f2dd..b5a69b63 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -97,19 +97,28 @@ end scalar_getindex(x) = x scalar_getindex(x::Ref) = x[] -@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize - first_staticarray = a[findfirst(ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray}}, a)] +isstatic(::StaticArray) = true +isstatic(::Transpose{<:Any, <:StaticArray}) = true +isstatic(::Adjoint{<:Any, <:StaticArray}) = true +isstatic(::Diagonal{<:Any, <:StaticArray}) = true +isstatic(_) = false +@inline first_statictype(x, y...) = isstatic(x) ? typeof(x) : first_statictype(y...) +first_statictype() = error("unresolved dest type") + +@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize + first_staticarray = first_statictype(a...) if prod(newsize) == 0 # Use inference to get eltype in empty case (see also comments in _map) - eltys = [:(eltype(a[$i])) for i ∈ 1:length(a)] - return quote - @_inline_meta - T = Core.Compiler.return_type(f, Tuple{$(eltys...)}) - @inbounds return similar_type($first_staticarray, T, Size(newsize))() - end + eltys = Tuple{map(eltype, a)...} + T = Core.Compiler.return_type(f, eltys) + @inbounds return similar_type(first_staticarray, T, Size(newsize))() end + elements = __broadcast(f, sz, s, a...) + @inbounds return similar_type(first_staticarray, eltype(elements), Size(newsize))(elements) +end +@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize sizes = [sz.parameters[1] for sz ∈ s.parameters] indices = CartesianIndices(newsize) exprs = similar(indices, Expr) @@ -123,8 +132,7 @@ scalar_getindex(x::Ref) = x[] return quote @_inline_meta - @inbounds elements = tuple($(exprs...)) - @inbounds return similar_type($first_staticarray, eltype(elements), Size(newsize))(elements) + @inbounds return elements = tuple($(exprs...)) end end diff --git a/src/precompile.jl b/src/precompile.jl index 8287f149..90e0c2af 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -22,7 +22,7 @@ function _precompile_() end # Some expensive generators - @assert precompile(Tuple{typeof(which(_broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any}) + @assert precompile(Tuple{typeof(which(__broadcast,(Any,Size,Tuple{Vararg{Size}},Vararg{Any},)).generator.gen),Any,Any,Any,Any,Any,Any}) @assert precompile(Tuple{typeof(which(_zeros,(Size,Type{<:StaticArray},)).generator.gen),Any,Any,Any,Type,Any}) @assert precompile(Tuple{typeof(which(combine_sizes,(Tuple{Vararg{Size}},)).generator.gen),Any,Any}) @assert precompile(Tuple{typeof(which(_mapfoldl,(Any,Any,Colon,Any,Size,Vararg{StaticArray},)).generator.gen),Any,Any,Any,Any,Any,Any,Any,Any}) diff --git a/test/broadcast.jl b/test/broadcast.jl index e039039e..1baa4f28 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -282,3 +282,36 @@ end end end + +# A help struct to test style-based broadcast dispatch with unknown array wrapper. +# `WrapArray(A)` behaves like `A` during broadcast. But its not a `StaticArray`. +struct WrapArray{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} + data::P +end +Base.@propagate_inbounds Base.getindex(A::WrapArray, i::Integer...) = A.data[i...] +Base.@propagate_inbounds Base.setindex!(A::WrapArray, v::Any, i::Integer...) = setindex!(A.data, v, i...) +Base.size(A::WrapArray) = size(A.data) +Base.axes(A::WrapArray) = axes(A.data) +Broadcast.BroadcastStyle(::Type{WrapArray{T,N,P}}) where {T,N,P} = Broadcast.BroadcastStyle(P) +StaticArrays.isstatic(A::WrapArray) = StaticArrays.isstatic(A.data) +StaticArrays.Size(::Type{WrapArray{T,N,P}}) where {T,N,P} = StaticArrays.Size(P) +function StaticArrays.similar_type(::Type{WrapArray{T,N,P}}, ::Type{t}, s::Size{S}) where {T,N,P,t,S} + return StaticArrays.similar_type(P, t, s) +end + +@testset "Broadcast with unknown wrapper" begin + data = (1, 2) + for T in (SVector{2}, MVector{2}) + destT = T <: SArray ? SArray : MArray + a = T(data) + for b in (WrapArray(a), WrapArray(a')) + @test @inferred(b .+ a) isa destT + @test @inferred(b .+ b) isa destT + @test @inferred(b .+ (1, 2)) isa destT + @test @inferred(b .+ a') isa destT + @test @inferred(a' .+ b) isa destT + # @test @inferred(b' .+ a') isa StaticMatrix # Adjoint doesn't propagate style + @test b .+ b.data == b .+ b == b.data .+ b.data + end + end +end