Skip to content

Commit 7273f45

Browse files
committed
Edited to work on Perlmutter
1 parent 24a8c86 commit 7273f45

File tree

5 files changed

+256
-6
lines changed

5 files changed

+256
-6
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ sparse_coo_tensor_cpp.egg-info/
1515
tests/
1616
job*
1717
slurm_outputs/
18+
examples/cagnet_outputs/
19+
*.out
20+
examples/*.nsys-rep
21+
examples/*.ncu-rep

gcn_distr_15d.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def broad_func(node_count, am_partitions, inputs, rank, size, row_groups, col_gr
212212
elif q_c == size // replication - 1:
213213
inputs_recv = torch.cuda.FloatTensor(am_partitions[am_partid].size(1), inputs.size(1), device=device).fill_(0)
214214
# inputs_recv = torch.zeros(list(am_partitions[i].t().size())[1], inputs.size(1))
215-
216215
tstart_comm = start_time(col_groups[rank_col], rank)
217216

218217
inputs_recv = inputs_recv.contiguous()
@@ -752,7 +751,8 @@ def main():
752751
devcount = torch.cuda.device_count()
753752

754753
if graphname == "Cora":
755-
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
754+
# path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', graphname)
755+
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
756756
dataset = Planetoid(path, graphname, transform=T.NormalizeFeatures())
757757
data = dataset[0]
758758
data = data.to(device)
@@ -778,6 +778,7 @@ def main():
778778
elif graphname == 'Amazon':
779779
print(f"Loading coo...", flush=True)
780780
edge_index = torch.load("../data/Amazon/processed/data.pt")
781+
edge_index = edge_index.t_()
781782
print(f"Done loading coo", flush=True)
782783
# edge_index = edge_index.t_()
783784
# n = 9430088
@@ -799,7 +800,9 @@ def main():
799800
data.y = data.y.to(device)
800801
elif graphname == 'subgraph3':
801802
print(f"Loading coo...", flush=True)
802-
edge_index = torch.load("../data/subgraph3/processed/data.pt")
803+
# edge_index = torch.load("../data/subgraph3/processed/data.pt")
804+
edge_index = torch.load("../data/protein/processed/protein.pt")
805+
edge_index = edge_index.t_()
803806
print(f"Done loading coo", flush=True)
804807
n = 8745542
805808
num_features = 128

sparse-extension/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from torch.utils import cpp_extension
33

44
setup(name='sparse_coo_tensor_cpp',
5-
ext_modules=[cpp_extension.CppExtension('sparse_coo_tensor_cpp', ['sparse_coo_tensor.cpp'],
5+
ext_modules=[cpp_extension.CppExtension('sparse_coo_tensor_cpp', ['sparse_coo_tensor.cu'],
66
extra_compile_args=["-lcusparse"])],
77
cmdclass={'build_ext': cpp_extension.BuildExtension})

sparse-extension/sparse_coo_tensor.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
#include <pybind11/pybind11.h>
1313

14-
#include <THC/THCGeneral.hpp>
15-
1614
#include <torch/extension.h>
1715

1816
namespace py = pybind11;

sparse-extension/sparse_coo_tensor.cu

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/cuda/CUDAContext.h>
3+
#include <ATen/Layout.h>
4+
#include <ATen/Parallel.h>
5+
#include <ATen/SparseTensorImpl.h>
6+
#include <ATen/NativeFunctions.h>
7+
#include <ATen/InitialTensorOptions.h>
8+
#include <ATen/SparseTensorUtils.h>
9+
10+
#include "cusparse.h"
11+
12+
#include <pybind11/pybind11.h>
13+
14+
#include <torch/extension.h>
15+
16+
namespace py = pybind11;
17+
18+
using namespace at::sparse;
19+
20+
#define CHECK_CUSPARSE(func) \
21+
{ \
22+
cusparseStatus_t status = (func); \
23+
if (status != CUSPARSE_STATUS_SUCCESS) { \
24+
printf("CUSPARSE API failed at line %d with error: %s (%d)\n", \
25+
__LINE__, cusparseGetErrorString(status), status); \
26+
} \
27+
}
28+
29+
#define CHECK_ERROR(str) \
30+
{cudaDeviceSynchronize(); cudaError_t err; err = cudaGetLastError(); if(err!=0) {printf("ERROR %s: %d %s\n", str, err, cudaGetErrorString(err)); fflush(stdout);}}
31+
32+
33+
at::Tensor expand_values_if_needed(const at::Tensor& values) {
34+
// expand
35+
if (values.dim() == 0) {
36+
// Mimic Numpy behavior here and treat it as a 1D tensor
37+
return values.expand({1});
38+
} else {
39+
return values;
40+
}
41+
}
42+
43+
at::Tensor sparse_coo_tensor_gpu(const at::Tensor& indices,
44+
const at::Tensor& values_,
45+
at::ArrayRef<int64_t> size) {
46+
47+
at::Tensor values = expand_values_if_needed(values_);
48+
49+
int64_t sparse_dim = indices.size(0);
50+
int64_t dense_dim = values.dim() - 1;
51+
52+
return at::_sparse_coo_tensor_with_dims_and_tensors(
53+
sparse_dim, dense_dim, size, indices, values, values.options().layout(at::kSparse));
54+
}
55+
56+
template<typename T>
57+
void printCusparseDnMat(int64_t rows, int64_t cols, int64_t ld, T *values_dev) {
58+
T* values_host = new T[rows*cols];
59+
cudaMemcpy(values_host, values_dev, rows*cols*sizeof(T), cudaMemcpyDeviceToHost);
60+
for (int64_t row = 0; row < rows; row++) {
61+
for (int64_t col = 0; col < cols; col++) {
62+
// Cusparse dense matrices are stored in column-major order
63+
std::cout << values_host[col*rows+row] << " ";
64+
}
65+
std::cout << std::endl;
66+
}
67+
std::cout << " values: ";
68+
for (int64_t i = 0; i < rows*cols; i++) {
69+
std::cout << values_host[i] << " ";
70+
}
71+
std::cout << std::endl;
72+
std::cout << " shape: " << rows << ", " << cols << std::endl;
73+
delete [] values_host;
74+
}
75+
76+
template<typename T>
77+
void printCusparseSpMat(int32_t rows, int32_t cols, int32_t nnz, int32_t *row_indices_dev,
78+
int32_t *col_indices_dev, T *values_dev) {
79+
T* values_host = new T[nnz];
80+
int32_t* row_indices_host = new int32_t[nnz];
81+
int32_t* col_indices_host = new int32_t[nnz];
82+
cudaMemcpy(values_host, values_dev, nnz*sizeof(T), cudaMemcpyDeviceToHost);
83+
cudaMemcpy(row_indices_host, row_indices_dev, nnz*sizeof(int32_t), cudaMemcpyDeviceToHost);
84+
cudaMemcpy(col_indices_host, col_indices_dev, nnz*sizeof(int32_t), cudaMemcpyDeviceToHost);
85+
86+
for (int64_t i = 0; i < nnz; i++) {
87+
std::cout << "(" << row_indices_host[i]
88+
<< ", " << col_indices_host[i]
89+
<< "): " << values_host[i] << std::endl;
90+
}
91+
std::cout << " values: ";
92+
for (int64_t i = 0; i < nnz; i++) {
93+
std::cout << values_host[i] << " ";
94+
}
95+
std::cout << std::endl;
96+
std::cout << " row_indices: ";
97+
for (int64_t i = 0; i < nnz; i++) {
98+
std::cout << row_indices_host[i] << " ";
99+
}
100+
std::cout << std::endl;
101+
std::cout << " col_indices: ";
102+
for (int64_t i = 0; i < nnz; i++) {
103+
std::cout << col_indices_host[i] << " ";
104+
}
105+
std::cout << std::endl;
106+
std::cout << " shape: " << rows << ", " << cols << std::endl;
107+
delete [] values_host;
108+
delete [] row_indices_host;
109+
delete [] col_indices_host;
110+
}
111+
112+
// at::Tensor spmm_gpu(const at::Tensor& A_rowindices,
113+
void spmm_gpu(const at::Tensor& A_rowindices,
114+
const at::Tensor& A_colindices,
115+
const at::Tensor& A_values,
116+
int32_t n,
117+
int32_t m,
118+
at::Tensor& B,
119+
at::Tensor& C) {
120+
121+
// cusparseHandle_t handle;
122+
// CHECK_CUSPARSE(cusparseCreate(&handle));
123+
auto handle = at::cuda::getCurrentCUDASparseHandle();
124+
125+
// Impl1 -- coo2csr + csrmm2
126+
int nnz = A_values.size(0);
127+
128+
clock_t start, stop;
129+
130+
int32_t *d_a_csrrows;
131+
132+
// int devid_old = 0;
133+
// cudaGetDevice(&devid_old);
134+
// cudaSetDevice(devid);
135+
136+
cudaMalloc(&d_a_csrrows, (n + 1) * sizeof(int32_t));
137+
CHECK_CUSPARSE(cusparseXcoo2csr(handle,
138+
A_rowindices.data<int>(),
139+
nnz,
140+
n,
141+
d_a_csrrows,
142+
CUSPARSE_INDEX_BASE_ZERO));
143+
144+
int32_t b_row = B.size(0);
145+
int32_t b_col = B.size(1);
146+
int32_t c_row = C.size(0);
147+
int32_t c_col = C.size(1);
148+
149+
float alpha = 1;
150+
float beta = 1;
151+
cusparseSpMatDescr_t matA;
152+
CHECK_CUSPARSE(cusparseCreateCsr(&matA,
153+
n, // rows
154+
// b_col, // cols
155+
m, // cols
156+
nnz, // nnz
157+
d_a_csrrows, // csrRowOffsets
158+
A_colindices.data<int>(), // csrColInd
159+
A_values.data<float>(), // csrValues
160+
CUSPARSE_INDEX_32I, // csrRowOffsetsType
161+
CUSPARSE_INDEX_32I, // csrColIndType
162+
CUSPARSE_INDEX_BASE_ZERO, // idxBase,
163+
CUDA_R_32F)); // valueType
164+
165+
// Row-major to column-major
166+
B.t_();
167+
B.set_data(B.contiguous());
168+
B.set_data(B.view({b_row, b_col}));
169+
170+
cusparseDnMatDescr_t matB;
171+
CHECK_CUSPARSE(cusparseCreateDnMat(&matB,
172+
// b_col, // rows
173+
b_row, // rows
174+
// b_row, // cols
175+
b_col, // cols
176+
// b_col, // ld
177+
b_row, // ld
178+
B.data<float>(), // values
179+
CUDA_R_32F, // valueType
180+
CUSPARSE_ORDER_COL)); // order
181+
182+
// Row-major to column-major
183+
C.t_();
184+
C.set_data(C.contiguous());
185+
C.set_data(C.view({c_row, c_col}));
186+
187+
cusparseDnMatDescr_t matC;
188+
CHECK_CUSPARSE(cusparseCreateDnMat(&matC,
189+
n, // rows
190+
B.size(1), // cols
191+
// n, // ld
192+
n, // ld
193+
C.data<float>(), // values
194+
CUDA_R_32F, // valueType
195+
CUSPARSE_ORDER_COL)); // order
196+
197+
size_t bufferSize;
198+
CHECK_CUSPARSE(cusparseSpMM_bufferSize(handle, // handle,
199+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opA
200+
// CUSPARSE_OPERATION_TRANSPOSE, // opB
201+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opB
202+
&alpha, // alpha
203+
matA, // matA
204+
matB, // matB
205+
&beta, // beta
206+
matC, // matC
207+
CUDA_R_32F, // computeType
208+
CUSPARSE_CSRMM_ALG1, // alg
209+
&bufferSize)); // bufferSize
210+
211+
212+
void* d_buffer = NULL;
213+
cudaMalloc(&d_buffer, bufferSize);
214+
215+
CHECK_CUSPARSE(cusparseSpMM(handle, // handle,
216+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opA
217+
// CUSPARSE_OPERATION_TRANSPOSE, // opB
218+
CUSPARSE_OPERATION_NON_TRANSPOSE, // opB
219+
&alpha, // alpha
220+
matA, // matA
221+
matB, // matB
222+
&beta, // beta
223+
matC, // matC
224+
CUDA_R_32F, // computeType
225+
CUSPARSE_CSRMM_ALG1, // alg
226+
d_buffer)); // buffer
227+
228+
229+
cudaFree(d_a_csrrows);
230+
cudaFree(d_buffer);
231+
CHECK_ERROR("spmm_gpu error")
232+
233+
// Column-major to row-major
234+
C.set_data(C.view({c_col, c_row}));
235+
C.t_();
236+
237+
// Column-major to row-major
238+
B.set_data(B.view({b_col, b_row}));
239+
B.t_();
240+
}
241+
242+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
243+
m.def("sparse_coo_tensor_gpu", &sparse_coo_tensor_gpu, "Sparse Tensor GPU-only constructor");
244+
m.def("spmm_gpu", &spmm_gpu, "SpMM wrapper for cusparse");
245+
}

0 commit comments

Comments
 (0)