diff --git a/Project.toml b/Project.toml index 2310f14a..95ba30db 100644 --- a/Project.toml +++ b/Project.toml @@ -3,21 +3,22 @@ uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" version = "0.6.16" [deps] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +StructArraysAdaptExt = "Adapt" StructArraysGPUArraysCoreExt = "GPUArraysCore" StructArraysStaticArraysExt = "StaticArrays" [compat] -Adapt = "1, 2, 3" +Adapt = "2, 3" ConstructionBase = "1" DataAPI = "1" GPUArraysCore = "0.1.2" diff --git a/ext/StructArraysAdaptExt.jl b/ext/StructArraysAdaptExt.jl new file mode 100644 index 00000000..2607105d --- /dev/null +++ b/ext/StructArraysAdaptExt.jl @@ -0,0 +1,16 @@ +module StructArraysAdaptExt +# Use Adapt allows for automatic conversion of CPU to GPU StructArrays +using Adapt, StructArrays +@static if !applicable(Adapt.adapt, Int) + # Adapt.jl has curried support, implement it ourself + adpat(to) = Base.Fix1(Adapt.adapt, to) + if VERSION < v"1.9.0-DEV.857" + @eval function adapt(to::Type{T}) where {T} + (@isdefined T) || return Base.Fix1(Adapt.adapt, to) + AT = Base.Fix1{typeof(Adapt.adapt),Type{T}} + return $(Expr(:new, :AT, :(Adapt.adapt), :to)) + end + end +end +Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s) +end diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 8cfbdc48..4130d059 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -25,11 +25,8 @@ function refvalue(s::StructArray{T}, v::Tup) where {T} createinstance(T, map(refvalue, components(s), v)...) end -# Use Adapt allows for automatic conversion of CPU to GPU StructArrays -import Adapt -Adapt.adapt_structure(to, s::StructArray) = replace_storage(x->Adapt.adapt(to, x), s) - @static if !isdefined(Base, :get_extension) + include("../ext/StructArraysAdaptExt.jl") include("../ext/StructArraysGPUArraysCoreExt.jl") include("../ext/StructArraysStaticArraysExt.jl") end