-
Notifications
You must be signed in to change notification settings - Fork 716
Updated getting started doc #2698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
@@ Coverage Diff @@
## main #2698 +/- ##
==========================================
+ Coverage 81.15% 81.22% +0.06%
==========================================
Files 51 53 +2
Lines 5493 5636 +143
==========================================
+ Hits 4458 4578 +120
- Misses 1035 1058 +23
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
0f603fd
to
c55c2da
Compare
docs/getting_started.md
Outdated
labels_onehot = jax.nn.one_hot(labels, num_classes=10) | ||
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labels_onehot = jax.nn.one_hot(labels, num_classes=10) | |
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean() | |
return optax.softmax_cross_entropy_with_integer_labels( | |
logits=logits, labels=labels).mean() |
c55c2da
to
4229fa7
Compare
This tutorial demonstrates how to construct a simple convolutional neural | ||
Welcome to Flax! | ||
|
||
Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural | ||
network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "and train it..."
docs/getting_started.md
Outdated
This tutorial demonstrates how to construct a simple convolutional neural | ||
Welcome to Flax! | ||
|
||
Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural | |
Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural |
docs/getting_started.md
Outdated
executionInfo: | ||
elapsed: 54 | ||
status: ok | ||
timestamp: 1671500846075 | ||
user: | ||
displayName: Marcus Chiam | ||
userId: '17531616275590396120' | ||
user_tz: 300 | ||
id: a9633134 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can assist with cleaning up the Colab Jupyter metadata. @IvyZX may also have a method.
docs/getting_started.md
Outdated
|
||
import numpy as np # Ordinary NumPy | ||
import optax # Optimizers | ||
import tensorflow as tf # Tensorflow to operate on TFDS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: "TensorFlow"
Maybe:
import tensorflow as tf # Tensorflow to operate on TFDS | |
import tensorflow as tf # TensorFlow for certain ops like `tf.data.Dataset` |
WDYT?
TF is used for like tf.random.set_seed(0)
as well as tf.cast()
and setting dtypes for e.g. tf.float32
. Plus, we're using the tf.data.Dataset
API (which is separate from TFDS, AFAIK). TFDS is tensorflow_datasets
.
docs/getting_started.md
Outdated
|
||
## 2. Define network | ||
## 4. Define network | ||
|
||
Create a convolutional neural network with the Linen API by subclassing | ||
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
In other guides we are/started using "Flax Module" since "module/Module" is a common word.
Maybe here:
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction). | |
[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction). |
And then repeat this, which can help new users.
docs/getting_started.md
Outdated
|
||
Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function. | ||
Create an instance of the Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As before:
Create an instance of the Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. | |
Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input. |
docs/getting_started.md
Outdated
+++ {"id": "lYz0Emry-ele"} | ||
|
||
## 5. Loading data | ||
We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
- Add a link to the Optax softmax cross entropy API doc and mention Optax since it's an external library.
- If you can, use "second person" like "you"/"your" (Google Style Guide).
For example:
We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding. | |
For your loss, use a predefined [`optax.softmax_cross_entropy()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy) from the Optax library. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, first convert them to a one-hot encoding. |
docs/getting_started.md
Outdated
``` | ||
|
||
+++ {"id": "UMFK51rsAUX4"} | ||
+++ {"id": "4b5ac16e"} | ||
|
||
## 6. Create train state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Since it's a Flax term/class maybe use
## 6. Create train state | |
## 6. Create a `TrainState` |
docs/getting_started.md
Outdated
that serves most basic usecases. Usually one would subclass it to add more data | ||
to be tracked, but in this example we can use it without any modifications. | ||
[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state) | ||
that serves most basic usecases. We can then subclass `TrainState` so that it also contains metrics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
that serves most basic usecases. We can then subclass `TrainState` so that it also contains metrics. | |
that serves most basic usecases. You can then subclass `TrainState` so that it also contains metrics. |
docs/getting_started.md
Outdated
user_tz: 300 | ||
id: e0102447 | ||
--- | ||
def create_train_state(module, rng, learning_rate, momentum): | ||
"""Creates initial `TrainState`.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
"""Creates initial `TrainState`.""" | |
"""Creates an initial `TrainState`.""" |
docs/getting_started.md
Outdated
import jax | ||
import jax.numpy as jnp # JAX NumPy | ||
|
||
from flax import linen as nn # The Linen API | ||
from flax.training import train_state # Useful dataclass to keep train state | ||
from flax import struct # Flax dataclasses | ||
|
||
import numpy as np # Ordinary NumPy | ||
import optax # Optimizers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import optax # Optimizers | |
import optax # Optax for common losses and optimizers |
docs/getting_started.md
Outdated
|
||
Define a function that loads and prepares the MNIST dataset and converts the | ||
samples to floating-point numbers. | ||
Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
- If you can, use "second person" like "you"/"your" (Google Style Guide).
Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function. | |
Your function returns a simple scalar value ready for optimization, make sure to first take the mean of the vector shaped `[batch]` returned by Optax's loss function. |
docs/getting_started.md
Outdated
+++ {"id": "mHQi20yVCsSf"} | ||
+++ {"id": "80fbb60b"} | ||
|
||
## 11. Initialize train state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned before, maybe:
## 11. Initialize train state | |
## 11. Initialize the `TrainState` |
docs/getting_started.md
Outdated
[data stored](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables) | ||
in a JAX | ||
[pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions). | ||
- Set TF random seed to ensure dataset shuffling is reproducible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
- Set TF random seed to ensure dataset shuffling is reproducible. | |
- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible. |
docs/getting_started.md
Outdated
|
||
## 14. Train and evaluate | ||
## 14. Inference on test set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
## 14. Inference on test set | |
## 14. Perform inference on the test set |
docs/getting_started.md
Outdated
``` | ||
|
||
+++ {"id": "oKcRiQ89xQkF"} | ||
+++ {"id": "edb528b6"} | ||
|
||
Congrats! You made it to the end of the annotated MNIST example. You can revisit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
Congrats! You made it to the end of the annotated MNIST example. You can revisit | |
Congratulations! You made it to the end of the annotated MNIST example. You can revisit |
Congrats may be considered as slang (https://developers.google.com/style/translation#be-inclusive)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few minor suggestions. Feel free to add them/change them or ignore them 👍 Hope this helps!
docs/getting_started.md
Outdated
@@ -30,62 +32,105 @@ If you see any changes between the two feel free to create a | |||
[pull request](https://github.com/google/flax/compare) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is no longer true as we've heavily modified the notebook. Maybe we should remove this note?
docs/getting_started.md
Outdated
for test_batch in test_ds.as_numpy_iterator(): | ||
test_state = compute_metrics(state=test_state, batch=test_batch) | ||
pred = state.apply_fn({'params': state.params}, test_batch['image']) # model inference | ||
break # get only the first batch | ||
pred = pred.argmax(axis=1) # argmax the logits to get predicted labels |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realistically we might want to create jitted function for inference e.g:
for test_batch in test_ds.as_numpy_iterator(): | |
test_state = compute_metrics(state=test_state, batch=test_batch) | |
pred = state.apply_fn({'params': state.params}, test_batch['image']) # model inference | |
break # get only the first batch | |
pred = pred.argmax(axis=1) # argmax the logits to get predicted labels | |
@jax.jit | |
def pred_step(state, batch): | |
logits = state.apply_fn({'params': state.params}, test_batch['image']) | |
return logits.argmax(axis=1) | |
test_batch = test_ds.as_numpy_iterator().next() | |
pred = pred_step(state, test_batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if we don't create pred_step
its better to use:
test_batch = test_ds.as_numpy_iterator().next()
instead of break
in the loop.
docs/getting_started.md
Outdated
def show_img(img, ax=None, title=None): | ||
"""Shows a single image.""" | ||
if ax is None: | ||
ax = plt.gca() | ||
ax.imshow(img[..., 0], cmap='gray') | ||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
if title: | ||
ax.set_title(title) | ||
|
||
def show_img_grid(imgs, titles): | ||
"""Shows a grid of images.""" | ||
n = int(np.ceil(len(imgs)**.5)) | ||
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n)) | ||
for i, (img, title) in enumerate(zip(imgs, titles)): | ||
show_img(img, axs[i // n][i % n], title) | ||
``` | ||
|
||
```{code-cell} | ||
--- | ||
colab: | ||
base_uri: https://localhost:8080/ | ||
id: ugGlV3u6Iq1A | ||
outputId: d0944ddb-8d5d-4e9f-9727-040789ef3f17 | ||
--- | ||
for epoch in range(1, num_epochs + 1): | ||
# Use a separate PRNG key to permute image data during shuffling | ||
rng, input_rng = jax.random.split(rng) | ||
# Run an optimization step over a training batch | ||
state = train_epoch(state, train_ds, batch_size, epoch, input_rng) | ||
# Evaluate on the test set after each training epoch | ||
test_loss, test_accuracy = eval_model(state.params, test_ds) | ||
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % ( | ||
epoch, test_loss, test_accuracy * 100)) | ||
height: 866 | ||
executionInfo: | ||
elapsed: 981 | ||
status: ok | ||
timestamp: 1671500872908 | ||
user: | ||
displayName: Marcus Chiam | ||
userId: '17531616275590396120' | ||
user_tz: 300 | ||
id: 5d5nF3u44JFI | ||
outputId: 22e013f6-b9b7-4088-84f3-caaf248377ce | ||
--- | ||
show_img_grid( | ||
[test_batch['image'][idx] for idx in range(25)], | ||
[f'label={pred[idx]}' for idx in range(25)], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this can be reduced to:
fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
ax.set_title(f"label={pred['label'][i]}")
ax.axis('off')
4229fa7
to
4f2381e
Compare
90c8bde
to
bbd9379
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome @chiamp, looks very good! Approved.
|
||
[](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb) | ||
[](https://github.com/google/flax/blob/main/docs/getting_started.ipynb) | ||
|
||
# Getting Started | ||
# Quick Start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Both "quickstart" and "quick start" seem to be OK. To be consistent with "JAX Quickstart", would it make sense to name our doc "Quickstart" (one word) or "Flax quickstart"?
## 2. Loading data | ||
|
||
Flax can use any | ||
data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the | |
data-loading pipeline and this example demonstrates how to utilize TensorFlow Datasets (TFDS). Define a function that loads and prepares the MNIST dataset and converts the |
Since we haven't mentioned TFDS before, it may help to spell out the full name of the library
bbd9379
to
06c7160
Compare
Updated getting started doc, as part of content restructuring mentioned in #2627. View the doc here.
Re-named "Getting Started" to "Quick Start"