Skip to content

Commit d0efe02

Browse files
authored
Migrate to DifferentiationInterface (#98)
* migrate to DifferentiationInterface * run formatter * tighten compat bound for ADTypes * fix compat bound for docs
1 parent 4eab1ac commit d0efe02

22 files changed

+97
-263
lines changed

.buildkite/pipeline.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
steps:
2+
- label: "CUDA with julia {{matrix.julia}}"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "{{matrix.julia}}"
6+
- JuliaCI/julia-test#v1:
7+
test_args: "--quickfail"
8+
agents:
9+
queue: "juliagpu"
10+
cuda: "*"
11+
timeout_in_minutes: 60
12+
env:
13+
GROUP: "GPU"
14+
ADVANCEDVI_TEST_CUDA: "true"
15+
matrix:
16+
setup:
17+
julia:
18+
- "1.10"

Project.toml

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
10+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1112
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1213
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -24,50 +25,47 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2425
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
2526
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2627
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
28+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2729
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
28-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
2930
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3031

3132
[extensions]
3233
AdvancedVIBijectorsExt = "Bijectors"
3334
AdvancedVIEnzymeExt = "Enzyme"
34-
AdvancedVIForwardDiffExt = "ForwardDiff"
35-
AdvancedVIReverseDiffExt = "ReverseDiff"
36-
AdvancedVITapirExt = "Tapir"
37-
AdvancedVIZygoteExt = "Zygote"
3835

3936
[compat]
40-
ADTypes = "0.1, 0.2, 1"
37+
ADTypes = "1"
4138
Accessors = "0.1"
4239
Bijectors = "0.13"
4340
ChainRulesCore = "1.16"
4441
DiffResults = "1"
42+
DifferentiationInterface = "0.6"
4543
Distributions = "0.25.111"
4644
DocStringExtensions = "0.8, 0.9"
4745
Enzyme = "0.13"
4846
FillArrays = "1.3"
49-
ForwardDiff = "0.10.36"
47+
ForwardDiff = "0.10"
5048
Functors = "0.4"
5149
LinearAlgebra = "1"
5250
LogDensityProblems = "2"
51+
Mooncake = "0.4"
5352
Optimisers = "0.2.16, 0.3"
5453
ProgressMeter = "1.6"
5554
Random = "1"
5655
Requires = "1.0"
57-
ReverseDiff = "1.15.1"
56+
ReverseDiff = "1"
5857
SimpleUnPack = "1.1.0"
5958
StatsBase = "0.32, 0.33, 0.34"
60-
Tapir = "0.2"
61-
Zygote = "0.6.63"
59+
Zygote = "0.6"
6260
julia = "1.7"
6361

6462
[extras]
6563
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
6664
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6765
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
6867
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
6968
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
70-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
7169
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7270
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7371

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
1414
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1515

1616
[compat]
17-
ADTypes = "0.1.6"
17+
ADTypes = "1"
1818
AdvancedVI = "0.3"
1919
Bijectors = "0.13.6"
2020
Distributions = "0.25"

ext/AdvancedVIEnzymeExt.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
module AdvancedVIEnzymeExt
32

43
if isdefined(Base, :get_extension)
@@ -15,21 +14,6 @@ function AdvancedVI.restructure_ad_forward(::ADTypes.AutoEnzyme, restructure, pa
1514
return restructure(params)::typeof(restructure.model)
1615
end
1716

18-
function AdvancedVI.value_and_gradient!(
19-
::ADTypes.AutoEnzyme, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
20-
)
21-
∇x = DiffResults.gradient(out)
22-
fill!(∇x, zero(eltype(∇x)))
23-
_, y = Enzyme.autodiff(
24-
Enzyme.set_runtime_activity(Enzyme.ReverseWithPrimal, true),
25-
Enzyme.Const(f),
26-
Enzyme.Active,
27-
Enzyme.Duplicated(x, ∇x),
28-
)
29-
DiffResults.value!(out, y)
30-
return out
31-
end
32-
3317
function AdvancedVI.value_and_gradient!(
3418
::ADTypes.AutoEnzyme,
3519
f,

ext/AdvancedVIForwardDiffExt.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +0,0 @@
1-
2-
module AdvancedVIForwardDiffExt
3-
4-
if isdefined(Base, :get_extension)
5-
using ForwardDiff
6-
using AdvancedVI
7-
using AdvancedVI: ADTypes, DiffResults
8-
else
9-
using ..ForwardDiff
10-
using ..AdvancedVI
11-
using ..AdvancedVI: ADTypes, DiffResults
12-
end
13-
14-
getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize
15-
16-
function AdvancedVI.value_and_gradient!(
17-
ad::ADTypes.AutoForwardDiff,
18-
f,
19-
x::AbstractVector{<:Real},
20-
out::DiffResults.MutableDiffResult,
21-
)
22-
chunk_size = getchunksize(ad)
23-
config = if isnothing(chunk_size)
24-
ForwardDiff.GradientConfig(f, x)
25-
else
26-
ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size))
27-
end
28-
ForwardDiff.gradient!(out, f, x, config)
29-
return out
30-
end
31-
32-
function AdvancedVI.value_and_gradient!(
33-
ad::ADTypes.AutoForwardDiff,
34-
f,
35-
x::AbstractVector,
36-
aux,
37-
out::DiffResults.MutableDiffResult,
38-
)
39-
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
40-
end
41-
42-
end

ext/AdvancedVIMooncakeExt.jl

Whitespace-only changes.

ext/AdvancedVIReverseDiffExt.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +0,0 @@
1-
2-
module AdvancedVIReverseDiffExt
3-
4-
if isdefined(Base, :get_extension)
5-
using AdvancedVI
6-
using AdvancedVI: ADTypes, DiffResults
7-
using ReverseDiff
8-
else
9-
using ..AdvancedVI
10-
using ..AdvancedVI: ADTypes, DiffResults
11-
using ..ReverseDiff
12-
end
13-
14-
# ReverseDiff without compiled tape
15-
function AdvancedVI.value_and_gradient!(
16-
ad::ADTypes.AutoReverseDiff,
17-
f,
18-
x::AbstractVector{<:Real},
19-
out::DiffResults.MutableDiffResult,
20-
)
21-
tp = ReverseDiff.GradientTape(f, x)
22-
ReverseDiff.gradient!(out, tp, x)
23-
return out
24-
end
25-
26-
function AdvancedVI.value_and_gradient!(
27-
ad::ADTypes.AutoReverseDiff,
28-
f,
29-
x::AbstractVector{<:Real},
30-
aux,
31-
out::DiffResults.MutableDiffResult,
32-
)
33-
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
34-
end
35-
36-
end

ext/AdvancedVITapirExt.jl

Lines changed: 0 additions & 37 deletions
This file was deleted.

ext/AdvancedVIZygoteExt.jl

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +0,0 @@
1-
2-
module AdvancedVIZygoteExt
3-
4-
if isdefined(Base, :get_extension)
5-
using AdvancedVI
6-
using AdvancedVI: ADTypes, DiffResults
7-
using ChainRulesCore
8-
using Zygote
9-
else
10-
using ..AdvancedVI
11-
using ..AdvancedVI: ADTypes, DiffResults
12-
using ..ChainRulesCore
13-
using ..Zygote
14-
end
15-
16-
function AdvancedVI.value_and_gradient!(
17-
::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult
18-
)
19-
y, back = Zygote.pullback(f, x)
20-
∇x = back(one(y))
21-
DiffResults.value!(out, y)
22-
DiffResults.gradient!(out, only(∇x))
23-
return out
24-
end
25-
26-
function AdvancedVI.value_and_gradient!(
27-
ad::ADTypes.AutoZygote,
28-
f,
29-
x::AbstractVector{<:Real},
30-
aux,
31-
out::DiffResults.MutableDiffResult,
32-
)
33-
return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out)
34-
end
35-
36-
end

src/AdvancedVI.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@ using LinearAlgebra
1616

1717
using LogDensityProblems
1818

19-
using ADTypes, DiffResults
19+
using ADTypes
20+
using DiffResults
21+
using DifferentiationInterface
2022
using ChainRulesCore
2123

2224
using FillArrays
2325

2426
using StatsBase
2527

26-
# derivatives
28+
# Derivatives
2729
"""
28-
value_and_gradient!(ad, f, x, out)
2930
value_and_gradient!(ad, f, x, aux, out)
3031
3132
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
@@ -38,7 +39,14 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
3839
- `aux`: Auxiliary input passed to `f`.
3940
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
4041
"""
41-
function value_and_gradient! end
42+
function value_and_gradient!(
43+
ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult
44+
)
45+
grad_buf = DiffResults.gradient(out)
46+
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux))
47+
DiffResults.value!(out, y)
48+
return out
49+
end
4250

4351
"""
4452
restructure_ad_forward(adtype, restructure, params)
@@ -131,7 +139,7 @@ function estimate_objective end
131139
export estimate_objective
132140

133141
"""
134-
estimate_gradient!(rng, obj, adtype, out, prob, λ, restructure, obj_state)
142+
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)
135143
136144
Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
137145
@@ -141,7 +149,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
141149
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
142150
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
143151
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
144-
- `λ`: Variational parameters to evaluate the gradient on.
152+
- `params`: Variational parameters to evaluate the gradient on.
145153
- `restructure`: Function that reconstructs the variational approximation from `λ`.
146154
- `obj_state`: Previous state of the objective.
147155

src/optimize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The arguments are as follows:
4242
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation.
4343
- `gradient`: The estimated (possibly stochastic) gradient.
4444
45-
`cb` can return a `NamedTuple` containing some additional information computed within `cb`.
45+
`callback` can return a `NamedTuple` containing some additional information computed within `cb`.
4646
This will be appended to the statistic of the current corresponding iteration.
4747
Otherwise, just return `nothing`.
4848

test/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
44
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
5+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
56
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
67
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
7-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
@@ -26,7 +26,8 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2626
[compat]
2727
ADTypes = "0.2.1, 1"
2828
Bijectors = "0.13"
29-
DiffResults = "1.0"
29+
DiffResults = "1"
30+
DifferentiationInterface = "0.6"
3031
Distributions = "0.25.111"
3132
DistributionsAD = "0.6.45"
3233
FillArrays = "1.6.1"
@@ -41,6 +42,7 @@ ReverseDiff = "1.15.1"
4142
SimpleUnPack = "1.1.0"
4243
StableRNGs = "1.0.0"
4344
Statistics = "1"
45+
StatsBase = "0.34"
4446
Test = "1"
4547
Tracker = "0.2.20"
4648
Zygote = "0.6.63"

test/inference/repgradelbo_distributionsad.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ AD_distributionsad = Dict(
55
:Zygote => AutoZygote(),
66
)
77

8-
if @isdefined(Tapir)
9-
AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
8+
if @isdefined(Mooncake)
9+
AD_distributionsad[:Mooncake] = AutoMooncake(; config=nothing)
1010
end
1111

1212
if @isdefined(Enzyme)
13-
AD_distributionsad[:Enzyme] = AutoEnzyme()
13+
AD_distributionsad[:Enzyme] = AutoEnzyme(;
14+
mode=set_runtime_activity(ReverseWithPrimal), function_annotation=Const
15+
)
1416
end
1517

1618
@testset "inference RepGradELBO DistributionsAD" begin

0 commit comments

Comments
 (0)