@@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
89
89
sampling_metadata_before )
90
90
91
91
92
+ def _is_req_state_block_table_match (model_runner , req_id : str ) -> bool :
93
+ req_index = model_runner .input_batch .req_id_to_index [req_id ]
94
+ block_table = model_runner .input_batch .block_table
95
+ req_state = model_runner .requests [req_id ]
96
+ if block_table .num_blocks_per_row [req_index ] != len (req_state .block_ids ):
97
+ return False
98
+ num_blocks = block_table .num_blocks_per_row [req_index ]
99
+ return (block_table .block_table_np [req_index , :num_blocks ] ==
100
+ req_state .block_ids ).all ()
101
+
102
+
92
103
def test_update_states_new_request (model_runner ):
93
104
req_id = "req_0"
94
105
@@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
100
111
assert _is_sampling_metadata_changed (model_runner , metadata_before )
101
112
assert _is_req_added (model_runner , req_id )
102
113
assert _is_req_scheduled (model_runner , req_id )
114
+ assert _is_req_state_block_table_match (model_runner , req_id )
103
115
104
116
105
117
def test_update_states_request_finished (model_runner ):
@@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
185
197
assert _is_sampling_metadata_changed (model_runner , metadata_before )
186
198
assert _is_req_added (model_runner , req_id )
187
199
assert _is_req_scheduled (model_runner , req_id )
200
+ assert _is_req_state_block_table_match (model_runner , req_id )
188
201
189
202
190
203
def test_update_states_no_changes (model_runner ):
@@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
215
228
assert not _is_sampling_metadata_changed (model_runner , metadata_before )
216
229
assert _is_req_added (model_runner , req_id )
217
230
assert _is_req_scheduled (model_runner , req_id )
231
+ assert _is_req_state_block_table_match (model_runner , req_id )
218
232
219
233
220
234
def test_update_states_request_unscheduled (model_runner ):
0 commit comments