|
18 | 18 | import sys
|
19 | 19 | import time
|
20 | 20 | import warnings
|
21 |
| -from collections.abc import Mapping, Sequence |
| 21 | +from collections.abc import Hashable, Mapping, Sequence |
22 | 22 | from copy import deepcopy
|
23 | 23 | from functools import partial
|
24 |
| -from typing import Any, Callable |
| 24 | +from typing import Any, Callable, Union |
25 | 25 |
|
26 | 26 | import numpy as np
|
27 | 27 | import torch
|
|
99 | 99 | "ConvertToMultiChannelBasedOnBratsClasses",
|
100 | 100 | "AddExtremePointsChannel",
|
101 | 101 | "TorchVision",
|
| 102 | + "TorchIO", |
102 | 103 | "MapLabelValue",
|
103 | 104 | "IntensityStats",
|
104 | 105 | "ToDevice",
|
105 | 106 | "CuCIM",
|
106 | 107 | "RandCuCIM",
|
| 108 | + "RandTorchIO", |
| 109 | + "RandTorchVision", |
107 | 110 | "ToCupy",
|
108 | 111 | "ImageFilter",
|
109 | 112 | "RandImageFilter",
|
@@ -1136,12 +1139,44 @@ def __call__(
|
1136 | 1139 | return concatenate((img, points_image), axis=0)
|
1137 | 1140 |
|
1138 | 1141 |
|
1139 |
| -class TorchVision: |
| 1142 | +class TorchVision(Transform): |
1140 | 1143 | """
|
1141 |
| - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. |
1142 |
| - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input |
1143 |
| - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. |
| 1144 | + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. |
| 1145 | + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. |
| 1146 | + """ |
| 1147 | + |
| 1148 | + backend = [TransformBackends.TORCH] |
| 1149 | + |
| 1150 | + def __init__(self, name: str, *args, **kwargs) -> None: |
| 1151 | + """ |
| 1152 | + Args: |
| 1153 | + name: The transform name in TorchVision package. |
| 1154 | + args: parameters for the TorchVision transform. |
| 1155 | + kwargs: parameters for the TorchVision transform. |
| 1156 | +
|
| 1157 | + """ |
| 1158 | + super().__init__() |
| 1159 | + self.name = name |
| 1160 | + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) |
| 1161 | + self.trans = transform(*args, **kwargs) |
| 1162 | + |
| 1163 | + def __call__(self, img: NdarrayOrTensor): |
| 1164 | + """ |
| 1165 | + Args: |
| 1166 | + img: PyTorch Tensor data for the TorchVision transform. |
1144 | 1167 |
|
| 1168 | + """ |
| 1169 | + img_t, *_ = convert_data_type(img, torch.Tensor) |
| 1170 | + |
| 1171 | + out = self.trans(img_t) |
| 1172 | + out, *_ = convert_to_dst_type(src=out, dst=img) |
| 1173 | + return out |
| 1174 | + |
| 1175 | + |
| 1176 | +class RandTorchVision(Transform, RandomizableTrait): |
| 1177 | + """ |
| 1178 | + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. |
| 1179 | + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. |
1145 | 1180 | """
|
1146 | 1181 |
|
1147 | 1182 | backend = [TransformBackends.TORCH]
|
@@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor):
|
1172 | 1207 | return out
|
1173 | 1208 |
|
1174 | 1209 |
|
| 1210 | +class TorchIO(Transform): |
| 1211 | + """ |
| 1212 | + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. |
| 1213 | + See https://torchio.readthedocs.io/transforms/transforms.html for more details. |
| 1214 | + """ |
| 1215 | + |
| 1216 | + backend = [TransformBackends.TORCH] |
| 1217 | + |
| 1218 | + def __init__(self, name: str, *args, **kwargs) -> None: |
| 1219 | + """ |
| 1220 | + Args: |
| 1221 | + name: The transform name in TorchIO package. |
| 1222 | + args: parameters for the TorchIO transform. |
| 1223 | + kwargs: parameters for the TorchIO transform. |
| 1224 | + """ |
| 1225 | + super().__init__() |
| 1226 | + self.name = name |
| 1227 | + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) |
| 1228 | + self.trans = transform(*args, **kwargs) |
| 1229 | + |
| 1230 | + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): |
| 1231 | + """ |
| 1232 | + Args: |
| 1233 | + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, |
| 1234 | + or dict containing 4D tensors as values |
| 1235 | +
|
| 1236 | + """ |
| 1237 | + return self.trans(img) |
| 1238 | + |
| 1239 | + |
| 1240 | +class RandTorchIO(Transform, RandomizableTrait): |
| 1241 | + """ |
| 1242 | + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. |
| 1243 | + See https://torchio.readthedocs.io/transforms/transforms.html for more details. |
| 1244 | + Use this wrapper for all TorchIO transform inheriting from RandomTransform: |
| 1245 | + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform |
| 1246 | + """ |
| 1247 | + |
| 1248 | + backend = [TransformBackends.TORCH] |
| 1249 | + |
| 1250 | + def __init__(self, name: str, *args, **kwargs) -> None: |
| 1251 | + """ |
| 1252 | + Args: |
| 1253 | + name: The transform name in TorchIO package. |
| 1254 | + args: parameters for the TorchIO transform. |
| 1255 | + kwargs: parameters for the TorchIO transform. |
| 1256 | + """ |
| 1257 | + super().__init__() |
| 1258 | + self.name = name |
| 1259 | + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) |
| 1260 | + self.trans = transform(*args, **kwargs) |
| 1261 | + |
| 1262 | + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): |
| 1263 | + """ |
| 1264 | + Args: |
| 1265 | + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, |
| 1266 | + or dict containing 4D tensors as values |
| 1267 | +
|
| 1268 | + """ |
| 1269 | + return self.trans(img) |
| 1270 | + |
| 1271 | + |
1175 | 1272 | class MapLabelValue:
|
1176 | 1273 | """
|
1177 | 1274 | Utility to map label values to another set of values.
|
|
0 commit comments