Skip to content

Commit 6bbe010

Browse files
authored
[MPS] deformable conv2d kernel (#9017)
1 parent d5df0d6 commit 6bbe010

File tree

3 files changed

+367
-10
lines changed

3 files changed

+367
-10
lines changed

test/test_ops.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ def test_batched_nms_implementations(self, seed):
929929

930930
class TestDeformConv:
931931
dtype = torch.float64
932+
mps_dtype = torch.float32
932933

933934
def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilation=1):
934935
stride_h, stride_w = _pair(stride)
@@ -1050,12 +1051,11 @@ def test_is_leaf_node(self, device):
10501051
assert len(graph_node_names[0]) == len(graph_node_names[1])
10511052
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
10521053

1053-
@pytest.mark.parametrize("device", cpu_and_cuda())
1054+
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
10541055
@pytest.mark.parametrize("contiguous", (True, False))
10551056
@pytest.mark.parametrize("batch_sz", (0, 33))
1056-
@pytest.mark.opcheck_only_one()
10571057
def test_forward(self, device, contiguous, batch_sz, dtype=None):
1058-
dtype = dtype or self.dtype
1058+
dtype = self.mps_dtype if device == "mps" else dtype or self.dtype
10591059
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
10601060
in_channels = 6
10611061
out_channels = 2
@@ -1201,13 +1201,67 @@ def test_forward_scriptability(self):
12011201
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))
12021202

12031203

1204-
optests.generate_opcheck_tests(
1205-
testcase=TestDeformConv,
1206-
namespaces=["torchvision"],
1207-
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
1208-
additional_decorators=[],
1209-
test_utils=OPTESTS,
1210-
)
1204+
@pytest.mark.parametrize("dtype", (torch.float16, torch.float32, torch.float64))
1205+
@pytest.mark.parametrize("device", cpu_and_cuda())
1206+
@pytest.mark.parametrize("requires_grad", (True, False))
1207+
def test_deform_conv2d_opcheck(dtype, device, requires_grad):
1208+
batch_size, channels_in, height, width = 1, 6, 10, 10
1209+
kernel_size = (3, 3)
1210+
stride = (1, 1)
1211+
padding = (1, 1)
1212+
dilation = (1, 1)
1213+
groups = 2
1214+
out_channels = 4
1215+
out_h = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1
1216+
out_w = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
1217+
x = torch.randn(batch_size, channels_in, height, width, dtype=dtype, device=device, requires_grad=requires_grad)
1218+
offset = torch.randn(
1219+
batch_size,
1220+
2 * kernel_size[0] * kernel_size[1],
1221+
out_h,
1222+
out_w,
1223+
dtype=dtype,
1224+
device=device,
1225+
requires_grad=requires_grad,
1226+
)
1227+
weight = torch.randn(
1228+
out_channels,
1229+
channels_in // groups,
1230+
kernel_size[0],
1231+
kernel_size[1],
1232+
dtype=dtype,
1233+
device=device,
1234+
requires_grad=requires_grad,
1235+
)
1236+
bias = torch.randn(out_channels, dtype=dtype, device=device, requires_grad=requires_grad)
1237+
use_mask = True
1238+
mask = torch.sigmoid(
1239+
torch.randn(
1240+
batch_size,
1241+
kernel_size[0] * kernel_size[1],
1242+
out_h,
1243+
out_w,
1244+
dtype=dtype,
1245+
device=device,
1246+
requires_grad=requires_grad,
1247+
)
1248+
)
1249+
kwargs = {
1250+
"offset": offset,
1251+
"weight": weight,
1252+
"bias": bias,
1253+
"stride_h": stride[0],
1254+
"stride_w": stride[1],
1255+
"pad_h": padding[0],
1256+
"pad_w": padding[1],
1257+
"dilation_h": dilation[0],
1258+
"dilation_w": dilation[1],
1259+
"groups": groups,
1260+
"offset_groups": 1,
1261+
"use_mask": use_mask,
1262+
"mask": mask, # no modulation in this test
1263+
}
1264+
optests.opcheck(torch.ops.torchvision.deform_conv2d, args=(x,), kwargs=kwargs)
12111265

12121266

12131267
class TestFrozenBNT:
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/mps/MPSProfiler.h>
3+
#include <ATen/native/mps/OperationUtils.h>
4+
#include "mps_kernels.h"
5+
6+
namespace vision {
7+
namespace ops {
8+
9+
namespace {
10+
11+
at::Tensor deform_conv2d_forward_kernel(
12+
const at::Tensor& input,
13+
const at::Tensor& weight,
14+
const at::Tensor& offset,
15+
const at::Tensor& mask,
16+
const at::Tensor& bias,
17+
int64_t stride_h,
18+
int64_t stride_w,
19+
int64_t pad_h,
20+
int64_t pad_w,
21+
int64_t dilation_h,
22+
int64_t dilation_w,
23+
int64_t n_weight_grps,
24+
int64_t n_offset_grps,
25+
bool use_mask) {
26+
using namespace at::native::mps;
27+
at::Tensor input_c = input.contiguous();
28+
at::Tensor weight_c = weight.contiguous();
29+
at::Tensor offset_c = offset.contiguous();
30+
at::Tensor mask_c = mask.contiguous();
31+
at::Tensor bias_c = bias.contiguous();
32+
33+
TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D");
34+
TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D");
35+
TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D");
36+
TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true");
37+
TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor");
38+
TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor");
39+
TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor");
40+
TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor");
41+
TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor");
42+
43+
at::DeviceGuard guard(input_c.device());
44+
45+
uint32_t batch = input_c.size(0);
46+
uint32_t in_channels = input_c.size(1);
47+
uint32_t in_h = input_c.size(2);
48+
uint32_t in_w = input_c.size(3);
49+
uint32_t weight_h = weight_c.size(2);
50+
uint32_t weight_w = weight_c.size(3);
51+
uint32_t out_channels = weight_c.size(0);
52+
uint32_t ker_h = dilation_h * (weight_h - 1) + 1;
53+
uint32_t ker_w = dilation_w * (weight_w - 1) + 1;
54+
uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
55+
uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
56+
uint32_t pad_h_u = static_cast<uint32_t>(pad_h);
57+
uint32_t pad_w_u = static_cast<uint32_t>(pad_w);
58+
uint32_t stride_h_u = static_cast<uint32_t>(stride_h);
59+
uint32_t stride_w_u = static_cast<uint32_t>(stride_w);
60+
uint32_t dilation_h_u = static_cast<uint32_t>(dilation_h);
61+
uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w);
62+
63+
TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels,
64+
"Input channels (", in_channels,
65+
") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")");
66+
TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0,
67+
"Weight tensor's out channels (", weight_c.size(0),
68+
") must be divisible by n_weight_grps (", n_weight_grps, ")");
69+
TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w,
70+
"Offset tensor shape[1] is invalid: got ", offset_c.size(1),
71+
", expected ", n_offset_grps * 2 * weight_h * weight_w);
72+
TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w,
73+
"Mask tensor shape[1] is invalid: got ", mask_c.size(1),
74+
", expected ", n_offset_grps * weight_h * weight_w);
75+
TORCH_CHECK(in_channels % n_offset_grps == 0,
76+
"Input tensor channels (", in_channels,
77+
") must be divisible by n_offset_grps (", n_offset_grps, ")");
78+
TORCH_CHECK(offset_c.size(0) == batch,
79+
"Offset tensor batch size (", offset_c.size(0),
80+
") must match input tensor batch size (", batch, ")");
81+
TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w,
82+
"Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3),
83+
") must match calculated output dimensions (", out_h, ", ", out_w, ")");
84+
TORCH_CHECK(!use_mask || mask_c.size(0) == batch,
85+
"Mask tensor batch size (", mask_c.size(0),
86+
") must match input tensor batch size (", batch, ")");
87+
TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w),
88+
"Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3),
89+
") must match calculated output dimensions (", out_h, ", ", out_w, ")");
90+
TORCH_CHECK(out_h > 0 && out_w > 0,
91+
"Calculated output size too small - out_h: ", out_h, " out_w: ", out_w);
92+
93+
auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options());
94+
95+
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_c);
96+
id<MTLBuffer> offsetBuffer = getMTLBufferStorage(offset_c);
97+
id<MTLBuffer> maskBuffer = use_mask ? getMTLBufferStorage(mask_c) : nil;
98+
id<MTLBuffer> outputBuffer = getMTLBufferStorage(columns);
99+
100+
id<MTLDevice> device = MPSDevice::getInstance()->device();
101+
std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type());
102+
id<MTLComputePipelineState> pipelineState = mps::visionPipelineState(device, kernelName);
103+
104+
int num_kernels = in_channels * out_h * out_w * batch;
105+
NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup;
106+
NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup;
107+
MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1);
108+
MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1);
109+
110+
MPSStream* mpsStream = getCurrentMPSStream();
111+
dispatch_sync(mpsStream->queue(), ^{
112+
@autoreleasepool {
113+
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
114+
[computeEncoder setComputePipelineState:pipelineState];
115+
at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer,
116+
std::array<uint32_t, 2>{in_h, in_w},
117+
std::array<uint32_t, 2>{weight_h, weight_w},
118+
std::array<uint32_t, 2>{pad_h_u, pad_w_u},
119+
std::array<uint32_t, 2>{stride_h_u, stride_w_u},
120+
std::array<uint32_t, 2>{dilation_h_u, dilation_w_u},
121+
batch, in_channels, n_offset_grps,
122+
std::array<uint32_t, 2>{out_h, out_w},
123+
use_mask, outputBuffer);
124+
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize];
125+
}
126+
});
127+
int in_channels_per_grp = in_channels / n_weight_grps;
128+
int out_channels_per_grp = out_channels / n_weight_grps;
129+
auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w});
130+
auto columns_grouped = columns.view({n_weight_grps,
131+
(in_channels * weight_h * weight_w) / n_weight_grps,
132+
batch * out_h * out_w});
133+
auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1});
134+
auto out_grouped = at::bmm(weight_reshaped, columns_grouped);
135+
auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w})
136+
.transpose(0, 1);
137+
return out + bias_c.view({1, out_channels, 1, 1});
138+
}
139+
140+
} // namespace
141+
142+
TORCH_LIBRARY_IMPL(torchvision, MPS, m) {
143+
m.impl(
144+
TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"),
145+
TORCH_FN(deform_conv2d_forward_kernel));
146+
}
147+
148+
} // namespace ops
149+
} // namespace vision

0 commit comments

Comments
 (0)