Skip to content

Commit

Permalink
fix: use functors for testing wrapped arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 14, 2024
1 parent fdb0170 commit 8e15802
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
7 changes: 5 additions & 2 deletions lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.6.4"
version = "1.6.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -15,6 +15,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -32,8 +33,9 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
MLDataDevicesCUDAExt = "CUDA"
MLDataDevicesChainRulesExt = "ChainRules"
MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
MLDataDevicesChainRulesExt = "ChainRules"
MLDataDevicesComponentArraysExt = "ComponentArrays"
MLDataDevicesFillArraysExt = "FillArrays"
MLDataDevicesGPUArraysExt = "GPUArrays"
MLDataDevicesMLUtilsExt = "MLUtils"
Expand All @@ -55,6 +57,7 @@ CUDA = "5.2"
ChainRules = "1.51"
ChainRulesCore = "1.23"
Compat = "4.16"
ComponentArrays = "0.15.18"
FillArrays = "1"
Functors = "0.5"
GPUArrays = "10, 11"
Expand Down
8 changes: 8 additions & 0 deletions lib/MLDataDevices/ext/MLDataDevicesComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module MLDataDevicesComponentArraysExt

using ComponentArrays: ComponentArrays
using MLDataDevices: MLDataDevices

MLDataDevices.isleaf(::ComponentArrays.ComponentArray) = true

end
6 changes: 4 additions & 2 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,7 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
"""
isleaf(x) = Functors.isleaf(x)

isleaf(::AbstractArray{T}) where {T} = isbitstype(T) || T <: Number # BigFloat and such are not bitstype
isleaf(::Adapt.WrappedArray) = false
function isleaf(x::AbstractArray{T}) where {T}
parent(x) !== x && return Functors.isleaf(x)
return isbitstype(T) || T <: Number # BigFloat and such are not bitstype
end

0 comments on commit 8e15802

Please sign in to comment.