Skip to content

Commit f358d8f

Browse files
committed
add pointops lib for eval; add example slurm train script
1 parent 1b9c78a commit f358d8f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2090
-0
lines changed

libs/pointops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .functions import *

libs/pointops/functions/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .aggregation import aggregation
2+
from .attention import attention_fusion_step, attention_relation_step
3+
from .grouping import grouping, grouping2
4+
from .interpolation import interpolation, interpolation2
5+
from .query import ball_query, knn_query, random_ball_query
6+
from .sampling import farthest_point_sampling
7+
from .subtraction import subtraction
8+
from .utils import (
9+
ball_query_and_group,
10+
batch2offset,
11+
knn_query_and_group,
12+
offset2batch,
13+
query_and_group,
14+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
from pointops._C import aggregation_backward_cuda, aggregation_forward_cuda
3+
from torch.autograd import Function
4+
5+
6+
class Aggregation(Function):
7+
@staticmethod
8+
def forward(ctx, input, position, weight, idx):
9+
"""
10+
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
11+
output: (n, c)
12+
"""
13+
assert (
14+
input.is_contiguous()
15+
and position.is_contiguous()
16+
and weight.is_contiguous()
17+
)
18+
n, nsample, c = position.shape
19+
w_c = weight.shape[-1]
20+
output = torch.cuda.FloatTensor(n, c).zero_()
21+
aggregation_forward_cuda(
22+
n, nsample, c, w_c, input, position, weight, idx, output
23+
)
24+
ctx.save_for_backward(input, position, weight, idx)
25+
return output
26+
27+
@staticmethod
28+
def backward(ctx, grad_output):
29+
"""
30+
input: grad_out: (n, c)
31+
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
32+
"""
33+
input, position, weight, idx = ctx.saved_tensors
34+
n, nsample, c = position.shape
35+
w_c = weight.shape[-1]
36+
grad_input = torch.cuda.FloatTensor(n, c).zero_()
37+
grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_()
38+
grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_()
39+
aggregation_backward_cuda(
40+
n,
41+
nsample,
42+
c,
43+
w_c,
44+
input,
45+
position,
46+
weight,
47+
idx,
48+
grad_output,
49+
grad_input,
50+
grad_position,
51+
grad_weight,
52+
)
53+
return grad_input, grad_position, grad_weight, None
54+
55+
56+
aggregation = Aggregation.apply

libs/pointops/functions/attention.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
from pointops._C import (
3+
attention_fusion_step_backward_cuda,
4+
attention_fusion_step_forward_cuda,
5+
attention_relation_step_backward_cuda,
6+
attention_relation_step_forward_cuda,
7+
)
8+
from torch.autograd import Function
9+
10+
11+
class AttentionRelationStep(Function):
12+
@staticmethod
13+
def forward(ctx, query, key, weight, index_target, index_refer):
14+
"""
15+
input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention,
16+
index_target: (m), index_refer: (m)
17+
output - relation: (M, g)
18+
"""
19+
20+
assert (
21+
query.is_contiguous()
22+
and key.is_contiguous()
23+
and index_target.is_contiguous()
24+
and index_refer.is_contiguous()
25+
and weight.is_contiguous()
26+
)
27+
28+
assert index_target.shape[0] == index_refer.shape[0]
29+
30+
_, g, c = query.shape
31+
m = index_target.shape[0]
32+
output = torch.cuda.FloatTensor(m, g).zero_()
33+
attention_relation_step_forward_cuda(
34+
m, g, c, query, key, weight, index_target.int(), index_refer.int(), output
35+
)
36+
ctx.save_for_backward(query, key, weight, index_target, index_refer)
37+
return output
38+
39+
@staticmethod
40+
def backward(ctx, grad_output):
41+
query, key, weight, index_target, index_refer = ctx.saved_tensors
42+
n, g, c = query.shape
43+
m = index_target.shape[0]
44+
grad_query = torch.cuda.FloatTensor(n, g, c).zero_()
45+
grad_key = torch.cuda.FloatTensor(n, g, c).zero_()
46+
grad_weight = torch.cuda.FloatTensor(c).zero_()
47+
attention_relation_step_backward_cuda(
48+
m,
49+
g,
50+
c,
51+
query,
52+
grad_query,
53+
key,
54+
grad_key,
55+
weight,
56+
grad_weight,
57+
index_target.int(),
58+
index_refer.int(),
59+
grad_output,
60+
)
61+
return grad_query, grad_key, None, None, None
62+
63+
64+
class AttentionFusionStep(Function):
65+
@staticmethod
66+
def forward(ctx, weight, value, index_target, index_refer):
67+
"""
68+
input - weight: (m, g), value: (n, g, c)
69+
index_target: (m), index_value: (m)
70+
output - output: (n, g, c)
71+
"""
72+
73+
assert (
74+
weight.is_contiguous()
75+
and value.is_contiguous()
76+
and index_target.is_contiguous()
77+
and index_refer.is_contiguous()
78+
and weight.is_contiguous()
79+
)
80+
81+
assert index_target.shape[0] == index_refer.shape[0]
82+
83+
n, g, c = value.shape
84+
m = index_refer.shape[0]
85+
output = torch.cuda.FloatTensor(n, g, c).zero_()
86+
attention_fusion_step_forward_cuda(
87+
m, g, c, weight, value, index_target.int(), index_refer.int(), output
88+
)
89+
ctx.save_for_backward(weight, value, index_target, index_refer)
90+
return output
91+
92+
@staticmethod
93+
def backward(ctx, grad_output):
94+
"""
95+
input: grad_output: (n, g, c)
96+
output: grad_weight: (m, g), grad_value: (n, g, c), none, none
97+
"""
98+
weight, value, index_target, index_refer = ctx.saved_tensors
99+
n, g, c = value.shape
100+
m = index_target.shape[0]
101+
grad_weight = torch.cuda.FloatTensor(m, g).zero_()
102+
grad_value = torch.cuda.FloatTensor(n, g, c).zero_()
103+
attention_fusion_step_backward_cuda(
104+
m,
105+
g,
106+
c,
107+
weight,
108+
grad_weight,
109+
value,
110+
grad_value,
111+
index_target.int(),
112+
index_refer.int(),
113+
grad_output,
114+
)
115+
return grad_weight, grad_value, None, None
116+
117+
118+
attention_relation_step = AttentionRelationStep.apply
119+
attention_fusion_step = AttentionFusionStep.apply

libs/pointops/functions/grouping.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
from pointops._C import grouping_backward_cuda, grouping_forward_cuda
3+
from torch.autograd import Function
4+
5+
6+
class Grouping(Function):
7+
@staticmethod
8+
def forward(ctx, input, idx):
9+
"""
10+
input: input: (n, c), idx : (m, nsample)
11+
output: (m, nsample, c)
12+
"""
13+
assert input.is_contiguous() and idx.is_contiguous()
14+
m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1]
15+
output = torch.cuda.FloatTensor(m, nsample, c)
16+
grouping_forward_cuda(m, nsample, c, input, idx, output)
17+
ctx.n = n
18+
ctx.save_for_backward(idx)
19+
return output
20+
21+
@staticmethod
22+
def backward(ctx, grad_output):
23+
"""
24+
input: grad_out: (m, c, nsample)
25+
output: (n, c), None
26+
"""
27+
n = ctx.n
28+
(idx,) = ctx.saved_tensors
29+
m, nsample, c = grad_output.shape
30+
grad_input = torch.cuda.FloatTensor(n, c).zero_()
31+
grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input)
32+
return grad_input, None
33+
34+
35+
def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False):
36+
if new_xyz is None:
37+
new_xyz = xyz
38+
assert xyz.is_contiguous() and feat.is_contiguous()
39+
m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1]
40+
xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0)
41+
feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0)
42+
grouped_feat = feat[idx.view(-1).long(), :].view(
43+
m, nsample, c
44+
) # (m, num_sample, c)
45+
46+
if with_xyz:
47+
assert new_xyz.is_contiguous()
48+
mask = torch.sign(idx + 1)
49+
grouped_xyz = xyz[idx.view(-1).long(), :].view(
50+
m, nsample, 3
51+
) - new_xyz.unsqueeze(
52+
1
53+
) # (m, num_sample, 3)
54+
grouped_xyz = torch.einsum(
55+
"n s c, n s -> n s c", grouped_xyz, mask
56+
) # (m, num_sample, 3)
57+
return torch.cat((grouped_xyz, grouped_feat), -1)
58+
else:
59+
return grouped_feat
60+
61+
62+
grouping2 = Grouping.apply
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
from pointops._C import interpolation_backward_cuda, interpolation_forward_cuda
3+
from torch.autograd import Function
4+
5+
from .query import knn_query
6+
7+
8+
def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3):
9+
"""
10+
input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b)
11+
output: (n, c)
12+
"""
13+
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous()
14+
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, 3), (n, 3)
15+
dist_recip = 1.0 / (dist + 1e-8) # (n, 3)
16+
norm = torch.sum(dist_recip, dim=1, keepdim=True)
17+
weight = dist_recip / norm # (n, 3)
18+
19+
new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_()
20+
for i in range(k):
21+
new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1)
22+
return new_feat
23+
24+
25+
class Interpolation(Function):
26+
@staticmethod
27+
def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3):
28+
"""
29+
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b)
30+
output: (n, c)
31+
"""
32+
assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous()
33+
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, k), (n, k)
34+
dist_recip = 1.0 / (dist + 1e-8) # (n, k)
35+
norm = torch.sum(dist_recip, dim=1, keepdim=True)
36+
weight = dist_recip / norm # (n, k)
37+
38+
n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0]
39+
output = torch.cuda.FloatTensor(n, c).zero_()
40+
interpolation_forward_cuda(n, c, k, input, idx, weight, output)
41+
ctx.m, ctx.k = m, k
42+
ctx.save_for_backward(idx, weight)
43+
return output
44+
45+
@staticmethod
46+
def backward(ctx, grad_output):
47+
"""
48+
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b)
49+
output: (n, c)
50+
"""
51+
m, k = ctx.m, ctx.k
52+
idx, weight = ctx.saved_tensors
53+
n, c = grad_output.shape
54+
grad_input = torch.cuda.FloatTensor(m, c).zero_()
55+
interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input)
56+
return None, None, grad_input, None, None, None
57+
58+
59+
interpolation2 = Interpolation.apply

0 commit comments

Comments
 (0)