From e690242474a5360f74e6c575bcd8d2f40ba79888 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:54:24 +0200 Subject: [PATCH 1/4] Speed up Jacobian and Hessian with mapreduce --- DifferentiationInterface/src/first_order/jacobian.jl | 8 ++------ DifferentiationInterface/src/second_order/hessian.jl | 7 +------ 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 3984dd49f..be4d37072 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -232,7 +232,7 @@ function _jacobian_aux( f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(hcat, eachindex(batched_seeds)) do a dy_batch = pushforward( f_or_f!y..., pushforward_prep_same, @@ -247,8 +247,6 @@ function _jacobian_aux( end block end - - jac = reduce(hcat, jac_blocks) return jac end @@ -265,7 +263,7 @@ function _jacobian_aux( f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... ) - jac_blocks = map(eachindex(batched_seeds)) do a + jac = mapreduce(vcat, eachindex(batched_seeds)) do a dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) @@ -275,8 +273,6 @@ function _jacobian_aux( end block end - - jac = reduce(vcat, jac_blocks) return jac end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index f8e612fd5..8b12db6f4 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess_blocks = map(eachindex(batched_seeds)) do a + hess = mapreduce(hcat, eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack(vec, dg_batch; dims=2) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,11 +119,6 @@ function hessian( end block end - - hess = reduce(hcat, hess_blocks) - if N < size(hess, 2) - hess = hess[:, 1:N] - end return hess end From 21a28f62b8f577e6f4c54674ae150f4f48581a3c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 11:56:07 +0200 Subject: [PATCH 2/4] Bump version --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index ce7502c68..41d30a3ab 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.9" +version = "0.6.10" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From a033d337e47d69053ce0b44f68e29740e57b3248 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:11:58 +0200 Subject: [PATCH 3/4] Revert --- DifferentiationInterface/src/first_order/jacobian.jl | 8 ++++++-- DifferentiationInterface/src/second_order/hessian.jl | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index be4d37072..3984dd49f 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -232,7 +232,7 @@ function _jacobian_aux( f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... ) - jac = mapreduce(hcat, eachindex(batched_seeds)) do a + jac_blocks = map(eachindex(batched_seeds)) do a dy_batch = pushforward( f_or_f!y..., pushforward_prep_same, @@ -247,6 +247,8 @@ function _jacobian_aux( end block end + + jac = reduce(hcat, jac_blocks) return jac end @@ -263,7 +265,7 @@ function _jacobian_aux( f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... ) - jac = mapreduce(vcat, eachindex(batched_seeds)) do a + jac_blocks = map(eachindex(batched_seeds)) do a dx_batch = pullback( f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts... ) @@ -273,6 +275,8 @@ function _jacobian_aux( end block end + + jac = reduce(vcat, jac_blocks) return jac end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 8b12db6f4..a40c2efd5 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -111,7 +111,7 @@ function hessian( f, hvp_prep, backend, x, batched_seeds[1], contexts... ) - hess = mapreduce(hcat, eachindex(batched_seeds)) do a + hess_blocks = map(eachindex(batched_seeds)) do a dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...) block = stack(vec, dg_batch; dims=2) if N % B != 0 && a == lastindex(batched_seeds) @@ -119,6 +119,8 @@ function hessian( end block end + + hess = reduce(hcat, hess_blocks) return hess end From a9d3c040919b79c1c9ff86ba7317ef50d84468e3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Oct 2024 13:21:34 +0200 Subject: [PATCH 4/4] Unbump --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 41d30a3ab..ce7502c68 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.10" +version = "0.6.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"