Skip to content

Commit bd6980c

Browse files
smolixastonzhang
authored andcommitted
adagrad
1 parent dcc4e4d commit bd6980c

File tree

2 files changed

+81
-51
lines changed

2 files changed

+81
-51
lines changed

chapter_optimization/adagrad.md

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,77 @@
11
# Adagrad
22
:label:`sec_adagrad`
33

4-
In the optimization algorithms we introduced previously, each element of the objective function's independent variables uses the same learning rate at the same time step for self-iteration. For example, if we assume that the objective function is $f$ and the independent variable is a two-dimensional vector $[x_1, x_2]^\top$, each element in the vector uses the same learning rate when iterating. For example, in gradient descent with the learning rate $\eta$, element $x_1$ and $x_2$ both use the same learning rate $\eta$ for iteration:
5-
6-
$$
7-
x_1 \leftarrow x_1 - \eta \frac{\partial{f}}{\partial{x_1}}, \quad
8-
x_2 \leftarrow x_2 - \eta \frac{\partial{f}}{\partial{x_2}}.
9-
$$
10-
11-
In :numref:`sec_momentum`, we can see that, when there is a big difference
12-
between the gradient values $x_1$ and $x_2$, a sufficiently small learning rate
13-
needs to be selected so that the independent variable will not diverge in the
14-
dimension of larger gradient values. However, this will cause the independent
15-
variables to iterate too slowly in the dimension with smaller gradient
16-
values. The momentum method relies on the exponentially weighted moving average
17-
(EWMA) to make the direction of the independent variable more consistent, thus
18-
reducing the possibility of divergence. In this section, we are going to
19-
introduce Adagrad :cite:`Duchi.Hazan.Singer.2011`, an algorithm that adjusts the learning rate according to the
20-
gradient value of the independent variable in each dimension to eliminate
21-
problems caused when a unified learning rate has to adapt to all dimensions.
4+
## Sparse Features and Learning Rates
225

6+
Imagine that we're training a language model. To get good accuracy we typically want to decrease the learning rate as we keep on training, usually at a rate of $O(t^{-\frac{1}{2}})$ or slower. Now consider a model training on sparse features, i.e. features that occur only infrequently. This is common for natural language, e.g. it is a lot less likely that we'll see the word *preconditioning* than *learning*. However, it is also common in other areas such as computational advertising and personalized collaborative filtering. After all, there are many things that are of interest only for a small number of people.
237

24-
## The Algorithm
8+
Parameters associated with infrequent features only receive meaningful updates whenever these features occur. Given a decreasing learning rate we might end up in a situation where the parameters for common features converge rather quickly to their optimal values, whereas for infrequent features we are still short of observing them sufficiently frequently before their optimal values can be determined. In other words, the learning rate either decreases too slowly for frequent features or too slowly for infrequent ones.
9+
10+
A possible hack to redress this issue would be to count the number of times we see a particular feature and to use this as a clock for adjusting learning rates. That is, rather than choosing a learning rate of the form $\eta = \frac{\eta_0}{\sqrt{t + c}}$ we could use $\eta_i = \frac{\eta_0}{\sqrt{s(i,t) + c}}$. Here $s(i,t)$ counts the number of nonzeros for feature $i$ that we have observed up to time $t$. This is actually quite easy to implement at no meaningful overhead. However, it fails whenever we don't quite have sparsity but rather just data where the gradients are often very small and only rarely large. After all, it is unclear where one would draw the line between something that qualifies as an observed feature or not.
11+
12+
Adagrad by :cite:`Duchi.Hazan.Singer.2011` addresses this by repacing the rather crude counter $s(i,t)$ by an aggregate of the squares of previously observed gradients. In particular, it uses $s(i,t+1) = s(i,t) + \left(\partial_i f(\mathbf{x})\right)^2$ as a means to adjust the learning rate. This has two benefits: firstly, we no longer need to decide just when a gradient is large enough. Secondly, it scales automatically with the magnitude of the gradients. Coordinates that routinely correspond to large gradients are scaled down significantly, whereas others with small gradients receive a much more gentle treatment. In practice this leads to a very effective optimization procedure for computational advertising and related problems. But this hides some of the additional benefits inherent in Adagrad that are best understood in the context of preconditioning.
13+
14+
15+
## Preconditioning
16+
17+
Convex optimization problems are good for analyzing the characteristics of algorithms. After all, for most nonconvex problems it is difficult to derive meaningful theoretical guarantees, but *intuition* and *insight* often carry over. Let's look at the problem of minimizing $f(\mathbf{x}) = \frac{1}{2} \mathbf{x}^\top Q \mathbf{x} + \mathbf{c}^\top \mathbf{x} + b$.
18+
19+
As we saw in :ref:`sec_momentum`, it is possible to rewrite this problem in terms of its eigendecomposition $Q = U^\top \Lambda U$ to arrive at a much simplified problem where each coordinate can be solved individually:
2520

26-
The Adagrad algorithm uses the cumulative variable $\mathbf{s}_t$ obtained from a square by element operation on the minibatch stochastic gradient $\mathbf{g}_t$. At time step 0, Adagrad initializes each element in $\mathbf{s}_0$ to 0. At time step $t$, we first sum the results of the square by element operation for the minibatch gradient $\mathbf{g}_t$ to get the variable $\mathbf{s}_t$:
21+
$$f(\mathbf{x}) = \bar{f}(\bar{\mathbf{x}}) = \frac{1}{2} \bar{\mathbf{x}}^\top \Lambda \bar{\mathbf{x}} + \bar{\mathbf{c}}^\top \bar{\mathbf{x}} + b$$
2722

28-
$$\mathbf{s}_t \leftarrow \mathbf{s}_{t-1} + \mathbf{g}_t \odot \mathbf{g}_t,$$
23+
Here we used $\mathbf{x} = U \mathbf{x}$ and consequently $\mathbf{c} = U \mathbf{c}$. The modified problem has as its minimizer $\bar{\mathbf{x}} = -\Lambda^{-1} \bar{\mathbf{c}}$ and minimum value $-\frac{1}{2} \bar{\mathbf{c}}^\top \Lambda^{-1} \bar{\mathbf{c}} + b$. This is much easier to compute since $\Lambda$ is a diagonal matrix containing the eigenvalues of $Q$.
2924

30-
Here, $\odot$ is the symbol for multiplication by element. Next, we re-adjust the learning rate of each element in the independent variable of the objective function using element operations:
25+
If we perturb $\mathbf{c}$ slightly we would hope to find only slight changes in the minimizer of $f$. Unfortunately this is not the case. While slight changes in $\mathbf{c}$ lead to equally slight changes in $\bar{\mathbf{c}}$, this is not the case for the minimizer of $f$ (and of $\bar{f}$ respectively). Whenever the eigenvalues $\lambda_i$ are large we will see only small changes in $\bar{x}_i$ and in the minimum of $\bar{f}$. Conversely, for small $\lambda_i$ changes in $\bar{x}_i$ can be dramatic. The ratio between the largest and the smallest eigenvalue is called the condition number of an optimization problem.
3126

32-
$$\mathbf{x}_t \leftarrow \mathbf{x}_{t-1} - \frac{\eta}{\sqrt{\mathbf{s}_t + \epsilon}} \odot \mathbf{g}_t,$$
27+
$$\kappa = \frac{\lambda_1}{\lambda_d}$$
3328

34-
Here, $\eta$ is the learning rate while $\epsilon$ is a constant added to maintain numerical stability, such as $10^{-6}$. Here, the square root, division, and multiplication operations are all element operations. Each element in the independent variable of the objective function will have its own learning rate after the operations by elements.
29+
If the condition number $\kappa$ is large, it is difficult to solve the optimization problem accurately. We need to ensure that we are careful in getting a large dynamic range of values right. Our analysis leads to an obvious, albeit somewhat naive question: couldn't we simply 'fix' the problem by distorting the space such that all eigenvalues are $1$. In theory this is quite easy - we only need the eigenvalues and eigenvectors of $Q$ to rescale the problem from $\mathbf{x}$ to one in $\mathbf{z} := \Lambda^{\frac{1}{2}} U \mathbf{x}$. In the new coordinate system $\mathbf{x}^\top Q \mathbf{x}$ could be simplified to $\|\mathbf{z}\|^2$. Alas, this is a rather impractical suggestion. Computing eigenvalues and eigenvectors is in general *much more* expensive than solving the actual problem.
3530

36-
## Features
31+
While computing eigenvalues exactly might be expensive, guessing them and computing them even somewhat approximately may already be a lot better than not doing anything at all. In particular, we could use the diagonal entries of $Q$ and rescale it accordingly. This is *much* cheaper than computing eigenvalues.
3732

38-
We should emphasize that the cumulative variable $\mathbf{s}_t$ produced by a square by element operation on the minibatch stochastic gradient is part of the learning rate denominator. Therefore, if an element in the independent variable of the objective function has a constant and large partial derivative, the learning rate of this element will drop faster. On the contrary, if the partial derivative of such an element remains small, then its learning rate will decline more slowly. However, since $\mathbf{s}_t$ accumulates the square by element gradient, the learning rate of each element in the independent variable declines (or remains unchanged) during iteration. Therefore, when the learning rate declines very fast during early iteration, yet the current solution is still not desirable, Adagrad might have difficulty finding a useful solution because the learning rate will be too small at later stages of iteration.
33+
$$\tilde{Q} = \mathrm{diag}^{-\frac{1}{2}}(Q) Q \mathrm{diag}^{-\frac{1}{2}}(Q).$$
3934

40-
Below we will continue to use the objective function $f(\mathbf{x})=0.1x_1^2+2x_2^2$ as an example to observe the iterative trajectory of the independent variable in Adagrad. We are going to implement Adagrad using the same learning rate as the experiment in last section, 0.4. As we can see, the iterative trajectory of the independent variable is smoother. However, due to the cumulative effect of $\mathbf{s}_t$, the learning rate continuously decays, so the independent variable does not move as much during later stages of iteration.
35+
In this case we have $\tilde{Q}_{ij} = Q_{ij} / \sqrt{Q_{ii} Q_{jj}}$ and specifically $\tilde{Q}_{ii} = 1$ for all $i$. In most cases this simplifies the condition number considerably. For instance, the the cases we discussed previously, this would entirely eliminate the problem at hand since the problem is axis aligned.
36+
37+
Unfortunately we face yet another problem - in deep learning we typically don't even have access to the second derivative of the objective function: for $\mathbf{x} \in \mathbb{R}^d$ the second derivative even on a minibatch may require $O(d^2)$ space and work to compute, thus making it practically infeasible. The ingenious idea of Adagrad is to use a proxy for that elusive diagonal of the Hessian that is both relatively cheap to compute and effective - the magnitude of the gradient itself.
38+
39+
In order to see why this works, let's look at $\bar{f}(\bar{\mathbf{x}})$. We have that
40+
41+
$$\partial_{\bar{\mathbf{x}}} \bar{f}(\bar{\mathbf{x}}) = \Lambda \bar{\mathbf{x}} + \bar{\mathbf{c}} = \Lambda \left(\bar{\mathbf{x}} - \bar{\mathbf{x}}_0\right)$$
42+
43+
where $\bar{\mathbf{x}}_0$ is the minimizer of $\bar{f}$. Hence the magnitude of the gradient depends both on $\Lambda$ and the distance from optimality. If $\bar{\mathbf{x}} - \bar{\mathbf{x}}_0$ didn't change, this would be all that's needed. After all, in this case the magnitude of the gradient $\partial_{\bar{\mathbf{x}}} \bar{f}(\bar{\mathbf{x}})$ suffices. Since AdaGrad is a stochastic gradient descent algorithm, we will see gradients with nonzero variance even at optimality. As a result we can safely use the variance of the gradients as a cheap proxy for the scale of the Hessian. A thorough analysis is beyond the scope of this section (it would be several pages). We refer the reader to :cite:`Duchi.Hazan.Singer.2011` for details.
44+
45+
## The Algorithm
4146

42-
```{.python .input n=1}
47+
Let's formalize the discussion from above. We use the variable $\mathbf{s}_t$ to accumulate past gradient variance as follows.
48+
49+
$$\begin{aligned}
50+
\mathbf{g}_t & = \partial_{\mathbf{w}} l(y_t, f(\mathbf{x}_t, \mathbf{w})) \\
51+
\mathbf{s}_t & = \mathbf{s}_{t-1} + \mathbf{g}_t^2 \\
52+
\mathbf{w}_t & = \mathbf{w}_{t-1} - \frac{\eta}{\sqrt{\mathbf{s}_t + \epsilon}} \cdot \mathbf{g}_t
53+
\end{aligned}$$
54+
55+
Here the operation are applied coordinate wise. That is, $\mathbf{v}^2$ has entries $v_i^2$. Likewise $\frac{1}{\sqrt{v}}$ has entries $\frac{1}{\sqrt{v_i}}$ and $\mathbf{u} \cdot \mathbf{v}$ has entries $u_i v_i$. As before $\eta$ is the learning rate and $\epsilon$ is an additive constant that ensures that we do not divide by $0$. Lastly, we initialize $\mathbf{s}_0 = \mathbf{0}$.
56+
57+
Just like in the case of momentum we need to keep track of an auxiliary variable, in this case to allow for an individual learning rate per coordinate. This doesn't increase the cost of Adagrad significantly relative to SGD, simply since the main cost is typically to compute $l(y_t, f(\mathbf{x}_t, \mathbf{w}))$ and its derivative.
58+
59+
Note that accumulating squared gradients in $\mathbf{s}_t$ means that $\mathbf{s}_t$ grows essentially at linear rate (somewhat slower than linearly in practice, since the gradients initially diminish). This leads to an $O(t^{-\frac{1}{2}})$ learning rate, albeit adjusted on a per coordinate basis. For convex problems this is perfectly adequate. In deep learning, though, we might want to decrease the learning rate rather more slowly. This led to a number of Adagrad variants that we will discuss in the subsequent chapters. For now let's see how it behaves in a quadratic convex problem. We use the same problem as before:
60+
61+
$$f(\mathbf{x}) = 0.1 x_1^2 + 2 x_2^2$$
62+
63+
We are going to implement Adagrad using the same learning rate previously, i.e. $\eta = 0.4$. As we can see, the iterative trajectory of the independent variable is smoother. However, due to the cumulative effect of $\boldsymbol{s}_t$, the learning rate continuously decays, so the independent variable does not move as much during later stages of iteration.
64+
65+
```{.python .input n=6}
4366
%matplotlib inline
4467
import d2l
4568
import math
4669
from mxnet import np, npx
4770
npx.set_np()
4871
4972
def adagrad_2d(x1, x2, s1, s2):
50-
# The first two terms are the independent variable gradients
51-
g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6
73+
eps = 1e-6
74+
g1, g2 = 0.2 * x1, 4 * x2
5275
s1 += g1 ** 2
5376
s2 += g2 ** 2
5477
x1 -= eta / math.sqrt(s1 + eps) * g1
@@ -62,19 +85,18 @@ eta = 0.4
6285
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))
6386
```
6487

65-
Now, we are going to increase the learning rate to $2$. As we can see, the independent variable approaches the optimal solution more quickly.
88+
As we increase the learning rate to $2$ we see much better behavior. This already indicates that the decrease in learning rate might be rather aggressive, even in the noise-free case and we need to ensure that parameters converge appropriately.
6689

67-
```{.python .input n=2}
90+
```{.python .input n=10}
6891
eta = 2
6992
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))
7093
```
7194

7295
## Implementation from Scratch
7396

74-
Like the momentum method, Adagrad needs to maintain a state variable of the same shape for each independent variable. We use the formula from the algorithm to implement Adagrad.
75-
76-
```{.python .input n=3}
97+
Just like the momentum method, Adagrad needs to maintain a state variable of the same shape as the parameters.
7798

99+
```{.python .input n=8}
78100
def init_adagrad_states(feature_dim):
79101
s_w = np.zeros((feature_dim, 1))
80102
s_b = np.zeros(1)
@@ -87,34 +109,46 @@ def adagrad(params, states, hyperparams):
87109
p[:] -= hyperparams['lr'] * p.grad / np.sqrt(s + eps)
88110
```
89111

90-
Compared with the experiment in :numref:`sec_minibatch_sgd`, here, we use a
112+
Compared to the experiment in :numref:`sec_minibatch_sgd` we use a
91113
larger learning rate to train the model.
92114

93-
```{.python .input n=4}
115+
```{.python .input n=9}
94116
data_iter, feature_dim = d2l.get_data_ch10(batch_size=10)
95-
d2l.train_ch10(adagrad, init_adagrad_states(feature_dim),
117+
d2l.train_ch10(adagrad, init_adagrad_states(feature_dim),
96118
{'lr': 0.1}, data_iter, feature_dim);
97119
```
98120

99121
## Concise Implementation
100122

101-
Using the `Trainer` instance of the algorithm named “adagrad, we can implement the Adagrad algorithm with Gluon to train models.
123+
Using the `Trainer` instance of the algorithm `adagrad`, we can invoke the Adagrad algorithm in Gluon.
102124

103125
```{.python .input n=5}
104126
d2l.train_gluon_ch10('adagrad', {'learning_rate': 0.1}, data_iter)
105127
```
106128

107129
## Summary
108130

109-
* Adagrad constantly adjusts the learning rate during iteration to give each element in the independent variable of the objective function its own learning rate.
110-
* When using Adagrad, the learning rate of each element in the independent variable decreases (or remains unchanged) during iteration.
131+
* Adagrad decreases the learning rate dynamically on a per-coordinate basis.
132+
* It uses the magnitude of the gradient as a means of adjusting how quickly progress is achieved - coordinates with large gradients are compensated with a smaller learning rate.
133+
* Computing the exact second derivative is typically infeasible in deep learing problems due to memory and computational constraints. The gradient can be a useful proxy.
134+
* If the optimization problem has a rather uneven uneven structure Adagrad can help mitigate the distortion.
135+
* Adagrad is particularly effective for sparse features where the learning rate needs to decrease more slowly for infrequently occurring terms.
136+
* On deep learning problems Adagrad can sometimes be too aggressive in reducing learning rates. We will discuss strategies for mitigating this in the context of :ref:`sec_adam`.
111137

112138
## Exercises
113139

114-
* When introducing the features of Adagrad, we mentioned a potential problem. What solutions can you think of to fix this problem?
115-
* Try to use other initial learning rates in the experiment. How does this change the results?
140+
1. Prove that for an orthogonal matrix $U$ and a vector $\mathbf{c}$ the following holds: $\|\mathbf{c} - \mathbf{\delta}\|_2 = \|U \mathbf{c} - U \mathbf{\delta}\|_2$. Why does this mean that the magnitude of perturbations does not change after an orthogonal change of variables?
141+
1. Try out Adagrad for $f(\mathbf{x}) = 0.1 x_1^2 + 2 x_2^2$ and also for the objective function was rotated by 45 degrees, i.e. $f(\mathbf{x}) = 0.1 (x_1 + x_2)^2 + 2 (x_1 - x_2)^2$. Does it behave differently?
142+
1. Prove [Gerschgorin's circle theorem](https://en.wikipedia.org/wiki/Gershgorin_circle_theorem) which states that eigenvalues $\lambda_i$ of a matrix $M$ satisfy $|\lambda_i - M_{jj}| \leq \sum_{k \neq j} |M_{jk}|$ for at least one choice of $j$.
143+
1. What does Gerschgorin's theorem tell us about the eigenvalues of the diagonally preconditioned matrix $\mathrm{diag}^{-\frac{1}{2}}(M) M \mathrm{diag}^{-\frac{1}{2}}(M)$?
144+
1. Try out Adagrad for a proper deep network, such as :ref:`sec_lenet` when applied to Fashion MNIST.
145+
1. How would you need to modify Adagrad to achieve a less aggressive decay in learning rate?
116146

117147

118148
## Scan the QR Code to [Discuss](https://discuss.mxnet.io/t/2375)
119149

120150
![](../img/qr_adagrad.svg)
151+
152+
```{.python .input}
153+
154+
```

0 commit comments

Comments
 (0)