Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Reactant support for the models #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down Expand Up @@ -51,6 +52,7 @@ NNlib = "0.9.21"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.24.0"
Reactant = "0.2.5"
Reexport = "1.2.2"
StableRNGs = "1.0.2"
Test = "1.10"
Expand Down
25 changes: 24 additions & 1 deletion test/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,33 @@

model = Layers.MLP(2, (4, 4, 2), act; norm_layer=norm)
ps, st = Lux.setup(StableRNG(0), model) |> dev
st_test = Lux.testmode(st)

x = randn(Float32, 2, 2) |> aType

@jet model(x, ps, st)
@jet model(x, ps, st_test)

__f = (x, ps) -> sum(abs2, first(model(x, ps, st)))
@test_gradients(__f, x, ps; atol=1e-3, rtol=1e-3,
soft_fail=[AutoFiniteDiff()])

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st_test)
x_ra = rdev(x)

model_compiled = Reactant.compile(model, (x_ra, ps_ra, st_ra))
if nType === GroupNorm
@test first(model_compiled(x_ra, ps_ra, st_ra)) ≈
zeros(Float32, 2, 2)
else
@test first(model_compiled(x_ra, ps_ra, st_ra)) ≈
Array(first(model(x, ps, st_test)))
end
end
end
end
end
Expand Down Expand Up @@ -217,6 +236,10 @@ end

__f = x -> sum(first(layer(x, ps, st)))
@test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true)

# TODO: Reactant testing
# We need to solve https://github.com/EnzymeAD/Reactant.jl/issues/242 and
# https://github.com/EnzymeAD/Reactant.jl/issues/243 first
end
end

Expand Down
20 changes: 20 additions & 0 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,24 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
using AMDGPU
end

@static if !Sys.iswindows()
@reexport using Reactant
test_reactant(mode) = mode != "amdgpu"
function set_reactant_backend!(mode)
if mode == "cuda"
Reactant.set_default_backend("gpu")
elseif mode == "cpu"
Reactant.set_default_backend("cpu")
end
end
else
test_reactant(::Any) = true
set_reactant_backend!(::Any) = nothing
macro compile(expr)
return :()
end
end

cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
function cuda_testing()
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
Expand All @@ -38,5 +56,7 @@ const MODES = begin
end

export MODES, BACKEND_GROUP
export test_reactant, set_reactant_backend!
export @compile

end
165 changes: 164 additions & 1 deletion test/vision_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

using Lux, Downloads, JLD2

@static if !Sys.iswindows()
using Reactant
end

function normalize_imagenet(data)
cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1))
cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1))
Expand All @@ -23,7 +27,12 @@ function imagenet_acctest(model, ps, st, dev; size=224)
TEST_X = size == 224 ? MONARCH_224 :
(size == 256 ? MONARCH_256 : error("size must be 224 or 256"))
x = TEST_X |> dev
ypred = first(model(x, ps, st)) |> collect |> vec

if dev isa MLDataDevices.ReactantDevice
model = Reactant.compile(model, (x, ps, st))
end

ypred = first(model(x, ps, st)) |> cpu_device() |> collect |> vec
top5 = TEST_LBLS[sortperm(ypred; rev=true)]
return "monarch" in top5
end
Expand All @@ -48,6 +57,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -63,6 +89,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -77,6 +116,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -91,6 +143,19 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end

Expand All @@ -110,6 +175,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -134,6 +216,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -157,6 +256,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -177,6 +293,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -197,6 +330,23 @@ end
end

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3

if pretrained
@test imagenet_acctest(model, ps_ra, st_ra, rdev)
end
end
end
end
end
Expand All @@ -221,5 +371,18 @@ end
@test size(first(model(img, ps, st))) == (1000, 2)

GC.gc(true)

if test_reactant(mode)
set_reactant_backend!(mode)
rdev = reactant_device()

ps_ra = rdev(ps)
st_ra = rdev(st)
img_ra = rdev(img)

model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra))
@test first(model_compiled(
img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3
end
end
end
Loading