Skip to content

Commit

Permalink
Faster Nested AD
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent 29e971e commit bb61c5f
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 10 deletions.
38 changes: 34 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[deps.Lux]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"]
git-tree-sha1 = "d7f49df9abfbb372fcbde5f41e547aa3679e9793"
git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43"
repo-rev = "ap/nested_ad"
repo-url = "https://github.com/LuxDL/Lux.jl.git"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down Expand Up @@ -573,12 +573,13 @@ version = "0.1.20"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.LuxLib]]
deps = ["ChainRulesCore", "FastClosures", "KernelAbstractions", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics"]
git-tree-sha1 = "b1f81a8aa8313c1f1b4cbfb18733db17c023427e"
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
version = "0.3.14"
version = "0.3.15"

[deps.LuxLib.extensions]
LuxLibAMDGPUExt = "AMDGPU"
LuxLibForwardDiffExt = "ForwardDiff"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
Expand Down Expand Up @@ -684,6 +685,12 @@ git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.6.3"

[[deps.PackageExtensionCompat]]
git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518"
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
version = "1.0.2"
weakdeps = ["Requires", "TOML"]

[[deps.Parameters]]
deps = ["OrderedCollections", "UnPack"]
git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe"
Expand Down Expand Up @@ -927,6 +934,24 @@ git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.5.6"

[[deps.Strided]]
deps = ["LinearAlgebra", "StridedViews", "TupleTools"]
git-tree-sha1 = "40c69be0e1b72ee2f42923b7d1ff13e0b04e675c"
uuid = "5e0ebb24-38b0-5f93-81fe-25c709ecae67"
version = "2.0.4"

[[deps.StridedViews]]
deps = ["LinearAlgebra", "PackageExtensionCompat"]
git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e"
uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
version = "0.2.2"

[deps.StridedViews.extensions]
StridedViewsCUDAExt = "CUDA"

[deps.StridedViews.weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Expand Down Expand Up @@ -985,6 +1010,11 @@ git-tree-sha1 = "ea3e54c2bdde39062abf5a9758a23735558705e1"
uuid = "781d530d-4396-4725-bb49-402e4bee1e77"
version = "1.4.0"

[[deps.TupleTools]]
git-tree-sha1 = "41d61b1c545b06279871ef1a4b5fcb2cac2191cd"
uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
version = "1.5.0"

[[deps.UUIDs]]
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DeepEquilibriumNetworksSciMLSensitivityExt = ["LinearSolve", "SciMLSensitivity"]
DeepEquilibriumNetworksZygoteExt = "Zygote"
DeepEquilibriumNetworksZygoteExt = ["ForwardDiff", "Zygote"]

[compat]
ADTypes = "0.2.5, 1"
Expand All @@ -38,6 +39,7 @@ ConstructionBase = "1"
DiffEqBase = "6.119"
ExplicitImports = "1.4.1"
FastClosures = "0.3"
ForwardDiff = "0.10.36"
Functors = "0.4.10"
LinearSolve = "2.21.2"
Lux = "0.5.37"
Expand Down
49 changes: 44 additions & 5 deletions ext/DeepEquilibriumNetworksZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,58 @@
module DeepEquilibriumNetworksZygoteExt

using ADTypes: AutoZygote
using ChainRulesCore: ChainRulesCore
using DeepEquilibriumNetworks: DEQs
using FastClosures: @closure
using ForwardDiff: ForwardDiff # This is a dependency of Zygote
using Lux: Lux, StatefulLuxLayer
using Statistics: mean
using Zygote: Zygote
using DeepEquilibriumNetworks: DEQs

@inline __tupleify(u) = @closure x -> (u, x)
const CRC = ChainRulesCore

@inline __tupleify(x) = @closure(u->(u, x))

## One day we will overload DI's APIs for Lux Layers and we can remove this
## Main challenge with overloading Zygote.pullback is that we need to return the correct
## tangent for the pullback to compute the correct gradient, which is quite hard. But
## wrapping the overall vjp is not that hard.
@inline function __compute_vector_jacobian_product(model::StatefulLuxLayer, ps, z, x, rng)
res, back = Zygote.pullback(model __tupleify(x), z)
return only(back(DEQs.__gaussian_like(rng, res)))
end

function CRC.rrule(
::typeof(__compute_vector_jacobian_product), model::StatefulLuxLayer, ps, z, x, rng)
res, back = Zygote.pullback(model __tupleify(x), z)
ε = DEQs.__gaussian_like(rng, res)
y = only(back(ε))
∇internal_gradient_capture = Δ -> begin
isa CRC.NoTangent || Δ isa CRC.ZeroTangent) &&
return ntuple(Returns(CRC.NoTangent()), 6)

Δ_ = reshape(CRC.unthunk(Δ), size(z))

Tag = typeof(ForwardDiff.Tag(model, eltype(z)))
partials = ForwardDiff.Partials{1, eltype(z)}.(tuple.(Δ_))
z_dual = ForwardDiff.Dual{Tag, eltype(z), 1}.(z, partials)

_, pb_f = Zygote.pullback((x1, x2, p) -> model((x1, x2), p), z_dual, x, ps)
∂z_duals, ∂x_duals, ∂ps_duals = pb_f(ε)

∂z = Lux.__partials(Tag, ∂z_duals, 1)
∂x = Lux.__partials(Tag, ∂x_duals, 1)
∂ps = Lux.__partials(Tag, ∂ps_duals, 1)

return CRC.NoTangent(), CRC.NoTangent(), ∂ps, ∂z, ∂x, CRC.NoTangent()
end
return y, ∇internal_gradient_capture
end

## Don't remove `ad`. See https://github.com/ericphanson/ExplicitImports.jl/issues/33
## FIXME: This will be broken in the new Lux release let's fix this
function DEQs.__estimate_jacobian_trace(ad::AutoZygote, model, z, x, rng)
res, back = Zygote.pullback(model __tupleify, z)
vjp_z = only(back(DEQs.__gaussian_like(rng, res)))
return mean(abs2, vjp_z)
return mean(abs2, __compute_vector_jacobian_product(model, model.ps, z, x, rng))
end

end

0 comments on commit bb61c5f

Please sign in to comment.