Skip to content

Commit 6bf01f6

Browse files
committed
update datamodules how-to-guides
Signed-off-by: Samet Akcay <[email protected]>
1 parent 47d111b commit 6bf01f6

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

docs/source/markdown/guides/how_to/data/datamodules.md

+15-18
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ DataModules encapsulate all the steps needed to process data:
2020
A typical Anomalib DataModule follows this structure:
2121

2222
```python
23-
from pytorch_lightning import LightningDataModule
24-
from anomalib.data import AnomalibDataset
23+
from lightning.pytorch import LightningDataModule
24+
from anomalib.data.datasets.base.image import AnomalibDataset
2525
from torch.utils.data import DataLoader
2626

2727
class AnomalibDataModule(LightningDataModule):
@@ -108,7 +108,6 @@ from anomalib.data import MVTec
108108
datamodule = MVTec(
109109
root="./datasets/MVTec",
110110
category="bottle",
111-
image_size=(256, 256),
112111
train_batch_size=32,
113112
eval_batch_size=32,
114113
num_workers=8
@@ -132,16 +131,14 @@ for batch in train_loader:
132131
from anomalib.data import Avenue
133132

134133
datamodule = Avenue(
135-
root="./datasets/Avenue",
136-
frame_size=(256, 256),
137-
train_batch_size=8,
138-
clip_length=16
134+
clip_length_in_frames=2,
135+
frames_between_clips=1,
136+
target_frame="last",
139137
)
140-
141-
# Access video batches
142-
for batch in datamodule.train_dataloader():
143-
print(batch.frames.shape) # torch.Size([8, 16, 3, 256, 256])
144-
print(batch.target_frame) # Frame indices
138+
datamodule.setup()
139+
i, data = next(enumerate(datamodule.train_dataloader()))
140+
data["image"].shape
141+
# torch.Size([32, 2, 3, 256, 256])
145142
```
146143

147144
### 3. Depth DataModule
@@ -152,14 +149,15 @@ from anomalib.data import MVTec3D
152149
datamodule = MVTec3D(
153150
root="./datasets/MVTec3D",
154151
category="bagel",
155-
image_size=(256, 256),
156-
train_batch_size=32
152+
train_batch_size=32,
157153
)
158154

159155
# Access RGB-D batches
160-
for batch in datamodule.train_dataloader():
161-
print(batch.image.shape) # RGB images
162-
print(batch.depth_map.shape) # Depth maps
156+
i, data = next(enumerate(datamodule.train_dataloader()))
157+
data["image"].shape
158+
# torch.Size([32, 3, 256, 256])
159+
data["depth_map"].shape
160+
# torch.Size([32, 1, 256, 256])
163161
```
164162

165163
## Creating Custom DataModules
@@ -176,7 +174,6 @@ class CustomDataModule(LightningDataModule):
176174
self,
177175
root: str,
178176
category: str,
179-
image_size: tuple[int, int] = (256, 256),
180177
train_batch_size: int = 32,
181178
**kwargs
182179
):

0 commit comments

Comments
 (0)