Skip to content

Commit 7efdfc8

Browse files
author
Flax Authors
committed
Merge pull request #4493 from google:nnx-tabulate-2
PiperOrigin-RevId: 724371721
2 parents 429033b + caed9ca commit 7efdfc8

File tree

9 files changed

+663
-28
lines changed

9 files changed

+663
-28
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
summary
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autofunction:: tabulate

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,5 @@
168168
from .extract import to_tree as to_tree
169169
from .extract import from_tree as from_tree
170170
from .extract import NodeStates as NodeStates
171+
from .summary import tabulate as tabulate
171172
from . import traversals as traversals

flax/nnx/extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def check_consistent_aliasing(
144144
for path, value in graph.iter_graph(node):
145145
if graph.is_graph_node(value) or isinstance(value, graph.Variable):
146146
if isinstance(value, Object):
147-
value.check_valid_context(
147+
value._check_valid_context(
148148
lambda: f'Trying to extract graph node from different trace level, got {value!r}'
149149
)
150150
if isinstance(value, graph.Variable):

flax/nnx/object.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
visualization,
3636
)
3737
from flax.nnx.variablelib import Variable, VariableState
38-
from flax.typing import SizeBytes, value_stats
38+
from flax.typing import SizeBytes
3939

4040
G = tp.TypeVar('G', bound='Object')
4141

@@ -57,7 +57,7 @@ def _collect_stats(
5757
var_type = type(node)
5858
if issubclass(var_type, nnx.RngState):
5959
var_type = nnx.RngState
60-
size_bytes = value_stats(node.value)
60+
size_bytes = SizeBytes.from_any(node.value)
6161
if size_bytes:
6262
stats[var_type] = size_bytes
6363

@@ -136,6 +136,10 @@ class Array(reprlib.Representable):
136136
shape: tp.Tuple[int, ...]
137137
dtype: tp.Any
138138

139+
@staticmethod
140+
def from_array(array: jax.Array | np.ndarray) -> Array:
141+
return Array(array.shape, array.dtype)
142+
139143
def __nnx_repr__(self):
140144
yield reprlib.Object(type='Array', same_line=True)
141145
yield reprlib.Attr('shape', self.shape)
@@ -169,12 +173,12 @@ def __setattr__(self, name: str, value: Any) -> None:
169173
self._setattr(name, value)
170174

171175
def _setattr(self, name: str, value: tp.Any) -> None:
172-
self.check_valid_context(
176+
self._check_valid_context(
173177
lambda: f"Cannot mutate '{type(self).__name__}' from different trace level"
174178
)
175179
object.__setattr__(self, name, value)
176180

177-
def check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
181+
def _check_valid_context(self, error_msg: tp.Callable[[], str]) -> None:
178182
if not self._object__state.trace_state.is_valid():
179183
raise errors.TraceContextError(error_msg())
180184

flax/nnx/rnglib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __post_init__(self):
6464
raise TypeError(f'key must be a jax.Array, got {type(self.key)}')
6565

6666
def __call__(self) -> jax.Array:
67-
self.check_valid_context(
67+
self._check_valid_context(
6868
lambda: 'Cannot call RngStream from a different trace level'
6969
)
7070
key = jax.random.fold_in(self.key.value, self.count.value)

flax/nnx/summary.py

Lines changed: 553 additions & 0 deletions
Large diffs are not rendered by default.

flax/nnx/variablelib.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,7 @@
2525

2626
from flax import errors
2727
from flax.nnx import filterlib, reprlib, tracers, visualization
28-
from flax.typing import (
29-
Missing,
30-
PathParts,
31-
value_stats,
32-
)
28+
from flax.typing import Missing, PathParts, SizeBytes
3329
import jax.tree_util as jtu
3430

3531
A = tp.TypeVar('A')
@@ -315,7 +311,7 @@ def to_state(self: Variable[A]) -> VariableState[A]:
315311
return VariableState(type(self), self.raw_value, **self._var_metadata)
316312

317313
def __nnx_repr__(self):
318-
stats = value_stats(self.value)
314+
stats = SizeBytes.from_any(self.value)
319315
if stats:
320316
comment = f' # {stats}'
321317
else:
@@ -327,7 +323,7 @@ def __nnx_repr__(self):
327323
yield reprlib.Attr(name, repr(value))
328324

329325
def __treescope_repr__(self, path, subtree_renderer):
330-
size_bytes = value_stats(self.value)
326+
size_bytes = SizeBytes.from_any(self.value)
331327
if size_bytes:
332328
stats_repr = f' # {size_bytes}'
333329
first_line_annotation = treescope.rendering_parts.comment_color(
@@ -814,7 +810,7 @@ def __delattr__(self, name: str) -> None:
814810
del self._var_metadata[name]
815811

816812
def __nnx_repr__(self):
817-
stats = value_stats(self.value)
813+
stats = SizeBytes.from_any(self.value)
818814
if stats:
819815
comment = f' # {stats}'
820816
else:
@@ -828,7 +824,7 @@ def __nnx_repr__(self):
828824
yield reprlib.Attr(name, value)
829825

830826
def __treescope_repr__(self, path, subtree_renderer):
831-
size_bytes = value_stats(self.value)
827+
size_bytes = SizeBytes.from_any(self.value)
832828
if size_bytes:
833829
stats_repr = f' # {size_bytes}'
834830
first_line_annotation = treescope.rendering_parts.comment_color(

flax/typing.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,19 @@ class SizeBytes: # type: ignore[misc]
194194
size: int
195195
bytes: int
196196

197-
@staticmethod
198-
def from_array(x: ShapeDtype) -> SizeBytes:
197+
@classmethod
198+
def from_array(cls, x: ShapeDtype):
199199
size = int(np.prod(x.shape))
200200
dtype: jnp.dtype
201201
if isinstance(x.dtype, str):
202202
dtype = jnp.dtype(x.dtype)
203203
else:
204204
dtype = x.dtype # type: ignore
205205
bytes = size * dtype.itemsize # type: ignore
206-
return SizeBytes(size, bytes)
206+
return cls(size, bytes)
207207

208-
def __add__(self, other: SizeBytes) -> SizeBytes:
209-
return SizeBytes(self.size + other.size, self.bytes + other.bytes)
208+
def __add__(self, other: SizeBytes):
209+
return type(self)(self.size + other.size, self.bytes + other.bytes)
210210

211211
def __bool__(self) -> bool:
212212
return bool(self.size)
@@ -215,12 +215,12 @@ def __repr__(self) -> str:
215215
bytes_repr = _bytes_repr(self.bytes)
216216
return f'{self.size:,} ({bytes_repr})'
217217

218+
@classmethod
219+
def from_any(cls, x):
220+
leaves = jax.tree.leaves(x)
221+
size_bytes = cls(0, 0)
222+
for leaf in leaves:
223+
if has_shape_dtype(leaf):
224+
size_bytes += cls.from_array(leaf)
218225

219-
def value_stats(x):
220-
leaves = jax.tree.leaves(x)
221-
size_bytes = SizeBytes(0, 0)
222-
for leaf in leaves:
223-
if has_shape_dtype(leaf):
224-
size_bytes += SizeBytes.from_array(leaf)
225-
226-
return size_bytes
226+
return size_bytes

tests/nnx/summary_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax.numpy as jnp
16+
from absl.testing import absltest
17+
18+
from flax import nnx
19+
20+
CONSOLE_TEST_KWARGS = dict(force_terminal=False, no_color=True, width=10_000)
21+
22+
23+
class SummaryTest(absltest.TestCase):
24+
def test_tabulate(self):
25+
class Block(nnx.Module):
26+
def __init__(self, din, dout, rngs: nnx.Rngs):
27+
self.linear = nnx.Linear(din, dout, rngs=rngs)
28+
self.bn = nnx.BatchNorm(dout, rngs=rngs)
29+
self.dropout = nnx.Dropout(0.2, rngs=rngs)
30+
31+
def forward(self, x):
32+
return nnx.relu(self.dropout(self.bn(self.linear(x))))
33+
34+
class Foo(nnx.Module):
35+
def __init__(self, rngs: nnx.Rngs):
36+
self.block1 = Block(32, 128, rngs=rngs)
37+
self.block2 = Block(128, 10, rngs=rngs)
38+
39+
def __call__(self, x):
40+
return self.block2.forward(self.block1.forward(x))
41+
42+
foo = Foo(nnx.Rngs(0))
43+
x = jnp.ones((1, 32))
44+
table_repr = nnx.tabulate(
45+
foo, x, console_kwargs=CONSOLE_TEST_KWARGS
46+
).splitlines()
47+
48+
self.assertIn('Foo Summary', table_repr[0])
49+
self.assertIn('path', table_repr[2])
50+
self.assertIn('type', table_repr[2])
51+
self.assertIn('BatchStat', table_repr[2])
52+
self.assertIn('Param', table_repr[2])
53+
self.assertIn('block1/forward', table_repr[6])
54+
self.assertIn('Block', table_repr[6])
55+
self.assertIn('block1/linear', table_repr[8])
56+
self.assertIn('Linear', table_repr[8])
57+
self.assertIn('block1/bn', table_repr[13])
58+
self.assertIn('BatchNorm', table_repr[13])
59+
self.assertIn('block1/dropout', table_repr[18])
60+
self.assertIn('Dropout', table_repr[18])
61+
self.assertIn('block2/forward', table_repr[20])
62+
self.assertIn('Block', table_repr[20])
63+
self.assertIn('block2/linear', table_repr[22])
64+
self.assertIn('Linear', table_repr[22])
65+
self.assertIn('block2/bn', table_repr[27])
66+
self.assertIn('BatchNorm', table_repr[27])
67+
self.assertIn('block2/dropout', table_repr[32])
68+
self.assertIn('Dropout', table_repr[32])
69+
70+
self.assertIn('Total', table_repr[34])
71+
self.assertIn('276 (1.1 KB)', table_repr[34])
72+
self.assertIn('5,790 (23.2 KB)', table_repr[34])
73+
self.assertIn('2 (12 B)', table_repr[34])
74+
self.assertIn('Total Parameters: 6,068 (24.3 KB)', table_repr[37])

0 commit comments

Comments
 (0)