@@ -259,6 +259,151 @@ def update3D(
259
259
260
260
return k_out , v_out
261
261
262
+ def _sliding_update (
263
+ self ,
264
+ layer_idx ,
265
+ key_states ,
266
+ value_states ,
267
+ position_ids ,
268
+ batch_index ,
269
+ k_out ,
270
+ v_out ,
271
+ ):
272
+ N = self .key_cache [layer_idx ].shape [2 ]
273
+
274
+ # Update the position_ids to handle the sliding window
275
+ kv_position_ids = torch .where (position_ids == - 1 , position_ids , position_ids % (N - 1 ))
276
+ kv_position_ids = torch .where (position_ids .max () >= (N - 1 ) * 2 , (position_ids + 1 ) % N , kv_position_ids )
277
+
278
+ # Update the cache
279
+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], kv_position_ids , key_states )
280
+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (self .value_cache [layer_idx ], kv_position_ids , value_states )
281
+
282
+ k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
283
+
284
+ # Original Gather
285
+ ctx_len = min (N , k_out .shape [2 ])
286
+ ctx_indices = torch .arange (ctx_len )[None , None , ...]
287
+ gather_limit = kv_position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
288
+ invalid_mask = ctx_indices > gather_limit
289
+ if torch .onnx .is_in_onnx_export ():
290
+ invalid_idx_value = torch .iinfo (torch .int32 ).max
291
+ else :
292
+ invalid_idx_value = 0
293
+ ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
294
+
295
+ # rolling indices
296
+ all_indices = torch .arange (N ) + kv_position_ids .max () + 1
297
+ rolling_indices = torch .where (all_indices > N - 1 , all_indices % N , all_indices )
298
+
299
+ final_indices = torch .where (position_ids .max () >= (N - 1 ), rolling_indices , ctx_indices )
300
+
301
+ k_out = CtxGatherFunc .apply (k_out , final_indices )
302
+ v_out = CtxGatherFunc .apply (v_out , final_indices )
303
+ prefill_v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
304
+
305
+ # Handle the rolling indices
306
+ v_out = torch .where (position_ids .max () >= (N - 1 ), v_out , prefill_v_out )
307
+ return k_out , v_out
308
+
309
+ def _static_update (
310
+ self ,
311
+ layer_idx ,
312
+ key_states ,
313
+ value_states ,
314
+ position_ids ,
315
+ batch_index ,
316
+ k_out ,
317
+ v_out ,
318
+ ):
319
+ # Update the cache
320
+ if len (self .key_cache ) <= layer_idx :
321
+ self .key_cache .append (key_states )
322
+ self .value_cache .append (value_states )
323
+ k_out , v_out = key_states , value_states
324
+ else :
325
+ # Scatter
326
+ if batch_index is not None :
327
+ invalid_scatter_index = torch .iinfo (torch .int32 ).max
328
+ scatter_position_ids = torch .where (position_ids < 0 , invalid_scatter_index , position_ids )
329
+
330
+ self .key_cache [layer_idx ] = CtxScatterFuncCB .apply (
331
+ self .key_cache [layer_idx ], batch_index , scatter_position_ids , key_states
332
+ )
333
+
334
+ self .value_cache [layer_idx ] = CtxScatterFuncCB .apply (
335
+ self .value_cache [layer_idx ], batch_index , scatter_position_ids , value_states
336
+ )
337
+ else :
338
+ self .key_cache [layer_idx ] = CtxScatterFunc .apply (self .key_cache [layer_idx ], position_ids , key_states )
339
+ self .value_cache [layer_idx ] = CtxScatterFunc .apply (
340
+ self .value_cache [layer_idx ], position_ids , value_states
341
+ )
342
+
343
+ k_out , v_out = self .key_cache [layer_idx ], self .value_cache [layer_idx ]
344
+
345
+ # Gather
346
+ ctx_len = k_out .shape [2 ]
347
+ ctx_indices = torch .arange (ctx_len )[None , None , ...]
348
+ gather_limit = position_ids .max (1 , keepdim = True ).values .unsqueeze (1 )
349
+ invalid_mask = ctx_indices > gather_limit
350
+
351
+ if torch .onnx .is_in_onnx_export ():
352
+ invalid_idx_value = torch .iinfo (torch .int32 ).max
353
+ else :
354
+ invalid_idx_value = 0
355
+
356
+ ctx_indices = torch .where (invalid_mask , invalid_idx_value , ctx_indices )
357
+ if batch_index is not None :
358
+ k_out = CtxGatherFuncCB .apply (k_out , batch_index , ctx_indices )
359
+ v_out = CtxGatherFuncCB .apply (v_out , batch_index , ctx_indices )
360
+ else :
361
+ k_out = CtxGatherFunc .apply (k_out , ctx_indices )
362
+ v_out = CtxGatherFunc .apply (v_out , ctx_indices )
363
+ v_out = torch .where (invalid_mask .unsqueeze (- 1 ), torch .tensor (0.0 , dtype = torch .float32 ), v_out )
364
+
365
+ return k_out , v_out
366
+
367
+ def update_hybrid_chunked (
368
+ self ,
369
+ key_states : torch .Tensor ,
370
+ value_states : torch .Tensor ,
371
+ layer_idx : int ,
372
+ cache_kwargs : Optional [Dict [str , Any ]] = None ,
373
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
374
+ """
375
+ Updates cache with support for both sliding window and position-based updates.
376
+ """
377
+ if cache_kwargs is None :
378
+ cache_kwargs = {}
379
+
380
+ k_out = self .key_cache [layer_idx ]
381
+ v_out = self .value_cache [layer_idx ]
382
+ key_states = key_states .to (k_out .dtype )
383
+ value_states = value_states .to (v_out .dtype )
384
+
385
+ # Get cache parameters
386
+ position_ids = cache_kwargs .get ("position_ids" )
387
+ batch_index = cache_kwargs .get ("batch_index" , None )
388
+ sliding_window = cache_kwargs .get ("is_sliding" , None )
389
+
390
+ if sliding_window [layer_idx ]:
391
+ update_fn = self ._sliding_update
392
+ else :
393
+ update_fn = self ._static_update
394
+
395
+ k_out , v_out = update_fn (
396
+ layer_idx ,
397
+ key_states ,
398
+ value_states ,
399
+ position_ids ,
400
+ batch_index ,
401
+ k_out ,
402
+ v_out ,
403
+ )
404
+
405
+ return k_out , v_out
406
+
262
407
263
408
class QEffEncoderDecoderCache (EncoderDecoderCache ):
264
409
"""
0 commit comments