Skip to content

Commit

Permalink
feat: handle RNGs in layers correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 1, 2025
1 parent ce9f77f commit a06c26e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ LinearAlgebra = "1.10"
LossFunctions = "0.11.1, 1"
LuxCore = "1.2"
LuxLib = "1.3.7"
MLDataDevices = "1.6"
MLDataDevices = "1.6.6"
MLUtils = "0.4.4"
MPI = "0.20.19"
MacroTools = "0.5.13"
Expand All @@ -110,7 +110,7 @@ NNlib = "0.9.26"
Optimisers = "0.4.1"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.12"
Reactant = "0.2.13"
Reexport = "1.2.2"
ReverseDiff = "1.15"
SIMDTypes = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Metal = "1"
OneHotArrays = "0.2.5"
Preferences = "1.4"
Random = "1.10"
Reactant = "0.2.12"
Reactant = "0.2.13"
RecursiveArrayTools = "3.8"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Expand Down
3 changes: 2 additions & 1 deletion lib/MLDataDevices/test/xla_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ using FillArrays, Zygote # Extensions

device = reactant_device()
aType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRArray : Array
rngType = Random.AbstractRNG
rngType = MLDataDevices.functional(ReactantDevice) ? Reactant.ConcreteRNG :
Random.AbstractRNG

ps_xpu = ps |> device
@test get_device(ps_xpu) isa ReactantDevice
Expand Down
34 changes: 34 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,37 @@ end
end
end
end

@testitem "Dropout Layers" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux, Random

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

dev = reactant_device(; force=true)

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for layer in (AlphaDropout, Dropout, VariationalHiddenDropout)
model = layer(0.5f0)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
x = randn(Float32, 10, 10) |> dev

@test st.rng isa Reactant.ConcreteRNG

hlo = @code_hlo model(x, ps, st)
@test contains(repr(hlo), "stablehlo.rng_bit_generator")

y, st2 = @jit model(x, ps, st)
@test st2.rng isa Reactant.ConcreteRNG
@test st.rng.seed != st2.rng.seed
end
end
end

0 comments on commit a06c26e

Please sign in to comment.