Skip to content

Commit e7bd944

Browse files
authored
[v1] Cleanup the BlockTable in InputBatch (vllm-project#13977)
Signed-off-by: Chen Zhang <[email protected]>
1 parent c3b6559 commit e7bd944

File tree

5 files changed

+25
-17
lines changed

5 files changed

+25
-17
lines changed

tests/v1/worker/test_gpu_model_runner.py

+14
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
8989
sampling_metadata_before)
9090

9191

92+
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
93+
req_index = model_runner.input_batch.req_id_to_index[req_id]
94+
block_table = model_runner.input_batch.block_table
95+
req_state = model_runner.requests[req_id]
96+
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
97+
return False
98+
num_blocks = block_table.num_blocks_per_row[req_index]
99+
return (block_table.block_table_np[req_index, :num_blocks] ==
100+
req_state.block_ids).all()
101+
102+
92103
def test_update_states_new_request(model_runner):
93104
req_id = "req_0"
94105

@@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
100111
assert _is_sampling_metadata_changed(model_runner, metadata_before)
101112
assert _is_req_added(model_runner, req_id)
102113
assert _is_req_scheduled(model_runner, req_id)
114+
assert _is_req_state_block_table_match(model_runner, req_id)
103115

104116

105117
def test_update_states_request_finished(model_runner):
@@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
185197
assert _is_sampling_metadata_changed(model_runner, metadata_before)
186198
assert _is_req_added(model_runner, req_id)
187199
assert _is_req_scheduled(model_runner, req_id)
200+
assert _is_req_state_block_table_match(model_runner, req_id)
188201

189202

190203
def test_update_states_no_changes(model_runner):
@@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
215228
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
216229
assert _is_req_added(model_runner, req_id)
217230
assert _is_req_scheduled(model_runner, req_id)
231+
assert _is_req_state_block_table_match(model_runner, req_id)
218232

219233

220234
def test_update_states_request_unscheduled(model_runner):

vllm/v1/worker/block_table.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@ class BlockTable:
1515
def __init__(
1616
self,
1717
max_num_reqs: int,
18-
max_model_len: int,
1918
max_num_blocks_per_req: int,
2019
pin_memory: bool,
2120
device: torch.device,
2221
):
2322
self.max_num_reqs = max_num_reqs
24-
self.max_model_len = max_model_len
2523
self.max_num_blocks_per_req = max_num_blocks_per_req
2624
self.pin_memory = pin_memory
2725
self.device = device
@@ -42,18 +40,19 @@ def __init__(
4240

4341
def append_row(
4442
self,
45-
row_idx: int,
46-
start: int,
4743
block_ids: List[int],
44+
row_idx: int,
4845
) -> None:
4946
if not block_ids:
5047
return
5148
num_blocks = len(block_ids)
49+
start = self.num_blocks_per_row[row_idx]
50+
self.num_blocks_per_row[row_idx] += num_blocks
5251
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
53-
self.num_blocks_per_row[row_idx] = start + num_blocks
5452

55-
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
56-
self.append_row(row_idx, 0, block_ids)
53+
def add_row(self, block_ids: List[int], row_idx: int) -> None:
54+
self.num_blocks_per_row[row_idx] = 0
55+
self.append_row(block_ids, row_idx)
5756

5857
def move_row(self, src: int, tgt: int) -> None:
5958
num_blocks = self.num_blocks_per_row[src]

vllm/v1/worker/gpu_input_batch.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def __init__(
9292
# Block table.
9393
self.block_table = BlockTable(
9494
max_num_reqs=max_num_reqs,
95-
max_model_len=max_model_len,
9695
max_num_blocks_per_req=max_num_blocks_per_req,
9796
pin_memory=pin_memory,
9897
device=device,
@@ -249,7 +248,7 @@ def add_request(
249248
self.num_tokens_no_spec[req_index] = request.num_tokens
250249

251250
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
252-
self.block_table.add_row(req_index, request.block_ids)
251+
self.block_table.add_row(request.block_ids, req_index)
253252

254253
sampling_params = request.sampling_params
255254
if sampling_params.sampling_type == SamplingType.GREEDY:

vllm/v1/worker/gpu_model_runner.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
399399
# Update the persistent batch.
400400
self.input_batch.num_computed_tokens_cpu[req_index] = (
401401
num_computed_tokens)
402-
start_index = (len(req_state.block_ids) -
403-
len(req_data.new_block_ids))
404-
self.input_batch.block_table.append_row(req_index, start_index,
405-
req_data.new_block_ids)
402+
self.input_batch.block_table.append_row(req_data.new_block_ids,
403+
req_index)
406404
# Add new_token_ids to token_ids_cpu.
407405
start_token_index = num_computed_tokens
408406
end_token_index = num_computed_tokens + len(req_data.new_token_ids)

vllm/v1/worker/tpu_model_runner.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
247247
# Update the persistent batch.
248248
self.input_batch.num_computed_tokens_cpu[req_index] = (
249249
req_data.num_computed_tokens)
250-
start_index = len(req_state.block_ids) - len(
251-
req_data.new_block_ids)
252-
self.input_batch.block_table.append_row(req_index, start_index,
253-
req_data.new_block_ids)
250+
self.input_batch.block_table.append_row(req_data.new_block_ids,
251+
req_index)
254252

255253
# Add the new or resumed requests to the persistent batch.
256254
# The smaller empty indices are filled first.

0 commit comments

Comments
 (0)