You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/JAX_examples_image_segmentation.md
+95-42
Original file line number
Diff line number
Diff line change
@@ -12,7 +12,7 @@ kernelspec:
12
12
name: python3
13
13
---
14
14
15
-
# Train a transformer-based UNETR model for image segmentation with JAX
15
+
# Image segmentation with Vision Transformer and UNETR using JAX
16
16
17
17
[](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_examples_image_segmentation.ipynb)
18
18
@@ -24,6 +24,14 @@ The tutorial covers the preparation of the [Oxford Pets](https://www.robots.ox.a
24
24
25
25
The image above show the UNETR architecture for processing 3D inputs, but it can be adapted to 2D inputs.
26
26
27
+
By the end of this tutorial, you will learn how to:
28
+
29
+
- Prepare and preprocess the Oxford Pets dataset for image segmentation.
30
+
- Implement the UNETR model with a Vision Transformer encoder using Flax NNX.
31
+
- Train the model, evaluate its performance, and visualize predictions.
32
+
33
+
This tutorial assumes familiarity with JAX, Flax NNX, and basic deep learning and AI concepts. If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
34
+
27
35
+++
28
36
29
37
## Setup
@@ -57,6 +65,9 @@ import cv2
57
65
import numpy as np
58
66
from PIL import Image # we'll read images with opencv and use Pillow as a fallback
59
67
68
+
from typing import Any, Callable
69
+
import grain.python as grain
70
+
60
71
print("Jax version:", jax.__version__)
61
72
print("Flax version:", flax.__version__)
62
73
print("Optax version:", optax.__version__)
@@ -93,7 +104,7 @@ We can nspect the `images` folder, listing a subset of these files:
93
104
94
105
### Splitting the dataset into training and validation sets
95
106
96
-
Next, we'll implement the `OxfordPetsDataset` class providing the access to the images and masks. The class implements `__len__` and `__getitem__` methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose, we create a helper `SubsetDataset` class to map indices into training and validation (test) set parts.
107
+
Next, we'll create the `OxfordPetsDataset` class providing access to our images and masks. The class implements `__len__` and `__getitem__` methods. In this example, we do not have a hard training and validation data split, so we will use the total dataset and make a random training/validation split by indices. For this purpose, we create a helper `SubsetDataset` class to map indices into training and validation (test) set parts.
Here, we'll define a simple data augmentation pipeline of joined image and mask transformations using [Albumentations](https://albumentations.ai/docs/examples/example/). We will apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, check the [Albumentations reference API](https://albumentations.ai/docs/api_reference/full_reference/).
237
+
Data augmentation can be important for increasing the diversity of the dataset, which includes random rotations, resizing crops, horizontal flips, and brightness/contrast adjustments. In this section, we'll define a simple data augmentation pipeline of joined image and mask transformations using [Albumentations](https://albumentations.ai/docs/examples/example/), so that we can apply geometric and color transformations to increase the diversity of the training data. For more details on the Albumentations transformations, check the [Albumentations reference API](https://albumentations.ai/docs/api_reference/full_reference/).
227
238
228
239
```{code-cell} ipython3
229
240
img_size = 256
230
241
231
242
train_transforms = A.Compose([
232
243
A.Affine(rotate=(-35, 35), cval_mask=1, p=0.3), # Random rotations -35 to 35 degrees
233
244
A.RandomResizedCrop(width=img_size, height=img_size, scale=(0.7, 1.0)), # Crop a random part of the input and rescale it to a specified size
234
-
A.HorizontalFlip(p=0.5), # Horizontal random flip
235
-
A.RandomBrightnessContrast(p=0.4), # Randomly changes the brightness and contrast
245
+
A.HorizontalFlip(p=0.5), # Horizontal random flip.
246
+
A.RandomBrightnessContrast(p=0.4), # Randomly changes the brightness and contrast.
236
247
A.Normalize(), # Normalize the image and cast to float
237
248
])
238
249
@@ -243,6 +254,16 @@ val_transforms = A.Compose([
243
254
])
244
255
```
245
256
257
+
In the code above:
258
+
259
+
-`Affine`: Applies random rotations to augment the dataset.
260
+
-`RandomResizedCrop`: Crops a random part of the image and then rescales it.
### Data loading with `grain.IndexSampler` and `grain.DataLoader`
261
282
262
-
Let's now use [`grain`](https://github.com/google/grain) to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation.
283
+
Let's now use [`grain`](https://github.com/google/grain) to perform data loading, augmentations and batching on a single device using multiple workers. We will create a random index sampler for training and an unshuffled sampler for validation. Note that using multiple workers (`worker_count`) allows us to parallelize data transformations, speeding up the data loading process.
## Defining the UNETR architecture with the ViT encoder
385
+
## Implementing the UNETR architecture with the ViT encoder
368
386
369
387
In this section, we will implement the UNETR model from scratch using Flax NNX. The transformer encoder of UNETR is a Vision Transformer (ViT), as discussed in the beginning of this tutorial The feature maps returned by ViT have the same spatial size (`H / 16, W / 16`), and deconvolutions are used to upsample the feature maps, while the feature maps are upsampled and concatenated up to the original image size.
370
388
371
389
The reference PyTorch implementation of this model can be found on the [MONAI Library GitHub repository](https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/nets/unetr.py).
372
390
373
391
### The ViT encoder implementation
374
392
375
-
Here, we will implement the following modules of the ViT:
393
+
Here, we will implement the following modules of the ViT, based on the ViT paper (["An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"](https://arxiv.org/abs/2010.11929)):
376
394
377
395
-`PatchEmbeddingBlock`: The patch embedding block, which maps patches of pixels to a sequence of vectors.
396
+
-`MLPBlock`: The multilayer perceptron (MLP) block.
378
397
-`ViTEncoderBlock`: The ViT encoder block.
379
-
-`MLPBlock`: The multilayer perceptron (MLP) block.
380
398
381
399
```{code-cell} ipython3
382
400
---
@@ -385,7 +403,7 @@ jupyter:
385
403
---
386
404
class PatchEmbeddingBlock(nnx.Module):
387
405
"""
388
-
A patch embedding block, based on the ViT ("An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" https://arxiv.org/abs/2010.11929
406
+
A patch embedding block, based on the ViT paper.
389
407
390
408
Args:
391
409
in_channels (int): Number of input channels in the image (such as 3 for RGB).
@@ -407,6 +425,7 @@ class PatchEmbeddingBlock(nnx.Module):
407
425
rngs: nnx.Rngs = nnx.Rngs(0),
408
426
):
409
427
n_patches = (img_size // patch_size) ** 2
428
+
# The convolution to extract patch embeddings using `flax.nnx.Conv`.
410
429
self.patch_embeddings = nnx.Conv(
411
430
in_channels,
412
431
hidden_size,
@@ -417,20 +436,27 @@ class PatchEmbeddingBlock(nnx.Module):
417
436
rngs=rngs,
418
437
)
419
438
439
+
# Positional embeddings for each patch using `flax.nnx.Param` and `jax.nn.initializers.truncated_normal`.
# Apply the convolution to extract patch embeddings.
427
449
x = self.patch_embeddings(x)
450
+
# Reshape for adding positional embeddings.
428
451
x = x.reshape(x.shape[0], -1, x.shape[-1])
452
+
# Add positional embeddings.
429
453
embeddings = x + self.position_embeddings
454
+
# Apply dropout for regularization.
430
455
embeddings = self.dropout(embeddings)
431
456
return embeddings
432
457
433
458
459
+
# Instantiate the patch embedding block.
434
460
mod = PatchEmbeddingBlock(3, 256, 16, 768, 0.5)
435
461
x = jnp.ones((4, 256, 256, 3))
436
462
y = mod(x)
@@ -447,13 +473,19 @@ from typing import Callable
447
473
448
474
class MLPBlock(nnx.Sequential):
449
475
"""
450
-
A multi-layer perceptron block, based on: "Dosovitskiy et al.,
451
-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
476
+
A multi-layer perceptron (MLP) block, inheriting from `flax.nnx.Module`.
477
+
478
+
Args:
479
+
hidden_size (int): Dimensionality of the hidden layer.
480
+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block.
481
+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
482
+
activation_layer: Activation function. Defaults to `flax.nnx.gelu` (GeLU).
483
+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
452
484
"""
453
485
def __init__(
454
486
self,
455
-
hidden_size: int, # dimension of hidden layer.
456
-
mlp_dim: int, # dimension of feedforward layer
487
+
hidden_size: int, # Dimension of hidden layer.
488
+
mlp_dim: int, # Dimension of feedforward layer
457
489
dropout_rate: float = 0.0,
458
490
activation_layer: Callable = nnx.gelu,
459
491
*,
@@ -468,7 +500,7 @@ class MLPBlock(nnx.Sequential):
468
500
]
469
501
super().__init__(*layers)
470
502
471
-
503
+
# Instantiate the MLP block.
472
504
mod = MLPBlock(768, 3072, 0.5)
473
505
x = jnp.ones((4, 256, 768))
474
506
y = mod(x)
@@ -482,14 +514,21 @@ jupyter:
482
514
---
483
515
class ViTEncoderBlock(nnx.Module):
484
516
"""
485
-
A transformer encoder block, based on: "Dosovitskiy et al.,
486
-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
517
+
A ViT encoder block, inheriting from `flax.nnx.Module`.
518
+
519
+
Args:
520
+
hidden_size (int): Dimensionality of the hidden layer.
521
+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block.
522
+
num_heads (int): Number of attention heads in each transformer layer.
523
+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
524
+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
525
+
487
526
"""
488
527
def __init__(
489
528
self,
490
-
hidden_size: int, # dimension of hidden layer.
491
-
mlp_dim: int, # dimension of feedforward layer.
492
-
num_heads: int, # number of attention heads
529
+
hidden_size: int, # Dimension of hidden layer.
530
+
mlp_dim: int, # Dimension of feedforward layer.
531
+
num_heads: int, # Number of attention heads
493
532
dropout_rate: float = 0.0,
494
533
*,
495
534
rngs: nnx.Rngs = nnx.Rngs(0),
@@ -511,7 +550,7 @@ class ViTEncoderBlock(nnx.Module):
511
550
x = x + self.mlp(self.norm2(x))
512
551
return x
513
552
514
-
553
+
# Instantiate the ViT encoder block.
515
554
mod = ViTEncoderBlock(768, 3072, 12)
516
555
x = jnp.ones((4, 256, 768))
517
556
y = mod(x)
@@ -524,19 +563,28 @@ jupyter:
524
563
source_hidden: true
525
564
---
526
565
class ViT(nnx.Module):
527
-
"""
528
-
Vision Transformer (ViT) Feature Extractor, based on: "Dosovitskiy et al.,
529
-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
566
+
""" Implements the ViT feature extractor, inheriting from `flax.nnx.Module`.
567
+
568
+
Args:
569
+
in_channels (int): Number of input channels in the image (such as 3 for RGB)..
570
+
img_size (int): Input image size.
571
+
patch_size (int): Size of the patches extracted from the image.
572
+
hidden_size (int): Dimensionality of the embedding vectors. Defaults to 768.
573
+
mlp_dim (int): Dimension of the hidden layers in the feed-forward/MLP block. Defaults to 3072.
574
+
num_layers (int): Number of transformer encoder layers. Defaults to 12.
575
+
num_heads (int): Number of attention heads in each transformer layer. Defaults to 12.
576
+
dropout_rate (int): Dropout rate (for regularization). Defaults to 0.0.
577
+
rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys. Defaults to `flax.nnx.Rngs(0)`.
530
578
"""
531
579
def __init__(
532
580
self,
533
-
in_channels: int, # dimension of input channels
534
-
img_size: int, # dimension of input image
535
-
patch_size: int, # dimension of patch size
536
-
hidden_size: int = 768, # dimension of hidden layer
537
-
mlp_dim: int = 3072, # dimension of feedforward layer
538
-
num_layers: int = 12, # number of transformer blocks
539
-
num_heads: int = 12, # number of attention heads
581
+
in_channels: int, # Dimension of input channels.
582
+
img_size: int, # Dimension of input image.
583
+
patch_size: int, # Dimension of patch size.
584
+
hidden_size: int = 768, # Dimension of hidden layer.
585
+
mlp_dim: int = 3072, # Dimension of feedforward layer.
586
+
num_layers: int = 12, # Number of transformer blocks.
Now, we will implement a confusion matrix metric derived from [`nnx.Metric`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/metrics.html#flax.nnx.metrics.Metric). A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.
1164
+
Now, we will implement a confusion matrix metric derived from [`flax.nnx.Metric`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/metrics.html#flax.nnx.metrics.Metric). A confusion matrix will help us to compute the Intersection-Over-Union (IoU) metric per class and on average. Finally, we can also compute the accuracy metric using the confusion matrix.
1113
1165
1114
1166
```{code-cell} ipython3
1115
1167
class ConfusionMatrix(nnx.Metric):
@@ -1226,8 +1278,9 @@ def eval_step(
1226
1278
) # In-place updates.
1227
1279
```
1228
1280
1229
-
We will also define metrics we want to compute during the evaluation phase: total loss and confusion matrix computed on training and validation datasets. Finally, we define helper objects to store the metrics history.
1230
-
Metrics like IoU per class, mean IoU and accuracy will be computed using the confusion matrix in the evaluation code.
1281
+
Next, we'll define metrics for the evaluation phase: the total loss and the confusion matrix computed on training and validation datasets. And we'll also define helper objects to store the metrics history.
1282
+
1283
+
Metrics like IoU per class, mean IoU and accuracy will be calculated using the confusion matrix in the evaluation code.
0 commit comments