Skip to content

Commit b51e730

Browse files
Geonhwa Jeongfacebook-github-bot
Geonhwa Jeong
authored andcommitted
Optimize kv cache usage for yoco (#4030)
Summary: X-link: facebookresearch/FBGEMM#1114 This diff aims to remove redundant temporary tensors and KV caches when using yoco. This includes changes for allowing optional xk, xv for rope/nope kernels, adding flag for checking to update kv cache or not, etc. Reviewed By: Aya-ZIbra Differential Revision: D73570737
1 parent 0485fcf commit b51e730

File tree

2 files changed

+241
-149
lines changed

2 files changed

+241
-149
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp

+41-33
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ namespace fbgemm_gpu {
2828

2929
at::Tensor nope_qkv_varseq_prefill(
3030
at::Tensor XQ,
31-
at::Tensor XK,
32-
at::Tensor XV,
31+
std::optional<at::Tensor> XK,
32+
std::optional<at::Tensor> XV,
3333
at::Tensor cache_K,
3434
at::Tensor cache_V,
3535
at::Tensor varseq_batch,
@@ -39,12 +39,13 @@ at::Tensor nope_qkv_varseq_prefill(
3939
std::optional<at::Tensor> varseq_cache_seqpos,
4040
std::optional<at::Tensor> qparam_k,
4141
std::optional<at::Tensor> qparam_v,
42-
bool k_norm);
42+
bool k_norm,
43+
bool update_kv);
4344

4445
at::Tensor nope_qkv_decoding(
4546
at::Tensor XQ,
46-
at::Tensor XK,
47-
at::Tensor XV,
47+
std::optional<at::Tensor> XK,
48+
std::optional<at::Tensor> XV,
4849
at::Tensor cache_K,
4950
at::Tensor cache_V,
5051
at::Tensor seqpos,
@@ -55,12 +56,13 @@ at::Tensor nope_qkv_decoding(
5556
std::optional<at::Tensor> cache_seqpos,
5657
std::optional<at::Tensor> qparam_k,
5758
std::optional<at::Tensor> qparam_v,
58-
bool k_norm);
59+
bool k_norm,
60+
bool update_kv);
5961

6062
at::Tensor rope_qkv_varseq_prefill(
6163
at::Tensor XQ,
62-
at::Tensor XK,
63-
at::Tensor XV,
64+
std::optional<at::Tensor> XK,
65+
std::optional<at::Tensor> XV,
6466
at::Tensor cache_K,
6567
at::Tensor cache_V,
6668
at::Tensor varseq_batch,
@@ -79,12 +81,13 @@ at::Tensor rope_qkv_varseq_prefill(
7981
std::optional<at::Tensor> qparam_k,
8082
std::optional<at::Tensor> qparam_v,
8183
bool write_k_back,
82-
bool k_norm);
84+
bool k_norm,
85+
bool update_kv);
8386

8487
at::Tensor rope_qkv_decoding(
8588
at::Tensor XQ,
86-
at::Tensor XK,
87-
at::Tensor XV,
89+
std::optional<at::Tensor> XK,
90+
std::optional<at::Tensor> XV,
8891
at::Tensor cache_K,
8992
at::Tensor cache_V,
9093
at::Tensor seqpos,
@@ -103,7 +106,8 @@ at::Tensor rope_qkv_decoding(
103106
double hi_freq_factor,
104107
std::optional<at::Tensor> qparam_k,
105108
std::optional<at::Tensor> qparam_v,
106-
bool k_norm);
109+
bool k_norm,
110+
bool update_kv);
107111

108112
at::Tensor xpos_qkv_varseq_prefill(
109113
at::Tensor XQ,
@@ -181,15 +185,15 @@ at::Tensor mqa_attn(
181185
int64_t cache_logical_dtype_int);
182186

183187
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
184-
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!) XK, Tensor XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
185-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
186-
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False) -> Tensor");
187-
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
188-
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
189-
m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
190-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
191-
m.def("nope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
192-
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False) -> Tensor");
188+
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!)? XK, Tensor? XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
189+
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
190+
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False, bool update_kv=True) -> Tensor");
191+
m.def("rope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
192+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True) -> Tensor");
193+
m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
194+
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True) -> Tensor");
195+
m.def("nope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
196+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True) -> Tensor");
193197
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
194198
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
195199
m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
@@ -225,8 +229,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
225229

226230
at::Tensor rope_qkv_varseq_prefill_meta(
227231
at::Tensor XQ,
228-
at::Tensor /* XK */,
229-
at::Tensor /* XV */,
232+
std::optional<at::Tensor> /* XK */,
233+
std::optional<at::Tensor> /* XV */,
230234
at::Tensor /* cache_K */,
231235
at::Tensor /* cache_V */,
232236
at::Tensor /* varseq_batch */,
@@ -245,15 +249,16 @@ at::Tensor rope_qkv_varseq_prefill_meta(
245249
std::optional<at::Tensor> /* qparam_k */,
246250
std::optional<at::Tensor> /* qparam_v */,
247251
bool /* write_k_back */,
248-
bool /* k_norm */
252+
bool /* k_norm */,
253+
bool /* update_kv */
249254
) {
250255
return at::empty_like(XQ);
251256
}
252257

253258
at::Tensor rope_qkv_decoding_meta(
254259
at::Tensor XQ,
255-
at::Tensor /* XK */,
256-
at::Tensor /* XV */,
260+
std::optional<at::Tensor> /* XK */,
261+
std::optional<at::Tensor> /* XV */,
257262
at::Tensor /* cache_K */,
258263
at::Tensor /* cache_V */,
259264
at::Tensor /* seqpos */,
@@ -272,15 +277,16 @@ at::Tensor rope_qkv_decoding_meta(
272277
double /* hi_freq_factor */,
273278
std::optional<at::Tensor> /* qparam_k */,
274279
std::optional<at::Tensor> /* qparam_v */,
275-
bool /* k_norm */
280+
bool /* k_norm */,
281+
bool /* update_kv */
276282
) {
277283
return at::empty_like(XQ);
278284
}
279285

280286
at::Tensor nope_qkv_varseq_prefill_meta(
281287
at::Tensor XQ,
282-
at::Tensor /* XK */,
283-
at::Tensor /* XV */,
288+
std::optional<at::Tensor> /* XK */,
289+
std::optional<at::Tensor> /* XV */,
284290
at::Tensor /* cache_K */,
285291
at::Tensor /* cache_V */,
286292
at::Tensor /* varseq_batch */,
@@ -290,15 +296,16 @@ at::Tensor nope_qkv_varseq_prefill_meta(
290296
std::optional<at::Tensor> /* varseq_cache_seqpos */,
291297
std::optional<at::Tensor> /* qparam_k */,
292298
std::optional<at::Tensor> /* qparam_v */,
293-
bool /* k_norm */
299+
bool /* k_norm */,
300+
bool /* update_kv */
294301
) {
295302
return at::empty_like(XQ);
296303
}
297304

298305
at::Tensor nope_qkv_decoding_meta(
299306
at::Tensor XQ,
300-
at::Tensor /* XK */,
301-
at::Tensor /* XV */,
307+
std::optional<at::Tensor> /* XK */,
308+
std::optional<at::Tensor> /* XV */,
302309
at::Tensor /* cache_K */,
303310
at::Tensor /* cache_V */,
304311
at::Tensor /* seqpos */,
@@ -309,7 +316,8 @@ at::Tensor nope_qkv_decoding_meta(
309316
std::optional<at::Tensor> /* cache_seqpos */,
310317
std::optional<at::Tensor> /* qparam_k */,
311318
std::optional<at::Tensor> /* qparam_v */,
312-
bool /* k_norm */
319+
bool /* k_norm */,
320+
bool /* update_kv */
313321
) {
314322
return at::empty_like(XQ);
315323
}

0 commit comments

Comments
 (0)