-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
chain.jl
62 lines (51 loc) · 2.16 KB
/
chain.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import Flux: ChainRulesCore
# Some experiments with chain to start removing the need for recur to be mutable.
# As per the conversation in the recurrent network rework issue.
# Main difference between this and the _applychain function is we return a new chain
# with the internal state modified as well as the output of applying x to the chain.
function apply(chain::Flux.Chain, x)
layers, out = _apply(chain.layers, x)
Flux.Chain(layers), out
end
function _apply(layers::NamedTuple{NMS, TPS}, x) where {NMS, TPS}
layers, out = _apply(Tuple(layers), x)
NamedTuple{NMS}(layers), out
end
function _scan(layers::AbstractVector, x)
new_layers = typeof(layers)(undef, length(layers))
for (idx, f) in enumerate(layers)
new_layers[idx], x = _apply(f, x)
end
new_layers, x
end
# Reverse rule for _scan
# example pulled from https://github.com/mcabbott/Flux.jl/blob/chain_rrule/src/cuda/cuda.jl
function ChainRulesCore.rrule(cfg::ChainRulesCore.RuleConfig, ::typeof(_scan), layers, x)
duo = accumulate(layers; init=((nothing, x), nothing)) do ((pl, input), _), cur_layer
out, back = ChainRulesCore.rrule_via_ad(cfg, _apply, cur_layer, input)
end
outs = map(first, duo)
backs = map(last, duo)
function _scan_pullback(dy)
multi = accumulate(reverse(backs); init=(nothing, dy)) do (_, delta), back
dapply, dlayer, din = back(delta)
return dapply, (dlayer, din)
end
layergrads = reverse(map(first, multi))
xgrad = last(multi[end])
return (ChainRulesCore.NoTangent(), layergrads, xgrad)
end
return (map(first, outs), last(outs[end])), _scan_pullback
end
function _apply(layers::AbstractVector, x) # type-unstable path, helps compile times
_scan(layers, x)
end
# Generated function returns a tuple of args and the last output of the network.
@generated function _apply(layers::Tuple{Vararg{<:Any,N}}, x) where {N}
x_symbols = vcat(:x, [gensym() for _ in 1:N])
l_symbols = [gensym() for _ in 1:N]
calls = [:(($(l_symbols[i]), $(x_symbols[i+1])) = _apply(layers[$i], $(x_symbols[i]))) for i in 1:N]
push!(calls, :(return tuple($(l_symbols...)), $(x_symbols[end])))
Expr(:block, calls...)
end
_apply(layer, x) = layer, layer(x)