Skip to content

Commit

Permalink
fix DI
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed May 21, 2024
1 parent ea9a94b commit 4fed861
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
f::Lux.StatefulLuxLayer,
xs::AbstractMatrix{<:Real},
) where {T}
y, VJ = DifferentiationInterface.value_and_pullback_split(f, icnf.autodiff_backend, xs)
y = f(xs)

Check warning on line 6 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L6

Added line #L6 was not covered by tests
z = similar(xs)
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
for i in axes(xs, 1)
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
res[i, :, :] = VJ(z)
res[i, :, :] = DifferentiationInterface.pullback(f, icnf.autodiff_backend, xs, z)

Check warning on line 12 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L12

Added line #L12 was not covered by tests
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
end
y, eachslice(copy(res); dims = 3)
Expand Down

0 comments on commit 4fed861

Please sign in to comment.