Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: support passing in device and client to XLA #94

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.2"
version = "1.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -57,7 +57,7 @@ MLUtils = "0.4.4"
Metal = "1"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2"
Reactant = "0.2.3"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
31 changes: 26 additions & 5 deletions ext/MLDataDevicesReactantExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module MLDataDevicesReactantExt

using Adapt: Adapt
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice
using Reactant: Reactant, RArray
using MLDataDevices: MLDataDevices, Internal, XLADevice, CPUDevice, get_device_type
using Reactant: Reactant, XLA, RArray, ConcreteRArray, TracedRArray, ConcreteRNumber,
TracedRNumber

MLDataDevices.loaded(::Union{XLADevice, Type{<:XLADevice}}) = true
MLDataDevices.functional(::Union{XLADevice, Type{<:XLADevice}}) = true
Expand All @@ -13,14 +14,34 @@ function MLDataDevices.default_device_rng(::XLADevice)
end

# Query Device from Array
Internal.get_device(::RArray) = XLADevice()
function Internal.get_device(x::Union{ConcreteRArray, ConcreteRNumber})
client = XLA.client(x.data)
device = XLA.device(x.data)
return XLADevice(client, device)
end

function Internal.get_device(::Union{TracedRArray, TracedRNumber})
error("`get_device` isn't meant to be called inside `Reactant.@compile` context.")
end

Internal.get_device_type(::RArray) = XLADevice
Internal.get_device_type(::Union{RArray, TracedRNumber, ConcreteRNumber}) = XLADevice

# unsafe_free!
Internal.unsafe_free_internal!(::Type{XLADevice}, x::AbstractArray) = nothing

# Device Transfer
Adapt.adapt_storage(::XLADevice, x::AbstractArray) = Reactant.to_rarray(x)
function Adapt.adapt_storage(dev::XLADevice, x::AbstractArray{<:Reactant.ReactantPrimitive})
@warn "XLADevice got an array on device: $(get_device_type(x)). We will have to \
transfer this via CPU." maxlog=1
return Adapt.adapt_storage(dev, Adapt.adapt_storage(CPUDevice(), x))
end

function Adapt.adapt_storage(dev::XLADevice, x::Array{<:Reactant.ReactantPrimitive})
client = dev.client === missing ? XLA.default_backend[] : dev.client
device = dev.device === missing ?
XLA.ClientGetDevice(client, XLA.default_device_idx[]) : dev.device
return ConcreteRArray{eltype(x), ndims(x)}(
XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing), size(x))
end

end
17 changes: 12 additions & 5 deletions src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ end
struct MetalDevice <: AbstractGPUDevice end
struct oneAPIDevice <: AbstractGPUDevice end

# TODO: Later we might want to add the client field here?
struct XLADevice <: AbstractAcceleratorDevice end
@kwdef struct XLADevice{C, D} <: AbstractAcceleratorDevice
client::C = missing
device::D = missing
end

# Fallback for when we don't know the device type
struct UnknownDevice <: AbstractDevice end
Expand Down Expand Up @@ -189,20 +191,25 @@ Return a `CPUDevice` object which can be used to transfer data to CPU.
cpu_device() = CPUDevice()

"""
xla_device(; force::Bool=false) -> Union{XLADevice, CPUDevice}
xla_device(;
force::Bool=false, client=missing, device=missing
) -> Union{XLADevice, CPUDevice}

Return a `XLADevice` object if functional. Otherwise, throw an error if `force` is `true`.
Falls back to `CPUDevice` if `force` is `false`.

`client` and `device` are used to specify the client and index of the XLA device. If not
specified, then the default client and index are used.

!!! danger

This is an experimental feature and might change without deprecations
"""
function xla_device(; force::Bool=false)
function xla_device(; force::Bool=false, client=missing, device=missing)
msg = "`XLADevice` is not loaded or not functional. Load `Reactant.jl` before calling \
this function. Defaulting to CPU."
if loaded(XLADevice)
functional(XLADevice) && return XLADevice()
functional(XLADevice) && return XLADevice(client, device)
msg = "`XLADevice` is loaded but not functional. Defaulting to CPU."
end
force && throw(Internal.DeviceSelectionException("XLA"))
Expand Down
Loading