|
| 1 | +r"""BackPACK's retain_graph option |
| 2 | +================================== |
| 3 | +
|
| 4 | +This tutorial demonstrates how to perform multiple backward passes through the |
| 5 | +same computation graph with BackPACK. This option can be useful if you run into |
| 6 | +out-of-memory errors. If your computation can be chunked, you might consider |
| 7 | +distributing it onto multiple backward passes to reduce peak memory. |
| 8 | +
|
| 9 | +Our use case for such a quantity is the GGN diagonal of an auto-encoder's |
| 10 | +reconstruction error. |
| 11 | +
|
| 12 | +But first, the imports: |
| 13 | +""" |
| 14 | + |
| 15 | +from functools import partial |
| 16 | +from time import time |
| 17 | +from typing import List |
| 18 | + |
| 19 | +from memory_profiler import memory_usage |
| 20 | +from torch import Tensor, allclose, manual_seed, rand, zeros_like |
| 21 | +from torch.nn import Conv2d, ConvTranspose2d, Flatten, MSELoss, Sequential, Sigmoid |
| 22 | + |
| 23 | +from backpack import backpack, extend |
| 24 | +from backpack.custom_module.slicing import Slicing |
| 25 | +from backpack.extensions import DiagGGNExact |
| 26 | + |
| 27 | +# make deterministic |
| 28 | +manual_seed(0) |
| 29 | + |
| 30 | +# %% |
| 31 | +# |
| 32 | +# Setup |
| 33 | +# ----- |
| 34 | +# |
| 35 | +# Let :math:`f_{\mathbf{\theta}}` denote the auto-encoder, and |
| 36 | +# :math:`\mathbf{x'} = f_{\mathbf{\theta}}(\mathbf{x}) \in \mathbb{R}^M` its |
| 37 | +# reconstruction of an input :math:`\mathbf{x} \in \mathbb{R}^M`. The |
| 38 | +# associated reconstruction error is measured by the mean squared error |
| 39 | +# |
| 40 | +# .. math:: |
| 41 | +# \ell(\mathbf{\theta}) |
| 42 | +# = |
| 43 | +# \frac{1}{M} |
| 44 | +# \left\lVert f_{\mathbf{\theta}}(\mathbf{x}) - \mathbf{x} \right\rVert^2_2\,. |
| 45 | +# |
| 46 | +# On a batch of :math:`N` examples, :math:`\mathbf{x}_1, \dots, \mathbf{x}_N`, |
| 47 | +# the loss is |
| 48 | +# |
| 49 | +# .. math:: |
| 50 | +# \mathcal{L}(\mathbf{\theta}) |
| 51 | +# = |
| 52 | +# \frac{1}{N} \frac{1}{M} |
| 53 | +# \sum_{n=1}^N |
| 54 | +# \left\lVert f_{\mathbf{\theta}}(\mathbf{x}_n) - \mathbf{x}_n \right\rVert^2_2\,. |
| 55 | +# |
| 56 | +# Let's create a toy model and data: |
| 57 | + |
| 58 | +# data |
| 59 | +batch_size, channels, spatial_dims = 5, 3, (32, 32) |
| 60 | +X = rand(batch_size, channels, *spatial_dims) |
| 61 | + |
| 62 | +# model (auto-encoder) |
| 63 | +hidden_channels = 10 |
| 64 | + |
| 65 | +encoder = Sequential( |
| 66 | + Conv2d(channels, hidden_channels, 3), |
| 67 | + Sigmoid(), |
| 68 | +) |
| 69 | +decoder = Sequential( |
| 70 | + ConvTranspose2d(hidden_channels, channels, 3), |
| 71 | + Flatten(), |
| 72 | +) |
| 73 | +model = Sequential( |
| 74 | + encoder, |
| 75 | + decoder, |
| 76 | +) |
| 77 | +loss_func = MSELoss() |
| 78 | + |
| 79 | +# %% |
| 80 | +# |
| 81 | +# We will use BackPACK to compute the GGN diagonal of the mini-batch loss. To |
| 82 | +# do that, we need to :py:func:`extend <backpack.extend>` model and loss |
| 83 | +# function. |
| 84 | + |
| 85 | +model = extend(model) |
| 86 | +loss_func = extend(loss_func) |
| 87 | + |
| 88 | +# %% |
| 89 | +# |
| 90 | +# GGN diagonal in one backward pass |
| 91 | +# --------------------------------- |
| 92 | +# |
| 93 | +# As usual, we can compute the GGN diagonal for the mini-batch loss in a single |
| 94 | +# backward pass. The following function does that: |
| 95 | + |
| 96 | + |
| 97 | +def diag_ggn_one_pass() -> List[Tensor]: |
| 98 | + """Compute the GGN diagonal in a single backward pass. |
| 99 | +
|
| 100 | + Returns: |
| 101 | + GGN diagonal in parameter list format. |
| 102 | + """ |
| 103 | + reconstruction = model(X) |
| 104 | + error = loss_func(reconstruction, X.flatten(start_dim=1)) |
| 105 | + |
| 106 | + with backpack(DiagGGNExact()): |
| 107 | + error.backward() |
| 108 | + |
| 109 | + return [p.diag_ggn_exact.clone() for p in model.parameters() if p.requires_grad] |
| 110 | + |
| 111 | + |
| 112 | +# %% |
| 113 | +# |
| 114 | +# Let's run it and determine (i) its peak memory consumption and (ii) its run |
| 115 | +# time. |
| 116 | + |
| 117 | +print("GGN diagonal in one backward pass:") |
| 118 | +start = time() |
| 119 | +max_mem, diag_ggn = memory_usage( |
| 120 | + diag_ggn_one_pass, interval=1e-3, max_usage=True, retval=True |
| 121 | +) |
| 122 | +end = time() |
| 123 | + |
| 124 | +print(f"\tPeak memory [MiB]: {max_mem:.2e}") |
| 125 | +print(f"\tTime [s]: {end-start:.2e}") |
| 126 | + |
| 127 | +# %% |
| 128 | +# |
| 129 | +# The memory consumption is pretty high, although our model is relatively |
| 130 | +# small! If we make the model deeper, or increase the mini-batch size, we will |
| 131 | +# quickly run out of memory. |
| 132 | +# |
| 133 | +# This is because computing the GGN diagonal scales with the network's output |
| 134 | +# dimension. For classification settings like MNIST and CIFAR-10, this number |
| 135 | +# is relatively small (:code:`10`). But for an auto-encoder, this number is the |
| 136 | +# input dimension :code:`M`, which in our case is |
| 137 | + |
| 138 | +print(f"Output dimension: {model(X).shape[1:].numel()}") |
| 139 | + |
| 140 | +# %% |
| 141 | +# |
| 142 | +# We will now take a look at how to circumvent the high peak memory by |
| 143 | +# distributing the computation over multiple backward passes. |
| 144 | + |
| 145 | +# %% |
| 146 | +# |
| 147 | +# GGN diagonal in chunks |
| 148 | +# ---------------------- |
| 149 | +# |
| 150 | +# The GGN diagonal computation can be distributed across multiple backward |
| 151 | +# passes. This greatly reduces peak memory consumption. |
| 152 | +# |
| 153 | +# To see this, let's consider the GGN diagonal for a single example |
| 154 | +# :math:`\mathbf{x}`, |
| 155 | +# |
| 156 | +# .. math:: |
| 157 | +# \mathrm{diag} |
| 158 | +# \left( |
| 159 | +# \left[ |
| 160 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 161 | +# \right]^\top |
| 162 | +# \left[ |
| 163 | +# \frac{2}{M} \mathbf{I}_{M\times M} |
| 164 | +# \right] |
| 165 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 166 | +# \right)\,, |
| 167 | +# |
| 168 | +# with the :math:`M \times |\mathbf{\theta}|` Jacobian |
| 169 | +# :math:`\mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x})` of the |
| 170 | +# model, and :math:`\frac{2}{M} \mathbf{I}_{M\times M}` the mean squared |
| 171 | +# error's Hessian w.r.t. the reconstructed input. Here you can see that the |
| 172 | +# memory consumption scales with the output dimension, as we need to compute |
| 173 | +# :code:`M` vector-Jacobian products. |
| 174 | +# |
| 175 | +# Let :math:`S`, the chunk size, be a number that divides the output dimension |
| 176 | +# :math:`M`. Then, we can decompose the above computation into chunks: |
| 177 | +# |
| 178 | +# .. math:: |
| 179 | +# \frac{S}{M} |
| 180 | +# \left\{ |
| 181 | +# \mathrm{diag} |
| 182 | +# \left( |
| 183 | +# \left[ |
| 184 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 185 | +# \right]^\top_{:, 0:S} |
| 186 | +# \left[ |
| 187 | +# \frac{2}{S} \mathbf{I}_{S\times S} |
| 188 | +# \right] |
| 189 | +# \left[ |
| 190 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 191 | +# \right]_{0:S, :} |
| 192 | +# \right) \right. |
| 193 | +# \\ |
| 194 | +# + |
| 195 | +# \left. |
| 196 | +# \mathrm{diag} |
| 197 | +# \left( |
| 198 | +# \left[ |
| 199 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 200 | +# \right]^\top_{:, S: 2S} |
| 201 | +# \left[ |
| 202 | +# \frac{2}{S} \mathbf{I}_{S\times S} |
| 203 | +# \right] |
| 204 | +# \left[ |
| 205 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 206 | +# \right]_{:, S:2S} |
| 207 | +# \right) |
| 208 | +# \right. |
| 209 | +# \\ |
| 210 | +# + |
| 211 | +# \left. |
| 212 | +# \mathrm{diag} |
| 213 | +# \left( |
| 214 | +# \left[ |
| 215 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 216 | +# \right]^\top_{:, 2S: 3S} |
| 217 | +# \left[ |
| 218 | +# \frac{2}{S} \mathbf{I}_{S\times S} |
| 219 | +# \right] |
| 220 | +# \left[ |
| 221 | +# \mathbf{J}_{\mathbf{\theta}} f_{\mathbf{\theta}}(\mathbf{x}) |
| 222 | +# \right]_{:, 2S: 3S} |
| 223 | +# \right) |
| 224 | +# + |
| 225 | +# \dots |
| 226 | +# \right\}\,. |
| 227 | +# |
| 228 | +# Each summand is the GGN diagonal of the mean squared error on a chunk |
| 229 | +# |
| 230 | +# .. math:: |
| 231 | +# \tilde{\ell}(\mathbf{\theta}) |
| 232 | +# = |
| 233 | +# \frac{1}{S} |
| 234 | +# \lVert |
| 235 | +# [f_{\mathbf{\theta}}(\mathbf{x})]_{i S: (i+1) S} |
| 236 | +# - |
| 237 | +# [\mathbf{x}]_{i S: (i+1) S} |
| 238 | +# \rVert_2^2\,, |
| 239 | +# \qquad i = 0, 1, \dots, \frac{M}{S} - 1\,, |
| 240 | +# |
| 241 | +# and its memory consumption scales with :math:`S < M`. |
| 242 | +# |
| 243 | +# In summary, the computation split works as follows: |
| 244 | +# |
| 245 | +# - Compute :math:`f_{\mathbf{\theta}}(\mathbf{x})` in a single forward pass. |
| 246 | +# |
| 247 | +# - Compute the reconstruction error for a chunk and its GGN in one backward |
| 248 | +# pass. |
| 249 | +# |
| 250 | +# - Repeat the last step for the other chunks. Accumulate the GGN diagonals |
| 251 | +# over all chunks. |
| 252 | +# |
| 253 | +# (This carries over to the mini-batch case in a straightforward fashion. We |
| 254 | +# avoid the presentation here because of the involved notation, though.) |
| 255 | +# |
| 256 | +# Note that because we perform multiple backward passes, we need to tell |
| 257 | +# PyTorch (and BackPACK) to retain the graph. |
| 258 | +# |
| 259 | +# To slice out a chunk, we use BackPACK's :py:class:`Slicing |
| 260 | +# <backpack.custom_module.slicing>` module. |
| 261 | +# |
| 262 | +# Here is the implementation: |
| 263 | + |
| 264 | + |
| 265 | +def diag_ggn_multiple_passes(num_chunks: int) -> List[Tensor]: |
| 266 | + """Compute the GGN diagonal in multiple backward passes. |
| 267 | +
|
| 268 | + Uses less memory than ``diag_ggn_one_pass`` if ``num_chunks > 1``. |
| 269 | + Does the same as ``diag_ggn_one_pass`` for ``num_chunks = 1``. |
| 270 | +
|
| 271 | + Args: |
| 272 | + num_chunks: Number of backward passes. Must divide the model's output dimension. |
| 273 | +
|
| 274 | + Returns: |
| 275 | + GGN diagonal in parameter list format. |
| 276 | +
|
| 277 | + Raises: |
| 278 | + ValueError: |
| 279 | + If ``num_chunks`` does not divide the model's output dimension. |
| 280 | + NotImplementedError: |
| 281 | + If the model does not return a batched vector (the slicing logic is only |
| 282 | + implemented for batched vectors, i.e. 2d tensors). |
| 283 | + """ |
| 284 | + reconstruction = model(X) |
| 285 | + |
| 286 | + if reconstruction.numel() % num_chunks != 0: |
| 287 | + raise ValueError("Network output must be divisible by number of chunks.") |
| 288 | + if reconstruction.dim() != 2: |
| 289 | + raise NotImplementedError("Slicing logic only implemented for 2d outputs.") |
| 290 | + |
| 291 | + chunk_size = reconstruction.shape[1:].numel() // num_chunks |
| 292 | + diag_ggn_exact = [zeros_like(p) for p in model.parameters()] |
| 293 | + |
| 294 | + for idx in range(num_chunks): |
| 295 | + # set up the layer that extracts the current slice |
| 296 | + slicing = (slice(None), slice(idx * chunk_size, (idx + 1) * chunk_size)) |
| 297 | + chunk_module = extend(Slicing(slicing)) |
| 298 | + |
| 299 | + # compute the chunk's loss |
| 300 | + sliced_reconstruction = chunk_module(reconstruction) |
| 301 | + sliced_X = X.flatten(start_dim=1)[slicing] |
| 302 | + slice_error = loss_func(sliced_reconstruction, sliced_X) |
| 303 | + |
| 304 | + # compute its GGN diagonal ... |
| 305 | + with backpack(DiagGGNExact(), retain_graph=True): |
| 306 | + slice_error.backward(retain_graph=True) |
| 307 | + |
| 308 | + # ... and accumulate it |
| 309 | + for p_idx, p in enumerate(model.parameters()): |
| 310 | + diag_ggn_exact[p_idx] += p.diag_ggn_exact |
| 311 | + |
| 312 | + # fix normalization |
| 313 | + return [ggn / num_chunks for ggn in diag_ggn_exact] |
| 314 | + |
| 315 | + |
| 316 | +# %% |
| 317 | +# |
| 318 | +# Let's benchmark peak memory and run time for different numbers of chunks: |
| 319 | + |
| 320 | +num_chunks = [1, 4, 16, 64] |
| 321 | + |
| 322 | +for n in num_chunks: |
| 323 | + print(f"GGN diagonal in {n} backward passes:") |
| 324 | + start = time() |
| 325 | + max_mem, diag_ggn_chunk = memory_usage( |
| 326 | + partial(diag_ggn_multiple_passes, n), interval=1e-3, max_usage=True, retval=True |
| 327 | + ) |
| 328 | + end = time() |
| 329 | + print(f"\tPeak memory [MiB]: {max_mem:.2e}") |
| 330 | + print(f"\tTime [s]: {end-start:.2e}") |
| 331 | + |
| 332 | + correct = [ |
| 333 | + allclose(diag1, diag2, rtol=5e-3, atol=5e-5) |
| 334 | + for diag1, diag2 in zip(diag_ggn, diag_ggn_chunk) |
| 335 | + ] |
| 336 | + print(f"\tCorrect: {correct}") |
| 337 | + |
| 338 | + if not all(correct): |
| 339 | + raise RuntimeError("Mismatch in GGN diagonals.") |
| 340 | + |
| 341 | +# %% |
| 342 | +# |
| 343 | +# We can see that using more chunks consistently decreases the peak memory. |
| 344 | +# Even run time decreases up to a sweet spot where increasing the number of |
| 345 | +# chunks further eventually slows down the computation. The details of this |
| 346 | +# trade-off will depend on your model and compute architecture. |
| 347 | +# |
| 348 | +# Concluding remarks |
| 349 | +# ------------------ |
| 350 | +# |
| 351 | +# Here, we considered chunking the computation along the auto-encoder's output |
| 352 | +# dimension. There are other ways to achieve the desired effect of reducing |
| 353 | +# peak memory: |
| 354 | +# |
| 355 | +# - In the mini-batch setting, we could only consider a subset of mini-batch |
| 356 | +# samples at each backpropagation. This can be done with the optional |
| 357 | +# :code:`subsampling` argument in many BackPACK's extensions. See the |
| 358 | +# :ref:`mini-batch sub-sampling tutorial <Mini-batch sub-sampling>`. This |
| 359 | +# technique can be combined with the above. |
| 360 | +# |
| 361 | +# - We could turn off the gradient computation (and thereby BackPACK's |
| 362 | +# computation) for all but a subgroup of parameters by setting their |
| 363 | +# :code:`requires_grad` attribute to :code:`False` and compute the GGN |
| 364 | +# diagonal only for these. However, for this to work we will need to perform |
| 365 | +# a new forward pass for each parameter subgroup. |
0 commit comments