Skip to content

Commit 0af6b6b

Browse files
committed
Implement TorchIO transforms wrapper analogous to TorchVision transforms wrapper and test case
1 parent c86e790 commit 0af6b6b

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@
505505
ToDevice,
506506
ToNumpy,
507507
ToPIL,
508+
TorchIO,
508509
TorchVision,
509510
ToTensor,
510511
Transpose,

monai/transforms/utility/array.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
"ConvertToMultiChannelBasedOnBratsClasses",
9999
"AddExtremePointsChannel",
100100
"TorchVision",
101+
"TorchIO",
101102
"MapLabelValue",
102103
"IntensityStats",
103104
"ToDevice",
@@ -1163,6 +1164,42 @@ def __call__(self, img: NdarrayOrTensor):
11631164
return out
11641165

11651166

1167+
class TorchIO:
1168+
"""
1169+
This is a wrapper transform for TorchIO transforms based on the specified transform name and args.
1170+
As most of the TorchIO transforms only work for PyTorch Tensor, this transform expects input
1171+
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
1172+
1173+
"""
1174+
1175+
backend = [TransformBackends.TORCH]
1176+
1177+
def __init__(self, name: str, *args, **kwargs) -> None:
1178+
"""
1179+
Args:
1180+
name: The transform name in TorchIO package.
1181+
args: parameters for the TorchIO transform.
1182+
kwargs: parameters for the TorchIO transform.
1183+
1184+
"""
1185+
super().__init__()
1186+
self.name = name
1187+
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1188+
self.trans = transform(*args, **kwargs)
1189+
1190+
def __call__(self, img: NdarrayOrTensor):
1191+
"""
1192+
Args:
1193+
img: PyTorch Tensor data for the TorchIO transform.
1194+
1195+
"""
1196+
img_t, *_ = convert_data_type(img, torch.Tensor)
1197+
1198+
out = self.trans(img_t)
1199+
out, *_ = convert_to_dst_type(src=out, dst=img)
1200+
return out
1201+
1202+
11661203
class MapLabelValue:
11671204
"""
11681205
Utility to map label values to another set of values.

tests/test_torchio.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
from parameterized import parameterized
17+
import numpy as np
18+
import torch
19+
20+
from monai.transforms import TorchIO
21+
from monai.utils import set_determinism
22+
23+
TEST_DIMS = [3, 128, 160, 160]
24+
TESTS = [
25+
[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)],
26+
[{"name": "ZNormalization"}, torch.rand(TEST_DIMS)],
27+
[{"name": "RandomAffine"}, torch.rand(TEST_DIMS)],
28+
[{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)],
29+
[{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)],
30+
[{"name": "RandomMotion"}, torch.rand(TEST_DIMS)],
31+
[{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)],
32+
[{"name": "RandomSpike"}, torch.rand(TEST_DIMS)],
33+
[{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)],
34+
[{"name": "RandomBlur"}, torch.rand(TEST_DIMS)],
35+
[{"name": "RandomNoise"}, torch.rand(TEST_DIMS)],
36+
[{"name": "RandomSwap"}, torch.rand(TEST_DIMS)],
37+
[{"name": "RandomGamma"}, torch.rand(TEST_DIMS)],
38+
]
39+
40+
41+
class TestTorchIO(unittest.TestCase):
42+
43+
@parameterized.expand(TESTS)
44+
def test_value(self, input_param, input_data):
45+
set_determinism(seed=0)
46+
result = TorchIO(**input_param)(input_data)
47+
self.assertIsNotNone(result)
48+
self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()),
49+
f'{input_param} failed')
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)