Skip to content

Commit

Permalink
Merge pull request #104 from omlins/ad
Browse files Browse the repository at this point in the history
Fix AD for threads
  • Loading branch information
omlins authored Jul 19, 2023
2 parents 6bb0e04 + fb533d2 commit 81ddca9
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
main()
!!! note "Enzyme runtime activity default"
If ParallelStencil is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
To see a description of a function type `?<functionname>`.
"""
module AD
Expand Down
3 changes: 3 additions & 0 deletions src/ParallelKernel/AD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
!!! note "Enzyme runtime activity default"
If ParallelKernel is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
To see a description of a function type `?<functionname>`.
"""
module AD
Expand Down
1 change: 1 addition & 0 deletions src/ParallelKernel/init_parallel_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ function init_parallel_kernel(caller::Module, package::Symbol, numbertype::DataT
pkg_import_cmd = :()
end
ad_import_cmd = :(import ParallelStencil.ParallelKernel.Enzyme)
if (package == PKG_THREADS) ad_import_cmd = :(import ParallelStencil.ParallelKernel.Enzyme; Enzyme.API.runtimeActivity!(true)) end # NOTE: Enzyme requires this currently to work correctly with threads.
if !isdefined(caller, :Data) || (@eval(caller, isa(Data, Module)) && length(symbols(caller, :Data)) == 1) # Only if the module Data does not exist in the caller or is empty, create it.
if (datadoc_call==:())
if (numbertype == NUMBERTYPE_NONE) datadoc_call = :(@doc ParallelStencil.ParallelKernel.DATA_DOC_NUMBERTYPE_NONE Data)
Expand Down
2 changes: 1 addition & 1 deletion test/ParallelKernel/test_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ end
end
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, f!, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
Enzyme.autodiff_deferred(Enzyme.Reverse, g!, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
#@test Array(Ā) ≈ Ā_ref # NOTE: this test does not pass when run with the package manager.
@test Array(Ā) Ā_ref
@test Array(B̄) B̄_ref
end
end
Expand Down

0 comments on commit 81ddca9

Please sign in to comment.