Skip to content

Commit adbad95

Browse files
author
Flax Authors
committed
Merge pull request #4438 from google:nnx-tabulate
PiperOrigin-RevId: 713824899
2 parents 9ebdbdc + c59bc1a commit adbad95

24 files changed

+834
-475
lines changed

docs_nnx/guides/checkpointing.ipynb

Lines changed: 11 additions & 11 deletions
Large diffs are not rendered by default.

docs_nnx/mnist_tutorial.ipynb

Lines changed: 59 additions & 176 deletions
Large diffs are not rendered by default.

docs_nnx/mnist_tutorial.md

Lines changed: 16 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ Let's put the CNN model to the test! Here, you’ll perform a forward pass with
112112
import jax.numpy as jnp # JAX NumPy
113113
114114
y = model(jnp.ones((1, 28, 28, 1)))
115-
nnx.display(y)
115+
y
116116
```
117117

118118
## 4. Create the optimizer and define some metrics
@@ -179,6 +179,9 @@ the accuracy) during the process. Typically this leads to the model achieving ar
179179
```{code-cell} ipython3
180180
:outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87
181181
182+
from IPython.display import clear_output
183+
import matplotlib.pyplot as plt
184+
182185
metrics_history = {
183186
'train_loss': [],
184187
'train_accuracy': [],
@@ -208,40 +211,20 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
208211
metrics_history[f'test_{metric}'].append(value)
209212
metrics.reset() # Reset the metrics for the next training epoch.
210213
211-
print(
212-
f"[train] step: {step}, "
213-
f"loss: {metrics_history['train_loss'][-1]}, "
214-
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
215-
)
216-
print(
217-
f"[test] step: {step}, "
218-
f"loss: {metrics_history['test_loss'][-1]}, "
219-
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
220-
)
221-
```
222-
223-
## 7. Visualize the metrics
224-
225-
With Matplotlib, you can create plots for the loss and the accuracy:
226-
227-
```{code-cell} ipython3
228-
:outputId: 431a2fcd-44fa-4202-f55a-906555f060ac
229-
230-
import matplotlib.pyplot as plt # Visualization
231-
232-
# Plot loss and accuracy in subplots
233-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
234-
ax1.set_title('Loss')
235-
ax2.set_title('Accuracy')
236-
for dataset in ('train', 'test'):
237-
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
238-
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
239-
ax1.legend()
240-
ax2.legend()
241-
plt.show()
214+
clear_output(wait=True)
215+
# Plot loss and accuracy in subplots
216+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
217+
ax1.set_title('Loss')
218+
ax2.set_title('Accuracy')
219+
for dataset in ('train', 'test'):
220+
ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
221+
ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
222+
ax1.legend()
223+
ax2.legend()
224+
plt.show()
242225
```
243226

244-
## 10. Perform inference on the test set
227+
## 7. Perform inference on the test set
245228

246229
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
247230

docs_nnx/nnx_basics.ipynb

Lines changed: 88 additions & 41 deletions
Large diffs are not rendered by default.

docs_nnx/nnx_basics.md

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,7 @@ jupytext:
1212

1313
Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.
1414

15-
In this guide you will learn about:
16-
17-
- The Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html) system: An example of creating and initializing a custom `Linear` layer.
18-
- Stateful computation: An example of creating a Flax [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and updating its value (such as state updates needed during the forward pass).
19-
- Nested [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s: An MLP example with `Linear`, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout), and [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layers.
20-
- Model surgery: An example of replacing custom `Linear` layers inside a model with custom `LoraLinear` layers.
21-
- Flax transformations: An example of using [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) for automatic state management.
22-
- [`nnx.scan`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.scan) over layers.
23-
- The Flax NNX Functional API: An example of a custom `StatefulLinear` layer with [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)s with fine-grained control over the state.
24-
- [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) and [`GraphDef`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.GraphDef).
25-
- [`split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split), [`merge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.merge), and `update`
26-
- Fine-grained [`State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State) control: An example of using [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) type `Filter`s ([`nnx.filterlib.Filter`](https://flax.readthedocs.io/en/latest/guides/filters_guide.html)) to split into multiple [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State)s.
15+
To begin, install Flax with `pip` and import necessary dependencies:
2716

2817
## Setup
2918

@@ -106,7 +95,7 @@ to handle them, as demonstrated in later sections of this guide.
10695

10796
Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
10897

109-
The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
98+
The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
11099

111100
```{code-cell} ipython3
112101
class MLP(nnx.Module):

flax/linen/summary.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@
4848
LogicalNames,
4949
)
5050

51+
try:
52+
from IPython import get_ipython
53+
54+
in_ipython = get_ipython() is not None
55+
except ImportError:
56+
in_ipython = False
57+
5158

5259
class _ValueRepresentation(ABC):
5360
"""A class that represents a value in the summary table."""
@@ -242,11 +249,6 @@ def tabulate(
242249
243250
Total Parameters: 50 (200 B)
244251
245-
246-
**Note**: rows order in the table does not represent execution order,
247-
instead it aligns with the order of keys in `variables` which are sorted
248-
alphabetically.
249-
250252
**Note**: `vjp_flops` returns `0` if the module is not differentiable.
251253
252254
Args:
@@ -267,7 +269,9 @@ def tabulate(
267269
mutable.
268270
console_kwargs: An optional dictionary with additional keyword arguments
269271
that are passed to `rich.console.Console` when rendering the table.
270-
Default arguments are `{'force_terminal': True, 'force_jupyter': False}`.
272+
Default arguments are ``'force_terminal': True``, and ``'force_jupyter'``
273+
is set to ``True`` if the code is running in a Jupyter notebook, otherwise
274+
it is set to ``False``.
271275
table_kwargs: An optional dictionary with additional keyword arguments that
272276
are passed to `rich.table.Table` constructor.
273277
column_kwargs: An optional dictionary with additional keyword arguments that
@@ -564,7 +568,7 @@ def _render_table(
564568
non_params_cols: list[str],
565569
) -> str:
566570
"""A function that renders a Table to a string representation using rich."""
567-
console_kwargs = {'force_terminal': True, 'force_jupyter': False}
571+
console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython}
568572
if console_extras is not None:
569573
console_kwargs.update(console_extras)
570574

flax/nnx/filterlib.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def to_predicate(filter: Filter) -> Predicate:
5454
else:
5555
raise TypeError(f'Invalid collection filter: {filter:!r}. ')
5656

57-
def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
57+
def filters_to_predicates(
58+
filters: tp.Sequence[Filter],
59+
) -> tuple[Predicate, ...]:
5860
for i, filter_ in enumerate(filters):
5961
if filter_ in (..., True) and i != len(filters) - 1:
6062
remaining_filters = filters[i + 1 :]

flax/nnx/graph.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import numpy as np
2525
import typing_extensions as tpe
2626

27-
from flax.nnx import filterlib, reprlib
27+
from flax.nnx import filterlib, reprlib, visualization
2828
from flax.nnx.proxy_caller import (
2929
ApplyCaller,
3030
CallableProxy,
@@ -63,7 +63,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
6363
return isinstance(x, Variable)
6464

6565

66-
class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]):
66+
class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin):
6767
"""A mapping that uses object id as the hash for the keys."""
6868

6969
def __init__(
@@ -248,8 +248,7 @@ def __nnx_repr__(self):
248248
yield reprlib.Attr('index', self.index)
249249

250250
def __treescope_repr__(self, path, subtree_renderer):
251-
import treescope # type: ignore[import-not-found,import-untyped]
252-
return treescope.repr_lib.render_object_constructor(
251+
return visualization.render_object_constructor(
253252
object_type=type(self),
254253
attributes={'type': self.type, 'index': self.index},
255254
path=path,
@@ -272,9 +271,7 @@ def __nnx_repr__(self):
272271
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))
273272

274273
def __treescope_repr__(self, path, subtree_renderer):
275-
import treescope # type: ignore[import-not-found,import-untyped]
276-
277-
return treescope.repr_lib.render_object_constructor(
274+
return visualization.render_object_constructor(
278275
object_type=type(self),
279276
attributes={
280277
'type': self.type,
@@ -353,8 +350,7 @@ def __nnx_repr__(self):
353350
)
354351

355352
def __treescope_repr__(self, path, subtree_renderer):
356-
import treescope # type: ignore[import-not-found,import-untyped]
357-
return treescope.repr_lib.render_object_constructor(
353+
return visualization.render_object_constructor(
358354
object_type=type(self),
359355
attributes={
360356
'type': self.type,

flax/nnx/module.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -403,23 +403,6 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
403403
flatten_func=partial(_module_flatten, with_keys=False),
404404
)
405405

406-
def __treescope_repr__(self, path, subtree_renderer):
407-
import treescope # type: ignore[import-not-found,import-untyped]
408-
children = {}
409-
for name, value in vars(self).items():
410-
if name.startswith('_'):
411-
continue
412-
children[name] = value
413-
return treescope.repr_lib.render_object_constructor(
414-
object_type=type(self),
415-
attributes=children,
416-
path=path,
417-
subtree_renderer=subtree_renderer,
418-
color=treescope.formatting_util.color_from_string(
419-
type(self).__qualname__
420-
)
421-
)
422-
423406
# -------------------------
424407
# Pytree Definition
425408
# -------------------------

flax/nnx/nn/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ class Embed(Module):
10631063
>>> layer = nnx.Embed(num_embeddings=5, features=3, rngs=nnx.Rngs(0))
10641064
>>> nnx.state(layer)
10651065
State({
1066-
'embedding': VariableState(
1066+
'embedding': VariableState( # 15 (60 B)
10671067
type=Param,
10681068
value=Array([[-0.90411377, -0.3648777 , -1.1083648 ],
10691069
[ 0.01070483, 0.27923733, 1.7487359 ],

flax/nnx/nn/normalization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ class LayerNorm(Module):
395395
396396
>>> nnx.state(layer)
397397
State({
398-
'bias': VariableState(
398+
'bias': VariableState( # 6 (24 B)
399399
type=Param,
400400
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
401401
),
402-
'scale': VariableState(
402+
'scale': VariableState( # 6 (24 B)
403403
type=Param,
404404
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
405405
)
@@ -531,7 +531,7 @@ class RMSNorm(Module):
531531
532532
>>> nnx.state(layer)
533533
State({
534-
'scale': VariableState(
534+
'scale': VariableState( # 6 (24 B)
535535
type=Param,
536536
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
537537
)
@@ -655,11 +655,11 @@ class GroupNorm(Module):
655655
>>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0))
656656
>>> nnx.state(layer)
657657
State({
658-
'bias': VariableState(
658+
'bias': VariableState( # 6 (24 B)
659659
type=Param,
660660
value=Array([0., 0., 0., 0., 0., 0.], dtype=float32)
661661
),
662-
'scale': VariableState(
662+
'scale': VariableState( # 6 (24 B)
663663
type=Param,
664664
value=Array([1., 1., 1., 1., 1., 1.], dtype=float32)
665665
)

flax/nnx/nn/stochastic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from flax.nnx.module import Module, first_from
2525

2626

27-
@dataclasses.dataclass
27+
@dataclasses.dataclass(repr=False)
2828
class Dropout(Module):
2929
"""Create a dropout layer.
3030

0 commit comments

Comments
 (0)