Skip to content

Commit 54dff15

Browse files
Red-Portalgithub-actions[bot]yebai
authored
refactor interface for projections/proximal operators (#147)
* fix outdated type parameters in `LocationScale` * add `operator` keyword argument to `optimize` so that projection/proximal operatord can have their own interface. * fix benchmark --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]>
1 parent 04a894a commit 54dff15

22 files changed

+241
-221
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ q_avg, _, stats, _ = AdvancedVI.optimize(
109109
max_iter;
110110
adtype=ADTypes.AutoForwardDiff(),
111111
optimizer=Optimisers.Adam(1e-3),
112+
operator=ClipScale(),
112113
)
113114

114115
# Evaluate final ELBO with 10^3 Monte Carlo samples

bench/benchmarks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ begin
4141
max_iter = 10^4
4242
d = LogDensityProblems.dimension(prob)
4343
optimizer = Optimisers.Adam(T(1e-3))
44+
operator = ClipScale()
4445

4546
for (objname, obj) in [
4647
("RepGradELBO", RepGradELBO(10)),
@@ -73,6 +74,7 @@ begin
7374
$max_iter;
7475
adtype=$adtype,
7576
optimizer=$optimizer,
77+
operator=$operator,
7678
show_progress=false,
7779
)
7880
end

docs/src/elbo/repgradelbo.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ _, _, stats_cfe, _ = AdvancedVI.optimize(
219219
show_progress = false,
220220
adtype = AutoForwardDiff(),
221221
optimizer = Optimisers.Adam(3e-3),
222+
operator = ClipScale(),
222223
callback = callback,
223224
);
224225
@@ -230,6 +231,7 @@ _, _, stats_stl, _ = AdvancedVI.optimize(
230231
show_progress = false,
231232
adtype = AutoForwardDiff(),
232233
optimizer = Optimisers.Adam(3e-3),
234+
operator = ClipScale(),
233235
callback = callback,
234236
);
235237
@@ -317,6 +319,7 @@ _, _, stats_qmc, _ = AdvancedVI.optimize(
317319
show_progress = false,
318320
adtype = AutoForwardDiff(),
319321
optimizer = Optimisers.Adam(3e-3),
322+
operator = ClipScale(),
320323
callback = callback,
321324
);
322325

docs/src/examples.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,14 @@ q_avg_trans, q_trans, stats, _ = AdvancedVI.optimize(
118118
show_progress=false,
119119
adtype=AutoForwardDiff(),
120120
optimizer=Optimisers.Adam(1e-3),
121+
operator=ClipScale(),
121122
);
122123
nothing
123124
```
124125

126+
`ClipScale` is a projection operator, which ensures that the variational approximation stays within a stable region of the variational family.
127+
For more information see [this section](@ref clipscale).
128+
125129
`q_avg_trans` is the final output of the optimization procedure.
126130
If a parameter averaging strategy is used through the keyword argument `averager`, `q_avg_trans` is be the output of the averaging strategy, while `q_trans` is the last iterate.
127131

docs/src/optimization.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,18 @@ PolynomialAveraging
2626
[^DCAMHV2020]: Dhaka, A. K., Catalina, A., Andersen, M. R., Magnusson, M., Huggins, J., & Vehtari, A. (2020). Robust, accurate stochastic optimization for variational inference. Advances in Neural Information Processing Systems, 33, 10961-10973.
2727
[^KMJ2024]: Khaled, A., Mishchenko, K., & Jin, C. (2023). Dowg unleashed: An efficient universal parameter-free gradient descent method. Advances in Neural Information Processing Systems, 36, 6748-6769.
2828
[^IHC2023]: Ivgi, M., Hinder, O., & Carmon, Y. (2023). Dog is sgd's best friend: A parameter-free dynamic step size schedule. In International Conference on Machine Learning (pp. 14465-14499). PMLR.
29+
## Operators
30+
31+
Depending on the variational family, variational objective, and optimization strategy, it might be necessary to modify the variational parameters after performing a gradient-based update.
32+
For this, an operator acting on the parameters can be supplied via the `operator` keyword argument of `AdvancedVI.optimize`.
33+
34+
### [`ClipScale`](@id clipscale)
35+
36+
For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
37+
To ensure this, we provide the following projection operator:
38+
39+
```@docs
40+
ClipScale
41+
```
42+
43+
[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.

ext/AdvancedVIBijectorsExt.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,38 @@ else
1515
using ..Random
1616
end
1717

18-
function AdvancedVI.update_variational_params!(
18+
function AdvancedVI.apply(
19+
op::ClipScale,
1920
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
20-
opt_st,
2121
params,
2222
restructure,
23-
grad,
2423
)
25-
opt_st, params = Optimisers.update!(opt_st, params, grad)
2624
q = restructure(params)
27-
ϵ = q.dist.scale_eps
25+
ϵ = convert(eltype(params), op.epsilon)
2826

2927
# Project the scale matrix to the set of positive definite triangular matrices
3028
diag_idx = diagind(q.dist.scale)
3129
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)
3230

3331
params, _ = Optimisers.destructure(q)
3432

35-
return opt_st, params
33+
return params
34+
end
35+
36+
function AdvancedVI.apply(
37+
op::ClipScale,
38+
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
39+
params,
40+
restructure,
41+
)
42+
q = restructure(params)
43+
ϵ = convert(eltype(params), op.epsilon)
44+
45+
@. q.dist.scale_diag = max(q.dist.scale_diag, ϵ)
46+
47+
params, _ = Optimisers.destructure(q)
48+
49+
return params
3650
end
3751

3852
function AdvancedVI.reparam_with_entropy(

src/AdvancedVI.jl

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,34 +60,6 @@ This is an indirection for handling the type stability of `restructure`, as some
6060
"""
6161
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)
6262

63-
# Update for gradient descent step
64-
"""
65-
update_variational_params!(family_type, opt_st, params, restructure, grad)
66-
67-
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
68-
69-
This is a wrapper around `Optimisers.update!` to provide some indirection.
70-
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
71-
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
72-
Instead, the return values should be used.
73-
74-
# Arguments
75-
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
76-
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
77-
- `params`: Current set of parameters to be updated.
78-
- `restructure`: Callable for restructuring the varitional distribution from `params`.
79-
- `grad`: Gradient to be used by the update rule of `opt_st`.
80-
81-
# Returns
82-
- `opt_st`: Updated optimizer state.
83-
- `params`: Updated parameters.
84-
"""
85-
function update_variational_params! end
86-
87-
function update_variational_params!(::Type, opt_st, params, restructure, grad)
88-
return Optimisers.update!(opt_st, params, grad)
89-
end
90-
9163
# estimators
9264
"""
9365
AbstractVariationalObjective
@@ -149,7 +121,7 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
149121
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
150122
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
151123
- `params`: Variational parameters to evaluate the gradient on.
152-
- `restructure`: Function that reconstructs the variational approximation from `λ`.
124+
- `restructure`: Function that reconstructs the variational approximation from `params`.
153125
- `obj_state`: Previous state of the objective.
154126
155127
# Returns
@@ -215,7 +187,7 @@ Initialize the state of the averaging strategy `avg` with the initial parameters
215187
init(::AbstractAverager, ::Any) = nothing
216188

217189
"""
218-
apply(avg, avg_st, params)
190+
apply(avg::AbstractAverager, avg_st, params)
219191
220192
Apply averaging strategy `avg` on `params` given the state `avg_st`.
221193
@@ -241,6 +213,39 @@ include("optimization/averaging.jl")
241213

242214
export NoAveraging, PolynomialAveraging
243215

216+
# Operators for Optimization
217+
abstract type AbstractOperator end
218+
219+
"""
220+
apply(op::AbstractOperator, family, params, restructure)
221+
222+
Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator.
223+
224+
# Arguments
225+
- `op::AbstractOperator`: Operator operating on the parameters `params`.
226+
- `family::Type`: Type of the variational approximation `restructure(params)`.
227+
- `params`: Variational parameters.
228+
- `restructure`: Function that reconstructs the variational approximation from `params`.
229+
230+
# Returns
231+
- `oped_params`: Parameters resulting from applying the operator.
232+
"""
233+
function apply(::AbstractOperator, ::Type, ::Any, ::Any) end
234+
235+
"""
236+
IdentityOperator()
237+
238+
Identity operator.
239+
"""
240+
struct IdentityOperator <: AbstractOperator end
241+
242+
apply(::IdentityOperator, ::Type, params, restructure) = params
243+
244+
include("optimization/clip_scale.jl")
245+
246+
export IdentityOperator, ClipScale
247+
248+
# Main optimization routine
244249
function optimize end
245250

246251
export optimize

src/families/location_scale.jl

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11

2-
struct MvLocationScale{S,D<:ContinuousDistribution,L,E<:Real} <:
3-
ContinuousMultivariateDistribution
4-
location::L
5-
scale::S
6-
dist::D
7-
scale_eps::E
8-
end
9-
102
"""
11-
MvLocationScale(location, scale, dist; scale_eps)
3+
MvLocationScale(location, scale, dist)
124
135
The location scale variational family broadly represents various variational
146
families using `location` and `scale` variational parameters.
@@ -20,21 +12,11 @@ represented as follows:
2012
u = rand(dist, d)
2113
z = scale*u + location
2214
```
23-
24-
`scale_eps` sets a constraint on the smallest value of `scale` to be enforced during optimization.
25-
This is necessary to guarantee stable convergence.
26-
27-
# Keyword Arguments
28-
- `scale_eps`: Lower bound constraint for the diagonal of the scale. (default: `1e-4`).
2915
"""
30-
function MvLocationScale(
31-
location::AbstractVector{T},
32-
scale::AbstractMatrix{T},
33-
dist::ContinuousUnivariateDistribution;
34-
scale_eps::T=T(1e-4),
35-
) where {T<:Real}
36-
@assert minimum(diag(scale)) scale_eps "Initial scale is too small (smallest diagonal value is $(minimum(diag(scale)))). This might result in unstable optimization behavior."
37-
return MvLocationScale(location, scale, dist, scale_eps)
16+
struct MvLocationScale{S,D<:ContinuousDistribution,L} <: ContinuousMultivariateDistribution
17+
location::L
18+
scale::S
19+
dist::D
3820
end
3921

4022
Functors.@functor MvLocationScale (location, scale)
@@ -44,18 +26,18 @@ Functors.@functor MvLocationScale (location, scale)
4426
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
4527
# is very inefficient.
4628
# begin
47-
struct RestructureMeanField{S<:Diagonal,D,L,E}
48-
model::MvLocationScale{S,D,L,E}
29+
struct RestructureMeanField{S<:Diagonal,D,L}
30+
model::MvLocationScale{S,D,L}
4931
end
5032

5133
function (re::RestructureMeanField)(flat::AbstractVector)
5234
n_dims = div(length(flat), 2)
5335
location = first(flat, n_dims)
5436
scale = Diagonal(last(flat, n_dims))
55-
return MvLocationScale(location, scale, re.model.dist, re.model.scale_eps)
37+
return MvLocationScale(location, scale, re.model.dist)
5638
end
5739

58-
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L,E}) where {D,L,E}
40+
function Optimisers.destructure(q::MvLocationScale{<:Diagonal,D,L}) where {D,L}
5941
(; location, scale, dist) = q
6042
flat = vcat(location, diag(scale))
6143
return flat, RestructureMeanField(q)
@@ -66,7 +48,7 @@ Base.length(q::MvLocationScale) = length(q.location)
6648

6749
Base.size(q::MvLocationScale) = size(q.location)
6850

69-
Base.eltype(::Type{<:MvLocationScale{S,D,L,E}}) where {S,D,L,E} = eltype(D)
51+
Base.eltype(::Type{<:MvLocationScale{S,D,L}}) where {S,D,L} = eltype(D)
7052

7153
function StatsBase.entropy(q::MvLocationScale)
7254
(; location, scale, dist) = q
@@ -131,55 +113,29 @@ function Distributions.cov(q::MvLocationScale)
131113
end
132114

133115
"""
134-
FullRankGaussian(μ, L; scale_eps)
116+
FullRankGaussian(μ, L)
135117
136118
Construct a Gaussian variational approximation with a dense covariance matrix.
137119
138120
# Arguments
139121
- `μ::AbstractVector{T}`: Mean of the Gaussian.
140122
- `L::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
141-
142-
# Keyword Arguments
143-
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
144123
"""
145124
function FullRankGaussian(
146-
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}; scale_eps::T=T(1e-4)
125+
μ::AbstractVector{T}, L::LinearAlgebra.AbstractTriangular{T}
147126
) where {T<:Real}
148-
q_base = Normal{T}(zero(T), one(T))
149-
return MvLocationScale(μ, L, q_base, scale_eps)
127+
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
150128
end
151129

152130
"""
153-
MeanFieldGaussian(μ, L; scale_eps)
131+
MeanFieldGaussian(μ, L)
154132
155133
Construct a Gaussian variational approximation with a diagonal covariance matrix.
156134
157135
# Arguments
158136
- `μ::AbstractVector{T}`: Mean of the Gaussian.
159137
- `L::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
160-
161-
# Keyword Arguments
162-
- `scale_eps`: Smallest value allowed for the diagonal of the scale. (default: `1e-4`).
163138
"""
164-
function MeanFieldGaussian(
165-
μ::AbstractVector{T}, L::Diagonal{T}; scale_eps::T=T(1e-4)
166-
) where {T<:Real}
167-
q_base = Normal{T}(zero(T), one(T))
168-
return MvLocationScale(μ, L, q_base, scale_eps)
169-
end
170-
171-
function update_variational_params!(
172-
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
173-
)
174-
opt_st, params = Optimisers.update!(opt_st, params, grad)
175-
q = restructure(params)
176-
ϵ = q.scale_eps
177-
178-
# Project the scale matrix to the set of positive definite triangular matrices
179-
diag_idx = diagind(q.scale)
180-
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)
181-
182-
params, _ = Optimisers.destructure(q)
183-
184-
return opt_st, params
139+
function MeanFieldGaussian::AbstractVector{T}, L::Diagonal{T}) where {T<:Real}
140+
return MvLocationScale(μ, L, Normal{T}(zero(T), one(T)))
185141
end

0 commit comments

Comments
 (0)