Skip to content

Commit fc6ed65

Browse files
committed
Fixed cusparse SpMM() call bug
1 parent eaeadab commit fc6ed65

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ sparse_coo_tensor_cpp.egg-info/
1111
.nfs*
1212
*.txt
1313
*.pt
14+
*.qdrep
15+
tests/
16+
job*
17+
slurm_outputs/

gcn_distr.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def backward(ctx, grad_output):
321321
sigmap = torch.autograd.grad(outputs=func_eval, inputs=z, grad_outputs=grad_output)[0]
322322
grad_output = sigmap
323323

324+
324325
# First backprop equation
325326
ag = broad_func(adj_matrix.size(0), am_partitions, grad_output, rank, size, group)
326327

@@ -342,8 +343,11 @@ def train(inputs, weight1, weight2, adj_matrix, am_partitions, optimizer, data,
342343
outputs = GCNFunc.apply(outputs, weight2, adj_matrix, am_partitions, rank, size, group, F.log_softmax)
343344

344345
optimizer.zero_grad()
345-
rank_train_mask = torch.split(data.train_mask.bool(), outputs.size(0), dim=0)[rank]
346-
datay_rank = torch.split(data.y, outputs.size(0), dim=0)[rank]
346+
347+
node_count = adj_matrix.size(0)
348+
n_per_proc = int(math.ceil(float(node_count) / size))
349+
rank_train_mask = torch.split(data.train_mask.bool(), n_per_proc, dim=0)[rank]
350+
datay_rank = torch.split(data.y, n_per_proc, dim=0)[rank]
347351

348352
# Note: bool type removes warnings, unsure of perf penalty
349353
# loss = F.nll_loss(outputs[data.train_mask.bool()], data.y[data.train_mask.bool()])

gcn_distr_15d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ def train(inputs, weight1, weight2, adj_matrix, am_partitions, optimizer, data,
353353
optimizer.zero_grad()
354354

355355
rank_c = rank // replication
356-
357-
rank_train_mask = torch.split(data.train_mask.bool(), outputs.size(0), dim=0)[rank_c]
358-
datay_rank = torch.split(data.y, outputs.size(0), dim=0)[rank_c]
356+
n_per_proc = int(math.ceil(float(node_count) / (size / replication)))
357+
rank_train_mask = torch.split(data.train_mask.bool(), n_per_proc, dim=0)[rank_c]
358+
datay_rank = torch.split(data.y, n_per_proc, dim=0)[rank_c]
359359

360360
# Note: bool type removes warnings, unsure of perf penalty
361361
# loss = F.nll_loss(outputs[data.train_mask.bool()], data.y[data.train_mask.bool()])

sparse-extension/sparse_coo_tensor.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ void spmm_gpu(const at::Tensor& A_rowindices,
155155
cusparseSpMatDescr_t matA;
156156
CHECK_CUSPARSE(cusparseCreateCsr(&matA,
157157
n, // rows
158-
b_col, // cols
158+
// b_col, // cols
159+
m, // cols
159160
nnz, // nnz
160161
d_a_csrrows, // csrRowOffsets
161162
A_colindices.data<int>(), // csrColInd
@@ -165,11 +166,19 @@ void spmm_gpu(const at::Tensor& A_rowindices,
165166
CUSPARSE_INDEX_BASE_ZERO, // idxBase,
166167
CUDA_R_32F)); // valueType
167168

169+
// Row-major to column-major
170+
B.t_();
171+
B.set_data(B.contiguous());
172+
B.set_data(B.view({b_row, b_col}));
173+
168174
cusparseDnMatDescr_t matB;
169175
CHECK_CUSPARSE(cusparseCreateDnMat(&matB,
170-
B.size(1), // rows
176+
// b_col, // rows
177+
b_row, // rows
178+
// b_row, // cols
171179
b_col, // cols
172-
B.size(1), // ld
180+
// b_col, // ld
181+
b_row, // ld
173182
B.data<float>(), // values
174183
CUDA_R_32F, // valueType
175184
CUSPARSE_ORDER_COL)); // order
@@ -183,6 +192,7 @@ void spmm_gpu(const at::Tensor& A_rowindices,
183192
CHECK_CUSPARSE(cusparseCreateDnMat(&matC,
184193
n, // rows
185194
B.size(1), // cols
195+
// n, // ld
186196
n, // ld
187197
C.data<float>(), // values
188198
CUDA_R_32F, // valueType
@@ -191,7 +201,8 @@ void spmm_gpu(const at::Tensor& A_rowindices,
191201
size_t bufferSize;
192202
CHECK_CUSPARSE(cusparseSpMM_bufferSize(handle, // handle,
193203
CUSPARSE_OPERATION_NON_TRANSPOSE, // opA
194-
CUSPARSE_OPERATION_TRANSPOSE, // opB
204+
// CUSPARSE_OPERATION_TRANSPOSE, // opB
205+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opB
195206
&alpha, // alpha
196207
matA, // matA
197208
matB, // matB
@@ -207,7 +218,8 @@ void spmm_gpu(const at::Tensor& A_rowindices,
207218

208219
CHECK_CUSPARSE(cusparseSpMM(handle, // handle,
209220
CUSPARSE_OPERATION_NON_TRANSPOSE, // opA
210-
CUSPARSE_OPERATION_TRANSPOSE, // opB
221+
// CUSPARSE_OPERATION_TRANSPOSE, // opB
222+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opB
211223
&alpha, // alpha
212224
matA, // matA
213225
matB, // matB
@@ -218,12 +230,16 @@ void spmm_gpu(const at::Tensor& A_rowindices,
218230
d_buffer)); // buffer
219231

220232

221-
cudaFree(d_a_csrrows);
222-
cudaFree(d_buffer);
233+
CHECK_ERROR(cudaFree(d_a_csrrows));
234+
CHECK_ERROR(cudaFree(d_buffer));
223235

224236
// Column-major to row-major
225237
C.set_data(C.view({c_col, c_row}));
226238
C.t_();
239+
240+
// Column-major to row-major
241+
B.set_data(B.view({b_col, b_row}));
242+
B.t_();
227243
}
228244

229245
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

0 commit comments

Comments
 (0)