Skip to content

Fix a bug related to XLA subgroup creation and usage #2948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM", group: Optional[
if group is not None and not self._check_group_type(group):
raise ValueError("Argument group should be list of int")
op = self._reduce_op_map[op]
xm.all_reduce(op, [tensor], groups=group)
xm.all_reduce(op, [tensor], groups=[group])
return tensor

def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> torch.Tensor:
Expand All @@ -152,11 +152,11 @@ def _do_all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None) -> t
group_size = self.get_world_size()
output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
output[self.get_rank() % group_size] = tensor
xm.all_reduce("sum", [output], groups=group)
xm.all_reduce("sum", [output], groups=[group])
return output.reshape(-1, *output.shape[2:])

def _do_new_group(self, ranks: List[int], **kwargs: Any) -> Any:
return [ranks]
return ranks

def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
# from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
Expand Down
14 changes: 7 additions & 7 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,34 @@ def _test_distrib_all_reduce_group(device):


def _test_distrib_all_gather(device):
rank = idist.get_rank()

res = torch.tensor(idist.all_gather(10), device=device)
true_res = torch.tensor([10] * idist.get_world_size(), device=device)
assert (res == true_res).all()

t = torch.tensor(idist.get_rank(), device=device)
t = torch.tensor(rank, device=device)
res = idist.all_gather(t)
true_res = torch.tensor([i for i in range(idist.get_world_size())], device=device)
assert (res == true_res).all()

x = "test-test"
if idist.get_rank() == 0:
if rank == 0:
x = "abc"
res = idist.all_gather(x)
true_res = ["abc"] + ["test-test"] * (idist.get_world_size() - 1)
assert res == true_res

base_x = "tests/ignite/distributed/utils/test_native.py" * 2000
x = base_x
if idist.get_rank() == 0:
if rank == 0:
x = "abc"

res = idist.all_gather(x)
true_res = ["abc"] + [base_x] * (idist.get_world_size() - 1)
assert res == true_res

t = torch.arange(100, device=device).reshape(4, 25) * (idist.get_rank() + 1)
t = torch.arange(100, device=device).reshape(4, 25) * (rank + 1)
in_dtype = t.dtype
res = idist.all_gather(t)
assert res.shape == (idist.get_world_size() * 4, 25)
Expand Down Expand Up @@ -218,8 +220,6 @@ def _test_distrib_all_gather_group(device):
res = idist.all_gather(t, group=ranks)
assert torch.equal(res, torch.tensor(ranks, device=device))

ranks = "abc"

if bnd in ("nccl", "gloo", "mpi"):
with pytest.raises(ValueError, match=r"Argument group should be list of int or ProcessGroup"):
res = idist.all_gather(t, group="abc")
Expand Down Expand Up @@ -307,7 +307,7 @@ def _test_distrib_new_group(device):
if rank in ranks:
assert g1.rank() == g2.rank()
elif idist.has_xla_support and bnd in ("xla-tpu"):
assert idist.new_group(ranks) == [ranks]
assert idist.new_group(ranks) == ranks
elif idist.has_hvd_support and bnd in ("horovod"):
from horovod.common.process_sets import ProcessSet

Expand Down