From ee3d77e9e16c2a8975caffebd8e49bdeaad8bcf3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 3 Nov 2024 20:31:33 -0500 Subject: [PATCH] refactor: restrict Reactant to 0.2.4 --- Project.toml | 2 +- docs/Project.toml | 2 +- lib/MLDataDevices/Project.toml | 2 +- .../ext/MLDataDevicesReactantExt.jl | 29 +++++++------------ 4 files changed, 14 insertions(+), 21 deletions(-) diff --git a/Project.toml b/Project.toml index f584e52c1..361dece44 100644 --- a/Project.toml +++ b/Project.toml @@ -98,7 +98,7 @@ NNlib = "0.9.24" Optimisers = "0.3.3" Preferences = "1.4.3" Random = "1.10" -Reactant = "0.2.3" +Reactant = "0.2.4" Reexport = "1.2.2" ReverseDiff = "1.15" SIMDTypes = "0.1" diff --git a/docs/Project.toml b/docs/Project.toml index a3f7543ef..f7d98eb37 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -54,7 +54,7 @@ Optimisers = "0.3.3" Pkg = "1.10" Printf = "1.10" Random = "1.10" -Reactant = "0.2.1" +Reactant = "0.2.4" StableRNGs = "1" StaticArrays = "1" WeightInitializers = "1" diff --git a/lib/MLDataDevices/Project.toml b/lib/MLDataDevices/Project.toml index 640a309c7..3cc272fd3 100644 --- a/lib/MLDataDevices/Project.toml +++ b/lib/MLDataDevices/Project.toml @@ -57,7 +57,7 @@ MLUtils = "0.4.4" Metal = "1" Preferences = "1.4" Random = "1.10" -Reactant = "0.2" +Reactant = "0.2.4" RecursiveArrayTools = "3.8" ReverseDiff = "1.15" SparseArrays = "1.10" diff --git a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl index 2e53362c6..a21486bc9 100644 --- a/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl +++ b/lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl @@ -2,17 +2,8 @@ module MLDataDevicesReactantExt using Adapt: Adapt using MLDataDevices: MLDataDevices, Internal, ReactantDevice, CPUDevice, get_device_type -using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, TracedRNumber - -@static if isdefined(Reactant, :ConcreteRNumber) - const ConcreteRType = Union{ConcreteRArray, Reactant.ConcreteRNumber} - const ReactantType = Union{ - RArray, TracedRArray, TracedRNumber, Reactant.ConcreteRNumber - } -else - const ConcreteRType = ConcreteRArray - const ReactantType = Union{RArray, TracedRArray, TracedRNumber} -end +using Reactant: Reactant, XLA, RArray, ConcreteRArray, ConcreteRNumber, TracedRArray, + TracedRNumber MLDataDevices.loaded(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true MLDataDevices.functional(::Union{ReactantDevice, Type{<:ReactantDevice}}) = true @@ -23,7 +14,7 @@ function MLDataDevices.default_device_rng(::ReactantDevice) end # Query Device from Array -function Internal.get_device(x::ConcreteRType) +function Internal.get_device(x::Union{ConcreteRNumber, ConcreteRArray}) client = XLA.client(x.data) device = XLA.device(x.data) return ReactantDevice(client, device) @@ -33,13 +24,17 @@ function Internal.get_device(::Union{TracedRArray, TracedRNumber}) error("`get_device` isn't meant to be called inside `Reactant.@compile` context.") end -Internal.get_device_type(::ReactantType) = ReactantDevice +function Internal.get_device_type( + ::Union{TracedRArray, TracedRNumber, ConcreteRArray, ConcreteRNumber}) + return ReactantDevice +end # unsafe_free! Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothing # Device Transfer -function Adapt.adapt_storage(dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) +function Adapt.adapt_storage( + dev::ReactantDevice, x::AbstractArray{<:Reactant.ReactantPrimitive}) @warn "ReactantDevice got an array on device: $(get_device_type(x)). We will have to \ transfer this via CPU." maxlog=1 return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x)) @@ -47,10 +42,8 @@ end function Adapt.adapt_storage(dev::ReactantDevice, x::Array{<:Reactant.ReactantPrimitive}) client = dev.client === missing ? XLA.default_backend[] : dev.client - device = dev.device === missing ? - XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device - return ConcreteRArray{eltype(x), ndims(x)}( - XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x)) + device = dev.device === missing ? nothing : dev.device + return ConcreteRArray(x; client, device) end end