Skip to content

Commit

Permalink
Use weakdeps on Julia v1.9
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Apr 4, 2023
1 parent f23f3a8 commit 9eeb22b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 13 deletions.
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[extensions]
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysCoreExt = "StaticArraysCore"
StructArraysTablesExt = "Tables"

[compat]
Adapt = "1, 2, 3"
DataAPI = "1"
Expand Down
19 changes: 19 additions & 0 deletions ext/StructArraysGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module StructArraysGPUArraysCoreExt


using StructArrays
import GPUArraysCore

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true


end # module
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
module StructArraysStaticArraysCoreExt


using StructArrays
using StructArrays: StructArrayStyle

using Base.Broadcast: Broadcasted

using StaticArraysCore: StaticArray, FieldArray, tuple_prod

"""
Expand Down Expand Up @@ -66,3 +74,6 @@ end
return map(Base.Fix2(getfield, i), x)
end
end


end # module
7 changes: 7 additions & 0 deletions src/tables.jl → ext/StructArraysTablesExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
module StructArraysTablesExt


using StructArrays
import Tables

Tables.isrowtable(::Type{<:StructArray}) = true
Expand Down Expand Up @@ -38,3 +42,6 @@ for (f, g) in zip((:append!, :prepend!), (:push!, :pushfirst!))
end
end
end


end # module
19 changes: 6 additions & 13 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@ include("utils.jl")
include("collect.jl")
include("sort.jl")
include("lazy.jl")
include("tables.jl")
include("staticarrays_support.jl")

@static if !isdefined(Base, :get_extension)
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysTablesExt.jl")
include("../ext/StructArraysStaticArraysCoreExt.jl")
end

# Implement refarray and refvalue to deal with pooled arrays and weakrefstrings effectively
import DataAPI: refarray, refvalue
Expand All @@ -29,15 +33,4 @@ end
import Adapt
Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s)

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module

0 comments on commit 9eeb22b

Please sign in to comment.