Skip to content

Commit 61e47c2

Browse files
author
Flax Authors
committed
Merge pull request #2499 from levskaya:lifetimefix
PiperOrigin-RevId: 479153741
2 parents 63e7657 + bfa517f commit 61e47c2

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

flax/linen/module.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,27 @@ def reimport(self, other: '_ModuleInternalState') -> None:
523523
_ParentType = Union[Type['Module'], Type[Scope], Type[_Sentinel], None]
524524

525525

526+
class ParentDescriptor:
527+
"""Wraps parent module references in weak refs.
528+
529+
This prevents reference cycles from forming via parent links which can lead
530+
to accidental OOMs in eager mode due to slow garbage collection as well as
531+
spurious tracer leaks during jit compilation.
532+
533+
Note: "descriptors" are the underlying python mechanism for implementing
534+
dynamic @property decorators. We need to use a raw descriptor instead of the
535+
more common decorator in order to force that the appropriate getter/setter
536+
logic applies in subclasses even after various dataclass transforms.
537+
"""
538+
def __get__(self, obj, objtype=None):
539+
parent = object.__getattribute__(obj, "_parent_ref")
540+
return parent() if isinstance(parent, weakref.ReferenceType) else parent
541+
542+
def __set__(self, obj, value):
543+
maybe_weak = weakref.ref(value) if isinstance(value, Module) else value
544+
object.__setattr__(obj, "_parent_ref", maybe_weak)
545+
546+
526547
# Base Module definition.
527548
# -----------------------------------------------------------------------------
528549

@@ -588,6 +609,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
588609
# Set empty class defaults.
589610
cls._state = _uninitialized_module_internal_state
590611
cls.scope: Optional[Scope] = None
612+
# Handles weak referencing of parent Modules to prevent reference cycles.
613+
cls.parent = ParentDescriptor()
591614

592615
@classmethod
593616
def _customized_dataclass_transform(cls):

tests/linen/linen_module_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import dataclasses
1818
import functools
19+
import gc
1920
import operator
2021
from typing import (Any, Callable, Generic, Mapping, NamedTuple, Sequence,
2122
Tuple, TypeVar)
@@ -31,6 +32,7 @@
3132
from jax.nn import initializers
3233
import jax.numpy as jnp
3334
import numpy as np
35+
from unittest.mock import patch
3436

3537
# Parse absl flags test_srcdir and test_tmpdir.
3638
jax.config.parse_flags_with_absl()
@@ -1755,5 +1757,31 @@ def __call__(self):
17551757
self.assertFalse(foo.apply({}))
17561758

17571759

1760+
class LeakTests(absltest.TestCase):
1761+
1762+
def test_tracer_leaks(self):
1763+
model = nn.Sequential([nn.Dense(50)])
1764+
1765+
@jax.jit
1766+
@functools.partial(jax.vmap, in_axes=(0, None))
1767+
def sample_from_prior(rng, inp):
1768+
params = model.init(rng, np.zeros((10, 50)))
1769+
out = model.apply(params, inp)
1770+
del params
1771+
return out
1772+
1773+
# disable manual gc.collect call in jax leak checker
1774+
# so that we can test tracer leaks in ref-cycles. This is a
1775+
# reasonable proxy for transiently leaked memory during
1776+
# eager execution.
1777+
with patch.object(gc, 'collect', return_value=0):
1778+
with jax.checking_leaks():
1779+
for i in range(5):
1780+
rngs = jax.random.split(jax.random.PRNGKey(23), 100)
1781+
out = sample_from_prior(rngs, np.ones((4, 50)))
1782+
out.block_until_ready()
1783+
del out, rngs
1784+
1785+
17581786
if __name__ == '__main__':
17591787
absltest.main()

0 commit comments

Comments
 (0)