21
21
DECOE_BSZ = 4
22
22
DTYPE = "mxfp6"
23
23
KV_DTYPE = "mxint8"
24
- DEVICE_GROUP = [0 ]
25
24
26
25
27
26
@pytest .mark .vllm
@@ -43,7 +42,6 @@ def test_output_consistency(model_name):
43
42
# Creating LLM Object
44
43
qllm = LLM (
45
44
model = model_name ,
46
- device_group = DEVICE_GROUP ,
47
45
max_num_seqs = DECOE_BSZ ,
48
46
max_model_len = CTX_LEN ,
49
47
max_seq_len_to_capture = SEQ_LEN ,
@@ -53,17 +51,20 @@ def test_output_consistency(model_name):
53
51
)
54
52
55
53
# Single prompt test
56
- prompt1 = ["My name is" ]
54
+ single_prompt = ["My name is" ]
57
55
58
- output1 = qllm .generate (prompt1 * 5 , sampling_params )
56
+ single_prompt_output = qllm .generate (single_prompt * 5 , sampling_params )
59
57
60
- check_output1 = []
61
- for i , op in enumerate (output1 ):
62
- check_output1 .append (op .outputs [0 ].text )
58
+ check_output = []
59
+ for i , op in enumerate (single_prompt_output ):
60
+ check_output .append (op .outputs [0 ].text )
61
+
62
+ # Assertion to check the consistency of single prompt.
63
+ assert len (set (check_output )) == 1 , "Outputs from different slots for same prompt does not match!!"
63
64
64
65
# Multiple prompt test
65
66
outputDict = dict ()
66
- prompt2 = [
67
+ multiple_prompt = [
67
68
"My name is" ,
68
69
"How to eat mangosteen?" ,
69
70
"How many people died in World War II" ,
@@ -80,22 +81,21 @@ def test_output_consistency(model_name):
80
81
"Where is Statue of Liberty located?" ,
81
82
]
82
83
83
- for p in prompt2 :
84
+ for p in multiple_prompt :
84
85
outputDict [p ] = []
85
86
86
87
for _ in range (5 ):
87
- random .shuffle (prompt2 )
88
- output2 = qllm .generate (prompt2 , sampling_params )
89
- for i , op in enumerate (output2 ):
88
+ random .shuffle (multiple_prompt )
89
+ multiple_prompt_output = qllm .generate (multiple_prompt , sampling_params )
90
+ for i , op in enumerate (multiple_prompt_output ):
90
91
generated_text = op .outputs [0 ].text
91
- outputDict [prompt2 [i ]].append (str (prompt2 [i ] + generated_text ))
92
-
93
- # Assertion to check the consistency of single prompt.
94
- assert len (set (check_output1 )) == 1 , "Outputs from different slots for same prompt does not match!!"
92
+ outputDict [multiple_prompt [i ]].append (str (multiple_prompt [i ] + generated_text ))
95
93
96
94
# Assertion to check multiple prompts.
97
95
for key in outputDict .keys ():
98
96
assert len (set (outputDict [key ])) == 1 , "Outputs from different slots for same prompt does not match!!"
99
97
100
98
# Assertion to check if any prompts are missed.
101
- assert len (prompt2 ) == len (output2 ), "Number of Generated Tokens do not match the number of valid inputs!!"
99
+ assert len (multiple_prompt ) == len (multiple_prompt_output ), (
100
+ "Number of Generated Tokens do not match the number of valid inputs!!"
101
+ )
0 commit comments