Skip to content

Commit 44e249d

Browse files
SomeUserName1KumoLiuericspod
authored
Implement TorchIO transforms wrapper analogous to TorchVision transfo… (#7579)
…rms wrapper and test case Fixes #7499 . ### Description As discussed in the issue, this PR implements a wrapper class for TorchIO transforms, analogous to the TorchVision transforms wrapper. The test cases just check that transforms are callable and that after applying a transform, the result is different from the inputs. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Fabian Klopfer <[email protected]> Signed-off-by: Fabian Klopfer <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Fabian Klopfer <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 20372f0 commit 44e249d

12 files changed

+397
-6
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd
149149
# clang format tool
150150
.clang-format-bin/
151151

152+
# ctags
153+
tags
154+
152155
# VSCode
153156
.vscode/
154157
*.zip

docs/source/transforms.rst

+24
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,18 @@ Utility
11801180
:members:
11811181
:special-members: __call__
11821182

1183+
`TorchIO`
1184+
"""""""""
1185+
.. autoclass:: TorchIO
1186+
:members:
1187+
:special-members: __call__
1188+
1189+
`RandTorchIO`
1190+
"""""""""""""
1191+
.. autoclass:: RandTorchIO
1192+
:members:
1193+
:special-members: __call__
1194+
11831195
`MapLabelValue`
11841196
"""""""""""""""
11851197
.. autoclass:: MapLabelValue
@@ -2253,6 +2265,18 @@ Utility (Dict)
22532265
:members:
22542266
:special-members: __call__
22552267

2268+
`TorchIOd`
2269+
""""""""""
2270+
.. autoclass:: TorchIOd
2271+
:members:
2272+
:special-members: __call__
2273+
2274+
`RandTorchIOd`
2275+
""""""""""""""
2276+
.. autoclass:: RandTorchIOd
2277+
:members:
2278+
:special-members: __call__
2279+
22562280
`MapLabelValued`
22572281
""""""""""""""""
22582282
.. autoclass:: MapLabelValued

environment-dev.yml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ channels:
77
dependencies:
88
- numpy>=1.24,<2.0
99
- pytorch>=1.9
10+
- torchio
1011
- torchvision
1112
- pytorch-cuda>=11.6
1213
- pip

monai/transforms/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@
531531
RandIdentity,
532532
RandImageFilter,
533533
RandLambda,
534+
RandTorchIO,
535+
RandTorchVision,
534536
RemoveRepeatedChannel,
535537
RepeatChannel,
536538
SimulateDelay,
@@ -540,6 +542,7 @@
540542
ToDevice,
541543
ToNumpy,
542544
ToPIL,
545+
TorchIO,
543546
TorchVision,
544547
ToTensor,
545548
Transpose,
@@ -620,6 +623,9 @@
620623
RandLambdad,
621624
RandLambdaD,
622625
RandLambdaDict,
626+
RandTorchIOd,
627+
RandTorchIOD,
628+
RandTorchIODict,
623629
RandTorchVisiond,
624630
RandTorchVisionD,
625631
RandTorchVisionDict,
@@ -653,6 +659,9 @@
653659
ToPILd,
654660
ToPILD,
655661
ToPILDict,
662+
TorchIOd,
663+
TorchIOD,
664+
TorchIODict,
656665
TorchVisiond,
657666
TorchVisionD,
658667
TorchVisionDict,

monai/transforms/utility/array.py

+103-6
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import sys
1919
import time
2020
import warnings
21-
from collections.abc import Mapping, Sequence
21+
from collections.abc import Hashable, Mapping, Sequence
2222
from copy import deepcopy
2323
from functools import partial
24-
from typing import Any, Callable
24+
from typing import Any, Callable, Union
2525

2626
import numpy as np
2727
import torch
@@ -99,11 +99,14 @@
9999
"ConvertToMultiChannelBasedOnBratsClasses",
100100
"AddExtremePointsChannel",
101101
"TorchVision",
102+
"TorchIO",
102103
"MapLabelValue",
103104
"IntensityStats",
104105
"ToDevice",
105106
"CuCIM",
106107
"RandCuCIM",
108+
"RandTorchIO",
109+
"RandTorchVision",
107110
"ToCupy",
108111
"ImageFilter",
109112
"RandImageFilter",
@@ -1136,12 +1139,44 @@ def __call__(
11361139
return concatenate((img, points_image), axis=0)
11371140

11381141

1139-
class TorchVision:
1142+
class TorchVision(Transform):
11401143
"""
1141-
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
1142-
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
1143-
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
1144+
This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
1145+
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
1146+
"""
1147+
1148+
backend = [TransformBackends.TORCH]
1149+
1150+
def __init__(self, name: str, *args, **kwargs) -> None:
1151+
"""
1152+
Args:
1153+
name: The transform name in TorchVision package.
1154+
args: parameters for the TorchVision transform.
1155+
kwargs: parameters for the TorchVision transform.
1156+
1157+
"""
1158+
super().__init__()
1159+
self.name = name
1160+
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
1161+
self.trans = transform(*args, **kwargs)
1162+
1163+
def __call__(self, img: NdarrayOrTensor):
1164+
"""
1165+
Args:
1166+
img: PyTorch Tensor data for the TorchVision transform.
11441167
1168+
"""
1169+
img_t, *_ = convert_data_type(img, torch.Tensor)
1170+
1171+
out = self.trans(img_t)
1172+
out, *_ = convert_to_dst_type(src=out, dst=img)
1173+
return out
1174+
1175+
1176+
class RandTorchVision(Transform, RandomizableTrait):
1177+
"""
1178+
This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
1179+
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
11451180
"""
11461181

11471182
backend = [TransformBackends.TORCH]
@@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor):
11721207
return out
11731208

11741209

1210+
class TorchIO(Transform):
1211+
"""
1212+
This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
1213+
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1214+
"""
1215+
1216+
backend = [TransformBackends.TORCH]
1217+
1218+
def __init__(self, name: str, *args, **kwargs) -> None:
1219+
"""
1220+
Args:
1221+
name: The transform name in TorchIO package.
1222+
args: parameters for the TorchIO transform.
1223+
kwargs: parameters for the TorchIO transform.
1224+
"""
1225+
super().__init__()
1226+
self.name = name
1227+
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1228+
self.trans = transform(*args, **kwargs)
1229+
1230+
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1231+
"""
1232+
Args:
1233+
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1234+
or dict containing 4D tensors as values
1235+
1236+
"""
1237+
return self.trans(img)
1238+
1239+
1240+
class RandTorchIO(Transform, RandomizableTrait):
1241+
"""
1242+
This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
1243+
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
1244+
Use this wrapper for all TorchIO transform inheriting from RandomTransform:
1245+
https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
1246+
"""
1247+
1248+
backend = [TransformBackends.TORCH]
1249+
1250+
def __init__(self, name: str, *args, **kwargs) -> None:
1251+
"""
1252+
Args:
1253+
name: The transform name in TorchIO package.
1254+
args: parameters for the TorchIO transform.
1255+
kwargs: parameters for the TorchIO transform.
1256+
"""
1257+
super().__init__()
1258+
self.name = name
1259+
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
1260+
self.trans = transform(*args, **kwargs)
1261+
1262+
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
1263+
"""
1264+
Args:
1265+
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
1266+
or dict containing 4D tensors as values
1267+
1268+
"""
1269+
return self.trans(img)
1270+
1271+
11751272
class MapLabelValue:
11761273
"""
11771274
Utility to map label values to another set of values.

monai/transforms/utility/dictionary.py

+67
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ToDevice,
6161
ToNumpy,
6262
ToPIL,
63+
TorchIO,
6364
TorchVision,
6465
ToTensor,
6566
Transpose,
@@ -136,6 +137,9 @@
136137
"RandLambdaD",
137138
"RandLambdaDict",
138139
"RandLambdad",
140+
"RandTorchIOd",
141+
"RandTorchIOD",
142+
"RandTorchIODict",
139143
"RandTorchVisionD",
140144
"RandTorchVisionDict",
141145
"RandTorchVisiond",
@@ -172,6 +176,9 @@
172176
"ToTensorD",
173177
"ToTensorDict",
174178
"ToTensord",
179+
"TorchIOD",
180+
"TorchIODict",
181+
"TorchIOd",
175182
"TorchVisionD",
176183
"TorchVisionDict",
177184
"TorchVisiond",
@@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
14451452
return d
14461453

14471454

1455+
class TorchIOd(MapTransform):
1456+
"""
1457+
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
1458+
For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
1459+
"""
1460+
1461+
backend = TorchIO.backend
1462+
1463+
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1464+
"""
1465+
Args:
1466+
keys: keys of the corresponding items to be transformed.
1467+
See also: :py:class:`monai.transforms.compose.MapTransform`
1468+
name: The transform name in TorchIO package.
1469+
allow_missing_keys: don't raise exception if key is missing.
1470+
args: parameters for the TorchIO transform.
1471+
kwargs: parameters for the TorchIO transform.
1472+
1473+
"""
1474+
super().__init__(keys, allow_missing_keys)
1475+
self.name = name
1476+
kwargs["include"] = self.keys
1477+
1478+
self.trans = TorchIO(name, *args, **kwargs)
1479+
1480+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1481+
return dict(self.trans(data))
1482+
1483+
1484+
class RandTorchIOd(MapTransform, RandomizableTrait):
1485+
"""
1486+
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
1487+
For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
1488+
"""
1489+
1490+
backend = TorchIO.backend
1491+
1492+
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1493+
"""
1494+
Args:
1495+
keys: keys of the corresponding items to be transformed.
1496+
See also: :py:class:`monai.transforms.compose.MapTransform`
1497+
name: The transform name in TorchIO package.
1498+
allow_missing_keys: don't raise exception if key is missing.
1499+
args: parameters for the TorchIO transform.
1500+
kwargs: parameters for the TorchIO transform.
1501+
1502+
"""
1503+
super().__init__(keys, allow_missing_keys)
1504+
self.name = name
1505+
kwargs["include"] = self.keys
1506+
1507+
self.trans = TorchIO(name, *args, **kwargs)
1508+
1509+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1510+
return dict(self.trans(data))
1511+
1512+
14481513
class MapLabelValued(MapTransform):
14491514
"""
14501515
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
@@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
18711936
ConvertToMultiChannelBasedOnBratsClassesd
18721937
)
18731938
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
1939+
TorchIOD = TorchIODict = TorchIOd
18741940
TorchVisionD = TorchVisionDict = TorchVisiond
18751941
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
1942+
RandTorchIOD = RandTorchIODict = RandTorchIOd
18761943
RandLambdaD = RandLambdaDict = RandLambdad
18771944
MapLabelValueD = MapLabelValueDict = MapLabelValued
18781945
IntensityStatsD = IntensityStatsDict = IntensityStatsd

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows"
2424
types-setuptools
2525
mypy>=1.5.0, <1.12.0
2626
ninja
27+
torchio
2728
torchvision
2829
psutil
2930
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"

setup.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ all =
5555
tensorboard
5656
gdown>=4.7.3
5757
pytorch-ignite==0.4.11
58+
torchio
5859
torchvision
5960
itk>=5.2
6061
tqdm>=4.47.0
@@ -102,6 +103,8 @@ gdown =
102103
gdown>=4.7.3
103104
ignite =
104105
pytorch-ignite==0.4.11
106+
torchio =
107+
torchio
105108
torchvision =
106109
torchvision
107110
itk =

0 commit comments

Comments
 (0)