Skip to content

Commit 7b0b712

Browse files
authored
[DOC] Use case for retain_graph (#302)
Adds a tutorial for BackPACK's `retain_graph` option. It shows how to distribute the GGN diagonal computation of an auto- encoder architecture over multiple backward passes to reduce peak memory. This use case recently came up in a discussion with @wiseodd on Laplace approximations for auto-encoders (or any large output neural network with square loss). * [ADD] Prototype of `retain_graph` example * [DOC] Add comments to retain_graph example * [REF] Improve comments * [REF] Improve title format
1 parent d4530c0 commit 7b0b712

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
r"""BackPACK's retain_graph option
2+
==================================
3+
4+
This tutorial demonstrates how to perform multiple backward passes through the
5+
same computation graph with BackPACK. This option can be useful if you run into
6+
out-of-memory errors. If your computation can be chunked, you might consider
7+
distributing it onto multiple backward passes to reduce peak memory.
8+
9+
Our use case for such a quantity is the GGN diagonal of an auto-encoder's
10+
reconstruction error.
11+
12+
But first, the imports:
13+
"""
14+
15+
from functools import partial
16+
from time import time
17+
from typing import List
18+
19+
from memory_profiler import memory_usage
20+
from torch import Tensor, allclose, manual_seed, rand, zeros_like
21+
from torch.nn import Conv2d, ConvTranspose2d, Flatten, MSELoss, Sequential, Sigmoid
22+
23+
from backpack import backpack, extend
24+
from backpack.custom_module.slicing import Slicing
25+
from backpack.extensions import DiagGGNExact
26+
27+
# make deterministic
28+
manual_seed(0)
29+
30+
# %%
31+
#
32+
# Setup
33+
# -----
34+
#
35+
# Let :math:`f_{\mathbf{\theta}}` denote the auto-encoder, and
36+
# :math:`\mathbf{x'} = f_{\mathbf{\theta}}(\mathbf{x}) \in \mathbb{R}^M` its
37+
# reconstruction of an input :math:`\mathbf{x} \in \mathbb{R}^M`. The
38+
# associated reconstruction error is measured by the mean squared error
39+
#
40+
# .. math::
41+
# \ell(\mathbf{\theta})
42+
# =
43+
# \frac{1}{M}
44+
# \left\lVert f_{\mathbf{\theta}}(\mathbf{x}) - \mathbf{x} \right\rVert^2_2\,.
45+
#
46+
# On a batch of :math:`N` examples, :math:`\mathbf{x}_1, \dots, \mathbf{x}_N`,
47+
# the loss is
48+
#
49+
# .. math::
50+
# \mathcal{L}(\mathbf{\theta})
51+
# =
52+
# \frac{1}{N} \frac{1}{M}
53+
# \sum_{n=1}^N
54+
# \left\lVert f_{\mathbf{\theta}}(\mathbf{x}_n) - \mathbf{x}_n \right\rVert^2_2\,.
55+
#
56+
# Let's create a toy model and data:
57+
58+
# data
59+
batch_size, channels, spatial_dims = 5, 3, (32, 32)
60+
X = rand(batch_size, channels, *spatial_dims)
61+
62+
# model (auto-encoder)
63+
hidden_channels = 10
64+
65+
encoder = Sequential(
66+
Conv2d(channels, hidden_channels, 3),
67+
Sigmoid(),
68+
)
69+
decoder = Sequential(
70+
ConvTranspose2d(hidden_channels, channels, 3),
71+
Flatten(),
72+
)
73+
model = Sequential(
74+
encoder,
75+
decoder,
76+
)
77+
loss_func = MSELoss()
78+
79+
# %%
80+
#
81+
# We will use BackPACK to compute the GGN diagonal of the mini-batch loss. To
82+
# do that, we need to :py:func:`extend <backpack.extend>` model and loss
83+
# function.
84+
85+
model = extend(model)
86+
loss_func = extend(loss_func)
87+
88+
# %%
89+
#
90+
# GGN diagonal in one backward pass
91+
# ---------------------------------
92+
#
93+
# As usual, we can compute the GGN diagonal for the mini-batch loss in a single
94+
# backward pass. The following function does that:
95+
96+
97+
def diag_ggn_one_pass() -> List[Tensor]:
98+
"""Compute the GGN diagonal in a single backward pass.
99+
100+
Returns:
101+
GGN diagonal in parameter list format.
102+
"""
103+
reconstruction = model(X)
104+
error = loss_func(reconstruction, X.flatten(start_dim=1))
105+
106+
with backpack(DiagGGNExact()):
107+
error.backward()
108+
109+
return [p.diag_ggn_exact.clone() for p in model.parameters() if p.requires_grad]
110+
111+
112+
# %%
113+
#
114+
# Let's run it and determine (i) its peak memory consumption and (ii) its run
115+
# time.
116+
117+
print("GGN diagonal in one backward pass:")
118+
start = time()
119+
max_mem, diag_ggn = memory_usage(
120+
diag_ggn_one_pass, interval=1e-3, max_usage=True, retval=True
121+
)
122+
end = time()
123+
124+
print(f"\tPeak memory [MiB]: {max_mem:.2e}")
125+
print(f"\tTime [s]: {end-start:.2e}")
126+
127+
# %%
128+
#
129+
# The memory consumption is pretty high, although our model is relatively
130+
# small! If we make the model deeper, or increase the mini-batch size, we will
131+
# quickly run out of memory.
132+
#
133+
# This is because computing the GGN diagonal scales with the network's output
134+
# dimension. For classification settings like MNIST and CIFAR-10, this number
135+
# is relatively small (:code:`10`). But for an auto-encoder, this number is the
136+
# input dimension :code:`M`, which in our case is
137+
138+
print(f"Output dimension: {model(X).shape[1:].numel()}")
139+
140+
# %%
141+
#
142+
# We will now take a look at how to circumvent the high peak memory by
143+
# distributing the computation over multiple backward passes.
144+
145+
# %%
146+
#
147+
# GGN diagonal in chunks
148+
# ----------------------
149+
#
150+
# The GGN diagonal computation can be distributed across multiple backward
151+
# passes. This greatly reduces peak memory consumption.
152+
#
153+
# To see this, let's consider the GGN diagonal for a single example
154+
# :math:`\mathbf{x}`,
155+
#
156+
# .. math::
157+
# \mathrm{diag}
158+
# \left(
159+
# \left[
160+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
161+
# \right]^\top
162+
# \left[
163+
# \frac{2}{M} \mathbf{I}_{M\times M}
164+
# \right]
165+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
166+
# \right)\,,
167+
#
168+
# with the :math:`M \times |\mathbf{\theta}|` Jacobian
169+
# :math:`\mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})` of the
170+
# model, and :math:`\frac{2}{M} \mathbf{I}_{M\times M}` the mean squared
171+
# error's Hessian w.r.t. the reconstructed input. Here you can see that the
172+
# memory consumption scales with the output dimension, as we need to compute
173+
# :code:`M` vector-Jacobian products.
174+
#
175+
# Let :math:`S`, the chunk size, be a number that divides the output dimension
176+
# :math:`M`. Then, we can decompose the above computation into chunks:
177+
#
178+
# .. math::
179+
# \frac{S}{M}
180+
# \left\{
181+
# \mathrm{diag}
182+
# \left(
183+
# \left[
184+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
185+
# \right]^\top_{:, 0:S}
186+
# \left[
187+
# \frac{2}{S} \mathbf{I}_{S\times S}
188+
# \right]
189+
# \left[
190+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
191+
# \right]_{0:S, :}
192+
# \right) \right.
193+
# \\
194+
# +
195+
# \left.
196+
# \mathrm{diag}
197+
# \left(
198+
# \left[
199+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
200+
# \right]^\top_{:, S: 2S}
201+
# \left[
202+
# \frac{2}{S} \mathbf{I}_{S\times S}
203+
# \right]
204+
# \left[
205+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
206+
# \right]_{:, S:2S}
207+
# \right)
208+
# \right.
209+
# \\
210+
# +
211+
# \left.
212+
# \mathrm{diag}
213+
# \left(
214+
# \left[
215+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
216+
# \right]^\top_{:, 2S: 3S}
217+
# \left[
218+
# \frac{2}{S} \mathbf{I}_{S\times S}
219+
# \right]
220+
# \left[
221+
# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})
222+
# \right]_{:, 2S: 3S}
223+
# \right)
224+
# +
225+
# \dots
226+
# \right\}\,.
227+
#
228+
# Each summand is the GGN diagonal of the mean squared error on a chunk
229+
#
230+
# .. math::
231+
# \tilde{\ell}(\mathbf{\theta})
232+
# =
233+
# \frac{1}{S}
234+
# \lVert
235+
# [f_{\mathbf{\theta}}(\mathbf{x})]_{i S: (i+1) S}
236+
# -
237+
# [\mathbf{x}]_{i S: (i+1) S}
238+
# \rVert_2^2\,,
239+
# \qquad i = 0, 1, \dots, \frac{M}{S} - 1\,,
240+
#
241+
# and its memory consumption scales with :math:`S < M`.
242+
#
243+
# In summary, the computation split works as follows:
244+
#
245+
# - Compute :math:`f_{\mathbf{\theta}}(\mathbf{x})` in a single forward pass.
246+
#
247+
# - Compute the reconstruction error for a chunk and its GGN in one backward
248+
# pass.
249+
#
250+
# - Repeat the last step for the other chunks. Accumulate the GGN diagonals
251+
# over all chunks.
252+
#
253+
# (This carries over to the mini-batch case in a straightforward fashion. We
254+
# avoid the presentation here because of the involved notation, though.)
255+
#
256+
# Note that because we perform multiple backward passes, we need to tell
257+
# PyTorch (and BackPACK) to retain the graph.
258+
#
259+
# To slice out a chunk, we use BackPACK's :py:class:`Slicing
260+
# <backpack.custom_module.slicing>` module.
261+
#
262+
# Here is the implementation:
263+
264+
265+
def diag_ggn_multiple_passes(num_chunks: int) -> List[Tensor]:
266+
"""Compute the GGN diagonal in multiple backward passes.
267+
268+
Uses less memory than ``diag_ggn_one_pass`` if ``num_chunks > 1``.
269+
Does the same as ``diag_ggn_one_pass`` for ``num_chunks = 1``.
270+
271+
Args:
272+
num_chunks: Number of backward passes. Must divide the model's output dimension.
273+
274+
Returns:
275+
GGN diagonal in parameter list format.
276+
277+
Raises:
278+
ValueError:
279+
If ``num_chunks`` does not divide the model's output dimension.
280+
NotImplementedError:
281+
If the model does not return a batched vector (the slicing logic is only
282+
implemented for batched vectors, i.e. 2d tensors).
283+
"""
284+
reconstruction = model(X)
285+
286+
if reconstruction.numel() % num_chunks != 0:
287+
raise ValueError("Network output must be divisible by number of chunks.")
288+
if reconstruction.dim() != 2:
289+
raise NotImplementedError("Slicing logic only implemented for 2d outputs.")
290+
291+
chunk_size = reconstruction.shape[1:].numel() // num_chunks
292+
diag_ggn_exact = [zeros_like(p) for p in model.parameters()]
293+
294+
for idx in range(num_chunks):
295+
# set up the layer that extracts the current slice
296+
slicing = (slice(None), slice(idx * chunk_size, (idx + 1) * chunk_size))
297+
chunk_module = extend(Slicing(slicing))
298+
299+
# compute the chunk's loss
300+
sliced_reconstruction = chunk_module(reconstruction)
301+
sliced_X = X.flatten(start_dim=1)[slicing]
302+
slice_error = loss_func(sliced_reconstruction, sliced_X)
303+
304+
# compute its GGN diagonal ...
305+
with backpack(DiagGGNExact(), retain_graph=True):
306+
slice_error.backward(retain_graph=True)
307+
308+
# ... and accumulate it
309+
for p_idx, p in enumerate(model.parameters()):
310+
diag_ggn_exact[p_idx] += p.diag_ggn_exact
311+
312+
# fix normalization
313+
return [ggn / num_chunks for ggn in diag_ggn_exact]
314+
315+
316+
# %%
317+
#
318+
# Let's benchmark peak memory and run time for different numbers of chunks:
319+
320+
num_chunks = [1, 4, 16, 64]
321+
322+
for n in num_chunks:
323+
print(f"GGN diagonal in {n} backward passes:")
324+
start = time()
325+
max_mem, diag_ggn_chunk = memory_usage(
326+
partial(diag_ggn_multiple_passes, n), interval=1e-3, max_usage=True, retval=True
327+
)
328+
end = time()
329+
print(f"\tPeak memory [MiB]: {max_mem:.2e}")
330+
print(f"\tTime [s]: {end-start:.2e}")
331+
332+
correct = [
333+
allclose(diag1, diag2, rtol=5e-3, atol=5e-5)
334+
for diag1, diag2 in zip(diag_ggn, diag_ggn_chunk)
335+
]
336+
print(f"\tCorrect: {correct}")
337+
338+
if not all(correct):
339+
raise RuntimeError("Mismatch in GGN diagonals.")
340+
341+
# %%
342+
#
343+
# We can see that using more chunks consistently decreases the peak memory.
344+
# Even run time decreases up to a sweet spot where increasing the number of
345+
# chunks further eventually slows down the computation. The details of this
346+
# trade-off will depend on your model and compute architecture.
347+
#
348+
# Concluding remarks
349+
# ------------------
350+
#
351+
# Here, we considered chunking the computation along the auto-encoder's output
352+
# dimension. There are other ways to achieve the desired effect of reducing
353+
# peak memory:
354+
#
355+
# - In the mini-batch setting, we could only consider a subset of mini-batch
356+
# samples at each backpropagation. This can be done with the optional
357+
# :code:`subsampling` argument in many BackPACK's extensions. See the
358+
# :ref:`mini-batch sub-sampling tutorial <Mini-batch sub-sampling>`. This
359+
# technique can be combined with the above.
360+
#
361+
# - We could turn off the gradient computation (and thereby BackPACK's
362+
# computation) for all but a subgroup of parameters by setting their
363+
# :code:`requires_grad` attribute to :code:`False` and compute the GGN
364+
# diagonal only for these. However, for this to work we will need to perform
365+
# a new forward pass for each parameter subgroup.

0 commit comments

Comments
 (0)