Skip to content

Commit bb60399

Browse files
authored
Merge pull request #176 from omlins/ad
Improve AD module
2 parents c5a945a + fe0c955 commit bb60399

File tree

5 files changed

+12
-17
lines changed

5 files changed

+12
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ParallelStencil_MetalExt = "Metal"
2626
AMDGPU = "0.6, 0.7, 0.8, 0.9, 1"
2727
CUDA = "3.12, 4, 5"
2828
CellArrays = "0.3"
29-
Enzyme = "0.11, 0.12, 0.13"
29+
Enzyme = "0.12, 0.13"
3030
MacroTools = "0.5"
3131
Metal = "1.2"
3232
Polyester = "0.7"

src/AD.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
77
import ParallelStencil.AD
88
99
# Functions
10-
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
11-
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
10+
- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
11+
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
1212
1313
# Examples
1414
const USE_GPU = true
@@ -43,9 +43,6 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
4343
4444
main()
4545
46-
!!! note "Enzyme runtime activity default"
47-
If ParallelStencil is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
48-
4946
To see a description of a function type `?<functionname>`.
5047
"""
5148
module AD

src/ParallelKernel/EnzymeExt/AD.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the
77
import ParallelKernel.AD
88
99
# Functions
10-
- `autodiff_deferred!`: wraps function `autodiff_deferred`.
11-
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`.
12-
13-
!!! note "Enzyme runtime activity default"
14-
If ParallelKernel is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil.
10+
- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
11+
- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const.
1512
1613
To see a description of a function type `?<functionname>`.
1714
"""

src/ParallelKernel/EnzymeExt/autodiff_gpu.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ import ParallelStencil
22
import ParallelStencil: PKG_THREADS, PKG_POLYESTER
33
import Enzyme
44

5+
# NOTE: package specific initialization of Enzyme could be done as follows (not needed in the currently supported versions of Enzyme)
56
# function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol)
67
# if iscpu(package)
78
# Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads
89
# end
910
# end
1011

11-
# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation
12-
# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user
13-
# provided as well. This is needed to support @parallel autodiff_deferred(...)
14-
function promote_to_const(args::Vararg{Any,N}) where N
12+
# NOTE: @parallel injects four parameters at the end, which need to be wrapped as Annotations. The current solution is to wrap all
13+
# arguments which are not already Annotations (all the other arguments must be Annotations). Should this change, then one could
14+
# explicitly wrap just the injected parameters.
15+
function promote_to_const(args::Vararg{Any,N}) where N
1516
ntuple(Val(N)) do i
1617
@inline
1718
if !(args[i] isa Enzyme.Annotation ||

test/ParallelKernel/test_parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ eval(:(
133133
end
134134
return
135135
end
136-
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a))
137-
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
136+
@parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, f!, Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) # NOTE: f! is automatically promoted to Const.
137+
Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a))
138138
@test Array(Ā) Ā_ref
139139
@test Array(B̄) B̄_ref
140140
end

0 commit comments

Comments
 (0)