Skip to content

Commit e270741

Browse files
authored
4636 4637 backward compatible types (#4638)
Signed-off-by: Wenqi Li <[email protected]>
1 parent 4ddd2bc commit e270741

23 files changed

+152
-99
lines changed

monai/data/meta_tensor.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
from copy import deepcopy
1616
from typing import Any, Sequence
1717

18+
import numpy as np
1819
import torch
1920

2021
from monai.config.type_definitions import NdarrayTensor
2122
from monai.data.meta_obj import MetaObj, get_track_meta
2223
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
24+
from monai.utils import look_up_option
2325
from monai.utils.enums import PostFix
24-
from monai.utils.type_conversion import convert_to_tensor
26+
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
2527

2628
__all__ = ["MetaTensor"]
2729

@@ -307,6 +309,33 @@ def as_dict(self, key: str) -> dict:
307309
PostFix.transforms(key): deepcopy(self.applied_operations),
308310
}
309311

312+
def astype(self, dtype, device=None, *unused_args, **unused_kwargs):
313+
"""
314+
Cast to ``dtype``, sharing data whenever possible.
315+
316+
Args:
317+
dtype: dtypes such as np.float32, torch.float, "np.float32", float.
318+
device: the device if `dtype` is a torch data type.
319+
unused_args: additional args (currently unused).
320+
unused_kwargs: additional kwargs (currently unused).
321+
322+
Returns:
323+
data array instance
324+
"""
325+
if isinstance(dtype, str):
326+
mod_str, *dtype = dtype.split(".", 1)
327+
dtype = mod_str if not dtype else dtype[0]
328+
else:
329+
mod_str = getattr(dtype, "__module__", "torch")
330+
mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy")
331+
if mod_str == "torch":
332+
out_type = torch.Tensor
333+
elif mod_str in ("numpy", "np"):
334+
out_type = np.ndarray
335+
else:
336+
out_type = None
337+
return convert_data_type(self, output_type=out_type, device=device, dtype=dtype, wrap_sequence=True)[0]
338+
310339
@property
311340
def affine(self) -> torch.Tensor:
312341
"""Get the affine."""
@@ -334,7 +363,7 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False):
334363
)
335364

336365
@staticmethod
337-
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
366+
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False):
338367
"""
339368
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
340369
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
@@ -353,12 +382,12 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
353382
if not get_track_meta() or meta is None:
354383
return img
355384

356-
# ensure affine is of type `torch.Tensor`
357-
if "affine" in meta:
358-
meta["affine"] = convert_to_tensor(meta["affine"])
359-
360385
# remove any superfluous metadata.
361-
remove_extra_metadata(meta)
386+
if simple_keys:
387+
# ensure affine is of type `torch.Tensor`
388+
if "affine" in meta:
389+
meta["affine"] = convert_to_tensor(meta["affine"]) # bc-breaking
390+
remove_extra_metadata(meta) # bc-breaking
362391

363392
# return the `MetaTensor`
364393
return MetaTensor(img, meta=meta)

monai/transforms/io/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,10 @@ class LoadImage(Transform):
108108
def __init__(
109109
self,
110110
reader=None,
111-
image_only: bool = True,
111+
image_only: bool = False,
112112
dtype: DtypeLike = np.float32,
113113
ensure_channel_first: bool = False,
114+
simple_keys: bool = False,
114115
*args,
115116
**kwargs,
116117
) -> None:
@@ -127,6 +128,7 @@ def __init__(
127128
dtype: if not None convert the loaded image to this data type.
128129
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
129130
the image array shape to `channel first`. default to `False`.
131+
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
130132
args: additional parameters for reader if providing a reader name.
131133
kwargs: additional parameters for reader if providing a reader name.
132134
@@ -145,6 +147,7 @@ def __init__(
145147
self.image_only = image_only
146148
self.dtype = dtype
147149
self.ensure_channel_first = ensure_channel_first
150+
self.simple_keys = simple_keys
148151

149152
self.readers: List[ImageReader] = []
150153
for r in SUPPORTED_READERS: # set predefined readers as default
@@ -255,7 +258,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
255258
meta_data = switch_endianness(meta_data, "<")
256259

257260
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
258-
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
261+
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data, self.simple_keys)
259262
if self.ensure_channel_first:
260263
img = EnsureChannelFirst()(img)
261264
if self.image_only:

monai/transforms/io/dictionary.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ def __init__(
7272
meta_keys: Optional[KeysCollection] = None,
7373
meta_key_postfix: str = DEFAULT_POST_FIX,
7474
overwriting: bool = False,
75-
image_only: bool = True,
75+
image_only: bool = False,
7676
ensure_channel_first: bool = False,
77+
simple_keys: bool = False,
7778
allow_missing_keys: bool = False,
7879
*args,
7980
**kwargs,
@@ -103,12 +104,13 @@ def __init__(
103104
dictionary containing image data array and header dict per input key.
104105
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
105106
the image array shape to `channel first`. default to `False`.
107+
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
106108
allow_missing_keys: don't raise exception if key is missing.
107109
args: additional parameters for reader if providing a reader name.
108110
kwargs: additional parameters for reader if providing a reader name.
109111
"""
110112
super().__init__(keys, allow_missing_keys)
111-
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs)
113+
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, simple_keys, *args, **kwargs)
112114
if not isinstance(meta_key_postfix, str):
113115
raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
114116
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)

monai/transforms/utility/array.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,8 @@ class EnsureType(Transform):
435435
device: for Tensor data type, specify the target device.
436436
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
437437
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
438+
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
439+
If False, the output data type will be `torch.Tensor`. Default to the return value of ``get_track_meta``.
438440
439441
"""
440442

@@ -446,11 +448,13 @@ def __init__(
446448
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
447449
device: Optional[torch.device] = None,
448450
wrap_sequence: bool = True,
451+
track_meta: Optional[bool] = None,
449452
) -> None:
450453
self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
451454
self.dtype = dtype
452455
self.device = device
453456
self.wrap_sequence = wrap_sequence
457+
self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)
454458

455459
def __call__(self, data: NdarrayOrTensor):
456460
"""
@@ -461,10 +465,17 @@ def __call__(self, data: NdarrayOrTensor):
461465
if applicable and `wrap_sequence=False`.
462466
463467
"""
464-
output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray
468+
if self.data_type == "tensor":
469+
output_type = MetaTensor if self.track_meta else torch.Tensor
470+
else:
471+
output_type = np.ndarray # type: ignore
465472
out: NdarrayOrTensor
466473
out, *_ = convert_data_type(
467-
data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence
474+
data=data,
475+
output_type=output_type, # type: ignore
476+
dtype=self.dtype,
477+
device=self.device,
478+
wrap_sequence=self.wrap_sequence,
468479
)
469480
return out
470481

monai/transforms/utility/dictionary.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
)
6363
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
6464
from monai.transforms.utils_pytorch_numpy_unification import concatenate
65-
from monai.utils import convert_to_numpy, deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
65+
from monai.utils import deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
6666
from monai.utils.enums import PostFix, TraceKeys, TransformBackends
6767
from monai.utils.type_conversion import convert_to_dst_type
6868

@@ -519,7 +519,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
519519
return d
520520

521521

522-
class EnsureTyped(MapTransform, InvertibleTransform):
522+
class EnsureTyped(MapTransform):
523523
"""
524524
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureType`.
525525
@@ -541,6 +541,7 @@ def __init__(
541541
dtype: Union[DtypeLike, torch.dtype] = None,
542542
device: Optional[torch.device] = None,
543543
wrap_sequence: bool = True,
544+
track_meta: Optional[bool] = None,
544545
allow_missing_keys: bool = False,
545546
) -> None:
546547
"""
@@ -552,28 +553,21 @@ def __init__(
552553
device: for Tensor data type, specify the target device.
553554
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
554555
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
556+
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
557+
If False, the output data type will be `torch.Tensor`. Default to the return value of `get_track_meta`.
555558
allow_missing_keys: don't raise exception if key is missing.
556559
"""
557560
super().__init__(keys, allow_missing_keys)
558-
self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence)
561+
self.converter = EnsureType(
562+
data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
563+
)
559564

560565
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
561566
d = dict(data)
562567
for key in self.key_iterator(d):
563-
self.push_transform(d, key)
564568
d[key] = self.converter(d[key])
565569
return d
566570

567-
def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
568-
d = deepcopy(dict(data))
569-
for key in self.key_iterator(d):
570-
# FIXME: currently, only convert tensor data to numpy array or scalar number,
571-
# need to also invert numpy array but it's not easy to determine the previous data type
572-
d[key] = convert_to_numpy(d[key])
573-
# Remove the applied transform
574-
self.pop_transform(d, key)
575-
return d
576-
577571

578572
class ToNumpyd(MapTransform):
579573
"""

monai/utils/type_conversion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
172172
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
173173
"""
174174
if isinstance(data, torch.Tensor):
175-
data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy()
175+
data = np.asarray(data.detach().to(device="cpu").numpy(), dtype=get_equivalent_dtype(dtype, np.ndarray))
176176
elif has_cp and isinstance(data, cp_ndarray):
177177
data = cp.asnumpy(data).astype(dtype, copy=False)
178178
elif isinstance(data, (np.ndarray, float, int, bool)):
@@ -235,12 +235,13 @@ def convert_data_type(
235235
wrap_sequence: bool = False,
236236
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
237237
"""
238-
Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc.
238+
Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`,
239+
`np.ndarray`, `float`, `int`, etc.
239240
240241
Args:
241242
data: data to be converted
242-
output_type: `torch.Tensor` or `np.ndarray` (if `None`, unchanged)
243-
device: if output is `torch.Tensor`, select device (if `None`, unchanged)
243+
output_type: `monai.data.MetaTensor`, `torch.Tensor`, or `np.ndarray` (if `None`, unchanged)
244+
device: if output is `MetaTensor` or `torch.Tensor`, select device (if `None`, unchanged)
244245
dtype: dtype of output data. Converted to correct library type (e.g.,
245246
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
246247
If left blank, it remains unchanged.

tests/test_arraydataset.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
from monai.transforms import AddChannel, Compose, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing
2424

2525
TEST_CASE_1 = [
26-
Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]),
27-
Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]),
26+
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
27+
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
2828
(0, 1),
2929
(1, 128, 128, 128),
3030
]
3131

3232
TEST_CASE_2 = [
33-
Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]),
34-
Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]),
33+
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
34+
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
3535
(0, 1),
3636
(1, 128, 128, 128),
3737
]
@@ -48,13 +48,13 @@ def __call__(self, input_):
4848

4949

5050
TEST_CASE_3 = [
51-
TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
52-
TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
51+
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
52+
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
5353
(0, 2),
5454
(1, 64, 64, 33),
5555
]
5656

57-
TEST_CASE_4 = [Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)]
57+
TEST_CASE_4 = [Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)]
5858

5959

6060
class TestArrayDataset(unittest.TestCase):

tests/test_decollate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_decollation_dict(self, *transforms):
131131
t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
132132
# If nibabel present, read from disk
133133
if has_nib:
134-
t_compose = Compose([LoadImaged("image"), t_compose])
134+
t_compose = Compose([LoadImaged("image", image_only=True), t_compose])
135135

136136
dataset = CacheDataset(self.data_dict, t_compose, progress=False)
137137
self.check_decollate(dataset=dataset)
@@ -141,7 +141,7 @@ def test_decollation_tensor(self, *transforms):
141141
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
142142
# If nibabel present, read from disk
143143
if has_nib:
144-
t_compose = Compose([LoadImage(), t_compose])
144+
t_compose = Compose([LoadImage(image_only=True), t_compose])
145145

146146
dataset = Dataset(self.data_list, t_compose)
147147
self.check_decollate(dataset=dataset)
@@ -151,7 +151,7 @@ def test_decollation_list(self, *transforms):
151151
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
152152
# If nibabel present, read from disk
153153
if has_nib:
154-
t_compose = Compose([LoadImage(), t_compose])
154+
t_compose = Compose([LoadImage(image_only=True), t_compose])
155155

156156
dataset = Dataset(self.data_list, t_compose)
157157
self.check_decollate(dataset=dataset)

tests/test_ensure_channel_first.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim):
5252
filenames[i] = os.path.join(tempdir, name)
5353
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
5454

55-
result = LoadImage(**input_param)(filenames)
55+
result = LoadImage(image_only=True, **input_param)(filenames)
5656
result = EnsureChannelFirst()(result)
5757
self.assertEqual(result.shape[0], len(filenames))
5858

5959
@parameterized.expand([TEST_CASE_7])
6060
def test_itk_dicom_series_reader(self, input_param, filenames, _):
61-
result = LoadImage(**input_param)(filenames)
61+
result = LoadImage(image_only=True, **input_param)(filenames)
6262
result = EnsureChannelFirst()(result)
6363
self.assertEqual(result.shape[0], 1)
6464

@@ -68,7 +68,7 @@ def test_load_png(self):
6868
with tempfile.TemporaryDirectory() as tempdir:
6969
filename = os.path.join(tempdir, "test_image.png")
7070
Image.fromarray(test_image.astype("uint8")).save(filename)
71-
result = LoadImage()(filename)
71+
result = LoadImage(image_only=True)(filename)
7272
result = EnsureChannelFirst()(result)
7373
self.assertEqual(result.shape[0], 3)
7474

tests/test_ensure_type.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import numpy as np
1515
import torch
1616

17+
from monai.data import MetaTensor
1718
from monai.transforms import EnsureType
1819
from tests.utils import assert_allclose
1920

@@ -59,9 +60,9 @@ def test_string(self):
5960

6061
def test_list_tuple(self):
6162
for dtype in ("tensor", "numpy"):
62-
result = EnsureType(data_type=dtype, wrap_sequence=False)([[1, 2], [3, 4]])
63+
result = EnsureType(data_type=dtype, wrap_sequence=False, track_meta=True)([[1, 2], [3, 4]])
6364
self.assertTrue(isinstance(result, list))
64-
self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray))
65+
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
6566
torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
6667
# tuple of numpy arrays
6768
result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4])))
@@ -77,7 +78,7 @@ def test_dict(self):
7778
"extra": None,
7879
}
7980
for dtype in ("tensor", "numpy"):
80-
result = EnsureType(data_type=dtype)(test_data)
81+
result = EnsureType(data_type=dtype, track_meta=False)(test_data)
8182
self.assertTrue(isinstance(result, dict))
8283
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
8384
torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))

0 commit comments

Comments
 (0)