Skip to content

Commit be92fb4

Browse files
authoredApr 1, 2025··
Add im2row per tensor overload
Differential Revision: D70938715 Pull Request resolved: #9121
1 parent 2972388 commit be92fb4

File tree

6 files changed

+293
-18
lines changed

6 files changed

+293
-18
lines changed
 

‎backends/cadence/aot/functions.yaml

+10
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@
234234
- arg_meta: null
235235
kernel_name: impl::reference::im2row_out
236236

237+
- func: cadence::im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
238+
kernels:
239+
- arg_meta: null
240+
kernel_name: impl::reference::im2row_per_tensor_out
241+
237242
- func: cadence::quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
238243
kernels:
239244
- arg_meta: null
@@ -253,3 +258,8 @@
253258
kernels:
254259
- arg_meta: null
255260
kernel_name: impl::reference::requantize_out
261+
262+
- func: cadence::requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
263+
kernels:
264+
- arg_meta: null
265+
kernel_name: impl::reference::requantize_per_tensor_out

‎backends/cadence/aot/ops_registrations.py

+55-18
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from typing import Optional, Tuple
1111

1212
import torch
13+
from executorch.backends.cadence.aot.utils import (
14+
get_conv1d_output_size,
15+
get_conv2d_output_size,
16+
get_im2row_output_size,
17+
)
1318
from executorch.exir.scalar_type import ScalarType
1419
from torch.library import Library, register_fake
1520

16-
from .utils import get_conv1d_output_size, get_conv2d_output_size
17-
1821
lib = Library("cadence", "DEF")
1922

2023
lib.define(
@@ -131,6 +134,10 @@
131134
"im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
132135
"Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
133136
)
137+
lib.define(
138+
"im2row.per_tensor(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
139+
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
140+
)
134141
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
135142
lib.define(
136143
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
@@ -140,6 +147,10 @@
140147
"requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
141148
"Tensor out_zero_point, ScalarType out_dtype) -> (Tensor Y)"
142149
)
150+
lib.define(
151+
"requantize.per_tensor(Tensor input, float in_scale, int in_zero_point, float out_scale, "
152+
"int out_zero_point, ScalarType out_dtype) -> (Tensor Y)"
153+
)
143154
lib.define(
144155
"fully_connected(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor out)"
145156
)
@@ -223,6 +234,10 @@
223234
"im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
224235
"Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
225236
)
237+
lib.define(
238+
"im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
239+
"int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
240+
)
226241
lib.define(
227242
"transposed_im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, "
228243
"int[2] stride, int[2] output_padding, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
@@ -231,7 +246,10 @@
231246
"requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
232247
"Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)"
233248
)
234-
249+
lib.define(
250+
"requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, "
251+
"int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)"
252+
)
235253

236254
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
237255
aten_lib = Library("aten", "FRAGMENT")
@@ -562,22 +580,25 @@ def im2row_meta(
562580
in_zero_point: torch.Tensor,
563581
channel_last: bool = False,
564582
) -> torch.Tensor:
565-
if len(input.shape) == 3:
566-
height_dim = 1 if channel_last else 2
567-
input = input.unsqueeze(height_dim)
583+
output_size = get_im2row_output_size(
584+
input, kernel_size, dilation, padding, stride, channel_last
585+
)
586+
return input.new_empty(output_size, dtype=input.dtype)
568587

569-
batch_size = input.shape[0]
570-
n_input_plane = input.shape[3] if channel_last else input.shape[1]
571-
input_height = input.shape[1] if channel_last else input.shape[2]
572-
input_width = input.shape[2] if channel_last else input.shape[3]
573-
output_height = (
574-
input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)
575-
) // stride[0] + 1
576-
output_width = (
577-
input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)
578-
) // stride[1] + 1
579-
n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
580-
output_size = torch.Size((batch_size, output_height * output_width, n_output_plane))
588+
589+
@register_fake("cadence::im2row.per_tensor")
590+
def im2row_per_tensor_meta(
591+
input: torch.Tensor,
592+
kernel_size: Tuple[int],
593+
dilation: Tuple[int],
594+
padding: Tuple[int],
595+
stride: Tuple[int],
596+
in_zero_point: int,
597+
channel_last: bool = False,
598+
) -> torch.Tensor:
599+
output_size = get_im2row_output_size(
600+
input, kernel_size, dilation, padding, stride, channel_last
601+
)
581602
return input.new_empty(output_size, dtype=input.dtype)
582603

583604

@@ -606,6 +627,22 @@ def requantize_meta(
606627
)
607628

608629

630+
@register_fake("cadence::requantize.per_tensor")
631+
def requantize_per_tensor_meta(
632+
input: torch.Tensor,
633+
in_scale: float,
634+
in_zero_point: int,
635+
out_scale: float,
636+
out_zero_point: int,
637+
dtype: ScalarType,
638+
) -> torch.Tensor:
639+
return input.new_empty(
640+
input.size(),
641+
# pyre-ignore[6]: Incompatible type
642+
dtype=dtype,
643+
)
644+
645+
609646
@register_fake("cadence::quantized_relu.per_tensor")
610647
def quantized_relu_per_tensor_meta(
611648
input: torch.Tensor,

‎backends/cadence/aot/replace_ops.py

+11
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,14 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18641864
exir_ops.edge.cadence.quantized_relu.per_tensor,
18651865
[1, 3, 4],
18661866
),
1867+
exir_ops.edge.cadence.im2row: (
1868+
exir_ops.edge.cadence.im2row.per_tensor,
1869+
[5],
1870+
),
1871+
exir_ops.edge.cadence.requantize: (
1872+
exir_ops.edge.cadence.requantize.per_tensor,
1873+
[1, 2, 3, 4],
1874+
),
18671875
}
18681876

18691877
def call_operator(self, op, args, kwargs, meta):
@@ -1884,6 +1892,9 @@ def call_operator(self, op, args, kwargs, meta):
18841892
if not arg.is_tensor():
18851893
return super().call_operator(op, args, kwargs, meta)
18861894

1895+
if not isinstance(arg.node.target, EdgeOpOverload):
1896+
return super().call_operator(op, args, kwargs, meta)
1897+
18871898
if get_edge_overload_packet(arg.node.target) != exir_ops.edge.aten.full:
18881899
# Only replace if arg generated by a full op.
18891900
return super().call_operator(op, args, kwargs, meta)

‎backends/cadence/aot/utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,33 @@ def get_conv2d_output_size(
7575
return torch.Size((in_size[0], out_channels, hout, wout))
7676

7777

78+
def get_im2row_output_size(
79+
input: torch.Tensor,
80+
kernel_size: Tuple[int],
81+
dilation: Tuple[int],
82+
padding: Tuple[int],
83+
stride: Tuple[int],
84+
channel_last: bool,
85+
) -> torch.Size:
86+
if len(input.shape) == 3:
87+
height_dim = 1 if channel_last else 2
88+
input = input.unsqueeze(height_dim)
89+
90+
batch_size = input.shape[0]
91+
n_input_plane = input.shape[3] if channel_last else input.shape[1]
92+
input_height = input.shape[1] if channel_last else input.shape[2]
93+
input_width = input.shape[2] if channel_last else input.shape[3]
94+
output_height = (
95+
input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)
96+
) // stride[0] + 1
97+
output_width = (
98+
input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)
99+
) // stride[1] + 1
100+
n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
101+
output_size = torch.Size((batch_size, output_height * output_width, n_output_plane))
102+
return torch.Size(output_size)
103+
104+
78105
# Return the overload packet for the edge op
79106
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
80107
edge_op_namespace, edge_op_name = (

‎backends/cadence/reference/operators/im2row_out.cpp

+86
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,92 @@ void im2row_out(
207207
#undef typed_im2row
208208
}
209209

210+
void im2row_per_tensor_out(
211+
__ET_UNUSED KernelRuntimeContext& ctx,
212+
const Tensor& input,
213+
IntArrayRef kernel_size,
214+
IntArrayRef dilation,
215+
IntArrayRef padding,
216+
IntArrayRef stride,
217+
int64_t in_zero_point,
218+
bool channel_last,
219+
Tensor& out) {
220+
// Compute the input tensor's dims
221+
bool unit_height = input.dim() == 3;
222+
const int32_t batch_size = input.size(0);
223+
const int32_t in_c =
224+
channel_last ? input.size(3 - unit_height) : input.size(1);
225+
const int32_t in_h =
226+
unit_height ? 1 : (channel_last ? input.size(1) : input.size(2));
227+
const int32_t in_w =
228+
channel_last ? input.size(2 - unit_height) : input.size(3 - unit_height);
229+
230+
// Get the kernel parameters
231+
int32_t kernel_h = kernel_size[0];
232+
int32_t kernel_w = kernel_size[1];
233+
int32_t dilation_h = dilation[0];
234+
int32_t dilation_w = dilation[1];
235+
int32_t pad_h = padding[0];
236+
int32_t pad_w = padding[1];
237+
int32_t stride_h = stride[0];
238+
int32_t stride_w = stride[1];
239+
240+
// If we were to apply a convolution on the input tensor, compute the output
241+
// height and width.
242+
int32_t out_h =
243+
(in_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) / stride_h + 1;
244+
int32_t out_w =
245+
(in_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) / stride_w + 1;
246+
247+
ET_DCHECK_MSG(
248+
(out_h * out_w) == out.size(1), "dimension mismatch for output");
249+
ET_DCHECK_MSG(
250+
(kernel_h * kernel_w * in_c) == out.size(2),
251+
"dimension mismatch for output");
252+
253+
#define typed_im2row_per_tensor(dtype, ctype) \
254+
case ScalarType::dtype: { \
255+
const ctype* __restrict__ in_data = input.const_data_ptr<ctype>(); \
256+
ctype* __restrict__ out_data = out.mutable_data_ptr<ctype>(); \
257+
int32_t in_plane = in_c * in_h * in_w; \
258+
int32_t out_plane = kernel_h * kernel_w * in_c * out_h * out_w; \
259+
for (size_t n = 0; n < batch_size; ++n) { \
260+
im2row_<ctype>( \
261+
&in_data[n * in_plane], \
262+
in_zero_point, \
263+
in_c, \
264+
in_h, \
265+
in_w, \
266+
out_h, \
267+
out_w, \
268+
kernel_h, \
269+
kernel_w, \
270+
pad_h, \
271+
pad_w, \
272+
stride_h, \
273+
stride_w, \
274+
dilation_h, \
275+
dilation_w, \
276+
&out_data[n * out_plane], \
277+
channel_last); \
278+
} \
279+
break; \
280+
}
281+
282+
ScalarType dtype = input.scalar_type();
283+
switch (dtype) {
284+
typed_im2row_per_tensor(Float, float);
285+
typed_im2row_per_tensor(Byte, uint8_t);
286+
typed_im2row_per_tensor(Char, int8_t);
287+
default:
288+
ET_DCHECK_MSG(
289+
false,
290+
"im2row.per_tensor not implemented for dtype %s",
291+
torch::executor::toString(dtype));
292+
}
293+
#undef typed_im2row_per_tensor
294+
}
295+
210296
} // namespace native
211297
} // namespace reference
212298
} // namespace impl

‎backends/cadence/reference/operators/requantize_out.cpp

+104
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,110 @@ Tensor& requantize_out(
157157
return out;
158158
}
159159

160+
// Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor.
161+
// The scale and zero_point for requantization are in the args.
162+
Tensor& requantize_per_tensor_out(
163+
KernelRuntimeContext& ctx,
164+
const Tensor& input,
165+
double in_scale,
166+
int64_t in_zero_point,
167+
double out_scale,
168+
int64_t out_zero_point,
169+
const ScalarType out_dtype,
170+
Tensor& out) {
171+
ET_KERNEL_CHECK_MSG(
172+
ctx,
173+
out.scalar_type() == out_dtype,
174+
InvalidArgument,
175+
out,
176+
"Out tensor dtype (%s) does not match the passed in out dtype (%s)",
177+
torch::executor::toString(out.scalar_type()),
178+
torch::executor::toString(out_dtype));
179+
180+
const size_t numel = out.numel();
181+
ScalarType in_dtype = input.scalar_type();
182+
183+
// Assert that the output tensor's dtype is same as out_dtype.
184+
ET_KERNEL_CHECK_MSG(
185+
ctx,
186+
out_dtype == out.scalar_type(),
187+
InvalidArgument,
188+
out,
189+
"Out dtype %s does not match requant dtype %s",
190+
torch::executor::toString(out.scalar_type()),
191+
torch::executor::toString(out_dtype));
192+
193+
#define typed_requantize(ctype, dtype) \
194+
const ctype* input_data = input.const_data_ptr<ctype>(); \
195+
dtype* out_data = out.mutable_data_ptr<dtype>(); \
196+
kernels::requantize<ctype, dtype>( \
197+
out_data, \
198+
input_data, \
199+
static_cast<float>(in_scale), \
200+
static_cast<int32_t>(in_zero_point), \
201+
1.0 / static_cast<float>(out_scale), \
202+
static_cast<int32_t>(out_zero_point), \
203+
numel);
204+
205+
#define typed_requantize_in(ctype) \
206+
switch (out_dtype) { \
207+
case ScalarType::Byte: { \
208+
typed_requantize(ctype, uint8_t); \
209+
break; \
210+
} \
211+
case ScalarType::Char: { \
212+
typed_requantize(ctype, int8_t); \
213+
break; \
214+
} \
215+
case ScalarType::UInt16: { \
216+
typed_requantize(ctype, uint16_t); \
217+
break; \
218+
} \
219+
case ScalarType::Short: { \
220+
typed_requantize(ctype, int16_t); \
221+
break; \
222+
} \
223+
default: \
224+
ET_KERNEL_CHECK_MSG( \
225+
ctx, \
226+
false, \
227+
InvalidArgument, \
228+
out, \
229+
"Unhandled output dtype %s", \
230+
torch::executor::toString(out_dtype)); \
231+
}
232+
233+
switch (in_dtype) {
234+
case ScalarType::Byte: {
235+
typed_requantize_in(uint8_t);
236+
break;
237+
}
238+
case ScalarType::Char: {
239+
typed_requantize_in(int8_t);
240+
break;
241+
}
242+
case ScalarType::UInt16: {
243+
typed_requantize_in(uint16_t);
244+
break;
245+
}
246+
case ScalarType::Short: {
247+
typed_requantize_in(int16_t);
248+
break;
249+
}
250+
default:
251+
ET_KERNEL_CHECK_MSG(
252+
ctx,
253+
false,
254+
InvalidArgument,
255+
out,
256+
"Unhandled input dtype %s",
257+
torch::executor::toString(in_dtype));
258+
}
259+
#undef typed_requantize_in
260+
#undef typed_requantize
261+
return out;
262+
}
263+
160264
}; // namespace native
161265
}; // namespace reference
162266
}; // namespace impl

0 commit comments

Comments
 (0)
Please sign in to comment.