Skip to content

Commit d52c72d

Browse files
chenchunwx-csy
authored andcommitted
Fix missing device for pphy2mlog tensor
1 parent e1100fe commit d52c72d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

eplb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def inverse(perm: torch.Tensor) -> torch.Tensor:
121121

122122
pphy2mlog = phy2mlog.gather(-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
123123
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) +
124-
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes).view(1, -1, 1)).flatten(-2)
124+
torch.arange(0, num_logical_experts, num_logical_experts // num_nodes,
125+
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
125126
pphy2log = mlog2log.gather(-1, pphy2mlog)
126127
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
127128
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)

0 commit comments

Comments
 (0)