diff --git a/src/init.jl b/src/init.jl index b276cd662..b1a4a19af 100644 --- a/src/init.jl +++ b/src/init.jl @@ -34,4 +34,13 @@ function __init__() @inline ODE_DEFAULT_NORM(u::Unitful.Quantity) = abs(value(u)) end + @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin + value(x::Flux.Tracker.TrackedReal) = x.data + value(x::Flux.Tracker.TrackedArray) = x.data + @inline function ODE_DEFAULT_NORM(u::Flux.Tracker.TrackedArray) where {N} + sqrt(sum(ODE_DEFAULT_NORM,(value(x) for x in u)) / length(u)) + end + @inline ODE_DEFAULT_NORM(u::Flux.Tracker.TrackedReal) = abs(value(u)) + end + end