Skip to content

Commit 09fedcf

Browse files
authored
Merge pull request #39 from gdalle/gd/di_new
Use DI for non-implemented ADTypes
2 parents 4886546 + b4be1be commit 09fedcf

5 files changed

+114
-13
lines changed

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1010
[weakdeps]
1111
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
1212
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
13+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1314
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1415
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1516
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1920

2021
[extensions]
2122
LogDensityProblemsADADTypesExt = "ADTypes"
23+
LogDensityProblemsADDifferentiationInterfaceExt = ["ADTypes", "DifferentiationInterface"]
2224
LogDensityProblemsADEnzymeExt = "Enzyme"
2325
LogDensityProblemsADFiniteDifferencesExt = "FiniteDifferences"
2426
LogDensityProblemsADForwardDiffBenchmarkToolsExt = ["BenchmarkTools", "ForwardDiff"]
@@ -29,6 +31,7 @@ LogDensityProblemsADZygoteExt = "Zygote"
2931

3032
[compat]
3133
ADTypes = "1.5"
34+
DifferentiationInterface = "0.6.1"
3235
BenchmarkTools = "1"
3336
DocStringExtensions = "0.8, 0.9"
3437
Enzyme = "0.13.3"
@@ -44,6 +47,7 @@ julia = "1.10"
4447
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4548
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
4649
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
50+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4751
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4852
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
4953
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -54,4 +58,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5458
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5559

5660
[targets]
57-
test = ["ADTypes", "BenchmarkTools", "ComponentArrays", "Enzyme", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]
61+
test = ["ADTypes", "BenchmarkTools", "ComponentArrays", "DifferentiationInterface", "Enzyme", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"]

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,6 @@ x = zeros(LogDensityProblems.dimension(ℓ)) # ℓ is your log density
4646

4747
5. [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl) Finite differences are very robust, with a small numerical error, but usually not fast enough to practically replace AD on nontrivial problems. The backend in this package is mainly intended for checking and debugging results from other backends; but note that in most cases ForwardDiff is faster and more accurate.
4848

49-
PRs for other AD frameworks are welcome, even if they are WIP.
49+
Other AD frameworks are supported thanks to [ADTypes.jl](https://github.com/SciML/ADTypes.jl) and [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl).
50+
51+
PRs for remaining AD frameworks are welcome, even if they are WIP.

ext/LogDensityProblemsADADTypesExt.jl

+24-11
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ else
99
end
1010

1111
"""
12-
ADgradient(ad::ADTypes.AbstractADType, ℓ)
12+
ADgradient(ad::ADTypes.AbstractADType, ℓ; x::Union{Nothing,AbstractVector}=nothing)
1313
1414
Wrap log density `ℓ` using automatic differentiation (AD) of type `ad` to obtain a gradient.
1515
@@ -19,12 +19,19 @@ Currently,
1919
- `ad::ADTypes.AutoReverseDiff`
2020
- `ad::ADTypes.AutoTracker`
2121
- `ad::ADTypes.AutoZygote`
22-
are supported.
23-
The AD configuration specified by `ad` is forwarded to the corresponding calls of `ADgradient(Val(...), ℓ)`.
22+
are supported with custom implementations.
23+
The AD configuration specified by `ad` is forwarded to the corresponding calls of `ADgradient(Val(...), ℓ)`.
24+
25+
Passing `x` as a keyword argument means that the gradient operator will be "prepared" for the specific type and size of the array `x`. This can speed up further evaluations on similar inputs, but will likely cause errors if the new inputs have a different type or size. With `AutoReverseDiff`, it can also yield incorrect results if the logdensity contains value-dependent control flow.
26+
27+
If you want to use another backend from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) which is not in the list above, you need to load [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) first.
2428
"""
2529
LogDensityProblemsAD.ADgradient(::ADTypes.AbstractADType, ℓ)
2630

27-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ)
31+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ; x::Union{Nothing,AbstractVector}=nothing)
32+
if x !== nothing
33+
@warn "`ADgradient`: Keyword argument `x` is ignored"
34+
end
2835
if ad.mode === nothing
2936
# Use default mode (Enzyme.Reverse)
3037
return LogDensityProblemsAD.ADgradient(Val(:Enzyme), ℓ)
@@ -33,25 +40,31 @@ function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoEnzyme, ℓ)
3340
end
3441
end
3542

36-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ) where {C}
43+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {C}
3744
if C === nothing
3845
# Use default chunk size
39-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; tag = ad.tag)
46+
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; tag = ad.tag, x=x)
4047
else
41-
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk = C, tag = ad.tag)
48+
return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk = C, tag = ad.tag, x=x)
4249
end
4350
end
4451

45-
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ) where {T}
46-
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T))
52+
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ; x::Union{Nothing,AbstractVector}=nothing) where {T}
53+
return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T), x=x)
4754
end
4855

49-
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ)
56+
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ; x::Union{Nothing,AbstractVector}=nothing)
57+
if x !== nothing
58+
@warn "`ADgradient`: Keyword argument `x` is ignored"
59+
end
5060
return LogDensityProblemsAD.ADgradient(Val(:Tracker), ℓ)
5161
end
5262

5363

54-
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ)
64+
function LogDensityProblemsAD.ADgradient(::ADTypes.AutoZygote, ℓ; x::Union{Nothing,AbstractVector}=nothing)
65+
if x !== nothing
66+
@warn "`ADgradient`: Keyword argument `x` is ignored"
67+
end
5568
return LogDensityProblemsAD.ADgradient(Val(:Zygote), ℓ)
5669
end
5770

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
module LogDensityProblemsADDifferentiationInterfaceExt
2+
3+
import LogDensityProblemsAD
4+
import ADTypes
5+
import DifferentiationInterface as DI
6+
7+
"""
8+
DIGradient <: LogDensityProblemsAD.ADGradientWrapper
9+
10+
Gradient wrapper which uses [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)
11+
12+
# Fields
13+
14+
- `backend::AbstractADType`: one of the autodiff backend types defined in [ADTypes.jl](https://github.com/SciML/ADTypes.jl), for example `ADTypes.AutoForwardDiff()`
15+
- `prep`: either `nothing` or the output of `DifferentiationInterface.prepare_gradient` applied to the logdensity and the provided input
16+
- `ℓ`: logdensity function, amenable to `LogDensityProblemsAD.logdensity(ℓ, x)`
17+
"""
18+
struct DIGradient{B<:ADTypes.AbstractADType,P,L} <: LogDensityProblemsAD.ADGradientWrapper
19+
backend::B
20+
prep::P
21+
::L
22+
end
23+
24+
function logdensity_switched(x, ℓ)
25+
# active argument must come first in DI
26+
return LogDensityProblemsAD.logdensity(ℓ, x)
27+
end
28+
29+
function LogDensityProblemsAD.ADgradient(backend::ADTypes.AbstractADType, ℓ; x::Union{Nothing,AbstractVector}=nothing)
30+
if x === nothing
31+
prep = nothing
32+
else
33+
prep = DI.prepare_gradient(logdensity_switched, backend, x, DI.Constant(ℓ))
34+
end
35+
return DIGradient(backend, prep, ℓ)
36+
end
37+
38+
function LogDensityProblemsAD.logdensity_and_gradient(∇ℓ::DIGradient, x::AbstractVector)
39+
(; backend, prep, ℓ) = ∇ℓ
40+
if prep === nothing
41+
return DI.value_and_gradient(logdensity_switched, backend, x, DI.Constant(ℓ))
42+
else
43+
return DI.value_and_gradient(logdensity_switched, prep, backend, x, DI.Constant(ℓ))
44+
end
45+
end
46+
47+
end

test/runtests.jl

+35
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import FiniteDifferences, ForwardDiff, Enzyme, Tracker, Zygote, ReverseDiff # ba
66
import ADTypes # load support for AD types with options
77
import BenchmarkTools # load the heuristic chunks code
88
using ComponentArrays: ComponentVector # test with other vector types
9+
import DifferentiationInterface
910

1011
struct EnzymeTestMode <: Enzyme.Mode{Enzyme.DefaultABI, false, false} end
1112

@@ -91,6 +92,9 @@ ForwardDiff.checktag(::Type{ForwardDiff.Tag{TestTag, V}}, ::Base.Fix1{typeof(log
9192

9293
# ADTypes support
9394
@test typeof(ADgradient(ADTypes.AutoReverseDiff(; compile = Val(true)), ℓ)) === typeof(∇ℓ_compile)
95+
@test typeof(ADgradient(ADTypes.AutoReverseDiff(; compile = Val(true)), ℓ; x=rand(3))) === typeof(∇ℓ_compile_x)
96+
@test nameof(typeof(ADgradient(ADTypes.AutoReverseDiff(), ℓ))) !== :DIGradient
97+
@test nameof(typeof(ADgradient(ADTypes.AutoReverseDiff(), ℓ; x=rand(3)))) !== :DIGradient
9498

9599
for ∇ℓ in (∇ℓ_default, ∇ℓ_nocompile, ∇ℓ_compile, ∇ℓ_compile_x)
96100
@test dimension(∇ℓ) == 3
@@ -127,6 +131,8 @@ end
127131

128132
# ADTypes support
129133
@test ADgradient(ADTypes.AutoForwardDiff(), ℓ) === ∇ℓ
134+
@test nameof(typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ))) !== :DIGradient
135+
@test nameof(typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ; x=rand(3)))) !== :DIGradient
130136

131137
for _ in 1:100
132138
x = randn(3)
@@ -175,6 +181,7 @@ end
175181
# ADTypes support
176182
@test ADgradient(ADTypes.AutoForwardDiff(; chunksize = 3), ℓ) === ADgradient(:ForwardDiff, ℓ; chunk = 3)
177183
@test ADgradient(ADTypes.AutoForwardDiff(; chunksize = 3, tag = TestTag()), ℓ) === ADgradient(:ForwardDiff, ℓ; chunk = 3, tag = TestTag())
184+
@test typeof(ADgradient(ADTypes.AutoForwardDiff(), ℓ; x=rand(3))) == typeof(ADgradient(:ForwardDiff, ℓ; x=rand(3)))
178185
end
179186

180187
@testset "component vectors" begin
@@ -211,6 +218,8 @@ end
211218

212219
# ADTypes support
213220
@test ADgradient(ADTypes.AutoTracker(), ℓ) === ∇ℓ
221+
@test nameof(typeof(ADgradient(ADTypes.AutoTracker(), ℓ))) !== :DIGradient
222+
@test nameof(typeof(ADgradient(ADTypes.AutoTracker(), ℓ; x=rand(3)))) !== :DIGradient
214223
end
215224

216225
@testset "AD via Zygote" begin
@@ -227,6 +236,8 @@ end
227236

228237
# ADTypes support
229238
@test ADgradient(ADTypes.AutoZygote(), ℓ) === ∇ℓ
239+
@test nameof(typeof(ADgradient(ADTypes.AutoZygote(), ℓ))) !== :DIGradient
240+
@test nameof(typeof(ADgradient(ADTypes.AutoZygote(), ℓ; x=rand(3)))) !== :DIGradient
230241
end
231242

232243
@testset "AD via Enzyme" begin
@@ -241,6 +252,8 @@ end
241252

242253
# ADTypes support
243254
@test ADgradient(ADTypes.AutoEnzyme(), ℓ) === ∇ℓ_reverse
255+
@test nameof(typeof(ADgradient(ADTypes.AutoEnzyme(), ℓ))) !== :DIGradient
256+
@test nameof(typeof(ADgradient(ADTypes.AutoEnzyme(), ℓ; x=rand(3)))) !== :DIGradient
244257

245258
∇ℓ_forward = ADgradient(:Enzyme, ℓ; mode=Enzyme.Forward)
246259
@test ADgradient(ADTypes.AutoEnzyme(;mode=Enzyme.Forward), ℓ) === ∇ℓ_forward
@@ -291,3 +304,25 @@ end
291304
@test b isa Vector{Pair{Int,Float64}}
292305
@test length(b) 20
293306
end
307+
308+
@testset verbose=true "DifferentiationInterface for unsupported ADTypes" begin
309+
= TestLogDensity(test_logdensity1)
310+
backends = [
311+
ADTypes.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1)),
312+
]
313+
∇ℓ_candidates = []
314+
for backend in backends
315+
push!(∇ℓ_candidates, ADgradient(backend, ℓ))
316+
push!(∇ℓ_candidates, ADgradient(backend, ℓ; x=zeros(3)))
317+
end
318+
@testset "$(typeof(∇ℓ))" for ∇ℓ in ∇ℓ_candidates
319+
@test nameof(typeof(∇ℓ)) == :DIGradient
320+
@test dimension(∇ℓ) == 3
321+
@test capabilities(∇ℓ) LogDensityOrder(1)
322+
for _ in 1:100
323+
x = randn(3)
324+
@test @inferred(logdensity(∇ℓ, x)) test_logdensity1(x)
325+
@test logdensity_and_gradient(∇ℓ, x) (test_logdensity1(x), test_gradient(x)) atol = 1e-5
326+
end
327+
end
328+
end;

0 commit comments

Comments
 (0)