|
10 | 10 | from typing import Optional, Tuple
|
11 | 11 |
|
12 | 12 | import torch
|
| 13 | +from executorch.backends.cadence.aot.utils import ( |
| 14 | + get_conv1d_output_size, |
| 15 | + get_conv2d_output_size, |
| 16 | + get_im2row_output_size, |
| 17 | +) |
13 | 18 | from executorch.exir.scalar_type import ScalarType
|
14 | 19 | from torch.library import Library, register_fake
|
15 | 20 |
|
16 |
| -from .utils import get_conv1d_output_size, get_conv2d_output_size |
17 |
| - |
18 | 21 | lib = Library("cadence", "DEF")
|
19 | 22 |
|
20 | 23 | lib.define(
|
|
131 | 134 | "im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
|
132 | 135 | "Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
|
133 | 136 | )
|
| 137 | +lib.define( |
| 138 | + "im2row.per_tensor(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " |
| 139 | + "int in_zero_point, bool channel_last=False) -> (Tensor out)" |
| 140 | +) |
134 | 141 | lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
|
135 | 142 | lib.define(
|
136 | 143 | "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
|
|
140 | 147 | "requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
|
141 | 148 | "Tensor out_zero_point, ScalarType out_dtype) -> (Tensor Y)"
|
142 | 149 | )
|
| 150 | +lib.define( |
| 151 | + "requantize.per_tensor(Tensor input, float in_scale, int in_zero_point, float out_scale, " |
| 152 | + "int out_zero_point, ScalarType out_dtype) -> (Tensor Y)" |
| 153 | +) |
143 | 154 | lib.define(
|
144 | 155 | "fully_connected(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor out)"
|
145 | 156 | )
|
|
223 | 234 | "im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
|
224 | 235 | "Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
|
225 | 236 | )
|
| 237 | +lib.define( |
| 238 | + "im2row.per_tensor_out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " |
| 239 | + "int in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" |
| 240 | +) |
226 | 241 | lib.define(
|
227 | 242 | "transposed_im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, "
|
228 | 243 | "int[2] stride, int[2] output_padding, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
|
|
231 | 246 | "requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, "
|
232 | 247 | "Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)"
|
233 | 248 | )
|
234 |
| - |
| 249 | +lib.define( |
| 250 | + "requantize.per_tensor_out(Tensor input, float in_scale, int in_zero_point, float out_scale, " |
| 251 | + "int out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)" |
| 252 | +) |
235 | 253 |
|
236 | 254 | # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
|
237 | 255 | aten_lib = Library("aten", "FRAGMENT")
|
@@ -562,22 +580,25 @@ def im2row_meta(
|
562 | 580 | in_zero_point: torch.Tensor,
|
563 | 581 | channel_last: bool = False,
|
564 | 582 | ) -> torch.Tensor:
|
565 |
| - if len(input.shape) == 3: |
566 |
| - height_dim = 1 if channel_last else 2 |
567 |
| - input = input.unsqueeze(height_dim) |
| 583 | + output_size = get_im2row_output_size( |
| 584 | + input, kernel_size, dilation, padding, stride, channel_last |
| 585 | + ) |
| 586 | + return input.new_empty(output_size, dtype=input.dtype) |
568 | 587 |
|
569 |
| - batch_size = input.shape[0] |
570 |
| - n_input_plane = input.shape[3] if channel_last else input.shape[1] |
571 |
| - input_height = input.shape[1] if channel_last else input.shape[2] |
572 |
| - input_width = input.shape[2] if channel_last else input.shape[3] |
573 |
| - output_height = ( |
574 |
| - input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1) |
575 |
| - ) // stride[0] + 1 |
576 |
| - output_width = ( |
577 |
| - input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1) |
578 |
| - ) // stride[1] + 1 |
579 |
| - n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] |
580 |
| - output_size = torch.Size((batch_size, output_height * output_width, n_output_plane)) |
| 588 | + |
| 589 | +@register_fake("cadence::im2row.per_tensor") |
| 590 | +def im2row_per_tensor_meta( |
| 591 | + input: torch.Tensor, |
| 592 | + kernel_size: Tuple[int], |
| 593 | + dilation: Tuple[int], |
| 594 | + padding: Tuple[int], |
| 595 | + stride: Tuple[int], |
| 596 | + in_zero_point: int, |
| 597 | + channel_last: bool = False, |
| 598 | +) -> torch.Tensor: |
| 599 | + output_size = get_im2row_output_size( |
| 600 | + input, kernel_size, dilation, padding, stride, channel_last |
| 601 | + ) |
581 | 602 | return input.new_empty(output_size, dtype=input.dtype)
|
582 | 603 |
|
583 | 604 |
|
@@ -606,6 +627,22 @@ def requantize_meta(
|
606 | 627 | )
|
607 | 628 |
|
608 | 629 |
|
| 630 | +@register_fake("cadence::requantize.per_tensor") |
| 631 | +def requantize_per_tensor_meta( |
| 632 | + input: torch.Tensor, |
| 633 | + in_scale: float, |
| 634 | + in_zero_point: int, |
| 635 | + out_scale: float, |
| 636 | + out_zero_point: int, |
| 637 | + dtype: ScalarType, |
| 638 | +) -> torch.Tensor: |
| 639 | + return input.new_empty( |
| 640 | + input.size(), |
| 641 | + # pyre-ignore[6]: Incompatible type |
| 642 | + dtype=dtype, |
| 643 | + ) |
| 644 | + |
| 645 | + |
609 | 646 | @register_fake("cadence::quantized_relu.per_tensor")
|
610 | 647 | def quantized_relu_per_tensor_meta(
|
611 | 648 | input: torch.Tensor,
|
|
0 commit comments