@@ -120,6 +120,88 @@ void _q_at_k_gemm(
120
120
}
121
121
}
122
122
123
+ // Refactor op_dequantize.cpp to avoid code duplication
124
+ void dequantize_optimized (
125
+ const int8_t * in,
126
+ const float scale,
127
+ const int8_t zero_point,
128
+ float * out,
129
+ int64_t quant_min,
130
+ int64_t quant_max,
131
+ size_t numel) {
132
+ size_t i = 0 ;
133
+ #if defined(__aarch64__) || defined(__ARM_NEON)
134
+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
135
+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
136
+ constexpr int32_t kVecSize = 16 ;
137
+ const size_t num_vecs = numel / kVecSize ;
138
+ const int8_t * in_copy = in;
139
+ float * out_copy = out;
140
+ for (; i < num_vecs; i++) {
141
+ int8x16_t in_vec = vld1q_s8 (in_copy);
142
+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
143
+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
144
+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
145
+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
146
+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
147
+
148
+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
149
+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
150
+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
151
+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
152
+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
153
+ vst1q_f32 (out_copy + 0 , out_vec_0_3);
154
+ vst1q_f32 (out_copy + 4 , out_vec_4_7);
155
+ vst1q_f32 (out_copy + 8 , out_vec_8_11);
156
+ vst1q_f32 (out_copy + 12 , out_vec_12_15);
157
+ in_copy += kVecSize ;
158
+ out_copy += kVecSize ;
159
+ }
160
+ i = i * kVecSize ;
161
+ #endif
162
+ for (; i < numel; i++) {
163
+ out[i] = (static_cast <int16_t >(in[i]) - static_cast <int16_t >(zero_point)) *
164
+ scale;
165
+ }
166
+ }
167
+
168
+ void dequantize_per_channel_optimized (
169
+ const int8_t * in_data,
170
+ const float * scales_data,
171
+ const int8_t * zero_points_data,
172
+ float * out_data,
173
+ int64_t quant_min,
174
+ int64_t quant_max,
175
+ size_t outer_size,
176
+ size_t in_outer_stride,
177
+ size_t out_outer_stride,
178
+ size_t num_channels,
179
+ size_t in_channel_stride,
180
+ size_t out_channel_stride,
181
+ size_t channel_size,
182
+ size_t qparams_stride) {
183
+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
184
+ // Loop through dim
185
+ for (size_t channel_idx = 0 ; channel_idx < num_channels; ++channel_idx) {
186
+ const int8_t * in_data_local = in_data + outer_idx * in_outer_stride +
187
+ channel_idx * in_channel_stride;
188
+ const float scale = *(scales_data + channel_idx * qparams_stride);
189
+ const int8_t zero_point =
190
+ *(zero_points_data + channel_idx * qparams_stride);
191
+ float * out_data_local = out_data + outer_idx * out_outer_stride +
192
+ channel_idx * out_channel_stride;
193
+ dequantize_optimized (
194
+ in_data_local,
195
+ scale,
196
+ zero_point,
197
+ out_data_local,
198
+ quant_min,
199
+ quant_max,
200
+ channel_size);
201
+ }
202
+ }
203
+ }
204
+
123
205
template <typename accum_t >
124
206
void _qk_at_v_gemm (
125
207
const int64_t m,
@@ -134,24 +216,36 @@ void _qk_at_v_gemm(
134
216
const accum_t beta) {
135
217
if (v_data.dtype == ScalarType::Char) {
136
218
if constexpr (std::is_same<accum_t , float >::value) {
137
- int a_stride_m_tmp, b_stride_n_tmp;
138
- auto kernel = torchao::kernels::cpu::quantized_matmul::
139
- get_fp32_a_input_channelwise_8bit_b_f32_c_matmul (
140
- m, n, k, false , false , a_stride_m_tmp, b_stride_n_tmp);
141
- kernel (
142
- m,
219
+ std::vector<float > dequantized_v_data (v_data.m * v_data.n );
220
+ dequantize_per_channel_optimized (
221
+ static_cast <const int8_t *>(v_data.data ),
222
+ static_cast <const float *>(v_data.scales ),
223
+ static_cast <const int8_t *>(v_data.zero_points ),
224
+ dequantized_v_data.data (),
225
+ -128 ,
226
+ 127 ,
227
+ 1 ,
228
+ 0 ,
229
+ 0 ,
230
+ v_data.m ,
231
+ v_stride_n,
232
+ v_data.n ,
233
+ v_data.n ,
234
+ v_data.zero_points_stride );
235
+ ::executorch::cpublas::gemm (
236
+ ::executorch::cpublas::TransposeType::NoTranspose,
237
+ ::executorch::cpublas::TransposeType::NoTranspose,
143
238
n,
239
+ m,
144
240
k,
241
+ static_cast <accum_t >(1 ),
242
+ dequantized_v_data.data(),
243
+ v_data.n,
145
244
qk_data,
146
- qk_stride_m /* lhs_stride_m*/ ,
147
- static_cast <const int8_t *>(v_data.data ),
148
- v_stride_n /* rhs_stride_n*/ ,
149
- o_data,
150
- o_stride_m /* out_stride_n*/ ,
151
- static_cast <const int8_t *>(v_data.zero_points ),
152
- static_cast <const float *>(v_data.scales ),
245
+ qk_stride_m,
153
246
beta,
154
- v_data.zero_points_stride );
247
+ o_data,
248
+ o_stride_m);
155
249
} else {
156
250
ET_CHECK_MSG (
157
251
false , " Accumulation in dtype other than float not supported yet" );
0 commit comments