Skip to content

Commit

Permalink
Refactor SpecialFunctionsExt
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Sep 27, 2024
1 parent c90fc69 commit 8f8e5d9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 59 deletions.
28 changes: 1 addition & 27 deletions ext/TaylorDiffSFExt.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,8 @@
module TaylorDiffSFExt
using TaylorDiff, SpecialFunctions
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow
using TaylorDiff: value, raise
using ChainRules, ChainRulesCore

dummy = (NoTangent(), 1)
@variables z
# logerfc, logerfcx, erfinv, gamma, digamma, trigamma
for func in (erf, erfc, erfcinv, erfcx, erfi)
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
t0, t1 = value(t)
TaylorScalar{T, 2}(frule((NoTangent(), t1), op, t0))
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
df = $der_expr
$$raiser($f(value(t)[1]), df, t)
end
end
TaylorDiff.define_unary_function(func, TaylorDiffSFExt)
end

end
1 change: 1 addition & 0 deletions src/TaylorDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module TaylorDiff

include("scalar.jl")
include("primitive.jl")
include("utils.jl")
include("codegen.jl")
include("derivative.jl")
include("chainrules.jl")
Expand Down
34 changes: 2 additions & 32 deletions src/codegen.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,10 @@
using ChainRules
using ChainRulesCore
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow

func_list = (
for unary_func in (
+, -, deg2rad, rad2deg,
sinh, cosh, tanh,
asin, acos, atan, asec, acsc, acot,
log, log10, log1p, log2,
asinh, acosh, atanh, asech, acsch,
acoth,
abs, sign)

dummy = (NoTangent(), 1)
@variables z
for func in func_list
F = typeof(func)
# base case
@eval function (op::$F)(t::TaylorScalar{T, 2}) where {T}
t0, t1 = value(t)
f0, f1 = frule((NoTangent(), t1), op, t0)
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
f0 = $f(value(t)[1])
df = zero_tangent(z) + $der_expr
$$raiser(f0, df, t)
end
end
define_unary_function(unary_func, TaylorDiff)
end
32 changes: 32 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using ChainRules
using ChainRulesCore
using Symbolics: @variables
using SymbolicUtils, SymbolicUtils.Code
using SymbolicUtils: Pow

dummy = (NoTangent(), 1)
@variables z

function define_unary_function(func, m)
F = typeof(func)
# base case
@eval m function (op::$F)(t::TaylorScalar{T, 2}) where {T}
t0, t1 = value(t)
f0, f1 = frule((NoTangent(), t1), op, t0)
TaylorScalar{T, 2}(f0, zero_tangent(f0) + f1)
end
der = frule(dummy, func, z)[2]
term, raiser = der isa Pow && der.exp == -1 ? (der.base, raiseinv) : (der, raise)
# recursion by raising
@eval m @generated function (op::$F)(t::TaylorScalar{T, N}) where {T, N}
der_expr = $(QuoteNode(toexpr(term)))
f = $func
quote
$(Expr(:meta, :inline))
z = TaylorScalar{T, N - 1}(t)
f0 = $f(value(t)[1])
df = zero_tangent(z) + $der_expr
$$raiser(f0, df, t)
end
end
end

0 comments on commit 8f8e5d9

Please sign in to comment.