Skip to content

Commit 160820c

Browse files
authored
Update JAX_examples_image_segmentation.md
1 parent af513d0 commit 160820c

File tree

1 file changed

+95
-42
lines changed

1 file changed

+95
-42
lines changed

docs/source/JAX_examples_image_segmentation.md

+95-42
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ kernelspec:
1212
name: python3
1313
---
1414

15-
# Train a transformer-based UNETR model for image segmentation with JAX
15+
# Image segmentation with Vision Transformer and UNETR using JAX
1616

1717
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)
1818

@@ -24,6 +24,14 @@ The tutorial covers the preparation of the [Oxford Pets](https://www.robots.ox.a
2424

2525
The image above show the UNETR architecture for processing 3D inputs, but it can be adapted to 2D inputs.
2626

27+
By the end of this tutorial, you will learn how to:
28+
29+
- Prepare and preprocess the Oxford Pets dataset for image segmentation.
30+
- Implement the UNETR model with a Vision Transformer encoder using Flax NNX.
31+
- Train the model, evaluate its performance, and visualize predictions.
32+
33+
This tutorial assumes familiarity with JAX, Flax NNX, and basic deep learning and AI concepts. If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
34+
2735
+++
2836

2937
## Setup
@@ -57,6 +65,9 @@ import cv2
5765
import numpy as np
5866
from PIL import Image # we'll read images with opencv and use Pillow as a fallback
5967
68+
from typing import Any, Callable
69+
import grain.python as grain
70+
6071
print("Jax version:", jax.__version__)
6172
print("Flax version:", flax.__version__)
6273
print("Optax version:", optax.__version__)
@@ -93,7 +104,7 @@ We can nspect the `images` folder, listing a subset of these files:
93104

94105
### Splitting the dataset into training and validation sets
95106

96-
Next, we'll implement the `OxfordPetsDataset` class providing the access to the images and masks. The class implements `__len__` and `__getitem__` methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose, we create a helper `SubsetDataset` class to map indices into training and validation (test) set parts.
107+
Next, we'll create the `OxfordPetsDataset` class providing access to our images and masks. The class implements `__len__` and `__getitem__` methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose, we create a helper `SubsetDataset` class to map indices into training and validation (test) set parts.
97108

98109
```{code-cell} ipython3
99110
class OxfordPetsDataset:
@@ -223,16 +234,16 @@ display_datapoint(val_dataset[0], label=" (val set)")
223234

224235
### Data augmentation
225236

226-
Here, we'll define a simple data augmentation pipeline of joined image and mask transformations using [Albumentations](https://albumentations.ai/docs/examples/example/). We will apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, check the [Albumentations reference API](https://albumentations.ai/docs/api_reference/full_reference/).
237+
Data augmentation can be important for increasing the diversity of the dataset, which includes random rotations, resizing crops, horizontal flips, and brightness/contrast adjustments. In this section, we'll define a simple data augmentation pipeline of joined image and mask transformations using [Albumentations](https://albumentations.ai/docs/examples/example/), so that we can apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, check the [Albumentations reference API](https://albumentations.ai/docs/api_reference/full_reference/).
227238

228239
```{code-cell} ipython3
229240
img_size = 256
230241
231242
train_transforms = A.Compose([
232243
A.Affine(rotate=(-35, 35), cval_mask=1, p=0.3), # Random rotations -35 to 35 degrees
233244
A.RandomResizedCrop(width=img_size, height=img_size, scale=(0.7, 1.0)), # Crop a random part of the input and rescale it to a specified size
234-
A.HorizontalFlip(p=0.5), # Horizontal random flip
235-
A.RandomBrightnessContrast(p=0.4), # Randomly changes the brightness and contrast
245+
A.HorizontalFlip(p=0.5), # Horizontal random flip.
246+
A.RandomBrightnessContrast(p=0.4), # Randomly changes the brightness and contrast.
236247
A.Normalize(), # Normalize the image and cast to float
237248
])
238249
@@ -243,6 +254,16 @@ val_transforms = A.Compose([
243254
])
244255
```
245256

257+
In the code above:
258+
259+
- `Affine`: Applies random rotations to augment the dataset.
260+
- `RandomResizedCrop`: Crops a random part of the image and then rescales it.
261+
- `HorizontalFlip`: Randomly flips images horizontally.
262+
- `RandomBrightnessContrast`: Adjusts brightness and contrast to introduce variation to our data.
263+
- `Normalize`: Normalizes the images.
264+
265+
Let's preview the dataset after transformations:
266+
246267
```{code-cell} ipython3
247268
output = train_transforms(**train_dataset[0])
248269
img, mask = output["image"], output["mask"]
@@ -259,14 +280,9 @@ print("Mask array info:", mask.dtype, mask.shape, mask.min(), mask.max())
259280

260281
### Data loading with `grain.IndexSampler` and `grain.DataLoader`
261282

262-
Let's now use [`grain`](https://github.com/google/grain) to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation.
283+
Let's now use [`grain`](https://github.com/google/grain) to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation. Note that using multiple workers (`worker_count`) allows us to parallelize data transformations, speeding up the data loading process.
263284

264285
```{code-cell} ipython3
265-
from typing import Any, Callable
266-
267-
import grain.python as grain
268-
269-
270286
class DataAugs(grain.MapTransform):
271287
def __init__(self, transforms: Callable):
272288
self.albu_transforms = transforms
@@ -299,6 +315,8 @@ val_sampler = grain.IndexSampler(
299315
)
300316
```
301317

318+
Using multiple workers (`worker_count=4`) allows for parallel processing of transformations, improving efficiency.
319+
302320
```{code-cell} ipython3
303321
train_loader = grain.DataLoader(
304322
data_source=train_dataset,
@@ -336,7 +354,7 @@ train_eval_loader = grain.DataLoader(
336354
)
337355
```
338356

339-
Split the training and validation sets into batches.
357+
Split the training and validation sets into batches:
340358

341359
```{code-cell} ipython3
342360
train_batch = next(iter(train_loader))
@@ -364,19 +382,19 @@ for img, mask in zip(images[:3], masks[:3]):
364382
display_datapoint({"image": img, "mask": mask}, label=" (augmented validation set)")
365383
```
366384

367-
## Defining the UNETR architecture with the ViT encoder
385+
## Implementing the UNETR architecture with the ViT encoder
368386

369387
In this section, we will implement the UNETR model from scratch using Flax NNX. The transformer encoder of UNETR is a Vision Transformer (ViT), as discussed in the beginning of this tutorial The feature maps returned by ViT have the same spatial size (`H / 16, W / 16`), and deconvolutions are used to upsample the feature maps, while the feature maps are upsampled and concatenated up to the original image size.
370388

371389
The reference PyTorch implementation of this model can be found on the [MONAI Library GitHub repository](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py).
372390

373391
### The ViT encoder implementation
374392

375-
Here, we will implement the following modules of the ViT:
393+
Here, we will implement the following modules of the ViT, based on the ViT paper (["An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"](https://arxiv.org/abs/2010.11929)):
376394

377395
- `PatchEmbeddingBlock`: The patch embedding block, which maps patches of pixels to a sequence of vectors.
396+
- `MLPBlock`: The multilayer perceptron (MLP) block.
378397
- `ViTEncoderBlock`: The ViT encoder block.
379-
- `MLPBlock`: The multilayer perceptron (MLP) block.
380398

381399
```{code-cell} ipython3
382400
---
@@ -385,7 +403,7 @@ jupyter:
385403
---
386404
class PatchEmbeddingBlock(nnx.Module):
387405
"""
388-
A patch embedding block, based on the ViT ("An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" https://arxiv.org/abs/2010.11929
406+
A patch embedding block, based on the ViT paper.
389407
390408
Args:
391409
in_channels (int): Number of input channels in the image (such as 3 for RGB).
@@ -407,6 +425,7 @@ class PatchEmbeddingBlock(nnx.Module):
407425
rngs: nnx.Rngs = nnx.Rngs(0),
408426
):
409427
n_patches = (img_size // patch_size) ** 2
428+
# The convolution to extract patch embeddings using `flax.nnx.Conv`.
410429
self.patch_embeddings = nnx.Conv(
411430
in_channels,
412431
hidden_size,
@@ -417,20 +436,27 @@ class PatchEmbeddingBlock(nnx.Module):
417436
rngs=rngs,
418437
)
419438
439+
# Positional embeddings for each patch using `flax.nnx.Param` and `jax.nn.initializers.truncated_normal`.
420440
initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
421441
self.position_embeddings = nnx.Param(
422442
initializer(rngs.params(), (1, n_patches, hidden_size), jnp.float32)
423443
)
444+
# Dropout for regularization using `flax.nnx.Dropout`.
424445
self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
425446
426447
def __call__(self, x: jax.Array) -> jax.Array:
448+
# Apply the convolution to extract patch embeddings.
427449
x = self.patch_embeddings(x)
450+
# Reshape for adding positional embeddings.
428451
x = x.reshape(x.shape[0], -1, x.shape[-1])
452+
# Add positional embeddings.
429453
embeddings = x + self.position_embeddings
454+
# Apply dropout for regularization.
430455
embeddings = self.dropout(embeddings)
431456
return embeddings
432457
433458
459+
# Instantiate the patch embedding block.
434460
mod = PatchEmbeddingBlock(3, 256, 16, 768, 0.5)
435461
x = jnp.ones((4, 256, 256, 3))
436462
y = mod(x)
@@ -447,13 +473,19 @@ from typing import Callable
447473
448474
class MLPBlock(nnx.Sequential):
449475
"""
450-
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
451-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
476+
A multi-layer perceptron (MLP) block, inheriting from `flax.nnx.Module`.
477+
478+
Args:
479+
hidden_size (int): Dimensionality of the hidden layer.
480+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block.
481+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
482+
activation_layer: Activation function. Defaults to `flax.nnx.gelu` (GeLU).
483+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
452484
"""
453485
def __init__(
454486
self,
455-
hidden_size: int, # dimension of hidden layer.
456-
mlp_dim: int, # dimension of feedforward layer
487+
hidden_size: int, # Dimension of hidden layer.
488+
mlp_dim: int, # Dimension of feedforward layer
457489
dropout_rate: float = 0.0,
458490
activation_layer: Callable = nnx.gelu,
459491
*,
@@ -468,7 +500,7 @@ class MLPBlock(nnx.Sequential):
468500
]
469501
super().__init__(*layers)
470502
471-
503+
# Instantiate the MLP block.
472504
mod = MLPBlock(768, 3072, 0.5)
473505
x = jnp.ones((4, 256, 768))
474506
y = mod(x)
@@ -482,14 +514,21 @@ jupyter:
482514
---
483515
class ViTEncoderBlock(nnx.Module):
484516
"""
485-
A transformer encoder block, based on: "Dosovitskiy et al.,
486-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
517+
A ViT encoder block, inheriting from `flax.nnx.Module`.
518+
519+
Args:
520+
hidden_size (int): Dimensionality of the hidden layer.
521+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block.
522+
num_heads (int): Number of attention heads in each transformer layer.
523+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
524+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
525+
487526
"""
488527
def __init__(
489528
self,
490-
hidden_size: int, # dimension of hidden layer.
491-
mlp_dim: int, # dimension of feedforward layer.
492-
num_heads: int, # number of attention heads
529+
hidden_size: int, # Dimension of hidden layer.
530+
mlp_dim: int, # Dimension of feedforward layer.
531+
num_heads: int, # Number of attention heads
493532
dropout_rate: float = 0.0,
494533
*,
495534
rngs: nnx.Rngs = nnx.Rngs(0),
@@ -511,7 +550,7 @@ class ViTEncoderBlock(nnx.Module):
511550
x = x + self.mlp(self.norm2(x))
512551
return x
513552
514-
553+
# Instantiate the ViT encoder block.
515554
mod = ViTEncoderBlock(768, 3072, 12)
516555
x = jnp.ones((4, 256, 768))
517556
y = mod(x)
@@ -524,19 +563,28 @@ jupyter:
524563
source_hidden: true
525564
---
526565
class ViT(nnx.Module):
527-
"""
528-
Vision Transformer (ViT) Feature Extractor, based on: "Dosovitskiy et al.,
529-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
566+
""" Implements the ViT feature extractor, inheriting from `flax.nnx.Module`.
567+
568+
Args:
569+
in_channels (int): Number of input channels in the image (such as 3 for RGB)..
570+
img_size (int): Input image size.
571+
patch_size (int): Size of the patches extracted from the image.
572+
hidden_size (int): Dimensionality of the embedding vectors. Defaults to 768.
573+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block. Defaults to 3072.
574+
num_layers (int): Number of transformer encoder layers. Defaults to 12.
575+
num_heads (int): Number of attention heads in each transformer layer. Defaults to 12.
576+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
577+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
530578
"""
531579
def __init__(
532580
self,
533-
in_channels: int, # dimension of input channels
534-
img_size: int, # dimension of input image
535-
patch_size: int, # dimension of patch size
536-
hidden_size: int = 768, # dimension of hidden layer
537-
mlp_dim: int = 3072, # dimension of feedforward layer
538-
num_layers: int = 12, # number of transformer blocks
539-
num_heads: int = 12, # number of attention heads
581+
in_channels: int, # Dimension of input channels.
582+
img_size: int, # Dimension of input image.
583+
patch_size: int, # Dimension of patch size.
584+
hidden_size: int = 768, # Dimension of hidden layer.
585+
mlp_dim: int = 3072, # Dimension of feedforward layer.
586+
num_layers: int = 12, # Number of transformer blocks.
587+
num_heads: int = 12, # Number of attention heads.
540588
dropout_rate: float = 0.0,
541589
*,
542590
rngs: nnx.Rngs = nnx.Rngs(0),
@@ -567,7 +615,7 @@ class ViT(nnx.Module):
567615
x = self.norm(x)
568616
return x, hidden_states_out
569617
570-
618+
# Instantiate the ViT feature extractor.
571619
mod = ViT(3, 224, 16)
572620
x = jnp.ones((4, 224, 224, 3))
573621
y, hstates = mod(x)
@@ -1069,7 +1117,7 @@ plt.show()
10691117
optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum))
10701118
```
10711119

1072-
Let us implement Jaccard loss and the loss function combining Cross-Entropy and Jaccard losses.
1120+
Let us implement the Jaccard loss, and then define the total loss combining the Cross-Entropy and Jaccard losses:
10731121

10741122
```{code-cell} ipython3
10751123
def compute_softmax_jaccard_loss(logits, masks, reduction="mean"):
@@ -1100,16 +1148,20 @@ def compute_softmax_jaccard_loss(logits, masks, reduction="mean"):
11001148
def compute_losses_and_logits(model: nnx.Module, images: jax.Array, masks: jax.Array):
11011149
logits = model(images)
11021150
1151+
# Cross-Entropy loss.
11031152
xentropy_loss = optax.softmax_cross_entropy_with_integer_labels(
11041153
logits=logits, labels=masks
11051154
).mean()
11061155
1156+
# Jaccard loss.
11071157
jacc_loss = compute_softmax_jaccard_loss(logits=logits, masks=masks)
1158+
1159+
# Total loss.
11081160
loss = xentropy_loss + jacc_loss
11091161
return loss, (xentropy_loss, jacc_loss, logits)
11101162
```
11111163

1112-
Now, we will implement a confusion matrix metric derived from [`nnx.Metric`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/metrics.html#flax.nnx.metrics.Metric). A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.
1164+
Now, we will implement a confusion matrix metric derived from [`flax.nnx.Metric`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/metrics.html#flax.nnx.metrics.Metric). A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.
11131165

11141166
```{code-cell} ipython3
11151167
class ConfusionMatrix(nnx.Metric):
@@ -1226,8 +1278,9 @@ def eval_step(
12261278
) # In-place updates.
12271279
```
12281280

1229-
We will also define metrics we want to compute during the evaluation phase: total loss and confusion matrix computed on training and validation datasets. Finally, we define helper objects to store the metrics history.
1230-
Metrics like IoU per class, mean IoU and accuracy will be computed using the confusion matrix in the evaluation code.
1281+
Next, we'll define metrics for the evaluation phase: the total loss and the confusion matrix computed on training and validation datasets. And we'll also define helper objects to store the metrics history.
1282+
1283+
Metrics like IoU per class, mean IoU and accuracy will be calculated using the confusion matrix in the evaluation code.
12311284

12321285
```{code-cell} ipython3
12331286
eval_metrics = nnx.MultiMetric(

0 commit comments

Comments
 (0)