From bb18173a054736492b329b54d245734f48bf548c Mon Sep 17 00:00:00 2001 From: Christian Date: Mon, 29 Aug 2022 14:10:23 -0300 Subject: [PATCH 1/2] Make params non_differentiable --- src/functor.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/functor.jl b/src/functor.jl index d05489104f..dc9ac4113c 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -88,6 +88,9 @@ function params(m...) return ps end +# Allows caching of the parameters when params is called within gradient() +@non_differentiable params(m...) + struct FluxCUDAAdaptor end adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x)) From f5793b5a18eb156f897210268bd414a2b8df5fc3 Mon Sep 17 00:00:00 2001 From: Christian Date: Mon, 29 Aug 2022 21:19:02 -0300 Subject: [PATCH 2/2] Add mention of issue #2040 --- src/functor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/functor.jl b/src/functor.jl index dc9ac4113c..bfa075a6b8 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -88,7 +88,7 @@ function params(m...) return ps end -# Allows caching of the parameters when params is called within gradient() +# Allows caching of the parameters when params is called within gradient() to fix #2040. @non_differentiable params(m...) struct FluxCUDAAdaptor end