|
| 1 | +#include <ATen/ATen.h> |
| 2 | +#include <ATen/mps/MPSProfiler.h> |
| 3 | +#include <ATen/native/mps/OperationUtils.h> |
| 4 | +#include "mps_kernels.h" |
| 5 | + |
| 6 | +namespace vision { |
| 7 | +namespace ops { |
| 8 | + |
| 9 | +namespace { |
| 10 | + |
| 11 | +at::Tensor deform_conv2d_forward_kernel( |
| 12 | + const at::Tensor& input, |
| 13 | + const at::Tensor& weight, |
| 14 | + const at::Tensor& offset, |
| 15 | + const at::Tensor& mask, |
| 16 | + const at::Tensor& bias, |
| 17 | + int64_t stride_h, |
| 18 | + int64_t stride_w, |
| 19 | + int64_t pad_h, |
| 20 | + int64_t pad_w, |
| 21 | + int64_t dilation_h, |
| 22 | + int64_t dilation_w, |
| 23 | + int64_t n_weight_grps, |
| 24 | + int64_t n_offset_grps, |
| 25 | + bool use_mask) { |
| 26 | + using namespace at::native::mps; |
| 27 | + at::Tensor input_c = input.contiguous(); |
| 28 | + at::Tensor weight_c = weight.contiguous(); |
| 29 | + at::Tensor offset_c = offset.contiguous(); |
| 30 | + at::Tensor mask_c = mask.contiguous(); |
| 31 | + at::Tensor bias_c = bias.contiguous(); |
| 32 | + |
| 33 | + TORCH_CHECK(input_c.ndimension() == 4, "Input tensor must be 4D"); |
| 34 | + TORCH_CHECK(weight_c.ndimension() == 4, "Weight tensor must be 4D"); |
| 35 | + TORCH_CHECK(offset_c.ndimension() == 4, "Offset tensor must be 4D"); |
| 36 | + TORCH_CHECK(!use_mask || mask_c.ndimension() == 4, "Mask tensor must be 4D if use_mask is true"); |
| 37 | + TORCH_CHECK(input_c.is_mps(), "input must be a MPS tensor"); |
| 38 | + TORCH_CHECK(weight.is_mps(), "weight must be a MPS tensor"); |
| 39 | + TORCH_CHECK(offset.is_mps(), "offset must be a MPS tensor"); |
| 40 | + TORCH_CHECK(mask.is_mps(), "mask must be a MPS tensor"); |
| 41 | + TORCH_CHECK(bias.is_mps(), "bias must be a MPS tensor"); |
| 42 | + |
| 43 | + at::DeviceGuard guard(input_c.device()); |
| 44 | + |
| 45 | + uint32_t batch = input_c.size(0); |
| 46 | + uint32_t in_channels = input_c.size(1); |
| 47 | + uint32_t in_h = input_c.size(2); |
| 48 | + uint32_t in_w = input_c.size(3); |
| 49 | + uint32_t weight_h = weight_c.size(2); |
| 50 | + uint32_t weight_w = weight_c.size(3); |
| 51 | + uint32_t out_channels = weight_c.size(0); |
| 52 | + uint32_t ker_h = dilation_h * (weight_h - 1) + 1; |
| 53 | + uint32_t ker_w = dilation_w * (weight_w - 1) + 1; |
| 54 | + uint32_t out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1; |
| 55 | + uint32_t out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1; |
| 56 | + uint32_t pad_h_u = static_cast<uint32_t>(pad_h); |
| 57 | + uint32_t pad_w_u = static_cast<uint32_t>(pad_w); |
| 58 | + uint32_t stride_h_u = static_cast<uint32_t>(stride_h); |
| 59 | + uint32_t stride_w_u = static_cast<uint32_t>(stride_w); |
| 60 | + uint32_t dilation_h_u = static_cast<uint32_t>(dilation_h); |
| 61 | + uint32_t dilation_w_u = static_cast<uint32_t>(dilation_w); |
| 62 | + |
| 63 | + TORCH_CHECK(weight_c.size(1) * n_weight_grps == in_channels, |
| 64 | + "Input channels (", in_channels, |
| 65 | + ") must equal weight.size(1) * n_weight_grps (", weight_c.size(1), " * ", n_weight_grps, ")"); |
| 66 | + TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0, |
| 67 | + "Weight tensor's out channels (", weight_c.size(0), |
| 68 | + ") must be divisible by n_weight_grps (", n_weight_grps, ")"); |
| 69 | + TORCH_CHECK(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w, |
| 70 | + "Offset tensor shape[1] is invalid: got ", offset_c.size(1), |
| 71 | + ", expected ", n_offset_grps * 2 * weight_h * weight_w); |
| 72 | + TORCH_CHECK(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w, |
| 73 | + "Mask tensor shape[1] is invalid: got ", mask_c.size(1), |
| 74 | + ", expected ", n_offset_grps * weight_h * weight_w); |
| 75 | + TORCH_CHECK(in_channels % n_offset_grps == 0, |
| 76 | + "Input tensor channels (", in_channels, |
| 77 | + ") must be divisible by n_offset_grps (", n_offset_grps, ")"); |
| 78 | + TORCH_CHECK(offset_c.size(0) == batch, |
| 79 | + "Offset tensor batch size (", offset_c.size(0), |
| 80 | + ") must match input tensor batch size (", batch, ")"); |
| 81 | + TORCH_CHECK(offset_c.size(2) == out_h && offset_c.size(3) == out_w, |
| 82 | + "Offset tensor spatial dimensions (", offset_c.size(2), ", ", offset_c.size(3), |
| 83 | + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); |
| 84 | + TORCH_CHECK(!use_mask || mask_c.size(0) == batch, |
| 85 | + "Mask tensor batch size (", mask_c.size(0), |
| 86 | + ") must match input tensor batch size (", batch, ")"); |
| 87 | + TORCH_CHECK(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w), |
| 88 | + "Mask tensor spatial dimensions (", mask_c.size(2), ", ", mask_c.size(3), |
| 89 | + ") must match calculated output dimensions (", out_h, ", ", out_w, ")"); |
| 90 | + TORCH_CHECK(out_h > 0 && out_w > 0, |
| 91 | + "Calculated output size too small - out_h: ", out_h, " out_w: ", out_w); |
| 92 | + |
| 93 | + auto columns = at::empty({in_channels * weight_h * weight_w, batch * out_h * out_w}, input_c.options()); |
| 94 | + |
| 95 | + id<MTLBuffer> inputBuffer = getMTLBufferStorage(input_c); |
| 96 | + id<MTLBuffer> offsetBuffer = getMTLBufferStorage(offset_c); |
| 97 | + id<MTLBuffer> maskBuffer = use_mask ? getMTLBufferStorage(mask_c) : nil; |
| 98 | + id<MTLBuffer> outputBuffer = getMTLBufferStorage(columns); |
| 99 | + |
| 100 | + id<MTLDevice> device = MPSDevice::getInstance()->device(); |
| 101 | + std::string kernelName = "deformable_im2col_" + scalarToMetalTypeString(input.scalar_type()); |
| 102 | + id<MTLComputePipelineState> pipelineState = mps::visionPipelineState(device, kernelName); |
| 103 | + |
| 104 | + int num_kernels = in_channels * out_h * out_w * batch; |
| 105 | + NSUInteger threadsPerThreadgroup = pipelineState.maxTotalThreadsPerThreadgroup; |
| 106 | + NSUInteger threadgroups = (num_kernels + threadsPerThreadgroup - 1) / threadsPerThreadgroup; |
| 107 | + MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); |
| 108 | + MTLSize threadgroupsPerGrid = MTLSizeMake(threadgroups, 1, 1); |
| 109 | + |
| 110 | + MPSStream* mpsStream = getCurrentMPSStream(); |
| 111 | + dispatch_sync(mpsStream->queue(), ^{ |
| 112 | + @autoreleasepool { |
| 113 | + id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder(); |
| 114 | + [computeEncoder setComputePipelineState:pipelineState]; |
| 115 | + at::native::mps::mtl_setArgs(computeEncoder, inputBuffer, offsetBuffer, maskBuffer, |
| 116 | + std::array<uint32_t, 2>{in_h, in_w}, |
| 117 | + std::array<uint32_t, 2>{weight_h, weight_w}, |
| 118 | + std::array<uint32_t, 2>{pad_h_u, pad_w_u}, |
| 119 | + std::array<uint32_t, 2>{stride_h_u, stride_w_u}, |
| 120 | + std::array<uint32_t, 2>{dilation_h_u, dilation_w_u}, |
| 121 | + batch, in_channels, n_offset_grps, |
| 122 | + std::array<uint32_t, 2>{out_h, out_w}, |
| 123 | + use_mask, outputBuffer); |
| 124 | + [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadGroupSize]; |
| 125 | + } |
| 126 | + }); |
| 127 | + int in_channels_per_grp = in_channels / n_weight_grps; |
| 128 | + int out_channels_per_grp = out_channels / n_weight_grps; |
| 129 | + auto weight_grouped = weight_c.view({n_weight_grps, out_channels_per_grp, in_channels_per_grp, weight_h, weight_w}); |
| 130 | + auto columns_grouped = columns.view({n_weight_grps, |
| 131 | + (in_channels * weight_h * weight_w) / n_weight_grps, |
| 132 | + batch * out_h * out_w}); |
| 133 | + auto weight_reshaped = weight_grouped.reshape({n_weight_grps, out_channels_per_grp, -1}); |
| 134 | + auto out_grouped = at::bmm(weight_reshaped, columns_grouped); |
| 135 | + auto out = out_grouped.reshape({n_weight_grps * out_channels_per_grp, batch, out_h, out_w}) |
| 136 | + .transpose(0, 1); |
| 137 | + return out + bias_c.view({1, out_channels, 1, 1}); |
| 138 | +} |
| 139 | + |
| 140 | +} // namespace |
| 141 | + |
| 142 | +TORCH_LIBRARY_IMPL(torchvision, MPS, m) { |
| 143 | + m.impl( |
| 144 | + TORCH_SELECTIVE_NAME("torchvision::deform_conv2d"), |
| 145 | + TORCH_FN(deform_conv2d_forward_kernel)); |
| 146 | +} |
| 147 | + |
| 148 | +} // namespace ops |
| 149 | +} // namespace vision |
0 commit comments