Skip to content

Commit 8310c22

Browse files
authored
Merge pull request #38 from ptiede/ptiede-enzymefix
Fix Enzyme implementation for 0.13
2 parents 3be9056 + 386d77f commit 8310c22

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

.github/workflows/CI.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.9' # Replace this with the minimum Julia version that your package supports.
20+
- 'min' # Minimum Julia version that the package supports.
2121
- '1' # Leave this line unchanged. '1' will automatically expand to the latest stable 1.x release of Julia.
2222
os:
2323
- ubuntu-latest
2424
arch:
2525
- x64
2626
steps:
2727
- uses: actions/checkout@v2
28-
- uses: julia-actions/setup-julia@v1
28+
- uses: julia-actions/setup-julia@v2
2929
with:
3030
version: ${{ matrix.version }}
3131
arch: ${{ matrix.arch }}

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ LogDensityProblemsADZygoteExt = "Zygote"
3131
ADTypes = "1.5"
3232
BenchmarkTools = "1"
3333
DocStringExtensions = "0.8, 0.9"
34-
Enzyme = "0.11, 0.12, 0.13"
34+
Enzyme = "0.13.3"
3535
FiniteDifferences = "0.12"
3636
ForwardDiff = "0.10"
3737
LogDensityProblems = "1, 2"
3838
ReverseDiff = "1"
3939
Tracker = "0.2"
4040
Zygote = "0.6"
41-
julia = "1.9"
41+
julia = "1.10"
4242

4343
[extras]
4444
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/LogDensityProblemsADADTypesExt.jl

+7-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@ The AD configuration specified by `ad` is forwarded to the corresponding calls o
2424
"""
2525
LogDensityProblemsAD.ADgradient(::ADTypes.AbstractADType, ℓ)
2626

27-
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoEnzyme, ℓ)
28-
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ)
27+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ)
28+
if ad.mode === nothing
29+
# Use default mode (Enzyme.Reverse)
30+
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ)
31+
else
32+
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ; mode=ad.mode)
33+
end
2934
end
3035

3136
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ) where {C}

ext/LogDensityProblemsADEnzymeExt.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ if isdefined(Base, :get_extension)
88

99
import LogDensityProblemsAD: ADgradient, logdensity_and_gradient
1010
import Enzyme
11+
using Enzyme: EnzymeCore
1112
else
1213
using ..LogDensityProblemsAD: ADGradientWrapper, logdensity
1314

@@ -46,7 +47,7 @@ function ADgradient(::Val{:Enzyme}, ℓ; mode::Enzyme.Mode = Enzyme.Reverse, sha
4647
@info "keyword argument `shadow` is ignored in reverse mode"
4748
shadow = nothing
4849
end
49-
return EnzymeGradientLogDensity(ℓ, mode, shadow)
50+
return EnzymeGradientLogDensity(ℓ, EnzymeCore.WithPrimal(mode), shadow)
5051
end
5152

5253
function Base.show(io::IO, ∇ℓ::EnzymeGradientLogDensity)
@@ -58,17 +59,17 @@ function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme
5859
x::AbstractVector)
5960
(; ℓ, mode, shadow) = ∇ℓ
6061
_shadow = shadow === nothing ? Enzyme.onehot(x) : shadow
61-
y, ∂ℓ_∂x = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated,
62+
∂ℓ_∂x, y = Enzyme.autodiff(mode, logdensity, Enzyme.BatchDuplicated,
6263
Enzyme.Const(ℓ),
6364
Enzyme.BatchDuplicated(x, _shadow))
6465
return y, collect(∂ℓ_∂x)
6566
end
6667

6768
function logdensity_and_gradient(∇ℓ::EnzymeGradientLogDensity{<:Any,<:Enzyme.ReverseMode},
6869
x::AbstractVector)
69-
(; ℓ) = ∇ℓ
70+
(; ℓ, mode) = ∇ℓ
7071
∂ℓ_∂x = zero(x)
71-
_, y = Enzyme.autodiff(Enzyme.ReverseWithPrimal, logdensity, Enzyme.Active,
72+
_, y = Enzyme.autodiff(mode, logdensity, Enzyme.Active,
7273
Enzyme.Const(ℓ), Enzyme.Duplicated(x, ∂ℓ_∂x))
7374
y, ∂ℓ_∂x
7475
end

test/runtests.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ADTypes # load support for AD types with options
77
import BenchmarkTools # load the heuristic chunks code
88
using ComponentArrays: ComponentVector # test with other vector types
99

10-
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false} end
10+
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
1111

1212
####
1313
#### test setup and utilities
@@ -233,13 +233,18 @@ end
233233
= TestLogDensity(test_logdensity1)
234234

235235
∇ℓ_reverse = ADgradient(:Enzyme, ℓ)
236+
∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
236237
@test ∇ℓ_reverse === ADgradient(:Enzyme, ℓ; mode=Enzyme.Reverse)
238+
@test ∇ℓ_reverse.mode === Enzyme.ReverseWithPrimal
239+
@test ADgradient(:Enzyme, ℓ; mode=Enzyme.Reverse) === ADgradient(:Enzyme, ℓ; mode=Enzyme.ReverseWithPrimal)
237240
@test repr(∇ℓ_reverse) == "Enzyme AD wrapper for " * repr(ℓ) * " with reverse mode"
238241

239242
# ADTypes support
240243
@test ADgradient(ADTypes.AutoEnzyme(), ℓ) === ∇ℓ_reverse
241244

242245
∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
246+
@test ADgradient(ADTypes.AutoEnzyme(;mode=Enzyme.Forward), ℓ) === ∇ℓ_forward
247+
@test ADgradient(:Enzyme, ℓ; mode=Enzyme.ForwardWithPrimal) === ADgradient(:Enzyme, ℓ; mode=Enzyme.ForwardWithPrimal)
243248
∇ℓ_forward_shadow = ADgradient(:Enzyme, ℓ;
244249
mode=Enzyme.Forward,
245250
shadow=Enzyme.onehot(Vector{Float64}(undef, dimension(ℓ))))

0 commit comments

Comments
 (0)