diff --git a/test/Project.toml b/test/Project.toml index a05d3bd..a5a5fb5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,6 +22,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -51,6 +52,7 @@ NNlib = "0.9.21" Pkg = "1.10" Random = "1.10" ReTestItems = "1.24.0" +Reactant = "0.2.5" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 3477ba0..52c1c05 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -16,14 +16,33 @@ model = Layers.MLP(2, (4, 4, 2), act; norm_layer=norm) ps, st = Lux.setup(StableRNG(0), model) |> dev + st_test = Lux.testmode(st) x = randn(Float32, 2, 2) |> aType - @jet model(x, ps, st) + @jet model(x, ps, st_test) __f = (x, ps) -> sum(abs2, first(model(x, ps, st))) @test_gradients(__f, x, ps; atol=1e-3, rtol=1e-3, soft_fail=[AutoFiniteDiff()]) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st_test) + x_ra = rdev(x) + + model_compiled = Reactant.compile(model, (x_ra, ps_ra, st_ra)) + if nType === GroupNorm + @test first(model_compiled(x_ra, ps_ra, st_ra)) ≈ + zeros(Float32, 2, 2) + else + @test first(model_compiled(x_ra, ps_ra, st_ra)) ≈ + Array(first(model(x, ps, st_test))) + end + end end end end @@ -217,6 +236,10 @@ end __f = x -> sum(first(layer(x, ps, st))) @test_gradients(__f, x; atol=1.0f-3, rtol=1.0f-3, enzyme_set_runtime_activity=true) + + # TODO: Reactant testing + # We need to solve https://github.com/EnzymeAD/Reactant.jl/issues/242 and + # https://github.com/EnzymeAD/Reactant.jl/issues/243 first end end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index dc8e429..b5bb538 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -19,6 +19,24 @@ if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu" using AMDGPU end +@static if !Sys.iswindows() + @reexport using Reactant + test_reactant(mode) = mode != "amdgpu" + function set_reactant_backend!(mode) + if mode == "cuda" + Reactant.set_default_backend("gpu") + elseif mode == "cpu" + Reactant.set_default_backend("cpu") + end + end +else + test_reactant(::Any) = true + set_reactant_backend!(::Any) = nothing + macro compile(expr) + return :() + end +end + cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu" function cuda_testing() return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && @@ -38,5 +56,7 @@ const MODES = begin end export MODES, BACKEND_GROUP +export test_reactant, set_reactant_backend! +export @compile end diff --git a/test/vision_tests.jl b/test/vision_tests.jl index 449e90d..12903b1 100644 --- a/test/vision_tests.jl +++ b/test/vision_tests.jl @@ -2,6 +2,10 @@ using Lux, Downloads, JLD2 +@static if !Sys.iswindows() + using Reactant +end + function normalize_imagenet(data) cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1)) cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1)) @@ -23,7 +27,12 @@ function imagenet_acctest(model, ps, st, dev; size=224) TEST_X = size == 224 ? MONARCH_224 : (size == 256 ? MONARCH_256 : error("size must be 224 or 256")) x = TEST_X |> dev - ypred = first(model(x, ps, st)) |> collect |> vec + + if dev isa MLDataDevices.ReactantDevice + model = Reactant.compile(model, (x, ps, st)) + end + + ypred = first(model(x, ps, st)) |> cpu_device() |> collect |> vec top5 = TEST_LBLS[sortperm(ypred; rev=true)] return "monarch" in top5 end @@ -48,6 +57,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -63,6 +89,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + end end end @@ -77,6 +116,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + end end end @@ -91,6 +143,19 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + end end end @@ -110,6 +175,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -134,6 +216,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -157,6 +256,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -177,6 +293,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -197,6 +330,23 @@ end end GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + + if pretrained + @test imagenet_acctest(model, ps_ra, st_ra, rdev) + end + end end end end @@ -221,5 +371,18 @@ end @test size(first(model(img, ps, st))) == (1000, 2) GC.gc(true) + + if test_reactant(mode) + set_reactant_backend!(mode) + rdev = reactant_device() + + ps_ra = rdev(ps) + st_ra = rdev(st) + img_ra = rdev(img) + + model_compiled = Reactant.compile(model, (img_ra, ps_ra, st_ra)) + @test first(model_compiled( + img_ra, ps_ra, st_ra))≈Array(first(model(img, ps, st))) atol=1e-3 rtol=1e-3 + end end end