@@ -75,9 +75,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
75
75
"""
76
76
num_chunks = - (input_len // - prefill_seq_len ) # ceil divide without float
77
77
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
78
- assert input_len_padded <= ctx_len , (
79
- "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
80
- )
78
+ assert (
79
+ input_len_padded <= ctx_len
80
+ ), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
81
81
return input_len_padded
82
82
83
83
@@ -93,6 +93,7 @@ def split_dlm_bonus_token_inputs(dlm_decode_inputs):
93
93
return bonus_token_inputs , dlm_decode_inputs
94
94
95
95
96
+ @pytest .mark .on_qaic
96
97
@pytest .mark .skip () # remove when the SDK 1.20.0 issue solved for compiling this model
97
98
@pytest .mark .parametrize (
98
99
"prompts, num_speculative_tokens, prefill_seq_len, ctx_len, prefill_bsz, draft_model_name, target_model_name, full_batch_size" ,
@@ -319,9 +320,9 @@ def test_spec_decode_inference(
319
320
for prompt , generation in zip (prompts , batch_decode ):
320
321
print (f"{ prompt = } { generation = } " )
321
322
# validation check
322
- assert mean_num_accepted_tokens == float (num_speculative_tokens + 1 ), (
323
- f"mean number of accepted tokens is { mean_num_accepted_tokens } but should be { num_speculative_tokens + 1 } "
324
- )
323
+ assert mean_num_accepted_tokens == float (
324
+ num_speculative_tokens + 1
325
+ ), f"mean number of accepted tokens is { mean_num_accepted_tokens } but should be { num_speculative_tokens + 1 } "
325
326
del target_model_session
326
327
del draft_model_session
327
328
generated_ids = np .asarray (generated_ids [0 ]).flatten ()
0 commit comments