Skip to content

Commit 5ae466d

Browse files
yebaigithub-actions[bot]Red-Portal
authored
Remove the Enzyme extension, prepare gradient (#166)
* Remove the Enzyme extension --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/TuringLang/AdvancedVI.jl?shareId=XXXX-XXXX-XXXX-XXXX). * Update ad.jl * Update test/interface/ad.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update repgradelbo.jl * Update test/interface/repgradelbo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update AdvancedVI.jl * Avoid type piracy * Implement #101 * Update src/objectives/elbo/repgradelbo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/AdvancedVI.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update scoregradelbo_locationscale.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add prepare gradient, change order of arguments in DI wrappers --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Kyurae Kim <[email protected]>
1 parent b020976 commit 5ae466d

File tree

10 files changed

+143
-54
lines changed

10 files changed

+143
-54
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,9 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2121

2222
[weakdeps]
2323
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
24-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2524

2625
[extensions]
2726
AdvancedVIBijectorsExt = "Bijectors"
28-
AdvancedVIEnzymeExt = "Enzyme"
2927

3028
[compat]
3129
ADTypes = "1"
@@ -36,7 +34,6 @@ DiffResults = "1"
3634
DifferentiationInterface = "0.6"
3735
Distributions = "0.25.111"
3836
DocStringExtensions = "0.8, 0.9"
39-
Enzyme = "0.13"
4037
FillArrays = "1.3"
4138
Functors = "0.4, 0.5"
4239
LinearAlgebra = "1"
@@ -49,7 +46,6 @@ julia = "1.10, 1.11.2"
4946

5047
[extras]
5148
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
52-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5349
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5450
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5551

ext/AdvancedVIEnzymeExt.jl

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/AdvancedVI.jl

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,63 @@ using StatsBase
2626

2727
# Derivatives
2828
"""
29-
value_and_gradient!(ad, f, x, aux, out)
29+
_value_and_gradient!(f, out, ad, x, aux)
30+
_value_and_gradient!(f, out, prep, ad, x, aux)
3031
3132
Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`.
3233
`f` may receive auxiliary input as `f(x,aux)`.
3334
3435
# Arguments
35-
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
36+
- `ad::ADTypes.AbstractADType`:
37+
automatic differentiation backend. Currently supports
38+
`ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`,
39+
`ADTypes.AutoMooncake()` and
40+
`ADTypes.AutoEnzyme(;
41+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
42+
function_annotation=Enzyme.Const,
43+
)`.
44+
If one wants to use `AutoEnzyme`, please make sure to include the `set_runtime_activity` and `function_annotation` as shown above.
3645
- `f`: Function subject to differentiation.
3746
- `x`: The point to evaluate the gradient.
3847
- `aux`: Auxiliary input passed to `f`.
48+
- `prep`: Output of `DifferentiationInterface.prepare_gradient`.
3949
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
4050
"""
41-
function value_and_gradient!(
42-
ad::ADTypes.AbstractADType, f, x, aux, out::DiffResults.MutableDiffResult
51+
function _value_and_gradient!(
52+
f, out::DiffResults.MutableDiffResult, ad::ADTypes.AbstractADType, x, aux
4353
)
4454
grad_buf = DiffResults.gradient(out)
4555
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux))
4656
DiffResults.value!(out, y)
4757
return out
4858
end
4959

60+
function _value_and_gradient!(
61+
f, out::DiffResults.MutableDiffResult, prep, ad::ADTypes.AbstractADType, x, aux
62+
)
63+
grad_buf = DiffResults.gradient(out)
64+
y, _ = DifferentiationInterface.value_and_gradient!(
65+
f, grad_buf, prep, ad, x, Constant(aux)
66+
)
67+
DiffResults.value!(out, y)
68+
return out
69+
end
70+
71+
"""
72+
_prepare_gradient!(f, ad, x, aux)
73+
74+
Prepare AD backend for taking gradients of a function `f` at `x` using the automatic differentiation backend `ad`.
75+
76+
# Arguments
77+
- `ad::ADTypes.AbstractADType`: Automatic differentiation backend.
78+
- `f`: Function subject to differentiation.
79+
- `x`: The point to evaluate the gradient.
80+
- `aux`: Auxiliary input passed to `f`.
81+
"""
82+
function _prepare_gradient(f, ad::ADTypes.AbstractADType, x, aux)
83+
return DifferentiationInterface.prepare_gradient(f, ad, x, Constant(aux))
84+
end
85+
5086
"""
5187
restructure_ad_forward(adtype, restructure, params)
5288
@@ -74,18 +110,26 @@ If the estimator is stateful, it can implement `init` to initialize the state.
74110
abstract type AbstractVariationalObjective end
75111

76112
"""
77-
init(rng, obj, prob, params, restructure)
113+
init(rng, obj, adtype, prob, params, restructure)
78114
79115
Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
80116
This function needs to be implemented only if `obj` is stateful.
81117
82118
# Arguments
83119
- `rng::Random.AbstractRNG`: Random number generator.
84120
- `obj::AbstractVariationalObjective`: Variational objective.
121+
` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
85122
- `params`: Initial variational parameters.
86123
- `restructure`: Function that reconstructs the variational approximation from `λ`.
87124
"""
88-
init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing
125+
init(
126+
::Random.AbstractRNG,
127+
::AbstractVariationalObjective,
128+
::ADTypes.AbstractADType,
129+
::Any,
130+
::Any,
131+
::Any,
132+
) = nothing
89133

90134
"""
91135
estimate_objective([rng,] obj, q, prob; kwargs...)

src/objectives/elbo/repgradelbo.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@ struct RepGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalO
3333
n_samples::Int
3434
end
3535

36+
function init(
37+
rng::Random.AbstractRNG,
38+
obj::RepGradELBO,
39+
adtype::ADTypes.AbstractADType,
40+
prob,
41+
params,
42+
restructure,
43+
)
44+
q_stop = restructure(params)
45+
aux = (
46+
rng=rng,
47+
adtype=adtype,
48+
obj=obj,
49+
problem=prob,
50+
restructure=restructure,
51+
q_stop=q_stop,
52+
)
53+
return AdvancedVI._prepare_gradient(
54+
estimate_repgradelbo_ad_forward, adtype, params, aux
55+
)
56+
end
57+
3658
function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy())
3759
return RepGradELBO(entropy, n_samples)
3860
end
@@ -129,6 +151,7 @@ function estimate_gradient!(
129151
restructure,
130152
state,
131153
)
154+
prep = state
132155
q_stop = restructure(params)
133156
aux = (
134157
rng=rng,
@@ -138,8 +161,10 @@ function estimate_gradient!(
138161
restructure=restructure,
139162
q_stop=q_stop,
140163
)
141-
value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out)
164+
AdvancedVI._value_and_gradient!(
165+
estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux
166+
)
142167
nelbo = DiffResults.value(out)
143168
stat = (elbo=-nelbo,)
144-
return out, nothing, stat
169+
return out, state, stat
145170
end

src/objectives/elbo/scoregradelbo.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ struct ScoreGradELBO <: AbstractVariationalObjective
1616
n_samples::Int
1717
end
1818

19+
function init(
20+
rng::Random.AbstractRNG,
21+
obj::ScoreGradELBO,
22+
adtype::ADTypes.AbstractADType,
23+
prob,
24+
params,
25+
restructure,
26+
)
27+
q = restructure(params)
28+
samples = rand(rng, q, obj.n_samples)
29+
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
30+
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
31+
return AdvancedVI._prepare_gradient(
32+
estimate_scoregradelbo_ad_forward, adtype, params, aux
33+
)
34+
end
35+
1936
function Base.show(io::IO, obj::ScoreGradELBO)
2037
print(io, "ScoreGradELBO(n_samples=")
2138
print(io, obj.n_samples)
@@ -71,14 +88,15 @@ function AdvancedVI.estimate_gradient!(
7188
state,
7289
)
7390
q = restructure(params)
91+
prep = state
7492
samples = rand(rng, q, obj.n_samples)
7593
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
7694
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
77-
AdvancedVI.value_and_gradient!(
78-
adtype, estimate_scoregradelbo_ad_forward, params, aux, out
95+
AdvancedVI._value_and_gradient!(
96+
estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux
7997
)
8098
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
8199
elbo = mean(ℓπ - ℓq)
82100
stat = (elbo=elbo,)
83-
return out, nothing, stat
101+
return out, state, stat
84102
end

src/optimize.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ function optimize(
6868
)
6969
params, restructure = Optimisers.destructure(deepcopy(q_init))
7070
opt_st = maybe_init_optimizer(state_init, optimizer, params)
71-
obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure)
71+
obj_st = maybe_init_objective(
72+
state_init, rng, objective, adtype, problem, params, restructure
73+
)
7274
avg_st = maybe_init_averager(state_init, averager, params)
7375
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
7476
stats = NamedTuple[]

src/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ function maybe_init_objective(
2525
state_init::NamedTuple,
2626
rng::Random.AbstractRNG,
2727
objective::AbstractVariationalObjective,
28+
adtype::ADTypes.AbstractADType,
2829
problem,
2930
params,
3031
restructure,
3132
)
3233
if haskey(state_init, :objective)
3334
state_init.objective
3435
else
35-
init(rng, objective, problem, params, restructure)
36+
init(rng, objective, adtype, problem, params, restructure)
3637
end
3738
end
3839

test/inference/scoregradelbo_locationscale.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11

22
AD_scoregradelbo_locationscale = if TEST_GROUP == "Enzyme"
3-
Dict(:Enzyme => AutoEnzyme())
3+
Dict(
4+
:Enzyme => AutoEnzyme(;
5+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
6+
function_annotation=Enzyme.Const,
7+
),
8+
)
49
else
510
Dict(
611
:ForwarDiff => AutoForwardDiff(),

test/interface/ad.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
using Test
33

44
AD_interface = if TEST_GROUP == "Enzyme"
5-
Dict(:Enzyme => AutoEnzyme())
5+
Dict(
6+
:Enzyme => AutoEnzyme(;
7+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
8+
function_annotation=Enzyme.Const,
9+
),
10+
)
611
else
712
Dict(
813
:ForwarDiff => AutoForwardDiff(),
@@ -20,7 +25,26 @@ end
2025
b = randn(D)
2126
grad_buf = DiffResults.GradientResult(λ)
2227
f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′)
23-
AdvancedVI.value_and_gradient!(adtype, f, λ, (b=b,), grad_buf)
28+
AdvancedVI._value_and_gradient!(f, grad_buf, adtype, λ, (b=b,))
29+
= DiffResults.gradient(grad_buf)
30+
f = DiffResults.value(grad_buf)
31+
@test (A + A') * λ / 2 + b
32+
@test f λ' * A * λ / 2 + dot(b, λ)
33+
end
34+
35+
@testset "$(adname) with prep" for (adname, adtype) in AD_interface
36+
D = 10
37+
λ = randn(D)
38+
A = randn(D, D)
39+
grad_buf = DiffResults.GradientResult(λ)
40+
41+
b_prep = randn(D)
42+
f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′)
43+
prep = AdvancedVI._prepare_gradient(f, adtype, λ, (b=b_prep,))
44+
45+
b = randn(D)
46+
AdvancedVI._value_and_gradient!(f, grad_buf, prep, adtype, λ, (b=b,))
47+
2448
= DiffResults.gradient(grad_buf)
2549
f = DiffResults.value(grad_buf)
2650
@test (A + A') * λ / 2 + b

test/interface/repgradelbo.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11

22
AD_repgradelbo_interface = if TEST_GROUP == "Enzyme"
3-
[AutoEnzyme()]
3+
[
4+
AutoEnzyme(;
5+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
6+
function_annotation=Enzyme.Const,
7+
),
8+
]
49
else
510
[
611
AutoForwardDiff(),
@@ -71,8 +76,8 @@ end
7176
aux = (
7277
rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=adtype
7378
)
74-
AdvancedVI.value_and_gradient!(
75-
adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out
79+
AdvancedVI._value_and_gradient!(
80+
AdvancedVI.estimate_repgradelbo_ad_forward, out, adtype, params, aux
7681
)
7782
grad = DiffResults.gradient(out)
7883
@test norm(grad) 0 atol = 1e-5

0 commit comments

Comments
 (0)