Skip to content

Make to_device support list inputs #640

Open
@KevinMusgrave

Description

@KevinMusgrave

As suggested and implemented in pull request #635 by @domenicoMuscill0.

I'm removing this feature from that PR because it's already a large PR.

Here's the implementation by @domenicoMuscill0:

def to_device(
    x: Union[torch.Tensor, nn.Parameter, List, Tuple],
    tensor=None,
    device=None,
    dtype: Union[torch.dtype, List, Tuple] = None,
):
    dv = device if device is not None else tensor.device
    is_iterable = is_list_or_tuple(x)
    if not is_iterable:
        x = [x]

    xd = x
    if is_list_or_tuple(dtype):
        if len(dtype) == len(x):
            xd = [
                to_dtype(x[i].to(dv), tensor=tensor, dtype=dtype[i])
                for i in range(len(x))
            ]
        else:
            raise RuntimeError(
                f"The size of dtype was {len(dtype)}. It is only available 1 or the same of x"
            )
    elif dtype is not None:
        xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dtype) for xt in x]

    if len(xd) == 1:
        xd = xd[0]
    return xd

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions