Skip to content

Commit

Permalink
Merge pull request #112 from ArnoStrouwen/fix_unbound
Browse files Browse the repository at this point in the history
fix unbound type parameters in NTuple
  • Loading branch information
ChrisRackauckas authored Dec 12, 2023
2 parents 25eece3 + 42a35fe commit b38c669
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
17 changes: 10 additions & 7 deletions src/layers/mdeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwo
end

function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
post_fuse_layer::Union{Nothing, Tuple}, solver,
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L}
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand Down Expand Up @@ -226,8 +227,9 @@ end

function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, shortcut_layers::Union{Nothing, Tuple},
solver, scales::NTuple{N, NTuple{L, Int64}};
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
solver,
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L}
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...)
Expand Down Expand Up @@ -330,8 +332,9 @@ model(x, ps, st)
See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref)
"""
function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
post_fuse_layer::Union{Nothing, Tuple}, solver,
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {nMinus1, L}
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand All @@ -344,7 +347,7 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
split_idxs, scales)
end

return MultiScaleNeuralODE{N}(model, solver, sensealg, scales, split_idxs, kwargs)
return MultiScaleNeuralODE{nMinus1+1}(model, solver, sensealg, scales, split_idxs, kwargs)
end

_jacobian_regularization(::MultiScaleNeuralODE) = false
Expand Down
2 changes: 1 addition & 1 deletion test/qa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ using DeepEquilibriumNetworks, Aqua
Aqua.test_piracies(DeepEquilibriumNetworks; broken=true)
Aqua.test_project_extras(DeepEquilibriumNetworks)
Aqua.test_stale_deps(DeepEquilibriumNetworks)
Aqua.test_unbound_args(DeepEquilibriumNetworks; broken=true)
Aqua.test_unbound_args(DeepEquilibriumNetworks)
Aqua.test_undefined_exports(DeepEquilibriumNetworks; broken=true)
end

0 comments on commit b38c669

Please sign in to comment.