Skip to content

Commit 8d8f677

Browse files
authored
In quantized sdpa dequant v
Differential Revision: D71833063 Pull Request resolved: #10097
1 parent 91a14f1 commit 8d8f677

File tree

1 file changed

+108
-14
lines changed

1 file changed

+108
-14
lines changed

extension/llm/custom_ops/op_sdpa_impl.h

+108-14
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,88 @@ void _q_at_k_gemm(
120120
}
121121
}
122122

123+
// Refactor op_dequantize.cpp to avoid code duplication
124+
void dequantize_optimized(
125+
const int8_t* in,
126+
const float scale,
127+
const int8_t zero_point,
128+
float* out,
129+
int64_t quant_min,
130+
int64_t quant_max,
131+
size_t numel) {
132+
size_t i = 0;
133+
#if defined(__aarch64__) || defined(__ARM_NEON)
134+
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
135+
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
136+
constexpr int32_t kVecSize = 16;
137+
const size_t num_vecs = numel / kVecSize;
138+
const int8_t* in_copy = in;
139+
float* out_copy = out;
140+
for (; i < num_vecs; i++) {
141+
int8x16_t in_vec = vld1q_s8(in_copy);
142+
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
143+
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
144+
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
145+
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
146+
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);
147+
148+
int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
149+
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
150+
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
151+
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
152+
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
153+
vst1q_f32(out_copy + 0, out_vec_0_3);
154+
vst1q_f32(out_copy + 4, out_vec_4_7);
155+
vst1q_f32(out_copy + 8, out_vec_8_11);
156+
vst1q_f32(out_copy + 12, out_vec_12_15);
157+
in_copy += kVecSize;
158+
out_copy += kVecSize;
159+
}
160+
i = i * kVecSize;
161+
#endif
162+
for (; i < numel; i++) {
163+
out[i] = (static_cast<int16_t>(in[i]) - static_cast<int16_t>(zero_point)) *
164+
scale;
165+
}
166+
}
167+
168+
void dequantize_per_channel_optimized(
169+
const int8_t* in_data,
170+
const float* scales_data,
171+
const int8_t* zero_points_data,
172+
float* out_data,
173+
int64_t quant_min,
174+
int64_t quant_max,
175+
size_t outer_size,
176+
size_t in_outer_stride,
177+
size_t out_outer_stride,
178+
size_t num_channels,
179+
size_t in_channel_stride,
180+
size_t out_channel_stride,
181+
size_t channel_size,
182+
size_t qparams_stride) {
183+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
184+
// Loop through dim
185+
for (size_t channel_idx = 0; channel_idx < num_channels; ++channel_idx) {
186+
const int8_t* in_data_local = in_data + outer_idx * in_outer_stride +
187+
channel_idx * in_channel_stride;
188+
const float scale = *(scales_data + channel_idx * qparams_stride);
189+
const int8_t zero_point =
190+
*(zero_points_data + channel_idx * qparams_stride);
191+
float* out_data_local = out_data + outer_idx * out_outer_stride +
192+
channel_idx * out_channel_stride;
193+
dequantize_optimized(
194+
in_data_local,
195+
scale,
196+
zero_point,
197+
out_data_local,
198+
quant_min,
199+
quant_max,
200+
channel_size);
201+
}
202+
}
203+
}
204+
123205
template <typename accum_t>
124206
void _qk_at_v_gemm(
125207
const int64_t m,
@@ -134,24 +216,36 @@ void _qk_at_v_gemm(
134216
const accum_t beta) {
135217
if (v_data.dtype == ScalarType::Char) {
136218
if constexpr (std::is_same<accum_t, float>::value) {
137-
int a_stride_m_tmp, b_stride_n_tmp;
138-
auto kernel = torchao::kernels::cpu::quantized_matmul::
139-
get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
140-
m, n, k, false, false, a_stride_m_tmp, b_stride_n_tmp);
141-
kernel(
142-
m,
219+
std::vector<float> dequantized_v_data(v_data.m * v_data.n);
220+
dequantize_per_channel_optimized(
221+
static_cast<const int8_t*>(v_data.data),
222+
static_cast<const float*>(v_data.scales),
223+
static_cast<const int8_t*>(v_data.zero_points),
224+
dequantized_v_data.data(),
225+
-128,
226+
127,
227+
1,
228+
0,
229+
0,
230+
v_data.m,
231+
v_stride_n,
232+
v_data.n,
233+
v_data.n,
234+
v_data.zero_points_stride);
235+
::executorch::cpublas::gemm(
236+
::executorch::cpublas::TransposeType::NoTranspose,
237+
::executorch::cpublas::TransposeType::NoTranspose,
143238
n,
239+
m,
144240
k,
241+
static_cast<accum_t>(1),
242+
dequantized_v_data.data(),
243+
v_data.n,
145244
qk_data,
146-
qk_stride_m /*lhs_stride_m*/,
147-
static_cast<const int8_t*>(v_data.data),
148-
v_stride_n /*rhs_stride_n*/,
149-
o_data,
150-
o_stride_m /*out_stride_n*/,
151-
static_cast<const int8_t*>(v_data.zero_points),
152-
static_cast<const float*>(v_data.scales),
245+
qk_stride_m,
153246
beta,
154-
v_data.zero_points_stride);
247+
o_data,
248+
o_stride_m);
155249
} else {
156250
ET_CHECK_MSG(
157251
false, "Accumulation in dtype other than float not supported yet");

0 commit comments

Comments
 (0)