Skip to content

Commit

Permalink
Accumulate NamedTuple + Tangent (#88)
Browse files Browse the repository at this point in the history
* accumulate NamedTuple + Tangent

* fixup

* don't test on 1.8
  • Loading branch information
mcabbott authored Sep 4, 2022
1 parent 9a8a788 commit fb1d4ec
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 13 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ jobs:
matrix:
version:
- '1.7' # Lowest claimed support in Project.toml
- '1' # Latest Release
# - '1' # Latest Release # Testing on 1.8 gives this message:
# ┌ Warning: ir verification broken. Either use 1.9 or 1.7
# └ @ Diffractor ~/work/Diffractor.jl/Diffractor.jl/src/stage1/recurse.jl:889
- 'nightly'
os:
- ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions src/extra_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,6 @@ end
function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, add!!, val)
val, Δ->(NoTangent(), NoTangent(), Δ)
end

Base.real(z::ZeroTangent) = z # TODO should be in CRC
Base.real(z::NoTangent) = z
20 changes: 17 additions & 3 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,25 @@ struct DiffractorRuleConfig <: RuleConfig{Union{HasReverseMode,HasForwardsMode}}
@Base.constprop :aggressive accum(a::Tuple, b::Tuple) = map(accum, a, b)
@Base.constprop :aggressive @generated function accum(x::NamedTuple, y::NamedTuple)
fnames = union(fieldnames(x), fieldnames(y))
isempty(fnames) && return :((;)) # code below makes () instead
gradx(f) = f in fieldnames(x) ? :(getfield(x, $(quot(f)))) : :(ZeroTangent())
grady(f) = f in fieldnames(y) ? :(getfield(y, $(quot(f)))) : :(ZeroTangent())
Expr(:tuple, [:($f=accum($(gradx(f)), $(grady(f)))) for f in fnames]...)
end
@Base.constprop :aggressive accum(a, b, c, args...) = accum(accum(a, b), c, args...)
@Base.constprop :aggressive accum(a::NoTangent, b) = b
@Base.constprop :aggressive accum(a, b::NoTangent) = a
@Base.constprop :aggressive accum(a::NoTangent, b::NoTangent) = NoTangent()
@Base.constprop :aggressive accum(a::AbstractZero, b) = b
@Base.constprop :aggressive accum(a, b::AbstractZero) = a
@Base.constprop :aggressive accum(a::AbstractZero, b::AbstractZero) = NoTangent()

using ChainRulesCore: Tangent, backing

function accum(x::Tangent{T}, y::NamedTuple) where T
# @warn "gradient is both a Tangent and a NamedTuple" x y
_tangent(T, accum(backing(x), y))
end
accum(x::NamedTuple, y::Tangent) = accum(y, x)
# This solves an ambiguity, but also avoids Tangent{ZeroTangent}() which + does not:
accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing(y)))

_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
10 changes: 1 addition & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,7 @@ end
# Make sure that there's no infinite recursion in kwarg calls
g_kw(;x=1.0) = sin(x)
f_kw(x) = g_kw(;x)
@test bwd(f_kw)(1.0) == bwd(sin)(1.0) broken=true
#=
MethodError: no method matching +(::Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}, ::Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}})
...
[2] elementwise_add(a::NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}, b::NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/tangent.jl:287
[3] +(a::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{var"#g_kw#47"{var"#g_kw#11#48"}, NamedTuple{(Symbol("#g_kw#11"),), Tuple{ZeroTangent}}}}}}, b::Tangent{Core.Box, NamedTuple{(:contents,), Tuple{Tangent{Diffractor.KwFunc{var"#g_kw#47"{var"#g_kw#11#48"}, var"#g_kw#47##kw"}, NamedTuple{(:kwf,), Tuple{ZeroTangent}}}}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_arithmetic.jl:130
=#
@test bwd(f_kw)(1.0) == bwd(sin)(1.0)

function f_crit_edge(a, b, c, x)
# A function with two critical edges. This used to trigger an issue where
Expand Down

0 comments on commit fb1d4ec

Please sign in to comment.