Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update to newer versions #161

Merged
merged 4 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ steps:
queue: "juliagpu"
cuda: "*"
env:
GROUP: "CUDA"
BACKEND_GROUP: "CUDA"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 240
matrix:
Expand Down Expand Up @@ -54,7 +54,5 @@ steps:
timeout_in_minutes: 240

env:
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
SECRET_CODECOV_TOKEN: "fbSN+ZbScLIWr1FOpAu1Z8PYWFobqbLGFayOgZE1ebhE8LIH/PILGXUMcdm9gkXVSwgdETDD0s33k14lBkJ90O4dV9w6k79F/pEgzVHV8baMoXZG03BPMxztlcoRXrKtRtAp+MwoATc3Ldb9H5vqgAnVNn5rhn4Rp0Z6LOVRC43hbhKBBKYh/N4gqpIQlcW4dBXmELhlnMFnUILjwGRVgEt/zh8H+vmf0qiIulNIQ/rfGISROHqFML0QDL4icloiqX08J76ZP/gZCeg6rJ0gl3ok3IspNPz51rlbvijqsPNyIHWi29OrAtWX3qKHfrAOoGIrE1d5Oy4wx4XaN/YBhg==;U2FsdGVkX188gcRjkUNMEC2Z5fEFfhsYY4WJbhhINOuCUgqq9XNHVDbJhzFUFVQ+UiuPHFg7CW/gn+3IkSVyOA=="
SECRET_DOCUMENTER_KEY: "jzyAET5IdazYwPAEZAmYmnBALb2dC1GPizCDCdt8xpjIi4ce6QbGGJMKo00ZNzJ/A7ii4bhqysVPXniifFwIGl7x+GSCeavwcSr15pfxJSqPuQYLKxESzIo+SM+l2uJWUz8KYMJ1tSt/Z3Up3qQfLeQFtR+f43b9QrLfhgZGAAdxpwu5VHdI3Xm/gZo5d8xEJ1xs4gqVP0e2A5EFr/j/exaWJL9+AvgO+Gko8NaJGG5B89zP1W2NBlpjttbwzj2naBhDx8A43Qe4eXm+BZd9CIZImiEJnnqoGxLkAyLDksbA68getUHW5z3nGyhWTrg5yfRqq0uyZZGTIOFz6dJrRg==;U2FsdGVkX19QOxLLkdNoQf7Rid3mcSR/renIHQ+/X3o0WxTmU8KDDxzfKuWPeK1fxMon8y45HCJv3HlMuzyfvPWrOmUXccfHK272D8vHu1kk/qZZw8nPd7iYBU9+VAIxwfmI3Av2gC+8tUlOcuUTEVMtMbi/MiLHp+phLYcELKzzrxL8VdrLzna81M+8xVLu7zzNuyK0cUPWLxRHcZc/fewK5Nh7EQ2x8u1b6e5zR0/AcqjCzMayD1RiE7QhRVGdF5GJYnAxc1eoyCwIjXTRfFo0a0Q2h6DEz9FEat/ZCekIuWyVrUkGbpsRqXUTrSH0An7FRRqRlZ9lStRaQY4Z3XBkoIh94vQlXwwLUH20jC7yRTV73CeYmhfigQckHL0JsjjIENz04Ac346fCV6WNQtEak0m3pN/BucoiwRA8l+WU4AK1r84cwGSphKk4SnWRAqeZVuFHck7NkcmHDEkO4C7WTP400oui/5NDMtVZbtnZfLxVzQqijxXj7IflWqF1vKqGmW5aPFMVNeAqwNGu3xM4oIIeHRu0u+k2S5dp1wqRVlMxYXdPtcoFzE0CNsMQdWgsvPd2eet38YRc8ftXNjKzoUSRRCbjGbVr0iJXeNmPg3jfZoVdILHjCN/hcz4nY+61P11OlJAdfE/6HzEr4VoOS4CN+s/brjWycmAKZo2+1e4fSV1xBH7t1spOlESLvsBhZNtj9/zUKgWgMct5hnF4anQcPAeRpz/MBrkwX1gW3WOvCxaqVlRfgGSy6boPgRd3p/ZXN4Xnfeg9RFqKZn21d2gcrc3/1+PTUEkOIv+C9BGszo9IaUziW/Tz2mVP386kX86SF4fF4y3PofcUT2FLTm8Q9ZJBnslOsRP8bq3rIjDiQR3Iz3uGctkGZPs+GOtCR5OrhnnS6BXxkGwt/n9PJsnbXt0Z4tuXihC1B8KfP7mzDvZr3q9X/DGKyZ+oMHdDI+f2+lRwx42nJnsu+nZW9lyhdIwWla9F1rIoVz59HbUrmUhsVmFQYfjy7Nl18g8Wh5r9CkFL/vr6Zpy5lj1J/vhe1501X2FIkKOnLAM73GwtAa4GkbHyu5rNcij6YoozPrJWT4KRNFWGVAqNZ1atG8WwmziwIl2KfBn8jiuP/8o6rXQkmrAzBr6jVnto5FTWnIexEmnbELs20XDck8pO5WQxU1IR9YhKMbrDGbn0jWzVoRmCWpaJgV1AkWu09a++DxIec4+Zt+3SZLj/H57XsBchWHmkFz4NVTBeSans26VmdDd3LxprT8qeH6cioceakmu6yegsKQnJGLmSNyUkHqBqmsCcvyTUyaQUBTFkjLmDeZB3Ifu2kD7AFdx5n58wdJTMZxYviybOCgCV4qe95v5XfIqthp5mF/0F1Wt9ZcEreFSM2Paj5GrQ+M25cZ+kqOSlMet51Q+QBCfQyDF8jdu3j1hVniwpgMI1gqyb2alRfyNx52elTqRn9hPqpFptGH4uJXi8H72YPe4fYkFS7wwELeRIv+nKkNYNLPQAyQFvZ/qB/PRI1YoFBbpi0Vi6iE9xLRq7QVvhJde2EgNbvQk8uakwV630Tht2OuwVdJu/PIbXsQ5i+EuknIlPRdQdhbEIkpuBHFAzxBqA2K92gJ4bbcOjGtDHc0pt0RtvIVoyyJMkYVzr1yBeMWEmsL8qYJ5yzuAFGqpCTmJzXE0ETZLDDJtxwSKj5M2vG59wPNfo9DF+LgJLF+94VydYGNOHy9KuY2Oo3ejV7iFXUtsEV3Id9EkNGq8+t5KIAGk6lnDcM1TTOAc5W8fGGNhYzlqgWK1n3nwPJLykqY7VFHPZjF8Il/8E1IubnPCIyOCTJwKqQlBB5td/bt7YIDEFmkpl7OvUwyc2uYkFmrxGv81OtopsYZOJ+WnwSkqqZ3p2MyqNj3xp92p8itz5tM3tzjrkdfPXsx1QJGY+rkZhCsSf6DSG18AFqI4+Q8uWUwqO5/TJb2z/F2LT88+wJfGPtwGeR+98XgvwjsMWIA/TZfwTrTQsZX2YOIf0bg3yjlEbFM16xAFAA2oItBuvbC6d6NIit4Dukn2WamnOceoTyO6mdHYRh5SBOryr3AWnBJZsUPL3HsC+Xiibgixuwjjalj+HOrAzDlQc8L0Z77dZJhpST0x/gwCleSA3lOKs7MA8ASolCaPVL2pPJXkb97mBxZx8k1n6abhK1w3QVJuYvp7CyGhavsYEqcR+vYx/T0tN4MVOjfRhimqhNihz0VDfY97YS5XavZV07jycqoAlufmH5VSwNbiy8/NY6Q2djc46ISbqvKr6Pf0TZBuJti2gIpm02Btm4rMCawVPpEPieU3GI77nhQp6orq0Zjl5f4XfOKnfcxiqdgip4SVFTglHBTWTiRFnGTF0I3VX4V+RRmqJwwKPN8cxDsNd6wSpylhDUAMfxEvvb+0vAt1yGNUC52OB4bSOXOyZIAU8+08xl7mYGIVUnoWHaR6Y0aHdnywJUuzQ2q3dotfnI1j72MzlHsTK6Lro3YiolDNJpTqLtxmSzkWctw/PfijnoEXtmDnZKptZ7t0v7oTAkdE3kk0RrnFTnMAkyCOREcFcyxglROCoDHsZx3Q+MkWLG/tPMVpuMRhy9gJ1WZTpeExNgs5KgwtrS1HJg7KunWXguFH/zDgODTdKclgfvsVe/SCtlpbO6z5fZji0j1y9LRBVLyTN/LzeR40OBX3r0abk4SGyslAdZMgg2WJdSLVAJ+MtxxbnlKDXDNmu5YehpWdTvm/wIYwTKw+1A48plKburw8fBEofVy9Ubmc8E4z6hQRX2cwcNN9N/60aCwlpM7wVbYfBo4Hw9H/6EawjbRRN9UwmgsfmYUuqCTSi8fNi2dR36bqaoHHURgyqW7DiR7BYgVnOZ+B/2GM8uO2rYgSOhVJf+OK+2HsNly0MW5v3/ft6W7PEsab8IweYWmPLVvJNfHW5CDP6KotdDgm/DcD5owgQ12D95BGWawR5gQxpyjX9uIlxORPq6h0Z79j8gFFsYIfddIdxsJZS9r59FtZe2JL7nK3Dum07tXDGlBCUD4mwv+LNxOJLa8DM6YoEd4Nh8qosfQNJu505Vh/r+PgegnFvG9LRkwQnk8fgPTNKThB067s82YuVg0mv4O9q4Hlm13wTWuvlMr4k1ShBrNyy08YaFCu2hmZm7RizU1rU5MpieiwpQ6cGx+sHBTszB+c89045n2TM4VUedi+vjEq2KuhmIl9ID0EHvWDy5iwOQV5nJ7Rk/Alky2GRZ1CpnJPN29q5lrs6fhvfPquolJTyBTNgVjQ7f0z1zuUQFdhWQX5BFyq/bT46qc+X6dSOvlFenioPDe5MYpA3SZCi2lmVQqHnTOcRZp2HtYpoRUzNB6cT7P1wkRTvAR5PQyuAknkKB+T6HvRb2H8EBLfk+imlyu7mb2iAJNORpZ0Rz+no/5A3wd6qHfTq27h/CDZ91YqGamylrLBdFqyefFYfSbFa1BKikiavpZnYh19hHNl9v0Q4Tkb7ogy7biw8icOvMPWCrxd50zoY1EUe2maNKtkyytJeEOV+Yj6VaUa88M+7WSKaK8QbEB+fBKmwvGkknRKs8lYRoABChwqDZ7M+98pL057QbquxseKX6alzV+IOHlO7I8csRHnF6OpVQG6wdzJZhEwg/0n1K2qTExF9Z3WzUoiQ+NVC3gRQ9Au+x3fpkuLu85lOVjelX3JtdVt1T3623sqxmcEr/TCZT/+X1QyflAkJyw1EMq4sat7wjYK3ugyPPPjo/v2h4TuaoWc0X/+qJPV/o2Vu489loIQ+N59ABZOLldpbkPM7VJIOnnfg+/GMvaEor2YCYElDGXx4BdRmSfOyzFF2Wqz5iTxMbdCo3iZbPQqbFTdMOX7Hy3nT8vUOhCLo+Dkgb7B01nPnm7crmC+TOgi4iDLp4nCqx5OSiG5gd/m54gZHe6Cymwj+DbW303KcvpGBrk0xr2sGUkQiu4vxNz+uW09EyMNCp5cg2AUWG4w6ykTHXUeDDQG232C5K7/tTt8Z09Kp9v71PkwH7hmZUrjAERGvF29zucdTVRmzr++JMH82Sk0chAi5UFs/lbVcN/birI7OVl6okyO3+bKWMCuhje1huOgeZzSk5xKFrgJ1v69TyD1mOa5wYx6IskbWSrFW/sqrhURqpSlfdWVCZiaOHLb/UIgQ0s1xlHyZ0/YOBQFz1VCgKH217ALijV3FOr+q00761SKNFc/IhZLNtVhHhE7lutAjVqyme7RHKd4fjFOD6oREyDYXHULmDGPRTmlFSxwE4+n3N9AInajQLH82CGWO1nV3u7qSY5vSbuzQIxCr8OKQfW8AzTdNjUoEtU+ojprLZ4V4r1dr01eLqXSVJ12Yq9Sm/Ivu1SZkHJl9oIxNjWSbRRMoYIVH3yVv1HyGGajcmKuzIfORuBZm"
11 changes: 4 additions & 7 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ on:
branches:
- main
paths-ignore:
- 'docs/**'
- "docs/**"
push:
branches:
- main
paths-ignore:
- 'docs/**'
- "docs/**"
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
Expand All @@ -22,7 +22,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1'
- "1"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -41,10 +41,7 @@ jobs:
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: "CPU"
JULIA_NUM_THREADS: 12
RETESTITEMS_NWORKERS: 4
RETESTITEMS_NWORKER_THREADS: 2
BACKEND_GROUP: "CPU"
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down
2 changes: 0 additions & 2 deletions LocalPreferences.toml

This file was deleted.

20 changes: 13 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -40,23 +39,26 @@ ExplicitImports = "1.6.0"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
GPUArraysCore = "0.1.6"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LinearSolve = "2.21.2"
Lux = "0.5.56"
LuxCUDA = "0.3.2"
LuxCore = "0.1.14"
LuxTestUtils = "0.1.15"
LuxTestUtils = "1"
MLDataDevices = "1"
NLsolve = "4.5.1"
NNlib = "0.9.17"
NonlinearSolve = "3.10.0"
OrdinaryDiffEq = "6.74.1"
PrecompileTools = "1"
Pkg = "1.10"
Random = "1.10"
ReTestItems = "1.23.1"
SciMLBase = "2"
SciMLSensitivity = "7.43"
StableRNGs = "1.0.2"
Statistics = "1.10"
SteadyStateDiffEq = "2"
SteadyStateDiffEq = "2.3.2"
Test = "1.10"
Zygote = "0.6.69"
julia = "1.10"
Expand All @@ -67,16 +69,20 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "GPUArraysCore", "Hwloc", "InteractiveUtils", "LuxTestUtils", "MLDataDevices", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "Pkg", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"]
2 changes: 1 addition & 1 deletion ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ using DeepEquilibriumNetworks: DEQs
linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)
return SteadyStateAdjoint(; linsolve, linsolve_kwargs, autojacvec=ZygoteVJP())
end
@inline DEQs.__default_sensealg(prob::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
@inline DEQs.__default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())

end
3 changes: 1 addition & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ CRC.@non_differentiable __gaussian_like(::Any...)
@inline __tupleify(x) = @closure(u->(u, x))

# Jacobian Stabilization
## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
function __estimate_jacobian_trace(ad::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng)
function __estimate_jacobian_trace(::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng)
__f = @closure u -> model((u, x))
res = zero(eltype(x))
ϵ = cbrt(eps(typeof(res)))
Expand Down
57 changes: 28 additions & 29 deletions test/layers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ export loss_function, SOLVERS

end

@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin
@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] begin
using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote

rng = __get_prng(0)
rng = StableRNG(0)

base_models = [Parallel(+, __get_dense_layer(2 => 2), __get_dense_layer(2 => 2)),
Parallel(+, __get_conv_layer((1, 1), 1 => 1), __get_conv_layer((1, 1), 1 => 1))]
init_models = [__get_dense_layer(2 => 2), __get_conv_layer((1, 1), 1 => 1)]
base_models = [Parallel(+, dense_layer(2 => 2), dense_layer(2 => 2)),
Parallel(+, conv_layer((1, 1), 1 => 1), conv_layer((1, 1), 1 => 1))]
init_models = [dense_layer(2 => 2), conv_layer((1, 1), 1 => 1)]
x_sizes = [(2, 14), (3, 3, 1, 3)]

model_type = (:deq, :skipdeq, :skipregdeq)
Expand All @@ -34,7 +34,7 @@ end
jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] :
_jacobian_regularizations

@testset "Solver: $(__nameof(solver)) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
@testset "Solver: $(nameof(typeof(solver))) | Model Type: $(mtype) | Jac. Reg: $(jacobian_regularization)" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations

Expand All @@ -55,8 +55,7 @@ end
x = randn(rng, Float32, x_size...) |> dev
z, st = model(x, ps, st)

opt_broken = solver isa SimpleLimitedMemoryBroyden
@jet model(x, ps, st) opt_broken=opt_broken
@jet model(x, ps, st) opt_broken=true

@test all(isfinite, z)
@test size(z) == size(x)
Expand All @@ -65,8 +64,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
Expand All @@ -82,28 +81,28 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)
end
end
end
end

@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin
@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] begin
using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote

rng = __get_prng(0)
rng = StableRNG(0)

main_layers = [(Parallel(+, __get_dense_layer(4 => 4), __get_dense_layer(4 => 4)),
__get_dense_layer(3 => 3), __get_dense_layer(2 => 2), __get_dense_layer(1 => 1))]
main_layers = [(Parallel(+, dense_layer(4 => 4), dense_layer(4 => 4)),
dense_layer(3 => 3), dense_layer(2 => 2), dense_layer(1 => 1))]

mapping_layers = [[NoOpLayer() __get_dense_layer(4 => 3) __get_dense_layer(4 => 2) __get_dense_layer(4 => 1);
__get_dense_layer(3 => 4) NoOpLayer() __get_dense_layer(3 => 2) __get_dense_layer(3 => 1);
__get_dense_layer(2 => 4) __get_dense_layer(2 => 3) NoOpLayer() __get_dense_layer(2 => 1);
__get_dense_layer(1 => 4) __get_dense_layer(1 => 3) __get_dense_layer(1 => 2) NoOpLayer()]]
mapping_layers = [[NoOpLayer() dense_layer(4 => 3) dense_layer(4 => 2) dense_layer(4 => 1);
dense_layer(3 => 4) NoOpLayer() dense_layer(3 => 2) dense_layer(3 => 1);
dense_layer(2 => 4) dense_layer(2 => 3) NoOpLayer() dense_layer(2 => 1);
dense_layer(1 => 4) dense_layer(1 => 3) dense_layer(1 => 2) NoOpLayer()]]

init_layers = [(__get_dense_layer(4 => 4), __get_dense_layer(4 => 3),
__get_dense_layer(4 => 2), __get_dense_layer(4 => 1))]
init_layers = [(
dense_layer(4 => 4), dense_layer(4 => 3), dense_layer(4 => 2), dense_layer(4 => 1))]

x_sizes = [(4, 3)]
scales = [((4,), (3,), (2,), (1,))]
Expand All @@ -112,7 +111,7 @@ end
jacobian_regularizations = (nothing,)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
@testset "Solver: $(__nameof(solver))" for solver in SOLVERS,
@testset "Solver: $(nameof(typeof(solver)))" for solver in SOLVERS,
mtype in model_type,
jacobian_regularization in jacobian_regularizations

Expand Down Expand Up @@ -141,8 +140,8 @@ end
z, st = model(x, ps, st)
z_ = DEQs.__flatten_vcat(z)

opt_broken = solver isa SimpleLimitedMemoryBroyden
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
opt_broken = mtype !== :node
@jet model(x, ps, st) opt_broken=opt_broken

@test all(isfinite, z_)
@test size(z_) == (sum(prod, scale), size(x, ndims(x)))
Expand All @@ -153,8 +152,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)

ps, st = Lux.setup(rng, model) |> dev
st = Lux.update_state(st, :fixed_depth, Val(10))
Expand All @@ -172,8 +171,8 @@ end

_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)

@test __is_finite_gradient(gs_x)
@test __is_finite_gradient(gs_ps)
@test is_finite_gradient(gs_x)
@test is_finite_gradient(gs_ps)
end
end
end
Expand Down
28 changes: 26 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
using ReTestItems
using ReTestItems, Pkg, InteractiveUtils, Hwloc

ReTestItems.runtests(@__DIR__)
@info sprint(versioninfo)

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))
const EXTRA_PKGS = String[]

(BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") && push!(EXTRA_PKGS, "LuxCUDA")

if !isempty(EXTRA_PKGS)
@info "Installing Extra Packages for testing" EXTRA_PKGS=EXTRA_PKGS
Pkg.add(EXTRA_PKGS)
Pkg.update()
Base.retry_load_extensions()
Pkg.instantiate()
end

using DeepEquilibriumNetworks

const RETESTITEMS_NWORKERS = parse(
Int, get(ENV, "RETESTITEMS_NWORKERS", string(min(Hwloc.num_physical_cores(), 16))))
const RETESTITEMS_NWORKER_THREADS = parse(Int,
get(ENV, "RETESTITEMS_NWORKER_THREADS",
string(max(Hwloc.num_virtual_cores() ÷ RETESTITEMS_NWORKERS, 1))))

ReTestItems.runtests(DeepEquilibriumNetworks; nworkers=RETESTITEMS_NWORKERS,
nworker_threads=RETESTITEMS_NWORKER_THREADS, testitem_timeout=12000)
60 changes: 26 additions & 34 deletions test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,47 @@
@testsetup module SharedTestSetup

using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote, ForwardDiff
import LuxTestUtils: @jet
using LuxCUDA
using LuxTestUtils
using MLDataDevices, GPUArraysCore

CUDA.allowscalar(false)
LuxTestUtils.jet_target_modules!(["DeepEquilibriumNetworks", "Lux", "LuxLib"])

__nameof(::X) where {X} = nameof(X)
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "all"))

__get_prng(seed::Int) = StableRNG(seed)
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
using LuxCUDA
end

GPUArraysCore.allowscalar(false)

__is_finite_gradient(x::AbstractArray) = all(isfinite, x)
cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
function cuda_testing()
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
MLDataDevices.functional(CUDADevice)
end

function __is_finite_gradient(gs::NamedTuple)
gradient_is_finite = Ref(true)
function __is_gradient_finite(x)
!isnothing(x) && !all(isfinite, x) && (gradient_is_finite[] = false)
return x
end
fmap(__is_gradient_finite, gs)
return gradient_is_finite[]
const MODES = begin
modes = []
cpu_testing() && push!(modes, ("cpu", Array, CPUDevice(), false))
cuda_testing() && push!(modes, ("cuda", CuArray, CUDADevice(), true))
modes
end

function __get_dense_layer(args...; kwargs...)
is_finite_gradient(x::AbstractArray) = all(isfinite, x)
is_finite_gradient(::Nothing) = true
is_finite_gradient(gs) = all(is_finite_gradient, fleaves(gs))

function dense_layer(args...; kwargs...)
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0
return Dense(args...; init_weight, use_bias=false, kwargs...)
end

function __get_conv_layer(args...; kwargs...)
function conv_layer(args...; kwargs...)
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0
return Conv(args...; init_weight, use_bias=false, kwargs...)
end

const GROUP = get(ENV, "GROUP", "All")

cpu_testing() = GROUP == "All" || GROUP == "CPU"
cuda_testing() = LuxCUDA.functional() && (GROUP == "All" || GROUP == "CUDA")

const MODES = begin
cpu_mode = ("CPU", Array, LuxCPUDevice(), false)
cuda_mode = ("CUDA", CuArray, LuxCUDADevice(), true)

modes = []
cpu_testing() && push!(modes, cpu_mode)
cuda_testing() && push!(modes, cuda_mode)

modes
end

export Lux, LuxCore, LuxLib
export MODES, __get_dense_layer, __get_conv_layer, __is_finite_gradient, __get_prng,
__nameof, @jet
export MODES, dense_layer, conv_layer, is_finite_gradient, StableRNG, @jet, test_gradients

end
Loading