From 9e4b171411b2b52e916323a49bfeeb58a888799e Mon Sep 17 00:00:00 2001 From: Songchen Tan Date: Tue, 28 May 2024 15:51:20 -0400 Subject: [PATCH] Fix ChainRules codegen problem --- examples/ode.jl | 45 +++++++++++++++++++++++++++++++++++++++++++++ src/codegen.jl | 17 +++++++++++------ src/primitive.jl | 5 +++++ src/scalar.jl | 19 +++++++++++++++++-- test/primitive.jl | 2 +- 5 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 examples/ode.jl diff --git a/examples/ode.jl b/examples/ode.jl new file mode 100644 index 0000000..86ca42e --- /dev/null +++ b/examples/ode.jl @@ -0,0 +1,45 @@ +using ADTypes +using DifferentiationInterface +using ModelingToolkit, DifferentialEquations +using TaylorDiff, ForwardDiff +using Enzyme, Zygote, ReverseDiff +using SciMLSensitivity + +@parameters a +@variables t x1(t) +D = Differential(t) +states = [x1] +parameters = [a] + +@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters) +model = structural_simplify(pre_model) + +ic = Dict(x1 => 1.0) +p_true = Dict(a => 2.0) + +problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true) +soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12) +display(soln(0.5, idxs = [x1])) + +function different_time(new_ic, new_params, new_t) + #newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params) + #newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params) + newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p = new_params) + newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0)) + new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + return (soln(new_t, idxs = [x1])) +end + +function just_t(new_t) + return different_time(ic, p_true, new_t)[1] +end +display(different_time(ic, p_true, 2e-5)) +display(just_t(0.5)) + +#display(ForwardDiff.derivative(just_t,1.0)) +display(TaylorDiff.derivative(just_t, 1.0, 1)) #isnan error +#display(value_and_gradient(just_t, AutoForwardDiff(), 1.0)) +#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0)) +#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0)) +#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0)) +#display(value_and_gradient(just_t, AutoZygote(), 1.0)) diff --git a/src/codegen.jl b/src/codegen.jl index 6a5c003..8c61711 100644 --- a/src/codegen.jl +++ b/src/codegen.jl @@ -4,20 +4,24 @@ using Symbolics: @variables using SymbolicUtils, SymbolicUtils.Code using SymbolicUtils: Pow -dummy = (NoTangent(), 1) -@variables z -for func in (+, -, deg2rad, rad2deg, +func_list = ( + +, -, deg2rad, rad2deg, sinh, cosh, tanh, asin, acos, atan, asec, acsc, acot, log, log10, log1p, log2, asinh, acosh, atanh, asech, acsch, acoth, abs, sign) + +dummy = (NoTangent(), 1) +@variables z +for func in func_list F = typeof(func) # base case @eval function (op::$F)(t::TaylorScalar{T, 2}) where {T} t0, t1 = value(t) - TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0)) + f0, f1 = frule((NoTangent(), t1), op, t0) + TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1) end der = frule(dummy, func, z)[2] term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise) @@ -28,8 +32,9 @@ for func in (+, -, deg2rad, rad2deg, quote $(Expr(:meta, :inline)) z = TaylorScalar{T, N - 1}(t) - df = $der_expr - $$raiser($f(value(t)[1]), df, t) + f0 = $f(value(t)[1]) + df = zero_tangent(z) + $der_expr + $$raiser(f0, df, t) end end end diff --git a/src/primitive.jl b/src/primitive.jl index ee7617e..0e21b4a 100644 --- a/src/primitive.jl +++ b/src/primitive.jl @@ -151,8 +151,13 @@ for R in (Integer, Real) ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...))) return :(@inbounds $ex) end + @eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N} + exp(t * log(a)) + end end +^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t)) + @generated function raise(f::T, df::TaylorScalar{T, M}, t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N return quote diff --git a/src/scalar.jl b/src/scalar.jl index 6ad6c9c..ca523fb 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -90,6 +90,21 @@ function promote_rule(::Type{TaylorScalar{T, N}}, TaylorScalar{promote_type(T, S), N} end -function Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N} - TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value)) +function (::Type{F})(x::TaylorScalar{T, N}) where {T, N, F <: AbstractFloat} + F(primal(x)) +end + +function Base.nextfloat(x::TaylorScalar{T, N}) where {T, N} + TaylorScalar{T, N}(ntuple(i -> i == 1 ? nextfloat(value(x)[i]) : value(x)[i], N)) +end + +function Base.prevfloat(x::TaylorScalar{T, N}) where {T, N} + TaylorScalar{T, N}(ntuple(i -> i == 1 ? prevfloat(value(x)[i]) : value(x)[i], N)) +end + +const UNARY_PREDICATES = Symbol[ + :isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger] + +for pred in UNARY_PREDICATES + @eval Base.$(pred)(x::TaylorScalar) = $(pred)(primal(x)) end diff --git a/test/primitive.jl b/test/primitive.jl index de4a221..4cd6bb7 100644 --- a/test/primitive.jl +++ b/test/primitive.jl @@ -2,7 +2,7 @@ using FiniteDifferences @testset "No derivative or linear" begin some_number, another_number = 1.9, 2.6 - for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg), order in (2,) + for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg, abs, sign), order in (2,) @test derivative(f, some_number, order) ≈ 0.0 end for f in (+, -, <, <=, >, >=, ==), order in (2,)