Skip to content

Commit eaeadab

Browse files
committed
Replaces torch.mm with SpMM() in 1D
1 parent c5a08b5 commit eaeadab

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

gcn_distr.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,9 @@ def broad_func(node_count, am_partitions, inputs, rank, size, group):
235235

236236
tstart_comp = start_time(group, rank)
237237

238-
# spmm_gpu(am_partitions[i].indices()[0].int(), am_partitions[i].indices()[1].int(),
239-
# am_partitions[i].values(), am_partitions[i].size(0),
240-
# am_partitions[i].size(1), inputs_recv, z_loc)
241-
z_loc += torch.mm(am_partitions[i], inputs_recv)
238+
spmm_gpu(am_partitions[i].indices()[0].int(), am_partitions[i].indices()[1].int(),
239+
am_partitions[i].values(), am_partitions[i].size(0),
240+
am_partitions[i].size(1), inputs_recv, z_loc)
242241

243242
dur = stop_time(group, rank, tstart_comp)
244243
comp_time[run][rank] += dur

0 commit comments

Comments
 (0)