Skip to content

Commit 916c8e2

Browse files
committed
v1: Support worker -> scheduler KV connector MD
This commit makes the following changes: 1. Add a new kv_connector_metadata to ModelRunnerOutput to allow arbitrary connector metadata flow from workers to the scheduler. 2. Add a new worker-side connector API to build the above metadata. 3. Change MultiprocExecutor to get ModelRunnerOutput from all workers, and aggregate the kv_connector_metadata from all. 3. Move the get_finished connector API from the worker side to the scheduler side. 4. Change the nixl and multi connectors to match the above API changes. Signed-off-by: Or Ozeri <[email protected]>
1 parent eaa2e51 commit 916c8e2

File tree

11 files changed

+288
-194
lines changed

11 files changed

+288
-194
lines changed

tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import copy
4-
53
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
64
from vllm.v1.request import FinishReason, RequestStatus
75

8-
from .utils import (assert_scheduler_empty, create_model_runner_output,
9-
create_request, create_scheduler, create_vllm_config)
6+
from .utils import (assert_scheduler_empty, create_empty_model_runner_output,
7+
create_model_runner_output, create_request,
8+
create_scheduler, create_vllm_config)
109

1110

1211
def test_basic_lifecycle():
@@ -85,8 +84,8 @@ def test_basic_lifecycle():
8584
assert len(scheduler.finished_req_ids) == 0
8685

8786
# (3b): execute_model()
88-
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
89-
model_runner_output.finished_sending = [request_id]
87+
model_runner_output = create_empty_model_runner_output(
88+
finished_sending=[request_id])
9089

9190
# (3c): update_from_output()
9291
scheduler.update_from_output(scheduler_output, model_runner_output)
@@ -175,8 +174,8 @@ def test_prefix_cache_lifecycle():
175174
# STEP (2): Ensure it is freed.
176175
scheduler_output = scheduler.schedule()
177176
scheduler.schedule()
178-
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
179-
model_runner_output.finished_sending = [request_remote.request_id]
177+
model_runner_output = create_empty_model_runner_output(
178+
finished_sending=[request_remote.request_id])
180179
scheduler.update_from_output(scheduler_output, model_runner_output)
181180
_ = scheduler.schedule()
182181
assert_scheduler_empty(scheduler)

tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
import copy
4-
53
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
64
from vllm.v1.request import FinishReason, RequestStatus
75

8-
from .utils import (assert_scheduler_empty, create_model_runner_output,
9-
create_request, create_scheduler, create_vllm_config)
6+
from .utils import (assert_scheduler_empty, create_empty_model_runner_output,
7+
create_model_runner_output, create_request,
8+
create_scheduler, create_vllm_config)
109

1110

1211
def test_basic_lifecycle():
@@ -71,8 +70,8 @@ def test_basic_lifecycle():
7170
assert len(scheduler.running) == 0
7271

7372
# (2b): forward(): request finishes recv.
74-
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
75-
model_runner_output.finished_recving = [request_id]
73+
model_runner_output = create_empty_model_runner_output(
74+
finished_recving=[request_id])
7675

7776
# (2c): update_from_output():
7877
engine_core_outputs = scheduler.update_from_output(scheduler_output,
@@ -308,8 +307,8 @@ def test_full_block_prompt():
308307

309308
# # STEP (2): Recv.
310309
scheduler_output = scheduler.schedule()
311-
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
312-
model_runner_output.finished_recving = [request_id]
310+
model_runner_output = create_empty_model_runner_output(
311+
finished_recving=[request_id])
313312
scheduler.update_from_output(scheduler_output, model_runner_output)
314313
assert len(scheduler.waiting) == 1
315314
assert (request_id in scheduler.finished_recving_kv_req_ids)

tests/v1/kv_connector/unit/utils.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import copy
34
from typing import Any, Optional
45

56
import torch
67

78
from vllm import SamplingParams
89
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
910
ModelConfig, SchedulerConfig, VllmConfig)
11+
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
12+
NixlWorkerConnectorMetadata)
1013
from vllm.v1.core.sched.scheduler import Scheduler
1114
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1215
KVCacheGroupSpec)
13-
from vllm.v1.outputs import ModelRunnerOutput
16+
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
1417
from vllm.v1.request import Request
1518
from vllm.v1.structured_output import StructuredOutputManager
1619

@@ -175,6 +178,14 @@ def create_model_runner_output(
175178
sampled_token = EOS_TOKEN_ID if use_eos else 0
176179
sampled_token_ids = [[sampled_token] for _ in req_ids]
177180

181+
# Make worker connector metadata
182+
kv_connector_metadata = None
183+
if finished_sending or finished_recving:
184+
kv_connector_metadata = [
185+
NixlWorkerConnectorMetadata(finished_sending or [],
186+
finished_recving or [])
187+
]
188+
178189
# Make output data structure.
179190
return ModelRunnerOutput(
180191
req_ids=req_ids,
@@ -183,6 +194,21 @@ def create_model_runner_output(
183194
spec_token_ids=None,
184195
logprobs=None,
185196
prompt_logprobs_dict={},
186-
finished_sending=finished_sending,
187-
finished_recving=finished_recving,
197+
kv_connector_metadata=kv_connector_metadata,
188198
)
199+
200+
201+
def create_empty_model_runner_output(
202+
finished_sending: Optional[list[str]] = None,
203+
finished_recving: Optional[list[str]] = None,
204+
) -> ModelRunnerOutput:
205+
"""Make dummy empty model runner output for testing."""
206+
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
207+
208+
kv_connector_metadata = [
209+
NixlWorkerConnectorMetadata(finished_sending or [], finished_recving
210+
or [])
211+
]
212+
model_runner_output.kv_connector_metadata = kv_connector_metadata
213+
214+
return model_runner_output

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
Returns whether KV cache should be freed now or will be
1818
freed asynchronously and optionally returns KV transfer
1919
params.
20+
get_finished() - returns ids of requests that have completed
21+
async sending/recving.
2022
2123
Worker-side: runs in each worker, loads/saves KV cache to/from
2224
the Connector based on the metadata.
@@ -26,8 +28,8 @@
2628
save_kv_layer() - starts saving KV for layer i (maybe async)
2729
wait_for_save() - blocks until all saves are done
2830
29-
get_finished() - called with ids of finished requests, returns
30-
ids of requests that have completed async sending/recving.
31+
build_worker_connector_meta() - builds metadata to be sent
32+
back to the scheduler.
3133
"""
3234

3335
import enum
@@ -38,6 +40,7 @@
3840

3941
from vllm.logger import init_logger
4042
from vllm.v1.core.sched.output import SchedulerOutput
43+
from vllm.v1.outputs import ModelRunnerOutput
4144

4245
if TYPE_CHECKING:
4346
from vllm.attention.backends.abstract import AttentionMetadata
@@ -185,21 +188,21 @@ def wait_for_save(self):
185188
"""
186189
pass
187190

188-
def get_finished(
189-
self, finished_req_ids: set[str]
190-
) -> tuple[Optional[set[str]], Optional[set[str]]]:
191+
def build_worker_connector_meta(
192+
self, scheduler_output: SchedulerOutput,
193+
model_runner_output: ModelRunnerOutput
194+
) -> Optional[KVConnectorMetadata]:
191195
"""
192-
Notifies worker-side connector ids of requests that have
193-
finished generating tokens.
196+
Build the worker->scheduler connector metadata for this step.
194197
195-
Returns:
196-
ids of requests that have finished asynchronous transfer
197-
(requests that previously returned True from request_finished()),
198-
tuple of (sending/saving ids, recving/loading ids).
199-
The finished saves/sends req ids must belong to a set provided in a
200-
call to this method (this call or a prior one).
198+
This function should NOT modify fields of its arguments.
199+
200+
Args:
201+
scheduler_output (SchedulerOutput): the scheduler output object.
202+
model_runner_output (ModelRunnerOutput):
203+
the model runner (worker) output object.
201204
"""
202-
return None, None
205+
return None
203206

204207
# ==============================
205208
# Scheduler-side methods
@@ -281,3 +284,21 @@ def request_finished(
281284
returned by the engine.
282285
"""
283286
return False, None
287+
288+
def get_finished(
289+
self,
290+
model_runner_output: ModelRunnerOutput,
291+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
292+
"""
293+
Get request IDs that recently finished async transfer.
294+
295+
Args:
296+
model_runner_output (ModelRunnerOutput):
297+
the model runner (worker) output object.
298+
299+
Returns:
300+
ids of requests that have finished asynchronous transfer
301+
(requests that previously returned True from request_finished()),
302+
tuple of (sending/saving ids, recving/loading ids).
303+
"""
304+
return None, None

vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.logger import init_logger
1515
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
1616
from vllm.v1.core.sched.output import SchedulerOutput
17+
from vllm.v1.outputs import ModelRunnerOutput
1718

1819
if TYPE_CHECKING:
1920
from vllm.attention.backends.abstract import AttentionMetadata
@@ -26,7 +27,11 @@
2627
@dataclass
2728
class MultiKVConnectorMetadata(KVConnectorMetadata):
2829
metadata: tuple[KVConnectorMetadata, ...]
29-
extra_async_saves: Optional[dict[str, int]] = None
30+
31+
32+
@dataclass
33+
class MultiKVWorkerConnectorMetadata(KVConnectorMetadata):
34+
metadata: tuple[Optional[KVConnectorMetadata], ...]
3035

3136

3237
class MultiConnector(KVConnectorBase_V1):
@@ -58,7 +63,6 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
5863
# Keeps track of *additional* remaining async saves (beyond 1) to be
5964
# finished per request. Not needed for async loads since we only allow
6065
# a single connector to load.
61-
# Propagated from scheduler to worker side via the connector metadata.
6266
self._extra_async_saves: dict[str, int] = {}
6367

6468
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -71,9 +75,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
7175
def bind_connector_metadata(
7276
self, connector_metadata: KVConnectorMetadata) -> None:
7377
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
74-
if connector_metadata.extra_async_saves:
75-
self._extra_async_saves.update(
76-
connector_metadata.extra_async_saves)
7778
for c, cm in zip(self._connectors, connector_metadata.metadata):
7879
c.bind_connector_metadata(cm)
7980

@@ -102,32 +103,14 @@ def wait_for_save(self):
102103
for c in self._connectors:
103104
c.wait_for_save()
104105

105-
def get_finished(
106-
self, finished_req_ids: set[str]
107-
) -> tuple[Optional[set[str]], Optional[set[str]]]:
108-
finished_sending: set[str] = set()
109-
finished_recving: set[str] = set()
110-
for c in self._connectors:
111-
sending, recving = c.get_finished(finished_req_ids)
112-
if not recving and not sending:
113-
continue
114-
# Aggregate finished recving request ids.
115-
finished_recving.update(recving or ())
116-
# Aggregate finished sending request ids - only include
117-
# once we've drained the "extra" count (for cases where
118-
# more than one connector is async-saving the same request).
119-
for req_id in sending or ():
120-
extra_pending = self._extra_async_saves.get(req_id)
121-
if extra_pending is None:
122-
finished_sending.add(req_id)
123-
continue
124-
assert extra_pending > 0
125-
if extra_pending == 1:
126-
del self._extra_async_saves[req_id]
127-
else:
128-
self._extra_async_saves[req_id] = extra_pending - 1
129-
130-
return finished_sending or None, finished_recving or None
106+
def build_worker_connector_meta(
107+
self, scheduler_output: SchedulerOutput,
108+
model_runner_output: ModelRunnerOutput
109+
) -> Optional[MultiKVWorkerConnectorMetadata]:
110+
return MultiKVWorkerConnectorMetadata(metadata=tuple(
111+
c.build_worker_connector_meta(scheduler_output,
112+
model_runner_output)
113+
for c in self._connectors))
131114

132115
# ==============================
133116
# Scheduler-side methods
@@ -169,9 +152,6 @@ def build_connector_meta(
169152
metadata = MultiKVConnectorMetadata(metadata=tuple(
170153
c.build_connector_meta(scheduler_output)
171154
for c in self._connectors))
172-
if self._extra_async_saves:
173-
metadata.extra_async_saves = self._extra_async_saves
174-
self._extra_async_saves = {}
175155
return metadata
176156

177157
def request_finished(
@@ -199,3 +179,30 @@ def request_finished(
199179
self._requests_to_connector.pop(request.request_id, None)
200180

201181
return async_saves > 0, kv_txfer_params
182+
183+
def get_finished(
184+
self, model_runner_output: ModelRunnerOutput
185+
) -> tuple[Optional[set[str]], Optional[set[str]]]:
186+
finished_sending: set[str] = set()
187+
finished_recving: set[str] = set()
188+
for c in self._connectors:
189+
sending, recving = c.get_finished(model_runner_output)
190+
if not recving and not sending:
191+
continue
192+
# Aggregate finished recving request ids.
193+
finished_recving.update(recving or ())
194+
# Aggregate finished sending request ids - only include
195+
# once we've drained the "extra" count (for cases where
196+
# more than one connector is async-saving the same request).
197+
for req_id in sending or ():
198+
extra_pending = self._extra_async_saves.get(req_id)
199+
if extra_pending is None:
200+
finished_sending.add(req_id)
201+
continue
202+
assert extra_pending > 0
203+
if extra_pending == 1:
204+
del self._extra_async_saves[req_id]
205+
else:
206+
self._extra_async_saves[req_id] = extra_pending - 1
207+
208+
return finished_sending or None, finished_recving or None

0 commit comments

Comments
 (0)