Skip to content

Commit

Permalink
Fix ChainRules codegen problem
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed May 28, 2024
1 parent 2cf2659 commit 9e4b171
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 9 deletions.
45 changes: 45 additions & 0 deletions examples/ode.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using ADTypes
using DifferentiationInterface
using ModelingToolkit, DifferentialEquations
using TaylorDiff, ForwardDiff
using Enzyme, Zygote, ReverseDiff
using SciMLSensitivity

@parameters a
@variables t x1(t)
D = Differential(t)
states = [x1]
parameters = [a]

@named pre_model = ODESystem([D(x1) ~ a * x1], t, states, parameters)
model = structural_simplify(pre_model)

ic = Dict(x1 => 1.0)
p_true = Dict(a => 2.0)

problem = ODEProblem{true, SciMLBase.FullSpecialize}(model, ic, [0.0, 1.0], p_true)
soln = ModelingToolkit.solve(problem, Tsit5(), abstol = 1e-12, reltol = 1e-12)
display(soln(0.5, idxs = [x1]))

function different_time(new_ic, new_params, new_t)
#newprob = ODEProblem{true, SciMLBase.FullSpecialize}(model, new_ic, [0.0, new_t*2], new_params)
#newprob = remake(problem, u0=new_ic, tspan = [0.0, new_t], p = new_params)
newprob = remake(problem, u0 = new_ic, tspan = [0.0, new_t], p = new_params)
newprob = remake(newprob, u0 = typeof(new_t).(newprob.u0))
new_soln = ModelingToolkit.solve(newprob, Tsit5(), abstol = 1e-12, reltol = 1e-12)
return (soln(new_t, idxs = [x1]))
end

function just_t(new_t)
return different_time(ic, p_true, new_t)[1]
end
display(different_time(ic, p_true, 2e-5))
display(just_t(0.5))

#display(ForwardDiff.derivative(just_t,1.0))
display(TaylorDiff.derivative(just_t, 1.0, 1)) #isnan error
#display(value_and_gradient(just_t, AutoForwardDiff(), 1.0))
#display(value_and_gradient(just_t, AutoReverseDiff(), 1.0))
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Reverse), 1.0))
#display(value_and_gradient(just_t, AutoEnzyme(Enzyme.Forward), 1.0))
#display(value_and_gradient(just_t, AutoZygote(), 1.0))
17 changes: 11 additions & 6 deletions src/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow

dummy = (NoTangent(), 1)
@variables z
for func in (+, -, deg2rad, rad2deg,
func_list = (
+, -, deg2rad, rad2deg,
sinh, cosh, tanh,
asin, acos, atan, asec, acsc, acot,
log, log10, log1p, log2,
asinh, acosh, atanh, asech, acsch,
acoth,
abs, sign)

dummy = (NoTangent(), 1)
@variables z
for func in func_list
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
t0, t1 = value(t)
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
f0, f1 = frule((NoTangent(), t1), op, t0)
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
Expand All @@ -28,8 +32,9 @@ for func in (+, -, deg2rad, rad2deg,
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
df = $der_expr
$$raiser($f(value(t)[1]), df, t)
f0 = $f(value(t)[1])
df = zero_tangent(z) + $der_expr
$$raiser(f0, df, t)
end
end
end
5 changes: 5 additions & 0 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,13 @@ for R in (Integer, Real)
ex = :($ex; TaylorScalar($([Symbol('u', i) for i in 1:N]...)))
return :(@inbounds $ex)
end
@eval function ^(a::S, t::TaylorScalar{T, N}) where {S <: $R, T, N}
exp(t * log(a))
end
end

^(t::TaylorScalar, s::TaylorScalar) = exp(s * log(t))

@generated function raise(f::T, df::TaylorScalar{T, M},
t::TaylorScalar{T, N}) where {T, M, N} # M + 1 == N
return quote
Expand Down
19 changes: 17 additions & 2 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ function promote_rule(::Type{TaylorScalar{T, N}},
TaylorScalar{promote_type(T, S), N}
end

function Base.AbstractFloat(x::TaylorScalar{T, N}) where {T, N}
TaylorScalar{Float64, N}(convert(NTuple{N, Float64}, x.value))
function (::Type{F})(x::TaylorScalar{T, N}) where {T, N, F <: AbstractFloat}
F(primal(x))
end

function Base.nextfloat(x::TaylorScalar{T, N}) where {T, N}
TaylorScalar{T, N}(ntuple(i -> i == 1 ? nextfloat(value(x)[i]) : value(x)[i], N))
end

function Base.prevfloat(x::TaylorScalar{T, N}) where {T, N}
TaylorScalar{T, N}(ntuple(i -> i == 1 ? prevfloat(value(x)[i]) : value(x)[i], N))
end

const UNARY_PREDICATES = Symbol[
:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]

for pred in UNARY_PREDICATES
@eval Base.$(pred)(x::TaylorScalar) = $(pred)(primal(x))
end
2 changes: 1 addition & 1 deletion test/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using FiniteDifferences

@testset "No derivative or linear" begin
some_number, another_number = 1.9, 2.6
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg), order in (2,)
for f in (+, -, zero, one, adjoint, conj, deg2rad, rad2deg, abs, sign), order in (2,)
@test derivative(f, some_number, order) 0.0
end
for f in (+, -, <, <=, >, >=, ==), order in (2,)
Expand Down

0 comments on commit 9e4b171

Please sign in to comment.