diff --git a/Project.toml b/Project.toml index d1b4882d..8f1e07bd 100644 --- a/Project.toml +++ b/Project.toml @@ -5,18 +5,21 @@ version = "0.6.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] Adapt = "1, 2, 3" DataAPI = "1" -StaticArrays = "1" +GPUArraysCore = "= 0.1.2" +StaticArrays = ">= 1.4.2" Tables = "1" -julia = "1.3" +julia = "1.6" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -24,4 +27,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] +test = ["Test", "JLArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter"] diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 6fed453e..f7431992 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -29,4 +29,12 @@ end import Adapt Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) +# for GPU broadcast +import GPUArraysCore: backend +function backend(::Type{T}) where {T<:StructArray} + backs = map(backend, fieldtypes(array_types(T))) + all(Base.Fix2(===, backs[1]), tail(backs)) || error("backend mismatch!") + return backs[1] +end + end # module diff --git a/test/runtests.jl b/test/runtests.jl index ccbba8a4..a33f72d3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import Tables, PooledArrays, WeakRefStrings using TypedTables: Table using DataAPI: refarray, refvalue using Adapt: adapt, Adapt +using JLArrays using Test using Documenter: doctest @@ -1178,6 +1179,16 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS @test (@inferred bclog(s)) isa typeof(s) test_allocated(bclog, s) end + + @testset "StructJLArray" begin + bcabs(a) = abs.(a) + bcmul2(a) = 2 .* a + a = StructArray(randn(ComplexF32, 10, 10)) + sa = jl(a) + @test collect(@inferred(bcabs(sa))) == bcabs(a) + @test @inferred(bcmul2(sa)) isa StructArray + @test (sa .+= 1) isa StructArray + end end @testset "map" begin