Open
Description
As suggested and implemented in pull request #635 by @domenicoMuscill0.
I'm removing this feature from that PR because it's already a large PR.
Here's the implementation by @domenicoMuscill0:
def to_device(
x: Union[torch.Tensor, nn.Parameter, List, Tuple],
tensor=None,
device=None,
dtype: Union[torch.dtype, List, Tuple] = None,
):
dv = device if device is not None else tensor.device
is_iterable = is_list_or_tuple(x)
if not is_iterable:
x = [x]
xd = x
if is_list_or_tuple(dtype):
if len(dtype) == len(x):
xd = [
to_dtype(x[i].to(dv), tensor=tensor, dtype=dtype[i])
for i in range(len(x))
]
else:
raise RuntimeError(
f"The size of dtype was {len(dtype)}. It is only available 1 or the same of x"
)
elif dtype is not None:
xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dtype) for xt in x]
if len(xd) == 1:
xd = xd[0]
return xd