From ce978119aeb043c45de30b4c53f768caad35da7c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 18 Jun 2024 19:51:44 -0700 Subject: [PATCH] Update the code to the latest versions of Lux --- Project.toml | 16 ++-- docs/Project.toml | 2 +- docs/make.jl | 2 +- docs/src/tutorials/basic_mnist_deq.md | 13 +-- ext/DeepEquilibriumNetworksZygoteExt.jl | 57 ------------- src/DeepEquilibriumNetworks.jl | 5 +- src/layers.jl | 107 ++++++------------------ src/utils.jl | 20 ++++- test/layers_tests.jl | 6 +- test/qa_tests.jl | 16 +++- test/shared_testsetup.jl | 2 +- 11 files changed, 82 insertions(+), 164 deletions(-) delete mode 100644 ext/DeepEquilibriumNetworksZygoteExt.jl diff --git a/Project.toml b/Project.toml index 0f57e746..450b5c66 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DeepEquilibriumNetworks" uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334" authors = ["Avik Pal "] -version = "2.1.0" +version = "2.1.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -13,6 +13,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" @@ -20,14 +21,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" [weakdeps] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"] -DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"] [compat] ADTypes = "0.2.5, 1" @@ -37,16 +35,18 @@ CommonSolve = "0.2.4" ConcreteStructs = "0.2" ConstructionBase = "1" DiffEqBase = "6.119" -ExplicitImports = "1.4.1" +Documenter = "1.4" +ExplicitImports = "1.6.0" FastClosures = "0.3" ForwardDiff = "0.10.36" Functors = "0.4.10" LinearSolve = "2.21.2" -Lux = "0.5.38" +Lux = "0.5.50" LuxCUDA = "0.3.2" LuxCore = "0.1.14" LuxTestUtils = "0.1.15" NLsolve = "4.5.1" +NNlib = "0.9.17" NonlinearSolve = "3.10.0" OrdinaryDiffEq = "6.74.1" PrecompileTools = "1" @@ -63,7 +63,9 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" @@ -77,4 +79,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ExplicitImports", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"] +test = ["Aqua", "Documenter", "ExplicitImports", "ForwardDiff", "Functors", "LuxCUDA", "LuxTestUtils", "NLsolve", "NonlinearSolve", "OrdinaryDiffEq", "ReTestItems", "SciMLSensitivity", "StableRNGs", "Test", "Zygote"] diff --git a/docs/Project.toml b/docs/Project.toml index 79874acd..0ad07734 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -33,4 +33,4 @@ Random = "1" SciMLSensitivity = "7" Statistics = "1" Zygote = "0.6" -julia = "1.9" +julia = "1.10" diff --git a/docs/make.jl b/docs/make.jl index 3117b436..1b00e024 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,7 @@ makedocs(; sitename="Deep Equilibrium Networks", authors="Avik Pal et al.", modules=[DeepEquilibriumNetworks], clean=true, - doctest=true, + doctest=false, # Tested in CI linkcheck=true, format=Documenter.HTML(; assets=["assets/favicon.ico"], canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"), diff --git a/docs/src/tutorials/basic_mnist_deq.md b/docs/src/tutorials/basic_mnist_deq.md index 3684f4a7..815e89d0 100644 --- a/docs/src/tutorials/basic_mnist_deq.md +++ b/docs/src/tutorials/basic_mnist_deq.md @@ -20,7 +20,8 @@ const cdev = cpu_device() const gdev = gpu_device() ``` -We can now construct our dataloader. +We can now construct our dataloader. We are using only limited part of the data for +demonstration. ```@example basic_mnist_deq function onehot(labels_raw) @@ -32,17 +33,17 @@ function loadmnist(batchsize, split) mnist = MNIST(; split) imgs, labels_raw = mnist.features, mnist.targets # Process images into (H,W,C,BS) batches - x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |> - gdev + x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3)))[ + :, :, 1:1, 1:128] |> gdev x_train = batchview(x_train, batchsize) # Onehot and batch the labels - y_train = onehot(labels_raw) |> gdev + y_train = onehot(labels_raw)[:, 1:128] |> gdev y_train = batchview(y_train, batchsize) return x_train, y_train end -x_train, y_train = loadmnist(128, :train); -x_test, y_test = loadmnist(128, :test); +x_train, y_train = loadmnist(16, :train); +x_test, y_test = loadmnist(16, :test); ``` Construct the Lux Neural Network containing a DEQ layer. diff --git a/ext/DeepEquilibriumNetworksZygoteExt.jl b/ext/DeepEquilibriumNetworksZygoteExt.jl deleted file mode 100644 index a04697e0..00000000 --- a/ext/DeepEquilibriumNetworksZygoteExt.jl +++ /dev/null @@ -1,57 +0,0 @@ -module DeepEquilibriumNetworksZygoteExt - -using ADTypes: AutoZygote -using ChainRulesCore: ChainRulesCore -using DeepEquilibriumNetworks: DEQs -using FastClosures: @closure -using ForwardDiff: ForwardDiff # This is a dependency of Zygote -using Lux: Lux, StatefulLuxLayer -using Statistics: mean -using Zygote: Zygote - -const CRC = ChainRulesCore - -@inline __tupleify(x) = @closure(u->(u, x)) - -## One day we will overload DI's APIs for Lux Layers and we can remove this -## Main challenge with overloading Zygote.pullback is that we need to return the correct -## tangent for the pullback to compute the correct gradient, which is quite hard. But -## wrapping the overall vjp is not that hard. -@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng) - res, back = Zygote.pullback(model ∘ __tupleify(x), z) - return only(back(DEQs.__gaussian_like(rng, res))) -end - -function CRC.rrule( - ::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng) - res, back = Zygote.pullback(model ∘ __tupleify(x), z) - ε = DEQs.__gaussian_like(rng, res) - y = only(back(ε)) - ∇internal_gradient_capture = Δ -> begin - (Δ isa CRC.NoTangent || Δ isa CRC.ZeroTangent) && - return ntuple(Returns(CRC.NoTangent()), 6) - - Δ_ = reshape(CRC.unthunk(Δ), size(z)) - - Tag = typeof(ForwardDiff.Tag(model, eltype(z))) - partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_)) - z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials) - - _, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps) - ∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε) - - ∂z = Lux.__partials(Tag, ∂z_duals, 1) - ∂x = Lux.__partials(Tag, ∂x_duals, 1) - ∂ps = Lux.__partials(Tag, ∂ps_duals, 1) - - return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent() - end - return y, ∇internal_gradient_capture -end - -## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 -function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng) - return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng)) -end - -end diff --git a/src/DeepEquilibriumNetworks.jl b/src/DeepEquilibriumNetworks.jl index abaccfbb..8cd3ea5d 100644 --- a/src/DeepEquilibriumNetworks.jl +++ b/src/DeepEquilibriumNetworks.jl @@ -3,7 +3,7 @@ module DeepEquilibriumNetworks import PrecompileTools: @recompile_invalidations @recompile_invalidations begin - using ADTypes: AutoFiniteDiff + using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoZygote using ChainRulesCore: ChainRulesCore using CommonSolve: solve using ConcreteStructs: @concrete @@ -12,7 +12,8 @@ import PrecompileTools: @recompile_invalidations using FastClosures: @closure using Lux: Lux, BranchLayer, Chain, NoOpLayer, Parallel, RepeatedLayer, StatefulLuxLayer, WrappedFunction - using LuxCore: AbstractExplicitLayer, AbstractExplicitContainerLayer + using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer + using NNlib: ⊠ using Random: Random, AbstractRNG, randn! using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractODEAlgorithm, NonlinearSolution, ODESolution, ODEFunction, ODEProblem, diff --git a/src/layers.jl b/src/layers.jl index 995f94db..4665f492 100644 --- a/src/layers.jl +++ b/src/layers.jl @@ -64,11 +64,11 @@ const DEQ = DeepEquilibriumNetwork ConstructionBase.constructorof(::Type{<:DEQ{pType}}) where {pType} = DEQ{pType} -function Lux.initialstates(rng::AbstractRNG, deq::DEQ) - rng = Lux.replicate(rng) +function LuxCore.initialstates(rng::AbstractRNG, deq::DEQ) + rng = LuxCore.replicate(rng) randn(rng, 1) - return (; model=Lux.initialstates(rng, deq.model), fixed_depth=Val(0), - init=Lux.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng) + return (; model=LuxCore.initialstates(rng, deq.model), fixed_depth=Val(0), + init=LuxCore.initialstates(rng, deq.init), solution=DeepEquilibriumSolution(), rng) end (deq::DEQ)(x, ps, st::NamedTuple) = deq(x, ps, st, __check_unrolled_mode(st)) @@ -79,10 +79,10 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true}) repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth) z_star, st_ = repeated_model((z, x), ps.model, st.model) - model = StatefulLuxLayer(deq.model, ps.model, st_) + model = StatefulLuxLayer{true}(deq.model, ps.model, st_) resid = CRC.ignore_derivatives(z_star .- model((z_star, x))) - rng = Lux.replicate(st.rng) + rng = LuxCore.replicate(st.rng) jac_loss = __estimate_jacobian_trace( __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) @@ -97,7 +97,7 @@ end function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} z, st = __get_initial_condition(deq, x, ps, st) - model = StatefulLuxLayer(deq.model, ps.model, st.model) + model = StatefulLuxLayer{true}(deq.model, ps.model, st.model) dudt = @closure (u, p, t) -> begin # The type-assert is needed because of an upstream Lux issue with type stability of @@ -113,7 +113,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType} reltol=1e-3, termination_condition, maxiters=32, deq.kwargs...) z_star = __get_steady_state(sol) - rng = Lux.replicate(st.rng) + rng = LuxCore.replicate(st.rng) jac_loss = __estimate_jacobian_trace( __getproperty(deq, Val(:jacobian_regularization)), model, z_star, x, rng) @@ -144,7 +144,8 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] condition is set to `zero(x)`. If `missing`, the initial condition is set to `WrappedFunction(zero)`. In other cases the initial condition is set to `init(x, ps, st)`. - - `jacobian_regularization`: Must be one of `nothing`, `AutoFiniteDiff` or `AutoZygote`. + - `jacobian_regularization`: Must be one of `nothing`, `AutoForwardDiff`, `AutoFiniteDiff` + or `AutoZygote`. - `problem_type`: Provides a way to simulate a Vanilla Neural ODE by setting the `problem_type` to `ODEProblem`. By default, the problem type is set to `SteadyStateProblem`. @@ -152,9 +153,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing] ## Example -```julia -julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq - +```jldoctest julia> model = DeepEquilibriumNetwork( Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)), VCABM3(); verbose=false) @@ -168,13 +167,12 @@ DeepEquilibriumNetwork( ) # Total: 8 parameters, # plus 0 states. -julia> rng = Random.default_rng() -TaskLocalRNG() +julia> rng = Xoshiro(0); julia> ps, st = Lux.setup(rng, model); -julia> model(ones(Float32, 2, 1), ps, st); - +julia> size(first(model(ones(Float32, 2, 1), ps, st))) +(2, 1) ``` See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), @@ -234,84 +232,27 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref). ## Example -```julia -julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve - +```jldoctest julia> main_layers = ( Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)), - Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh)) -(Parallel(), Dense(3 => 3, tanh_fast), Dense(2 => 2, tanh_fast), Dense(1 => 1, tanh_fast)) + Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh)); julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh); Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh); Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh); - Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()] -4×4 Matrix{LuxCore.AbstractExplicitLayer}: - NoOpLayer() … Dense(4 => 1, tanh_fast) - Dense(3 => 4, tanh_fast) Dense(3 => 1, tanh_fast) - Dense(2 => 4, tanh_fast) Dense(2 => 1, tanh_fast) - Dense(1 => 4, tanh_fast) NoOpLayer() + Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()]; julia> model = MultiScaleDeepEquilibriumNetwork( - main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,))) -DeepEquilibriumNetwork( - model = MultiScaleInputLayer{scales = 4}( - model = Chain( - layer_1 = Parallel( - layer_1 = Parallel( - + - Dense(4 => 4, tanh_fast, bias=false), # 16 parameters - Dense(4 => 4, tanh_fast, bias=false), # 16 parameters - ), - layer_2 = Dense(3 => 3, tanh_fast), # 12 parameters - layer_3 = Dense(2 => 2, tanh_fast), # 6 parameters - layer_4 = Dense(1 => 1, tanh_fast), # 2 parameters - ), - layer_2 = BranchLayer( - layer_1 = Parallel( - + - NoOpLayer(), - Dense(3 => 4, tanh_fast), # 16 parameters - Dense(2 => 4, tanh_fast), # 12 parameters - Dense(1 => 4, tanh_fast), # 8 parameters - ), - layer_2 = Parallel( - + - Dense(4 => 3, tanh_fast), # 15 parameters - NoOpLayer(), - Dense(2 => 3, tanh_fast), # 9 parameters - Dense(1 => 3, tanh_fast), # 6 parameters - ), - layer_3 = Parallel( - + - Dense(4 => 2, tanh_fast), # 10 parameters - Dense(3 => 2, tanh_fast), # 8 parameters - NoOpLayer(), - Dense(1 => 2, tanh_fast), # 4 parameters - ), - layer_4 = Parallel( - + - Dense(4 => 1, tanh_fast), # 5 parameters - Dense(3 => 1, tanh_fast), # 4 parameters - Dense(2 => 1, tanh_fast), # 3 parameters - NoOpLayer(), - ), - ), - ), - ), - init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Val{((4,), (3,), (2,), (1,))}}(DeepEquilibriumNetworks.__zeros_init, Val{((4,), (3,), (2,), (1,))}())), -) # Total: 152 parameters, - # plus 0 states. + main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,))); -julia> rng = Random.default_rng() -TaskLocalRNG() +julia> rng = Xoshiro(0); julia> ps, st = Lux.setup(rng, model); julia> x = rand(rng, Float32, 4, 12); -julia> model(x, ps, st); - +julia> size.(first(model(x, ps, st))) +((4, 12), (3, 12), (2, 12), (1, 12)) ``` """ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, @@ -390,7 +331,9 @@ end function ConstructionBase.constructorof(::Type{<:MultiScaleInputLayer{N}}) where {N} return MultiScaleInputLayer{N} end -Lux.display_name(::MultiScaleInputLayer{N}) where {N} = "MultiScaleInputLayer{scales = $N}" +function LuxCore.display_name(::MultiScaleInputLayer{N}) where {N} + return "MultiScaleInputLayer{scales = $N}" +end function MultiScaleInputLayer(model, split_idxs, scales::Val{S}) where {S} return MultiScaleInputLayer{length(S)}(model, split_idxs, scales) @@ -401,7 +344,7 @@ end return quote u, x = z u_ = __split_and_reshape(u, m.split_idxs, m.scales) - u_res, st = Lux.apply(m.model, ($(inputs...),), ps, st) + u_res, st = LuxCore.apply(m.model, ($(inputs...),), ps, st) return __flatten_vcat(u_res), st end end diff --git a/src/utils.jl b/src/utils.jl index 647636dc..26c942a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,9 +95,11 @@ end CRC.@non_differentiable __gaussian_like(::Any...) +@inline __tupleify(x) = @closure(u->(u, x)) + # Jacobian Stabilization ## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33 -function __estimate_jacobian_trace(ad::AutoFiniteDiff, model, z, x, rng) +function __estimate_jacobian_trace(ad::AutoFiniteDiff, model::StatefulLuxLayer, z, x, rng) __f = @closure u -> model((u, x)) res = zero(eltype(x)) ϵ = cbrt(eps(typeof(res))) @@ -119,4 +121,20 @@ function __estimate_jacobian_trace(ad::AutoFiniteDiff, model, z, x, rng) return res end +function __estimate_jacobian_trace(ad::AutoZygote, model::StatefulLuxLayer, z, x, rng) + v = __gaussian_like(rng, x) + smodel = model ∘ __tupleify(x) + vjp = Lux.vector_jacobian_product(smodel, ad, z, v) + return sum(reshape(vjp, 1, :, size(vjp, ndims(vjp))) ⊠ + reshape(v, :, 1, size(v, ndims(v)))) +end + +function __estimate_jacobian_trace(ad::AutoForwardDiff, model::StatefulLuxLayer, z, x, rng) + v = __gaussian_like(rng, x) + smodel = model ∘ __tupleify(x) + jvp = Lux.jacobian_vector_product(smodel, ad, z, v) + return sum(reshape(v, 1, :, size(v, ndims(v))) ⊠ + reshape(jvp, :, 1, size(jvp, ndims(jvp)))) +end + __estimate_jacobian_trace(::Nothing, model, z, x, rng) = zero(eltype(x)) diff --git a/test/layers_tests.jl b/test/layers_tests.jl index aa19ea45..4a69127d 100644 --- a/test/layers_tests.jl +++ b/test/layers_tests.jl @@ -17,7 +17,7 @@ export loss_function, SOLVERS end -@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] begin +@testitem "DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote rng = __get_prng(0) @@ -28,7 +28,7 @@ end x_sizes = [(2, 14), (3, 3, 1, 3)] model_type = (:deq, :skipdeq, :skipregdeq) - _jacobian_regularizations = (nothing, AutoZygote(), AutoFiniteDiff()) + _jacobian_regularizations = (nothing, AutoZygote(), AutoForwardDiff(), AutoFiniteDiff()) @testset "$mode" for (mode, aType, dev, ongpu) in MODES jacobian_regularizations = ongpu ? _jacobian_regularizations[1:(end - 1)] : @@ -89,7 +89,7 @@ end end end -@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] begin +@testitem "Multiscale DEQ" setup=[SharedTestSetup, LayersTestSetup] timeout=10000 begin using ADTypes, Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity, Zygote rng = __get_prng(0) diff --git a/test/qa_tests.jl b/test/qa_tests.jl index 2dd1d11e..03d36d20 100644 --- a/test/qa_tests.jl +++ b/test/qa_tests.jl @@ -6,12 +6,22 @@ end @testitem "ExplicitImports" begin - import SciMLSensitivity, Zygote + import SciMLSensitivity using ExplicitImports - # Skip our own packages @test check_no_implicit_imports(DeepEquilibriumNetworks) === nothing - ## AbstractRNG seems to be a spurious detection in LuxFluxExt @test check_no_stale_explicit_imports(DeepEquilibriumNetworks) === nothing + @test check_all_qualified_accesses_via_owners(DeepEquilibriumNetworks) === nothing +end + +@testitem "Doctests" begin + using Documenter + + doctestexpr = quote + using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq, NonlinearSolve + end + + DocMeta.setdocmeta!(DeepEquilibriumNetworks, :DocTestSetup, doctestexpr; recursive=true) + doctest(DeepEquilibriumNetworks; manual=false) end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index b22de31b..a0a321a6 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -1,6 +1,6 @@ @testsetup module SharedTestSetup -using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote +using DeepEquilibriumNetworks, Functors, Lux, Random, StableRNGs, Zygote, ForwardDiff import LuxTestUtils: @jet using LuxCUDA