Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add frules for nested conv ad
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 24, 2024
1 parent a6a8820 commit 84a5ca3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.16"
version = "0.3.17"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
54 changes: 25 additions & 29 deletions ext/LuxLibForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,51 @@ end
# Convolutions: We might want to capture these furthur down in `conv!`
# NOTE: In principle we can concatenate all of the partials along the batch dimension
# and cut down substantially on the time to compute jacobians.
for op in [:conv, :depthwiseconv]
# Here we should be broadcasting with `Tag` for safety but that breaks GPU compilation.
for op in [:conv, :depthwiseconv, :∇conv_data, :∇conv_filter]
op! = Symbol("$(op)!")

@eval function NNlib.$(op)(
x::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N}, w::AbstractArray{<:Real, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
x_ = ForwardDiff.value.(x)
@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
x2::AbstractArray{<:Real, N}, cdims::NNlib.ConvDims;
kwargs...) where {N, Tag, V, P}
x1_data = ForwardDiff.value.(x1)

y = NNlib.$(op)(x_, w, cdims; kwargs...)
dys = ntuple(i -> NNlib.$(op)(ForwardDiff.partials.(x, i), w, cdims; kwargs...), P)
y = NNlib.$(op)(x1_data, x2, cdims; kwargs...)
dys = ntuple(
i -> NNlib.$(op)(ForwardDiff.partials.(x1, i), x2, cdims; kwargs...), P)

return map(
(yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)),
y, dys...)
end

@eval function NNlib.$(op)(
x::AbstractArray{<:Real, N}, w::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
@eval function NNlib.$(op)(x1::AbstractArray{<:Real, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, V, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, V, P}
w_ = ForwardDiff.value.(w)
x2_data = ForwardDiff.value.(x2)

y = NNlib.$(op)(x, w_, cdims; kwargs...)
dys = ntuple(i -> NNlib.$(op)(x, ForwardDiff.partials.(w, i), cdims; kwargs...), P)
y = NNlib.$(op)(x1, x2_data, cdims; kwargs...)
dys = ntuple(
i -> NNlib.$(op)(x1, ForwardDiff.partials.(x2, i), cdims; kwargs...), P)

return map(
(yᵢ, dyᵢ...) -> ForwardDiff.Dual{Tag, V, P}(yᵢ, ForwardDiff.Partials(dyᵢ)),
y, dys...)
end

@eval function NNlib.$(op)(x::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N},
w::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N},
@eval function NNlib.$(op)(x1::AbstractArray{<:ForwardDiff.Dual{Tag, Vₓ, P}, N},
x2::AbstractArray{<:ForwardDiff.Dual{Tag, Vₚ, P}, N},
cdims::NNlib.ConvDims; kwargs...) where {N, Tag, Vₓ, Vₚ, P}
x_ = ForwardDiff.value.(x)
w_ = ForwardDiff.value.(w)
x1_data = ForwardDiff.value.(x1)
x2_data = ForwardDiff.value.(x2)

y = NNlib.$(op)(x_, w_, cdims; kwargs...)
y = NNlib.$(op)(x1_data, x2_data, cdims; kwargs...)

dys₁ = ntuple(
_ -> similar(
x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)),
P)
dys₂ = ntuple(
_ -> similar(
x_, Vₓ, NNlib.output_size(cdims)..., NNlib.channels_out(cdims), size(x, N)),
P)
for i in 1:P
NNlib.$(op!)(dys₁[i], ForwardDiff.partials.(x, i), w_, cdims; kwargs...)
NNlib.$(op!)(dys₂[i], x_, ForwardDiff.partials.(w, i), cdims; kwargs...)
dys₁[i] .+= dys₂[i]
dys₁ = ntuple(P) do i
dys₁ᵢ = NNlib.$(op)(ForwardDiff.partials.(x1, i), x2_data, cdims; kwargs...)
dys₂ᵢ = NNlib.$(op)(x1_data, ForwardDiff.partials.(x2, i), cdims; kwargs...)
dys₁ᵢ .+= dys₂ᵢ
return dys₁ᵢ
end

# Technically it should `promote_type(Vₓ, Vₚ)` but this causes GPU compilation
Expand Down
53 changes: 36 additions & 17 deletions test/forwarddiff_tests.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
@testitem "Efficient JVPs" tags=[:nworkers, :others] setup=[SharedTestSetup] begin
using ForwardDiff, Zygote, ComponentArrays

struct LuxLibTestTag end

# Computes (∂f/∂x)u
function jvp_forwarddiff(f, x, u)
function jvp_forwarddiff(f::F, x, u) where {F}
uu = reshape(u, axes(x))
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))),
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(uu)))
y = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x),
1}.(x, ForwardDiff.Partials.(tuple.(uu)))
return vec(ForwardDiff.partials.(vec(f(y)), 1))
end

function jvp_forwarddiff(f, x::ComponentArray, u)
function jvp_forwarddiff(f::F, x::ComponentArray, u) where {F}
xx = getdata(x)
uu = vec(u)
y = ComponentArray(
ForwardDiff.Dual{typeof(ForwardDiff.Tag(LuxLibTestTag(), eltype(x))),
eltype(x), 1}.(xx, ForwardDiff.Partials.(tuple.(uu))),
ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(x))), eltype(x),
1}.(xx, ForwardDiff.Partials.(tuple.(uu))),
getaxes(x))
return vec(ForwardDiff.partials.(vec(f(y)), 1))
end

## This exists exclusively for testing. It has horrifying performance implications
jvp_forwarddiff_concrete(f, x, u) = ForwardDiff.jacobian(f, x) * vec(u)
jvp_zygote(f, x, u) = only(Zygote.jacobian(f, x)) * vec(u)
jvp_forwarddiff_concrete(f::F, x, u) where {F} = ForwardDiff.jacobian(f, x) * vec(u)
jvp_zygote(f::F, x, u) where {F} = only(Zygote.jacobian(f, x)) * vec(u)

function test_jvp_computation(f, x, u, on_gpu)
function test_jvp_computation(f::F, x, u, on_gpu) where {F}
jvp₁ = jvp_forwarddiff(f, x, u)
if !(x isa ComponentArray && on_gpu)
# ComponentArray + ForwardDiff on GPU don't play nice
jvp₂ = jvp_forwarddiff_concrete(f, x, u)
@test check_approx(jvp₁, jvp₂; atol=1e-5, rtol=1e-5)

jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
end

jvp₃ = jvp_zygote(f, x, u)
@test check_approx(jvp₁, jvp₃; atol=1e-5, rtol=1e-5)
end

@testset "$(mode): Jacobian Vector Products" for (mode, aType, on_gpu) in MODES
Expand All @@ -44,10 +42,10 @@
op === depthwiseconv && on_gpu && continue

input_dims = [(2, 4, 2, 1, 3), (4, 4, 1, 3), (4, 4, 3, 2), (4, 1, 3), (4, 3, 2)]
weight_dims = if op === conv
[(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)]
else
weight_dims = if op === depthwiseconv
[(2, 2, 2, 1, 1), (3, 3, 1, 1), (3, 3, 3, 3), (3, 1, 1), (3, 3, 3)]
else
[(2, 2, 2, 1, 4), (3, 3, 1, 4), (3, 3, 3, 2), (3, 1, 4), (3, 3, 2)]
end

@testset "Input Dims: $(in_dims) | Weight Dims: $(w_dims)" for (in_dims, w_dims) in zip(
Expand All @@ -62,6 +60,27 @@
test_jvp_computation(w -> op(x, w; flipped), w, uw, on_gpu)
test_jvp_computation(
xw -> op(xw.x, xw.w; flipped), ComponentArray(; x, w), u, on_gpu)

# Zygote.gradient here is used to test the ∇conv_data and ∇conv_filter
# functions. Also implicitly tests nested AD
test_jvp_computation(
x -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)),
x, ux, on_gpu)
test_jvp_computation(
x -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)),
x, ux, on_gpu)
test_jvp_computation(
w -> only(Zygote.gradient(x -> sum(abs2, op(x, w; flipped)), x)),
w, uw, on_gpu)
test_jvp_computation(
w -> only(Zygote.gradient(w -> sum(abs2, op(x, w; flipped)), w)),
w, uw, on_gpu)
test_jvp_computation(
xw -> only(Zygote.gradient(
xw -> sum(abs2, op(xw.x, xw.w; flipped)), xw)),
ComponentArray(; x, w),
u,
on_gpu)
end
end
end
Expand Down

0 comments on commit 84a5ca3

Please sign in to comment.