Skip to content

Commit

Permalink
Generalize StructArray's broadcast. (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 authored Nov 30, 2022
1 parent 4056c71 commit 1afecf4
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 26 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@ version = "0.6.13"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
StaticArraysCore = "1.1"
StaticArrays = "1.5.4"
GPUArraysCore = "0.1.2"
StaticArrays = "1.5.6"
StaticArraysCore = "1.3"
Tables = "1"
julia = "1.6"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand All @@ -26,4 +29,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"]
10 changes: 10 additions & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,14 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end

end # module
8 changes: 7 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,10 @@ function createinstance(::Type{T}, args...) where {T}
isconcretetype(T) ? bypass_constructor(T, args) : T(args...)
end

createinstance(::Type{T}, args...) where {T<:Tup} = T(args)
createinstance(::Type{T}, args...) where {T<:Tup} = T(args)

struct Instantiator{T} end

Instantiator(::Type{T}) where {T} = Instantiator{T}()

(::Instantiator{T})(args...) where {T} = createinstance(T, args...)
66 changes: 65 additions & 1 deletion src/staticarrays_support.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import StaticArraysCore: StaticArray, FieldArray, tuple_prod
using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
Expand Down Expand Up @@ -27,3 +27,67 @@ StructArrays.component(s::StaticArray, i) = getindex(s, i)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArraysCore: StaticArrayStyle, similar_type
StructStaticArrayStyle{N} = StructArrayStyle{StaticArrayStyle{N}, N}
function Broadcast.instantiate(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
bc′ = Broadcast.instantiate(replace_structarray(bc))
return convert(Broadcasted{StructStaticArrayStyle{M}}, bc′)
end
# This looks costly, but the compiler should be able to optimize them away
Broadcast._axes(bc::Broadcasted{<:StructStaticArrayStyle}, ::Nothing) = axes(replace_structarray(bc))

to_staticstyle(@nospecialize(x::Type)) = x
to_staticstyle(::Type{StructStaticArrayStyle{N}}) where {N} = StaticArrayStyle{N}

"""
replace_structarray(bc::Broadcasted)
An internal function transforms the `Broadcasted` with `StructArray` into
an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
args = replace_structarray_args(bc.args)
return Broadcasted{to_staticstyle(Style)}(bc.f, args, nothing)
end
function replace_structarray(A::StructArray)
f = Instantiator(eltype(A))
args = Tuple(components(A))
return Broadcasted{StaticArrayStyle{ndims(A)}}(f, args, nothing)
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

# StaticArrayStyle has no similar defined.
# Overload `Base.copy` instead.
@inline function Base.copy(bc::Broadcasted{StructStaticArrayStyle{M}}) where {M}
sa = copy(convert(Broadcasted{StaticArrayStyle{M}}, bc))
ET = eltype(sa)
isnonemptystructtype(ET) || return sa
elements = Tuple(sa)
@static if VERSION >= v"1.7"
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, fieldtype(ET, i))(_getfields(elements, i))
end
else
_fieldtype(::Type{T}) where {T} = i -> fieldtype(T, i)
__fieldtype = _fieldtype(ET)
arrs = ntuple(Val(fieldcount(ET))) do i
similar_type(sa, __fieldtype(i))(_getfields(elements, i))
end
end
return StructArray{ET}(arrs)
end

@inline function _getfields(x::Tuple, i::Int)
if @generated
return Expr(:tuple, (:(getfield(x[$j], i)) for j in 1:fieldcount(x))...)
else
return map(Base.Fix2(getfield, i), x)
end
end
28 changes: 24 additions & 4 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ function Base.showarg(io::IO, s::StructArray{T}, toplevel) where T
end

# broadcast
import Base.Broadcast: BroadcastStyle, ArrayStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end

Expand All @@ -496,19 +496,39 @@ function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
return StructArrayStyle{T, N}()
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()

function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S<:Union{DefaultArrayStyle,StructArrayStyle}, N, ElType}
ContainerType = isnonemptystructtype(ElType) ? StructArray{ElType} : Array{ElType}
return similar(ContainerType, axes(bc))
# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
return copyto!(dest, convert(Broadcasted{S}, bc))
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
return Broadcast.materialize!(S(), dest, bc)
end

# for aliasing analysis during broadcast
Expand Down
137 changes: 120 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings
using TypedTables: Table
using DataAPI: refarray, refvalue
using Adapt: adapt, Adapt
using JLArrays
using Test

using Documenter: doctest
Expand Down Expand Up @@ -1100,17 +1101,39 @@ Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
@test t.b.d isa Array
end

struct MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
# The following code defines `MyArray1/2/3` with different `BroadcastStyle`s.
# 1. `MyArray1` and `MyArray1` have `similar` defined.
# We use them to simulate `BroadcastStyle` overloading `Base.copyto!`.
# 2. `MyArray3` has no `similar` defined.
# We use it to simulate `BroadcastStyle` overloading `Base.copy`.
# 3. Their resolved style could be summaryized as (`-` means conflict)
# | MyArray1 | MyArray2 | MyArray3 | Array
# -------------------------------------------------------------
# MyArray1 | MyArray1 | - | MyArray1 | MyArray1
# MyArray2 | - | MyArray2 | - | MyArray2
# MyArray3 | MyArray1 | - | MyArray3 | MyArray3
# Array | MyArray1 | Array | MyArray3 | Array

for S in (1, 2, 3)
MyArray = Symbol(:MyArray, S)
@eval begin
struct $MyArray{T,N} <: AbstractArray{T,N}
A::Array{T,N}
end
$MyArray{T}(::UndefInitializer, sz::Dims) where T = $MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:$MyArray}) = IndexLinear()
Base.getindex(A::$MyArray, i::Int) = A.A[i]
Base.setindex!(A::$MyArray, val, i::Int) = A.A[i] = val
Base.size(A::$MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:$MyArray}) = Broadcast.ArrayStyle{$MyArray}()
end
end
MyArray{T}(::UndefInitializer, sz::Dims) where T = MyArray(Array{T}(undef, sz))
Base.IndexStyle(::Type{<:MyArray}) = IndexLinear()
Base.getindex(A::MyArray, i::Int) = A.A[i]
Base.setindex!(A::MyArray, val, i::Int) = A.A[i] = val
Base.size(A::MyArray) = Base.size(A.A)
Base.BroadcastStyle(::Type{<:MyArray}) = Broadcast.ArrayStyle{MyArray}()
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{ElType}) where ElType =
MyArray{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray1}}, ::Type{ElType}) where ElType =
MyArray1{ElType}(undef, size(bc))
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray2}}, ::Type{ElType}) where ElType =
MyArray2{ElType}(undef, size(bc))
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray1}, ::Broadcast.ArrayStyle{MyArray3}) = Broadcast.ArrayStyle{MyArray1}()
Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayStyle) = S

@testset "broadcast" begin
s = StructArray{ComplexF64}((rand(2,2), rand(2,2)))
Expand All @@ -1128,19 +1151,44 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El
# used inside of broadcast but we also test it here explicitly
@test isa(@inferred(Base.dataids(s)), NTuple{N, UInt} where {N})

s = StructArray{ComplexF64}((MyArray(rand(2)), MyArray(rand(2))))
@test_throws MethodError s .+ s

@testset "style conflict check" begin
using StructArrays: StructArrayStyle
# Make sure we can handle style with similar defined
# And we can handle most conflicts
# `s1` and `s2` have similar defined, but `s3` does not
# `s2` conflicts with `s1` and `s3` and is weaker than `DefaultArrayStyle`
s1 = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
s2 = StructArray{ComplexF64}((MyArray2(rand(2)), MyArray2(rand(2))))
s3 = StructArray{ComplexF64}((MyArray3(rand(2)), MyArray3(rand(2))))
s4 = StructArray{ComplexF64}((rand(2), rand(2)))
test_set = Any[s1, s2, s3, s4]
tested_style = Any[]
dotaddadd((a, b, c),) = @. a + b + c
for as in Iterators.product(test_set, test_set, test_set)
ares = map(a->a.re, as)
aims = map(a->a.im, as)
style = Broadcast.combine_styles(ares...)
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
if !(style in tested_style)
push!(tested_style, style)
if style isa Broadcast.ArrayStyle{MyArray3}
@test_throws MethodError dotaddadd(as)
else
d = StructArray{ComplexF64}((dotaddadd(ares), dotaddadd(aims)))
@test @inferred(dotaddadd(as))::typeof(d) == d
end
end
end
@test length(tested_style) == 5
end
# test for dimensionality track
s = StructArray{ComplexF64}((MyArray1(rand(2)), MyArray1(rand(2))))
@test Base.broadcasted(+, s, s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, 1:2) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{1}}
@test Base.broadcasted(+, s, reshape(1:2,1,2)) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{2}}
@test Base.broadcasted(+, reshape(1:2,1,1,2), s) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{3}}

a = StructArray([1;2+im])
b = StructArray([1;;2+im])
@test a .+ b == a .+ collect(b) == collect(a) .+ b == collect(a) .+ collect(b)
@test a .+ Any[1] isa StructArray
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

# issue #185
A = StructArray(randn(ComplexF64, 3, 3))
Expand All @@ -1155,6 +1203,61 @@ Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyArray}}, ::Type{El

@test identity.(StructArray(x=StructArray(a=1:3)))::StructArray == [(x=(a=1,),), (x=(a=2,),), (x=(a=3,),)]
@test (x -> x.x.a).(StructArray(x=StructArray(a=1:3))) == [1, 2, 3]
@test identity.(StructArray(x=StructArray(x=StructArray(a=1:3))))::StructArray == [(x=(x=(a=1,),),), (x=(x=(a=2,),),), (x=(x=(a=3,),),)]
@test (x -> x.x.x.a).(StructArray(x=StructArray(x=StructArray(a=1:3)))) == [1, 2, 3]

@testset "ambiguity check" begin
test_set = Any[StructArray([1;2+im]),
1:2,
(1,2),
StructArray(@SArray [1;1+2im]),
(@SArray [1 2]),
1]
tested_style = StructArrayStyle[]
dotaddsub((a, b, c),) = @. a + b - c
for as in Iterators.product(test_set, test_set, test_set)
if any(a -> a isa StructArray, as)
style = Broadcast.combine_styles(as...)
if !(style in tested_style)
push!(tested_style, style)
@test @inferred(dotaddsub(as))::StructArray == dotaddsub(map(collect, as))
end
end
end
@test length(tested_style) == 4
end

@testset "allocation test" begin
a = StructArray{ComplexF64}(undef, 1)
allocated(a) = @allocated a .+ 1
@test allocated(a) == 2allocated(a.re)
end

@testset "StructStaticArray" begin
bclog(s) = log.(s)
test_allocated(f, s) = @test (@allocated f(s)) == 0
a = @SMatrix [float(i) for i in 1:10, j in 1:10]
b = @SMatrix [0. for i in 1:10, j in 1:10]
s = StructArray{ComplexF64}((a , b))
@test (@inferred bclog(s)) isa typeof(s)
test_allocated(bclog, s)
@test abs.(s) .+ ((1,) .+ (1,2,3,4,5,6,7,8,9,10)) isa SMatrix
bc = Base.broadcasted(+, s, s);
bc = Base.broadcasted(+, bc, bc, s);
@test @inferred(Broadcast.axes(bc)) === axes(s)
end

@testset "StructJLArray" begin
bcabs(a) = abs.(a)
bcmul2(a) = 2 .* a
a = StructArray(randn(ComplexF32, 10, 10))
sa = jl(a)
backend = StructArrays.GPUArraysCore.backend
@test @inferred(backend(sa)) === backend(sa.re) === backend(sa.im)
@test collect(@inferred(bcabs(sa))) == bcabs(a)
@test @inferred(bcmul2(sa)) isa StructArray
@test (sa .+= 1) isa StructArray
end
end

@testset "map" begin
Expand Down

0 comments on commit 1afecf4

Please sign in to comment.