From 8e1580220721a3a54d2ac128defa8ce05e62a7e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 14 Dec 2024 11:58:31 +0530 Subject: [PATCH] fix: use functors for testing wrapped arrays --- lib/MLDataDevices/Project.toml | 7 +++++-- lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl | 8 ++++++++ lib/MLDataDevices/src/public.jl | 6 ++++-- 3 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index f9893771c..2bc461363 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -1,7 +1,7 @@ name = "MLDataDevices" uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" authors = ["Avik Pal and contributors"] -version = "1.6.4" +version = "1.6.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -15,6 +15,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -32,8 +33,9 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] MLDataDevicesAMDGPUExt = "AMDGPU" MLDataDevicesCUDAExt = "CUDA" -MLDataDevicesChainRulesExt = "ChainRules" MLDataDevicesChainRulesCoreExt = "ChainRulesCore" +MLDataDevicesChainRulesExt = "ChainRules" +MLDataDevicesComponentArraysExt = "ComponentArrays" MLDataDevicesFillArraysExt = "FillArrays" MLDataDevicesGPUArraysExt = "GPUArrays" MLDataDevicesMLUtilsExt = "MLUtils" @@ -55,6 +57,7 @@ CUDA = "5.2" ChainRules = "1.51" ChainRulesCore = "1.23" Compat = "4.16" +ComponentArrays = "0.15.18" FillArrays = "1" Functors = "0.5" GPUArrays = "10, 11" diff --git a/lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl b/lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl new file mode 100644 index 000000000..4b34749b0 --- /dev/null +++ b/lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl @@ -0,0 +1,8 @@ +module MLDataDevicesComponentArraysExt + +using ComponentArrays: ComponentArrays +using MLDataDevices: MLDataDevices + +MLDataDevices.isleaf(::ComponentArrays.ComponentArray) = true + +end diff --git a/lib/MLDataDevices/src/public.jl b/lib/MLDataDevices/src/public.jl index 068b8abf9..583b7dccb 100644 --- a/lib/MLDataDevices/src/public.jl +++ b/lib/MLDataDevices/src/public.jl @@ -399,5 +399,7 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct """ isleaf(x) = Functors.isleaf(x) -isleaf(::AbstractArray{T}) where {T} = isbitstype(T) || T <: Number # BigFloat and such are not bitstype -isleaf(::Adapt.WrappedArray) = false +function isleaf(x::AbstractArray{T}) where {T} + parent(x) !== x && return Functors.isleaf(x) + return isbitstype(T) || T <: Number # BigFloat and such are not bitstype +end