From fef311e6e0a23384376a729f334a7d400bfc4df8 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 27 Jul 2022 10:02:55 +0800 Subject: [PATCH 1/9] Try to resolve style conflict. --- src/structarray.jl | 28 ++++++++++++++--- test/runtests.jl | 75 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 82 insertions(+), 21 deletions(-) diff --git a/src/structarray.jl b/src/structarray.jl index d4bf529f..14c22e3f 100644 --- a/src/structarray.jl +++ b/src/structarray.jl @@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T end # broadcast -import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle +import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end @@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N} return StructArrayStyle{T, N}() end +# StructArrayStyle is a wrapped style. +# Here we try our best to resolve style conflict. +function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M} + N′ = M === Any || N === Any ? Any : max(M, N) + S′ = Broadcast.result_style(S(), b) + return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}() +end +BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown() + @inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(BroadcastStyle(A), args...) @inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} = combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...) +combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle combine_style_types(s::BroadcastStyle) = s Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...) BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}() -function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType} - ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType} - return similar(ContainerType, axes(bc)) +# Here we use `similar` defined for `S` to build the dest Array. +function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType} + bc′ = convert(Broadcasted{S}, bc) + return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType) +end + +# Unwrapper to recover the behaviour defined by parent style. +@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N} + return copyto!(dest, convert(Broadcasted{S}, bc)) +end + +@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S} + return Broadcast.materialize!(S(), dest, bc) end # for aliasing analysis during broadcast diff --git a/test/runtests.jl b/test/runtests.jl index 4693ca1b..9ceef216 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1100,17 +1100,26 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end -struct MyArray{T,N} <: AbstractArray{T,N} - A::Array{T,N} +for S in (1, 2, 3) + MyArray = Symbol(:MyArray, S) + @eval begin + struct $MyArray{T,N} <: AbstractArray{T,N} + A::Array{T,N} + end + $MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz)) + Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear() + Base.getindex(A::$MyArray, i::Int) = A.A[i] + Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val + Base.size(A::$MyArray) = Base.size(A.A) + Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}() + end end -MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz)) -Base.IndexStyle(::Type{<:MyArray}) = IndexLinear() -Base.getindex(A::MyArray, i::Int) = A.A[i] -Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val -Base.size(A::MyArray) = Base.size(A.A) -Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}() -Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType = - MyArray{ElType}(undef, size(bc)) +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType = + MyArray1{ElType}(undef, size(bc)) +Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType = + MyArray2{ElType}(undef, size(bc)) +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}() +Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S @testset "broadcast" begin s = StructArray{ComplexF64}((rand(2,2), rand(2,2))) @@ -1128,19 +1137,34 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2)))) - @test_throws MethodError s .+ s + # Make sure we can handle style with similar defined + # And we can handle most conflict + # s1 and s2 has similar defined, but s3 not + # s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle) + s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) + s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) + s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + s4 = StructArray{ComplexF64}((rand(2), rand(2))) + + function _test_similar(a, b, c) + try + d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im)) + @test typeof(a .+ b .- c) == typeof(d) + catch + @test_throws MethodError a .+ b .- c + end + end + for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4) + _test_similar(s, s′, s″) + end # test for dimensionality track + s = s1 @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} @test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}} - - a = StructArray([1;2+im]) - b = StructArray([1;;2+im]) - @test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b) - @test a .+ Any[1] isa StructArray + @test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} # issue #185 A = StructArray(randn(ComplexF64, 3, 3)) @@ -1155,6 +1179,23 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El @test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)] @test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3] + + @testset "ambiguity check" begin + function _test(a, b, c) + if a isa StructArray || b isa StructArray || c isa StructArray + d = @inferred a .+ b .- c + @test d == collect(a) .+ collect(b) .- collect(c) + @test d isa StructArray + end + end + testset = Any[StructArray([1;2+im]), + 1:2, + (1,2), + ] + for aa in testset, bb in testset, cc in testset + _test(aa, bb, cc) + end + end end @testset "map" begin From a78cab2a41e34a75718a8ce6ba8a438d0f423348 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Tue, 26 Jul 2022 22:12:01 +0800 Subject: [PATCH 2/9] Convert to `StaticArrayStyle` We first call broadcast from `StaticArrays` then split the output. This should has no extra runtime overhead. But some type info might missing because the eltype change. I think there's no better ways as we don't want to depend on the full `StaticArrays`. We don't overloading `Size` and `similar_type` at present. as they are only used for `broadcast`. With this, we can move much less code to `StaticArraysCore`. The only downside is that SizedArray would be allocated twice. That's not idea, but we can't do any better if we don't depend on StaticArray or copy a lot of code from there. --- src/interface.jl | 4 ++- src/staticarrays_support.jl | 49 ++++++++++++++++++++++++++++++++++++- test/runtests.jl | 14 ++++++++++- 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 461e2d49..d82aa4f6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -49,4 +49,6 @@ function createinstance(::Type{T}, args...) where {T} isconcretetype(T) ? bypass_constructor(T, args) : T(args...) end -createinstance(::Type{T}, args...) where {T<:Tup} = T(args) \ No newline at end of file +createinstance(::Type{T}, args...) where {T<:Tup} = T(args) + +createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 3fa9af98..a796b578 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -1,4 +1,4 @@ -import StaticArraysCore: StaticArray, FieldArray, tuple_prod +using StaticArraysCore: StaticArray, FieldArray, tuple_prod """ StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T} @@ -27,3 +27,50 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i) end StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i) StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...) + +# Broadcast overload +using StaticArraysCore: StaticArrayStyle, similar_type +StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N} +function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + bc′ = Broadcast.instantiate(replace_structarray(bc)) + return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) +end +# This looks costy, but compiler should be able to optimize them away +Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) + +to_staticstyle(@nospecialize(x::Type)) = x +to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} +function replace_structarray(bc::Broadcasted{Style}) where {Style} + args = replace_structarray_args(bc.args) + return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) +end +function replace_structarray(A::StructArray) + f = createinstance(eltype(A)) + args = Tuple(components(A)) + return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) +end +replace_structarray(@nospecialize(A)) = A + +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...) +replace_structarray_args(::Tuple{}) = () + +# StaticArrayStyle has no similar defined. +# Overload `Base.copy` instead. +@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M} + sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc)) + ET = eltype(sa) + isnonemptystructtype(ET) || return sa + elements = Tuple(sa) + arrs = ntuple(Val(fieldcount(ET))) do i + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) + end + return StructArray{ET}(arrs) +end + +@inline function _getfields(x::Tuple, i::Int) + if @generated + return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...) + else + return map(Base.Fix2(getfield, i), x) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9ceef216..33031d46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1190,12 +1190,24 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS end testset = Any[StructArray([1;2+im]), 1:2, - (1,2), + (1,2), + StructArray(@SArray [1 1+2im]), + (@SArray [1 2]) ] for aa in testset, bb in testset, cc in testset _test(aa, bb, cc) end end + + @testset "StructStaticArray" begin + bclog(s) = log.(s) + test_allocated(f, s) = @test (@allocated f(s)) == 0 + a = @SMatrix [float(i) for i in 1:10, j in 1:10] + b = @SMatrix [0. for i in 1:10, j in 1:10] + s = StructArray{ComplexF64}((a , b)) + @test (@inferred bclog(s)) isa typeof(s) + test_allocated(bclog, s) + end end @testset "map" begin From a6fe8a50159330794fc40936b1e4c416374d8ed6 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 27 Jul 2022 11:06:27 +0800 Subject: [PATCH 3/9] Add GPU broadcast support. --- Project.toml | 9 ++++++--- src/StructArrays.jl | 8 ++++++++ test/runtests.jl | 11 +++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index d8049eef..3e2021e7 100644 --- a/Project.toml +++ b/Project.toml @@ -5,19 +5,22 @@ version = "0.6.13" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Adapt = "1, 2, 3" DataAPI = "1" -StaticArraysCore = "1.1" -StaticArrays = "1.5.4" +StaticArraysCore = "1.3" +StaticArrays = "1.5.6" +GPUArraysCore = "~0.1.2" Tables = "1" julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 6fed453e..f7431992 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -29,4 +29,12 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) +# for GPU broadcast +import GPUArraysCore: backend +function backend(::Type{T}) where {T<:StructArray} + backs = map(backend, fieldtypes(array_types(T))) + all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") + return backs[1] +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 33031d46..b9b2fd27 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt +using JLArrays using Test using Documenter: doctest @@ -1208,6 +1209,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test (@inferred bclog(s)) isa typeof(s) test_allocated(bclog, s) end + + @testset "StructJLArray" begin + bcabs(a) = abs.(a) + bcmul2(a) = 2 .* a + a = StructArray(randn(ComplexF32, 10, 10)) + sa = jl(a) + @test collect(@inferred(bcabs(sa))) == bcabs(a) + @test @inferred(bcmul2(sa)) isa StructArray + @test (sa .+= 1) isa StructArray + end end @testset "map" begin From c165c61e4046282408f79df7c317a361e87d17fb Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Wed, 24 Aug 2022 20:01:09 +0800 Subject: [PATCH 4/9] 1.6 fix + cov fix --- src/staticarrays_support.jl | 12 ++++++++++-- test/runtests.jl | 4 ++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index a796b578..09a93a5d 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -61,8 +61,16 @@ replace_structarray_args(::Tuple{}) = () ET = eltype(sa) isnonemptystructtype(ET) || return sa elements = Tuple(sa) - arrs = ntuple(Val(fieldcount(ET))) do i - similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) + @static if VERSION >= v"1.7" + arrs = ntuple(Val(fieldcount(ET))) do i + similar_type(sa, fieldtype(ET, i))(_getfields(elements, i)) + end + else + _fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i) + __fieldtype = _fieldtype(ET) + arrs = ntuple(Val(fieldcount(ET))) do i + similar_type(sa, __fieldtype(i))(_getfields(elements, i)) + end end return StructArray{ET}(arrs) end diff --git a/test/runtests.jl b/test/runtests.jl index b9b2fd27..4be6fe4e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1208,6 +1208,10 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS s = StructArray{ComplexF64}((a , b)) @test (@inferred bclog(s)) isa typeof(s) test_allocated(bclog, s) + @test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix + bc = Base.broadcasted(+, s, s); + bc = Base.broadcasted(+, bc, bc, s); + @test @inferred(Broadcast.axes(bc)) === axes(s) end @testset "StructJLArray" begin From e711ebebcec916c7e42d25c1fc3087fcd515fb43 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sat, 15 Oct 2022 10:45:10 +0800 Subject: [PATCH 5/9] Adopt suggestions and add more internal doc/ comments. Co-Authored-By: Pietro Vertechi <6333339+piever@users.noreply.github.com> --- src/StructArrays.jl | 6 +++--- src/interface.jl | 6 +++++- src/staticarrays_support.jl | 15 ++++++++++++--- test/runtests.jl | 19 ++++++++++++++++--- 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/src/StructArrays.jl b/src/StructArrays.jl index f7431992..5d256c0d 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -30,9 +30,9 @@ import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) # for GPU broadcast -import GPUArraysCore: backend -function backend(::Type{T}) where {T<:StructArray} - backs = map(backend, fieldtypes(array_types(T))) +import GPUArraysCore +function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} + backs = map(GPUArraysCore.backend, fieldtypes(array_types(T))) all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") return backs[1] end diff --git a/src/interface.jl b/src/interface.jl index d82aa4f6..010fc2f9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -51,4 +51,8 @@ end createinstance(::Type{T}, args...) where {T<:Tup} = T(args) -createinstance(::Type{T}) where {T} = (x...) -> createinstance(T, x...) +struct Instantiator{T} end + +Instantiator(::Type{T}) where {T} = Instantiator{T}() + +(::Instantiator{T})(args...) where {T} = createinstance(T, args...) diff --git a/src/staticarrays_support.jl b/src/staticarrays_support.jl index 09a93a5d..1f898f82 100644 --- a/src/staticarrays_support.jl +++ b/src/staticarrays_support.jl @@ -35,23 +35,32 @@ function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where bc′ = Broadcast.instantiate(replace_structarray(bc)) return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′) end -# This looks costy, but compiler should be able to optimize them away +# This looks costly, but the compiler should be able to optimize them away Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc)) to_staticstyle(@nospecialize(x::Type)) = x to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N} + +""" + replace_structarray(bc::Broadcasted) + +An internal function transforms the `Broadcasted` with `StructArray` into +an equivalent one without it. This is not a must if the root `BroadcastStyle` +supports `AbstractArray`. But some `BroadcastStyle` limits the input array types, +e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`. +""" function replace_structarray(bc::Broadcasted{Style}) where {Style} args = replace_structarray_args(bc.args) return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing) end function replace_structarray(A::StructArray) - f = createinstance(eltype(A)) + f = Instantiator(eltype(A)) args = Tuple(components(A)) return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing) end replace_structarray(@nospecialize(A)) = A -replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(Base.tail(args))...) +replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...) replace_structarray_args(::Tuple{}) = () # StaticArrayStyle has no similar defined. diff --git a/test/runtests.jl b/test/runtests.jl index 4be6fe4e..6055acef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1101,6 +1101,19 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs) @test t.b.d isa Array end +# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s. +# 1. `MyArray1` and `MyArray1` have `similar` defined. +# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`. +# 2. `MyArray3` has no `similar` defined. +# We use it to simulate `BroadcastStyle` overloading `Base.copy`. +# 3. Their resolved style could be summaryized as (`-` means conflict) +# | MyArray1 | MyArray2 | MyArray3 | Array +# ------------------------------------------------------------- +# MyArray1 | MyArray1 | - | MyArray1 | MyArray1 +# MyArray2 | - | MyArray2 | - | MyArray2 +# MyArray3 | MyArray1 | - | MyArray3 | MyArray3 +# Array | MyArray1 | Array | MyArray3 | Array + for S in (1, 2, 3) MyArray = Symbol(:MyArray, S) @eval begin @@ -1139,9 +1152,9 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) # Make sure we can handle style with similar defined - # And we can handle most conflict - # s1 and s2 has similar defined, but s3 not - # s2 are conflict with s1 and s3. (And it's weaker than DefaultArrayStyle) + # And we can handle most conflicts + # `s1` and `s2` have similar defined, but `s3` does not + # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) From 8c832209a23d45b9fb4837f0bab27f24acd3c24c Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 6 Nov 2022 09:52:45 +0800 Subject: [PATCH 6/9] Resolve the review comments 1. Update Project.toml. 2. test `backend`'s inferability. Co-Authored-By: Pietro Vertechi <6333339+piever@users.noreply.github.com> --- Project.toml | 2 +- src/StructArrays.jl | 8 +++++--- test/runtests.jl | 2 ++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 3e2021e7..e6ee76f2 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ Adapt = "1, 2, 3" DataAPI = "1" StaticArraysCore = "1.3" StaticArrays = "1.5.6" -GPUArraysCore = "~0.1.2" +GPUArraysCore = "0.1.2" Tables = "1" julia = "1.6" diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 5d256c0d..27e234d5 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -32,9 +32,11 @@ Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x # for GPU broadcast import GPUArraysCore function GPUArraysCore.backend(::Type{T}) where {T<:StructArray} - backs = map(GPUArraysCore.backend, fieldtypes(array_types(T))) - all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") - return backs[1] + backends = map_params(GPUArraysCore.backend, array_types(T)) + backend, others = backends[1], tail(backends) + isconsistent = mapfoldl(isequal(backend), &, others; init=true) + isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend")) + return backend end end # module diff --git a/test/runtests.jl b/test/runtests.jl index 6055acef..1b119036 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1232,6 +1232,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS bcmul2(a) = 2 .* a a = StructArray(randn(ComplexF32, 10, 10)) sa = jl(a) + backend = StructArrays.GPUArraysCore.backend + @test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im) @test collect(@inferred(bcabs(sa))) == bcabs(a) @test @inferred(bcmul2(sa)) isa StructArray @test (sa .+= 1) isa StructArray From 10e6442cad09da46371f0ba99a20292657397cc0 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 6 Nov 2022 10:52:21 +0800 Subject: [PATCH 7/9] also test deep nested `StructArray`'s broadcast. --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 1b119036..9a77b0e2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1193,6 +1193,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)] @test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3] + @test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)] + @test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3] @testset "ambiguity check" begin function _test(a, b, c) From 960e1c778c2cd4408f0115b8ed4089072f43d518 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 7 Nov 2022 03:54:25 +0800 Subject: [PATCH 8/9] Skip redundant test only test each dest style once. --- Project.toml | 7 +++-- test/runtests.jl | 80 ++++++++++++++++++++++++++++-------------------- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index e6ee76f2..4b4f7dd7 100644 --- a/Project.toml +++ b/Project.toml @@ -12,9 +12,9 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Adapt = "1, 2, 3" DataAPI = "1" -StaticArraysCore = "1.3" -StaticArrays = "1.5.6" GPUArraysCore = "0.1.2" +StaticArrays = "1.5.6" +StaticArraysCore = "1.3" Tables = "1" julia = "1.6" @@ -23,10 +23,11 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "Random"] diff --git a/test/runtests.jl b/test/runtests.jl index 9a77b0e2..f6a389b8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt using JLArrays +using Random using Test using Documenter: doctest @@ -1151,29 +1152,39 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS # used inside of broadcast but we also test it here explicitly @test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N}) - # Make sure we can handle style with similar defined - # And we can handle most conflicts - # `s1` and `s2` have similar defined, but `s3` does not - # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` - s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) - s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) - s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) - s4 = StructArray{ComplexF64}((rand(2), rand(2))) - - function _test_similar(a, b, c) - try - d = StructArray{ComplexF64}((a.re .+ b.re .- c.re, a.im .+ b.im .- c.im)) - @test typeof(a .+ b .- c) == typeof(d) - catch - @test_throws MethodError a .+ b .- c + + @testset "style conflict check" begin + using StructArrays: StructArrayStyle + # Make sure we can handle style with similar defined + # And we can handle most conflicts + # `s1` and `s2` have similar defined, but `s3` does not + # `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle` + s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) + s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2)))) + s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2)))) + s4 = StructArray{ComplexF64}((rand(2), rand(2))) + test_set = Any[s1, s2, s3, s4] + tested_style = Any[] + dotaddadd((a, b, c),) = @. a + b + c + for is in Iterators.product(randperm(4), randperm(4), randperm(4)) + as = map(i -> test_set[i], is) + ares = map(a->a.re, as) + aims = map(a->a.im, as) + style = Broadcast.combine_styles(ares...) + if !(style in tested_style) + push!(tested_style, style) + if style isa Broadcast.ArrayStyle{MyArray3} + @test_throws MethodError dotaddadd(as) + else + d = StructArray{ComplexF64}((dotaddadd(ares), dotaddadd(aims))) + @test @inferred(dotaddadd(as))::typeof(d) == d + end + end end + @test length(tested_style) == 5 end - for s in (s1,s2,s3,s4), s′ in (s1,s2,s3,s4), s″ in (s1,s2,s3,s4) - _test_similar(s, s′, s″) - end - # test for dimensionality track - s = s1 + s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2)))) @test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}} @test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}} @@ -1197,22 +1208,25 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3] @testset "ambiguity check" begin - function _test(a, b, c) - if a isa StructArray || b isa StructArray || c isa StructArray - d = @inferred a .+ b .- c - @test d == collect(a) .+ collect(b) .- collect(c) - @test d isa StructArray - end - end - testset = Any[StructArray([1;2+im]), + test_set = Any[StructArray([1;2+im]), 1:2, (1,2), - StructArray(@SArray [1 1+2im]), - (@SArray [1 2]) - ] - for aa in testset, bb in testset, cc in testset - _test(aa, bb, cc) + StructArray(@SArray [1;1+2im]), + (@SArray [1 2]), + 1] + tested_style = StructArrayStyle[] + dotaddsub((a, b, c),) = @. a + b - c + for is in Iterators.product(randperm(6), randperm(6), randperm(6)) + as = map(i -> test_set[i], is) + if any(a -> a isa StructArray, as) + style = Broadcast.combine_styles(as...) + if !(style in tested_style) + push!(tested_style, style) + @test @inferred(dotaddsub(as))::StructArray == dotaddsub(map(collect, as)) + end + end end + @test length(tested_style) == 4 end @testset "StructStaticArray" begin From 14c7a84eec1875cbcae1da397044bda1cd7eba2e Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 7 Nov 2022 17:50:22 +0800 Subject: [PATCH 9/9] remove `randperm` --- Project.toml | 3 +-- test/runtests.jl | 14 +++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 4b4f7dd7..0d8880ac 100644 --- a/Project.toml +++ b/Project.toml @@ -23,11 +23,10 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "Random"] +test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] diff --git a/test/runtests.jl b/test/runtests.jl index f6a389b8..9ac56f6f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,6 @@ using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt using JLArrays -using Random using Test using Documenter: doctest @@ -1166,11 +1165,11 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS test_set = Any[s1, s2, s3, s4] tested_style = Any[] dotaddadd((a, b, c),) = @. a + b + c - for is in Iterators.product(randperm(4), randperm(4), randperm(4)) - as = map(i -> test_set[i], is) + for as in Iterators.product(test_set, test_set, test_set) ares = map(a->a.re, as) aims = map(a->a.im, as) style = Broadcast.combine_styles(ares...) + @test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}() if !(style in tested_style) push!(tested_style, style) if style isa Broadcast.ArrayStyle{MyArray3} @@ -1216,8 +1215,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS 1] tested_style = StructArrayStyle[] dotaddsub((a, b, c),) = @. a + b - c - for is in Iterators.product(randperm(6), randperm(6), randperm(6)) - as = map(i -> test_set[i], is) + for as in Iterators.product(test_set, test_set, test_set) if any(a -> a isa StructArray, as) style = Broadcast.combine_styles(as...) if !(style in tested_style) @@ -1229,6 +1227,12 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test length(tested_style) == 4 end + @testset "allocation test" begin + a = StructArray{ComplexF64}(undef, 1) + allocated(a) = @allocated a .+ 1 + @test allocated(a) == 2allocated(a.re) + end + @testset "StructStaticArray" begin bclog(s) = log.(s) test_allocated(f, s) = @test (@allocated f(s)) == 0