@@ -107,7 +107,9 @@ def set_block_size(self, block_size: int) -> None:
107
107
def _prepare_prompt (
108
108
self ,
109
109
seq_group_metadata_list : List [SequenceGroupMetadata ],
110
- ) -> tuple [list [list [int ]], list [list [int ]], InputMetadata , list [int ], list [Tensor ]]:
110
+ ) -> tuple [
111
+ list [list [int ]], list [list [int ]], InputMetadata , list [int ], list [Tensor ]
112
+ ]:
111
113
assert len (seq_group_metadata_list ) > 0
112
114
input_tokens : List [List [int ]] = []
113
115
input_positions : List [List [int ]] = []
@@ -360,17 +362,23 @@ def _prepare_sample(
360
362
def prepare_input_tensors (
361
363
self ,
362
364
seq_group_metadata_list : Optional [List [SequenceGroupMetadata ]],
363
- ) -> Tuple [torch .Tensor , torch .Tensor , InputMetadata , SamplingMetadata , list [torch .Tensor ]]:
365
+ ) -> Tuple [
366
+ torch .Tensor , torch .Tensor , InputMetadata , SamplingMetadata , list [torch .Tensor ]
367
+ ]:
364
368
speaker_embedding = None
365
369
if self .is_driver_worker :
366
370
# NOTE: We assume that all sequences in the group are all prompts or
367
371
# all decodes.
368
372
is_prompt = seq_group_metadata_list [0 ].is_prompt
369
373
# Prepare input tensors.
370
374
if is_prompt :
371
- (input_tokens , input_positions , input_metadata , prompt_lens , speaker_embedding ) = (
372
- self ._prepare_prompt (seq_group_metadata_list )
373
- )
375
+ (
376
+ input_tokens ,
377
+ input_positions ,
378
+ input_metadata ,
379
+ prompt_lens ,
380
+ speaker_embedding ,
381
+ ) = self ._prepare_prompt (seq_group_metadata_list )
374
382
else :
375
383
(input_tokens , input_positions , input_metadata ) = self ._prepare_decode (
376
384
seq_group_metadata_list
@@ -462,7 +470,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
462
470
perform_sampling = False ,
463
471
)
464
472
465
- return input_tokens , input_positions , input_metadata , sampling_metadata , speaker_embedding
473
+ return (
474
+ input_tokens ,
475
+ input_positions ,
476
+ input_metadata ,
477
+ sampling_metadata ,
478
+ speaker_embedding ,
479
+ )
466
480
467
481
@torch .inference_mode ()
468
482
def execute_model (
@@ -471,9 +485,13 @@ def execute_model(
471
485
kv_caches : List [Tuple [torch .Tensor , torch .Tensor ]],
472
486
) -> Optional [SamplerOutput ]:
473
487
474
- input_tokens , input_positions , input_metadata , sampling_metadata , speaker_embedding = (
475
- self .prepare_input_tensors (seq_group_metadata_list )
476
- )
488
+ (
489
+ input_tokens ,
490
+ input_positions ,
491
+ input_metadata ,
492
+ sampling_metadata ,
493
+ speaker_embedding ,
494
+ ) = self .prepare_input_tensors (seq_group_metadata_list )
477
495
# print(sampling_metadata.seq_data)
478
496
seq_groups = []
479
497
for i , rtn in enumerate (sampling_metadata .seq_groups ):
@@ -522,7 +540,9 @@ def execute_model(
522
540
if speaker_embedding_params is None :
523
541
speaker_embedding_params = speaker_embedding [i ]
524
542
else :
525
- speaker_embedding_params = torch .cat ((speaker_embedding_params , speaker_embedding [i ]))
543
+ speaker_embedding_params = torch .cat (
544
+ (speaker_embedding_params , speaker_embedding [i ])
545
+ )
526
546
527
547
else :
528
548
speaker_embedding_params = self .post_model (input_tokens , text_mask )
@@ -560,7 +580,7 @@ def execute_model(
560
580
# sampling_metadata=sampling_metadata,
561
581
# )
562
582
results = []
563
- for i ,val in enumerate (seq_groups ):
583
+ for i , val in enumerate (seq_groups ):
564
584
idx_next_i = idx_next [i , 0 , :].tolist ()
565
585
logprob_i = logprob [i ].tolist ()
566
586
tmp_hidden_states = hidden_states [i ]
@@ -781,7 +801,9 @@ def _make_tensor_with_pad(
781
801
for x_i in x :
782
802
pad_i = pad
783
803
if isinstance (x [0 ][0 ], list ):
784
- pad_i = [0 ,] * len (x [0 ][0 ])
804
+ pad_i = [
805
+ 0 ,
806
+ ] * len (x [0 ][0 ])
785
807
elif isinstance (x [0 ][0 ], tuple ):
786
808
pad_i = (0 ,) * len (x [0 ][0 ])
787
809
padded_x .append (_pad_to_max (x_i , max_len , pad_i ))
@@ -791,6 +813,7 @@ def _make_tensor_with_pad(
791
813
device = device ,
792
814
)
793
815
816
+
794
817
def _make_with_pad (
795
818
x : List [torch .Tensor ],
796
819
max_len : int ,
@@ -805,11 +828,15 @@ def _make_with_pad(
805
828
padded_x .append (x_i )
806
829
else :
807
830
padded_x .append (
808
- torch .cat ((torch .zeros (1 , max_len - x_i .shape [- 2 ], 768 ).to (device ), x_i ), dim = 1 )
831
+ torch .cat (
832
+ (torch .zeros (1 , max_len - x_i .shape [- 2 ], 768 ).to (device ), x_i ),
833
+ dim = 1 ,
834
+ )
809
835
)
810
836
811
837
return padded_x
812
838
839
+
813
840
def _get_graph_batch_size (batch_size : int ) -> int :
814
841
if batch_size <= 2 :
815
842
return batch_size
0 commit comments