diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index 72351e04..7e31e0ef 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -94,8 +94,9 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwo end function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, - post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}}; - sensealg=SteadyStateAdjoint(), kwargs...) where {N, L} + post_fuse_layer::Union{Nothing, Tuple}, solver, + scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}}; + sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L} l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) @@ -226,8 +227,9 @@ end function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple}, shortcut_layers::Union{Nothing, Tuple}, - solver, scales::NTuple{N, NTuple{L, Int64}}; - sensealg=SteadyStateAdjoint(), kwargs...) where {N, L} + solver, + scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}}; + sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L} l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...) @@ -330,8 +332,9 @@ model(x, ps, st) See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref) """ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix, - post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}}; - sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L} + post_fuse_layer::Union{Nothing, Tuple}, solver, + scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}}; + sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {nMinus1, L} l1 = Parallel(nothing, main_layers...) l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...) @@ -344,7 +347,7 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix, split_idxs, scales) end - return MultiScaleNeuralODE{N}(model, solver, sensealg, scales, split_idxs, kwargs) + return MultiScaleNeuralODE{nMinus1+1}(model, solver, sensealg, scales, split_idxs, kwargs) end _jacobian_regularization(::MultiScaleNeuralODE) = false diff --git a/test/qa.jl b/test/qa.jl index 472f4881..e6a491d7 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -6,6 +6,6 @@ using DeepEquilibriumNetworks, Aqua Aqua.test_piracies(DeepEquilibriumNetworks; broken=true) Aqua.test_project_extras(DeepEquilibriumNetworks) Aqua.test_stale_deps(DeepEquilibriumNetworks) - Aqua.test_unbound_args(DeepEquilibriumNetworks; broken=true) + Aqua.test_unbound_args(DeepEquilibriumNetworks) Aqua.test_undefined_exports(DeepEquilibriumNetworks; broken=true) end