Skip to content

Commit

Permalink
test: Reactant support for the models
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 8, 2024
1 parent dde207c commit ea76000
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 2 deletions.
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down Expand Up @@ -51,6 +52,7 @@ NNlib = "0.9.21"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
Reactant = "0.2.5"
Reexport = "1.2.2"
StableRNGs = "1.0.2"
Test = "1.10"
Expand Down
25 changes: 24 additions & 1 deletion test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,33 @@

model = Layers.MLP(2, (4, 4, 2), act; norm_layer=norm)
ps, st = Lux.setup(StableRNG(0), model) |> dev
st_test = Lux.testmode(st)

x = randn(Float32, 2, 2) |> aType

@jet model(x, ps, st)
@jet model(x, ps, st_test)

__f = (x, ps) -> sum(abs2, first(model(x, ps, st)))
@test_gradients(__f, x, ps; atol=1e-3, rtol=1e-3,
soft_fail=[AutoFiniteDiff()])

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st_test)
x_ra = rdev(x)

model_compiled = Reactant.compile(model, (x_ra, ps_ra, st_ra))
if nType === GroupNorm
@test first(model_compiled(x_ra, ps_ra, st_ra))
zeros(Float32, 2, 2)
else
@test first(model_compiled(x_ra, ps_ra, st_ra))
Array(first(model(x, ps, st_test)))
end
end
end
end
end
Expand Down Expand Up @@ -217,6 +236,10 @@ end

__f = x -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true)

# TODO: Reactant testing
# We need to solve https://github.com/EnzymeAD/Reactant.jl/issues/242 and
# https://github.com/EnzymeAD/Reactant.jl/issues/243 first
end
end

Expand Down
20 changes: 20 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
using AMDGPU
end

@static if !Sys.iswindows()
@reexport using Reactant
test_reactant(mode) = mode != "amdgpu"
function set_reactant_backend!(mode)
if mode == "cuda"
Reactant.set_default_backend("gpu")
elseif mode == "cpu"
Reactant.set_default_backend("cpu")
end
end
else
test_reactant(::Any) = true
set_reactant_backend!(::Any) = nothing
macro compile(expr)
return :()
end
end

cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
function cuda_testing()
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
Expand All @@ -38,5 +56,7 @@ const MODES = begin
end

export MODES, BACKEND_GROUP
export test_reactant, set_reactant_backend!
export @compile

end
165 changes: 164 additions & 1 deletion test/vision_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

using Lux, Downloads, JLD2

@static if !Sys.iswindows()
using Reactant
end

function normalize_imagenet(data)
cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1))
cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1))
Expand All @@ -23,7 +27,12 @@ function imagenet_acctest(model, ps, st, dev; size=224)
TEST_X = size == 224 ? MONARCH_224 :
(size == 256 ? MONARCH_256 : error("size must be 224 or 256"))
x = TEST_X |> dev
ypred = first(model(x, ps, st)) |> collect |> vec

if dev isa MLDataDevices.ReactantDevice
model = Reactant.compile(model, (x, ps, st))
end

ypred = first(model(x, ps, st)) |> cpu_device() |> collect |> vec
top5 = TEST_LBLS[sortperm(ypred; rev=true)]
return "monarch" in top5
end
Expand All @@ -48,6 +57,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -63,6 +89,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -77,6 +116,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -91,6 +143,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -110,6 +175,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -134,6 +216,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -157,6 +256,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -177,6 +293,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -197,6 +330,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -221,5 +371,18 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

0 comments on commit ea76000

Please sign in to comment.