Skip to content

Commit 0f631a2

Browse files
danielsuoFlax Authors
authored andcommitted
Fix PRNG handling in nn.jit under nn.scan.
* `nn.scan` does an abstract eval before compilation to check for constants that are then traced out. Before this change, the abstract eval increments static RNG counters, which creates a side-effect where RNG counters are not properly updated once inner functions are jitted (i.e., under `nn.jit`). * In this fix, we cache the impact a first pass through `nn.scan` and `nn.jit` would have on rng counters and "replay" that impact on subsequent passes so that rng state remains unaffected. * This solution doesn't affect PRNG derivations outside this isolated case and is a placeholder while a more permanent solution, which would affect PRNG derivations, is worked out. PiperOrigin-RevId: 694463650
1 parent 0679702 commit 0f631a2

File tree

6 files changed

+337
-92
lines changed

6 files changed

+337
-92
lines changed

flax/core/lift.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import warnings
2525

2626
from flax import traceback_util
27+
from flax import traverse_util
2728
from flax.typing import (
2829
In,
2930
InOutAxis,
@@ -1499,6 +1500,81 @@ def _hashable_filter(x):
14991500
return x
15001501

15011502

1503+
class CountsHolder:
1504+
1505+
def __init__(self, flat_d):
1506+
self.flat_d = flat_d
1507+
1508+
@classmethod
1509+
def make(cls, d):
1510+
flat_d = traverse_util.flatten_dict(d)
1511+
flat_d = {k: v for k, v in flat_d.items()}
1512+
return cls(flat_d)
1513+
1514+
def sub(self, other):
1515+
delta_flat_d = {}
1516+
new_flat_d = collections.defaultdict(int, self.flat_d)
1517+
old_flat_d = collections.defaultdict(int, other.flat_d)
1518+
for k in new_flat_d:
1519+
delta_flat_d[k] = new_flat_d[k] - old_flat_d[k]
1520+
return CountsHolder(delta_flat_d)
1521+
1522+
def add(self, other):
1523+
delta_flat_d = {}
1524+
new_flat_d = collections.defaultdict(int, self.flat_d)
1525+
old_flat_d = collections.defaultdict(int, other.flat_d)
1526+
for k in new_flat_d:
1527+
delta_flat_d[k] = new_flat_d[k] + old_flat_d[k]
1528+
return CountsHolder(delta_flat_d)
1529+
1530+
def unflat(self):
1531+
return traverse_util.unflatten_dict(self.flat_d)
1532+
1533+
1534+
def set_from_dict(original, updates):
1535+
for k in updates:
1536+
if k not in original:
1537+
original[k] = updates[k]
1538+
else:
1539+
if isinstance(updates[k], dict):
1540+
set_from_dict(original[k], updates[k])
1541+
else:
1542+
original[k] = updates[k]
1543+
1544+
1545+
class _SideEffectCache(threading.local):
1546+
1547+
def __init__(self):
1548+
self.cache = {}
1549+
1550+
1551+
_side_effect_cache = _SideEffectCache()
1552+
1553+
1554+
def _restore_rng_counters(scopes, fingerprint, capture_old_counts):
1555+
if fingerprint not in _side_effect_cache.cache:
1556+
capture_new_counts = jax.tree.map(
1557+
lambda s: CountsHolder.make(s.rng_counters), scopes
1558+
)
1559+
capture_delta_counts = jax.tree.map(
1560+
lambda old, new: new.sub(old),
1561+
capture_old_counts,
1562+
capture_new_counts,
1563+
)
1564+
_side_effect_cache.cache[fingerprint] = capture_delta_counts
1565+
else:
1566+
updated_counts = jax.tree.map(
1567+
lambda x, y: x.add(y).unflat(),
1568+
_side_effect_cache.cache[fingerprint],
1569+
capture_old_counts,
1570+
)
1571+
jax.tree.map(
1572+
lambda s, u: set_from_dict(s.rng_counters, u),
1573+
scopes,
1574+
updated_counts,
1575+
)
1576+
1577+
15021578
def jit(
15031579
fn: Callable[..., Any],
15041580
variables: CollectionFilter = True,
@@ -1599,13 +1675,18 @@ def inner(
15991675
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)
16001676

16011677
rng_groups = jax.tree.map(
1602-
lambda x: x.fold() if isinstance(x, LazyRng) else x,
1678+
lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x,
16031679
rng_groups,
16041680
is_leaf=lambda x: isinstance(x, LazyRng),
16051681
)
16061682

16071683
fingerprint = (mutable, module_hash_key)
1608-
return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
1684+
capture_old_counts = jax.tree.map(
1685+
lambda s: CountsHolder.make(s.rng_counters), scopes
1686+
)
1687+
res = jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
1688+
_restore_rng_counters(scopes, fingerprint, capture_old_counts)
1689+
return res
16091690

16101691
return pack(
16111692
inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True
@@ -1692,3 +1773,64 @@ def inner_loop(scope, carry):
16921773
def _unzip2(xs):
16931774
ys = tuple(zip(*xs))
16941775
return ys if ys else ((), ())
1776+
1777+
1778+
def fold_rngs(
1779+
fn: Callable[..., Any],
1780+
variables: CollectionFilter = True,
1781+
rngs: PRNGSequenceFilter = True,
1782+
) -> Callable[..., Any]:
1783+
# Close over scope_fn & repack_fn to avoid recompilation
1784+
# this is impure but we use the fingerprint arg to differentiate between cases
1785+
# where scope_fn or repack_fn actually produce non-identical results.
1786+
fold_rngs_context = TransformContext[tuple[Callable, Callable]]()
1787+
1788+
@functools.wraps(fn)
1789+
def wrapped_fold_rngs(fingerprint, variable_groups, rng_groups, *args, **kwargs):
1790+
scope_fn, repack_fn = fold_rngs_context.get()
1791+
hash_key = fingerprint[1]
1792+
# fingerprint is only used to differentiate the cache signature
1793+
# del fingerprint
1794+
scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable
1795+
y = fn(scope, hash_key, *args, **kwargs)
1796+
return y, repack_fn(scope) # pylint: disable=not-callable
1797+
1798+
def inner_fold_rngs(
1799+
scope_fn,
1800+
repack_fn,
1801+
variable_groups,
1802+
rng_groups,
1803+
module_hash_key,
1804+
*args,
1805+
**kwargs,
1806+
):
1807+
with fold_rngs_context.push((scope_fn, repack_fn)):
1808+
scopes: list[Scope] = jax.tree_util.tree_leaves(
1809+
scope_fn(variable_groups, rng_groups)
1810+
)
1811+
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)
1812+
1813+
rng_groups = jax.tree.map(
1814+
lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x,
1815+
rng_groups,
1816+
is_leaf=lambda x: isinstance(x, LazyRng),
1817+
)
1818+
1819+
fingerprint = (mutable, module_hash_key)
1820+
capture_old_counts = jax.tree.map(
1821+
lambda s: CountsHolder.make(s.rng_counters), scopes
1822+
)
1823+
res = wrapped_fold_rngs(
1824+
fingerprint, variable_groups, rng_groups, *args, **kwargs
1825+
)
1826+
_restore_rng_counters(scopes, fingerprint, capture_old_counts)
1827+
return res
1828+
1829+
return pack(
1830+
inner_fold_rngs,
1831+
(variables,),
1832+
(variables,),
1833+
(rngs,),
1834+
name='fold_rngs',
1835+
enable_kwargs=True,
1836+
)

flax/core/scope.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ def create(
102102
else:
103103
return LazyRng(rng, suffix)
104104

105-
def fold(self):
106-
key = self.as_jax_rng()
105+
def clear_suffix(self):
106+
key = self.rng
107107
return LazyRng(key, ())
108108

109109

@@ -583,13 +583,6 @@ def default_name(self, prefix: str) -> str:
583583
return name
584584
i += 1
585585

586-
def fold_rngs(self):
587-
"""Folds the rngs of this scope into the parent scope."""
588-
self._check_valid()
589-
for name, rng in self.rngs.items():
590-
assert isinstance(rng, LazyRng)
591-
self.rngs[name] = rng.fold()
592-
593586
def push(
594587
self, name: str | None = None, prefix: str = '', reuse=False
595588
) -> 'Scope':

flax/linen/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@
7272
from .batch_apply import BatchApply as BatchApply
7373
from .combinators import Sequential as Sequential
7474
from .fp8_ops import (
75-
Fp8DotGeneralOp as Fp8DotGeneralOp,
7675
Fp8DirectDotGeneralOp as Fp8DirectDotGeneralOp,
76+
Fp8DotGeneralOp as Fp8DotGeneralOp,
7777
NANOOFp8DotGeneralOp as NANOOFp8DotGeneralOp,
7878
)
7979
from .initializers import (
@@ -95,8 +95,8 @@
9595
Module as Module,
9696
Variable as Variable,
9797
apply as apply,
98-
compact as compact,
9998
compact_name_scope as compact_name_scope,
99+
compact as compact,
100100
disable_named_call as disable_named_call,
101101
enable_named_call as enable_named_call,
102102
init_with_output as init_with_output,
@@ -114,19 +114,19 @@
114114
LayerNorm as LayerNorm,
115115
RMSNorm as RMSNorm,
116116
SpectralNorm as SpectralNorm,
117-
WeightNorm as WeightNorm
117+
WeightNorm as WeightNorm,
118118
)
119119
from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool)
120120
from .recurrent import (
121121
Bidirectional as Bidirectional,
122122
ConvLSTMCell as ConvLSTMCell,
123-
SimpleCell as SimpleCell,
124123
GRUCell as GRUCell,
125-
MGUCell as MGUCell,
126124
LSTMCell as LSTMCell,
125+
MGUCell as MGUCell,
127126
OptimizedLSTMCell as OptimizedLSTMCell,
128127
RNNCellBase as RNNCellBase,
129128
RNN as RNN,
129+
SimpleCell as SimpleCell,
130130
)
131131
from .spmd import (
132132
LogicallyPartitioned as LogicallyPartitioned,
@@ -146,6 +146,8 @@
146146
checkpoint as checkpoint,
147147
cond as cond,
148148
custom_vjp as custom_vjp,
149+
fold_rngs as fold_rngs,
150+
grad as grad,
149151
jit as jit,
150152
jvp as jvp,
151153
map_variables as map_variables,
@@ -154,9 +156,8 @@
154156
remat as remat,
155157
scan as scan,
156158
switch as switch,
157-
vjp as vjp,
158-
grad as grad,
159159
value_and_grad as value_and_grad,
160+
vjp as vjp,
160161
vmap as vmap,
161162
while_loop as while_loop,
162163
)

0 commit comments

Comments
 (0)