|
24 | 24 | import warnings
|
25 | 25 |
|
26 | 26 | from flax import traceback_util
|
| 27 | +from flax import traverse_util |
27 | 28 | from flax.typing import (
|
28 | 29 | In,
|
29 | 30 | InOutAxis,
|
@@ -1499,6 +1500,81 @@ def _hashable_filter(x):
|
1499 | 1500 | return x
|
1500 | 1501 |
|
1501 | 1502 |
|
| 1503 | +class CountsHolder: |
| 1504 | + |
| 1505 | + def __init__(self, flat_d): |
| 1506 | + self.flat_d = flat_d |
| 1507 | + |
| 1508 | + @classmethod |
| 1509 | + def make(cls, d): |
| 1510 | + flat_d = traverse_util.flatten_dict(d) |
| 1511 | + flat_d = {k: v for k, v in flat_d.items()} |
| 1512 | + return cls(flat_d) |
| 1513 | + |
| 1514 | + def sub(self, other): |
| 1515 | + delta_flat_d = {} |
| 1516 | + new_flat_d = collections.defaultdict(int, self.flat_d) |
| 1517 | + old_flat_d = collections.defaultdict(int, other.flat_d) |
| 1518 | + for k in new_flat_d: |
| 1519 | + delta_flat_d[k] = new_flat_d[k] - old_flat_d[k] |
| 1520 | + return CountsHolder(delta_flat_d) |
| 1521 | + |
| 1522 | + def add(self, other): |
| 1523 | + delta_flat_d = {} |
| 1524 | + new_flat_d = collections.defaultdict(int, self.flat_d) |
| 1525 | + old_flat_d = collections.defaultdict(int, other.flat_d) |
| 1526 | + for k in new_flat_d: |
| 1527 | + delta_flat_d[k] = new_flat_d[k] + old_flat_d[k] |
| 1528 | + return CountsHolder(delta_flat_d) |
| 1529 | + |
| 1530 | + def unflat(self): |
| 1531 | + return traverse_util.unflatten_dict(self.flat_d) |
| 1532 | + |
| 1533 | + |
| 1534 | +def set_from_dict(original, updates): |
| 1535 | + for k in updates: |
| 1536 | + if k not in original: |
| 1537 | + original[k] = updates[k] |
| 1538 | + else: |
| 1539 | + if isinstance(updates[k], dict): |
| 1540 | + set_from_dict(original[k], updates[k]) |
| 1541 | + else: |
| 1542 | + original[k] = updates[k] |
| 1543 | + |
| 1544 | + |
| 1545 | +class _SideEffectCache(threading.local): |
| 1546 | + |
| 1547 | + def __init__(self): |
| 1548 | + self.cache = {} |
| 1549 | + |
| 1550 | + |
| 1551 | +_side_effect_cache = _SideEffectCache() |
| 1552 | + |
| 1553 | + |
| 1554 | +def _restore_rng_counters(scopes, fingerprint, capture_old_counts): |
| 1555 | + if fingerprint not in _side_effect_cache.cache: |
| 1556 | + capture_new_counts = jax.tree.map( |
| 1557 | + lambda s: CountsHolder.make(s.rng_counters), scopes |
| 1558 | + ) |
| 1559 | + capture_delta_counts = jax.tree.map( |
| 1560 | + lambda old, new: new.sub(old), |
| 1561 | + capture_old_counts, |
| 1562 | + capture_new_counts, |
| 1563 | + ) |
| 1564 | + _side_effect_cache.cache[fingerprint] = capture_delta_counts |
| 1565 | + else: |
| 1566 | + updated_counts = jax.tree.map( |
| 1567 | + lambda x, y: x.add(y).unflat(), |
| 1568 | + _side_effect_cache.cache[fingerprint], |
| 1569 | + capture_old_counts, |
| 1570 | + ) |
| 1571 | + jax.tree.map( |
| 1572 | + lambda s, u: set_from_dict(s.rng_counters, u), |
| 1573 | + scopes, |
| 1574 | + updated_counts, |
| 1575 | + ) |
| 1576 | + |
| 1577 | + |
1502 | 1578 | def jit(
|
1503 | 1579 | fn: Callable[..., Any],
|
1504 | 1580 | variables: CollectionFilter = True,
|
@@ -1599,13 +1675,18 @@ def inner(
|
1599 | 1675 | mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)
|
1600 | 1676 |
|
1601 | 1677 | rng_groups = jax.tree.map(
|
1602 |
| - lambda x: x.fold() if isinstance(x, LazyRng) else x, |
| 1678 | + lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, |
1603 | 1679 | rng_groups,
|
1604 | 1680 | is_leaf=lambda x: isinstance(x, LazyRng),
|
1605 | 1681 | )
|
1606 | 1682 |
|
1607 | 1683 | fingerprint = (mutable, module_hash_key)
|
1608 |
| - return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) |
| 1684 | + capture_old_counts = jax.tree.map( |
| 1685 | + lambda s: CountsHolder.make(s.rng_counters), scopes |
| 1686 | + ) |
| 1687 | + res = jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) |
| 1688 | + _restore_rng_counters(scopes, fingerprint, capture_old_counts) |
| 1689 | + return res |
1609 | 1690 |
|
1610 | 1691 | return pack(
|
1611 | 1692 | inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True
|
@@ -1692,3 +1773,64 @@ def inner_loop(scope, carry):
|
1692 | 1773 | def _unzip2(xs):
|
1693 | 1774 | ys = tuple(zip(*xs))
|
1694 | 1775 | return ys if ys else ((), ())
|
| 1776 | + |
| 1777 | + |
| 1778 | +def fold_rngs( |
| 1779 | + fn: Callable[..., Any], |
| 1780 | + variables: CollectionFilter = True, |
| 1781 | + rngs: PRNGSequenceFilter = True, |
| 1782 | +) -> Callable[..., Any]: |
| 1783 | + # Close over scope_fn & repack_fn to avoid recompilation |
| 1784 | + # this is impure but we use the fingerprint arg to differentiate between cases |
| 1785 | + # where scope_fn or repack_fn actually produce non-identical results. |
| 1786 | + fold_rngs_context = TransformContext[tuple[Callable, Callable]]() |
| 1787 | + |
| 1788 | + @functools.wraps(fn) |
| 1789 | + def wrapped_fold_rngs(fingerprint, variable_groups, rng_groups, *args, **kwargs): |
| 1790 | + scope_fn, repack_fn = fold_rngs_context.get() |
| 1791 | + hash_key = fingerprint[1] |
| 1792 | + # fingerprint is only used to differentiate the cache signature |
| 1793 | + # del fingerprint |
| 1794 | + scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable |
| 1795 | + y = fn(scope, hash_key, *args, **kwargs) |
| 1796 | + return y, repack_fn(scope) # pylint: disable=not-callable |
| 1797 | + |
| 1798 | + def inner_fold_rngs( |
| 1799 | + scope_fn, |
| 1800 | + repack_fn, |
| 1801 | + variable_groups, |
| 1802 | + rng_groups, |
| 1803 | + module_hash_key, |
| 1804 | + *args, |
| 1805 | + **kwargs, |
| 1806 | + ): |
| 1807 | + with fold_rngs_context.push((scope_fn, repack_fn)): |
| 1808 | + scopes: list[Scope] = jax.tree_util.tree_leaves( |
| 1809 | + scope_fn(variable_groups, rng_groups) |
| 1810 | + ) |
| 1811 | + mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) |
| 1812 | + |
| 1813 | + rng_groups = jax.tree.map( |
| 1814 | + lambda x: x.clear_suffix() if isinstance(x, LazyRng) else x, |
| 1815 | + rng_groups, |
| 1816 | + is_leaf=lambda x: isinstance(x, LazyRng), |
| 1817 | + ) |
| 1818 | + |
| 1819 | + fingerprint = (mutable, module_hash_key) |
| 1820 | + capture_old_counts = jax.tree.map( |
| 1821 | + lambda s: CountsHolder.make(s.rng_counters), scopes |
| 1822 | + ) |
| 1823 | + res = wrapped_fold_rngs( |
| 1824 | + fingerprint, variable_groups, rng_groups, *args, **kwargs |
| 1825 | + ) |
| 1826 | + _restore_rng_counters(scopes, fingerprint, capture_old_counts) |
| 1827 | + return res |
| 1828 | + |
| 1829 | + return pack( |
| 1830 | + inner_fold_rngs, |
| 1831 | + (variables,), |
| 1832 | + (variables,), |
| 1833 | + (rngs,), |
| 1834 | + name='fold_rngs', |
| 1835 | + enable_kwargs=True, |
| 1836 | + ) |
0 commit comments