Skip to content

Updated getting started doc #2698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2023
Merged

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Dec 8, 2022

Updated getting started doc, as part of content restructuring mentioned in #2627. View the doc here.

Re-named "Getting Started" to "Quick Start"

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@chiamp chiamp self-assigned this Dec 8, 2022
@chiamp chiamp marked this pull request as draft December 8, 2022 10:41
@chiamp chiamp requested a review from cgarciae December 8, 2022 10:41
@codecov-commenter
Copy link

codecov-commenter commented Dec 8, 2022

Codecov Report

Merging #2698 (4f2381e) into main (fec10eb) will increase coverage by 0.06%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##             main    #2698      +/-   ##
==========================================
+ Coverage   81.15%   81.22%   +0.06%     
==========================================
  Files          51       53       +2     
  Lines        5493     5636     +143     
==========================================
+ Hits         4458     4578     +120     
- Misses       1035     1058      +23     
Impacted Files Coverage Δ
flax/linen/partitioning.py 79.06% <0.00%> (-3.15%) ⬇️
flax/linen/module.py 92.22% <0.00%> (-0.42%) ⬇️
flax/io.py 84.84% <0.00%> (-0.42%) ⬇️
flax/errors.py 85.58% <0.00%> (-0.13%) ⬇️
flax/core/scope.py 90.13% <0.00%> (ø)
flax/linen/linear.py 97.51% <0.00%> (ø)
flax/linen/summary.py 99.01% <0.00%> (ø)
flax/linen/__init__.py 100.00% <0.00%> (ø)
flax/linen/recurrent.py 100.00% <0.00%> (ø)
flax/linen/activation.py 100.00% <0.00%> (ø)
... and 8 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Comment on lines 158 to 197
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
labels_onehot = jax.nn.one_hot(labels, num_classes=10)
return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
return optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=labels).mean()

@chiamp chiamp force-pushed the getting_started_doc branch from c55c2da to 4229fa7 Compare December 20, 2022 01:50
This tutorial demonstrates how to construct a simple convolutional neural
Welcome to Flax!

Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural
network (CNN) using the [Flax](https://flax.readthedocs.io) Linen API and train
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "and train it..."

This tutorial demonstrates how to construct a simple convolutional neural
Welcome to Flax!

Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
Flax is an open source Python neural network library that's built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural
Flax is an open source Python neural network library built on top of [JAX](https://github.com/google/jax). This tutorial demonstrates how to construct a simple convolutional neural

Comment on lines 56 to 64
executionInfo:
elapsed: 54
status: ok
timestamp: 1671500846075
user:
displayName: Marcus Chiam
userId: '17531616275590396120'
user_tz: 300
id: a9633134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can assist with cleaning up the Colab Jupyter metadata. @IvyZX may also have a method.


import numpy as np # Ordinary NumPy
import optax # Optimizers
import tensorflow as tf # Tensorflow to operate on TFDS
Copy link
Collaborator

@8bitmp3 8bitmp3 Dec 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "TensorFlow"

Maybe:

Suggested change
import tensorflow as tf # Tensorflow to operate on TFDS
import tensorflow as tf # TensorFlow for certain ops like `tf.data.Dataset`

WDYT?

TF is used for like tf.random.set_seed(0) as well as tf.cast() and setting dtypes for e.g. tf.float32. Plus, we're using the tf.data.Dataset API (which is separate from TFDS, AFAIK). TFDS is tensorflow_datasets.


## 2. Define network
## 4. Define network

Create a convolutional neural network with the Linen API by subclassing
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

In other guides we are/started using "Flax Module" since "module/Module" is a common word.

Maybe here:

Suggested change
[Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).
[Flax Module](https://flax.readthedocs.io/en/latest/flax.linen.html#core-module-abstraction).

And then repeat this, which can help new users.


Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function.
Create an instance of the Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As before:

Suggested change
Create an instance of the Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input.
Create an instance of the Flax Module and use the [`Module.tabulate`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.tabulate) method to visualize a table of the model layers by passing an RNG key and template image input.

+++ {"id": "lYz0Emry-ele"}

## 5. Loading data
We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

  • Add a link to the Optax softmax cross entropy API doc and mention Optax since it's an external library.
  • If you can, use "second person" like "you"/"your" (Google Style Guide).

For example:

Suggested change
We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding.
For your loss, use a predefined [`optax.softmax_cross_entropy()`](https://optax.readthedocs.io/en/latest/api.html#optax.softmax_cross_entropy) from the Optax library. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, first convert them to a one-hot encoding.

```

+++ {"id": "UMFK51rsAUX4"}
+++ {"id": "4b5ac16e"}

## 6. Create train state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Since it's a Flax term/class maybe use

Suggested change
## 6. Create train state
## 6. Create a `TrainState`

that serves most basic usecases. Usually one would subclass it to add more data
to be tracked, but in this example we can use it without any modifications.
[`flax.training.train_state.TrainState`](https://flax.readthedocs.io/en/latest/flax.training.html#train-state)
that serves most basic usecases. We can then subclass `TrainState` so that it also contains metrics.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
that serves most basic usecases. We can then subclass `TrainState` so that it also contains metrics.
that serves most basic usecases. You can then subclass `TrainState` so that it also contains metrics.

user_tz: 300
id: e0102447
---
def create_train_state(module, rng, learning_rate, momentum):
"""Creates initial `TrainState`."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
"""Creates initial `TrainState`."""
"""Creates an initial `TrainState`."""

import jax
import jax.numpy as jnp # JAX NumPy

from flax import linen as nn # The Linen API
from flax.training import train_state # Useful dataclass to keep train state
from flax import struct # Flax dataclasses

import numpy as np # Ordinary NumPy
import optax # Optimizers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import optax # Optimizers
import optax # Optax for common losses and optimizers


Define a function that loads and prepares the MNIST dataset and converts the
samples to floating-point numbers.
Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

  • If you can, use "second person" like "you"/"your" (Google Style Guide).
Suggested change
Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax's loss function.
Your function returns a simple scalar value ready for optimization, make sure to first take the mean of the vector shaped `[batch]` returned by Optax's loss function.

+++ {"id": "mHQi20yVCsSf"}
+++ {"id": "80fbb60b"}

## 11. Initialize train state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned before, maybe:

Suggested change
## 11. Initialize train state
## 11. Initialize the `TrainState`

[data stored](https://flax.readthedocs.io/en/latest/design_notes/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables)
in a JAX
[pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions).
- Set TF random seed to ensure dataset shuffling is reproducible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
- Set TF random seed to ensure dataset shuffling is reproducible.
- Set the TF random seed to ensure dataset shuffling (with `tf.data.Dataset.shuffle`) is reproducible.


## 14. Train and evaluate
## 14. Inference on test set
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
## 14. Inference on test set
## 14. Perform inference on the test set

```

+++ {"id": "oKcRiQ89xQkF"}
+++ {"id": "edb528b6"}

Congrats! You made it to the end of the annotated MNIST example. You can revisit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
Congrats! You made it to the end of the annotated MNIST example. You can revisit
Congratulations! You made it to the end of the annotated MNIST example. You can revisit

Congrats may be considered as slang (https://developers.google.com/style/translation#be-inclusive)

Copy link
Collaborator

@8bitmp3 8bitmp3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few minor suggestions. Feel free to add them/change them or ignore them 👍 Hope this helps!

@@ -30,62 +32,105 @@ If you see any changes between the two feel free to create a
[pull request](https://github.com/google/flax/compare)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer true as we've heavily modified the notebook. Maybe we should remove this note?

Comment on lines 599 to 603
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
pred = state.apply_fn({'params': state.params}, test_batch['image']) # model inference
break # get only the first batch
pred = pred.argmax(axis=1) # argmax the logits to get predicted labels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realistically we might want to create jitted function for inference e.g:

Suggested change
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
pred = state.apply_fn({'params': state.params}, test_batch['image']) # model inference
break # get only the first batch
pred = pred.argmax(axis=1) # argmax the logits to get predicted labels
@jax.jit
def pred_step(state, batch):
logits = state.apply_fn({'params': state.params}, test_batch['image'])
return logits.argmax(axis=1)
test_batch = test_ds.as_numpy_iterator().next()
pred = pred_step(state, test_batch)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we don't create pred_step its better to use:

test_batch = test_ds.as_numpy_iterator().next()

instead of break in the loop.

Comment on lines 618 to 654
def show_img(img, ax=None, title=None):
"""Shows a single image."""
if ax is None:
ax = plt.gca()
ax.imshow(img[..., 0], cmap='gray')
ax.set_xticks([])
ax.set_yticks([])
if title:
ax.set_title(title)

def show_img_grid(imgs, titles):
"""Shows a grid of images."""
n = int(np.ceil(len(imgs)**.5))
_, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
for i, (img, title) in enumerate(zip(imgs, titles)):
show_img(img, axs[i // n][i % n], title)
```

```{code-cell}
---
colab:
base_uri: https://localhost:8080/
id: ugGlV3u6Iq1A
outputId: d0944ddb-8d5d-4e9f-9727-040789ef3f17
---
for epoch in range(1, num_epochs + 1):
# Use a separate PRNG key to permute image data during shuffling
rng, input_rng = jax.random.split(rng)
# Run an optimization step over a training batch
state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
# Evaluate on the test set after each training epoch
test_loss, test_accuracy = eval_model(state.params, test_ds)
print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
epoch, test_loss, test_accuracy * 100))
height: 866
executionInfo:
elapsed: 981
status: ok
timestamp: 1671500872908
user:
displayName: Marcus Chiam
userId: '17531616275590396120'
user_tz: 300
id: 5d5nF3u44JFI
outputId: 22e013f6-b9b7-4088-84f3-caaf248377ce
---
show_img_grid(
[test_batch['image'][idx] for idx in range(25)],
[f'label={pred[idx]}' for idx in range(25)],
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of this can be reduced to:

fig, axs = plt.subplots(5, 5, figsize=(12, 12))
for i, ax in enumerate(axs.flatten()):
    ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')
    ax.set_title(f"label={pred['label'][i]}")
    ax.axis('off')

@chiamp chiamp force-pushed the getting_started_doc branch from 4229fa7 to 4f2381e Compare January 6, 2023 21:32
@chiamp
Copy link
Collaborator Author

chiamp commented Jan 6, 2023

Thanks for the suggestions @cgarciae @8bitmp3! I made some updated changes.

@chiamp chiamp marked this pull request as ready for review January 11, 2023 21:19
@chiamp chiamp force-pushed the getting_started_doc branch 3 times, most recently from 90c8bde to bbd9379 Compare January 12, 2023 01:19
Copy link
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome @chiamp, looks very good! Approved.


[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs/getting_started.ipynb)
[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs/getting_started.ipynb)

# Getting Started
# Quick Start
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Both "quickstart" and "quick start" seem to be OK. To be consistent with "JAX Quickstart", would it make sense to name our doc "Quickstart" (one word) or "Flax quickstart"?

## 2. Loading data

Flax can use any
data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
data-loading pipeline and this example demonstrates how to utilize TFDS. Define a function that loads and prepares the MNIST dataset and converts the
data-loading pipeline and this example demonstrates how to utilize TensorFlow Datasets (TFDS). Define a function that loads and prepares the MNIST dataset and converts the

Since we haven't mentioned TFDS before, it may help to spell out the full name of the library

@chiamp chiamp force-pushed the getting_started_doc branch from bbd9379 to 06c7160 Compare January 17, 2023 23:37
@copybara-service copybara-service bot merged commit 71772f6 into google:main Jan 18, 2023
@chiamp chiamp deleted the getting_started_doc branch January 18, 2023 04:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants