Skip to content

Draft: WIP NixlConnector drop ZMQ in favor of HTTP metadata exchanges #19447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import vllm.envs as envs
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
Expand Down Expand Up @@ -1511,6 +1513,10 @@ class CacheConfig:
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""

transfer_handshake_metadata: Optional[dict[int, dict[
int, KVConnectorHandshakeMetadata]]] = field(default=None, init=False)
"""Metadata for the KV connector handshake. Structure: dp_rank -> tp_rank -> metadata"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -4504,6 +4510,11 @@ def __post_init__(self):
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

if (self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance):
from collections import defaultdict
self.cache_config.transfer_handshake_metadata = defaultdict(dict)

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
Expand Down
156 changes: 122 additions & 34 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase_V1 Class for Distributed KV Cache & Hidden State
communication in vLLM v1
Expand All @@ -8,15 +7,9 @@
Scheduler-side: runs in the scheduler, binds metadata, which
is used by the worker-side to load/save KV cache.
get_num_new_matched_tokens() - get number of new tokens
that exist in the remote KV cache. Might be called multiple
times for a given request and should be side-effect free.
that exist in the remote KV cache
update_state_after_alloc() - update KVConnector state after
temporary buffer alloc by the CacheManager.
request_finished() - called when a request is finished, with
the computed kv cache blocks for the request.
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.

Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
Expand All @@ -25,16 +18,16 @@

save_kv_layer() - starts saving KV for layer i (maybe async)
wait_for_save() - blocks until all saves are done

get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
"""

import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional

import msgspec
import torch
from pydantic_core import core_schema

from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
Expand All @@ -49,6 +42,64 @@
logger = init_logger(__name__)


@dataclass
class KVTransferFinishedResult:
"""Result of KV transfer get_finished operation."""

finished_sending: set[str]
finished_recving: set[str]
pending_handshake: set[str]

def has_any_finished(self) -> bool:
"""Check if any requests finished or are pending."""
return bool(self.finished_sending or self.finished_recving
or self.pending_handshake)

def is_empty(self) -> bool:
"""Check if all sets are empty."""
return not self.has_any_finished()

def get_all_finished_req_ids(self) -> set[str]:
"""Get all request IDs that have finished (sending or receiving)."""
return self.finished_sending.union(self.finished_recving)

def merge(self,
other: 'KVTransferFinishedResult') -> 'KVTransferFinishedResult':
"""Merge with another result, combining all sets."""
return KVTransferFinishedResult(
finished_sending=self.finished_sending.union(
other.finished_sending),
finished_recving=self.finished_recving.union(
other.finished_recving),
pending_handshake=self.pending_handshake.union(
other.pending_handshake))

@classmethod
def empty(cls) -> 'KVTransferFinishedResult':
"""Create an empty result."""
return cls(finished_sending=set(),
finished_recving=set(),
pending_handshake=set())

@classmethod
def from_tuple(
cls, result_tuple: tuple[set[str], set[str], set[str]]
) -> 'KVTransferFinishedResult':
"""Create from the old tuple format for backward compatibility."""
finished_sending, finished_recving, pending_handshake = result_tuple
return cls(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

def to_tuple(self) -> tuple[set[str], set[str], set[str]]:
"""Convert to the old tuple format for backward compatibility."""
return (
self.finished_sending,
self.finished_recving,
self.pending_handshake,
)


class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
Expand All @@ -65,6 +116,39 @@ class KVConnectorMetadata:
pass


class KVConnectorHandshakeMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
"""
Metadata optionally used for out of band connector handshake between
P/D workers.
"""
connector_type: str = "base"

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: Callable[[Any],
core_schema.CoreSchema]
) -> core_schema.CoreSchema:
"""bridge msgspec.Struct with pydantic for schema generation"""
return core_schema.no_info_after_validator_function(
cls, core_schema.dict_schema())


class KVConnectorTransferMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
dict=True):
"""
Wrapper for transfer handshake metadata sent between engine and utils.
"""
tensor_parallel_rank: int
data_parallel_rank: int
content: Optional[dict]


class KVConnectorBase_V1(ABC):

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
Expand All @@ -74,6 +158,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
self._connector_metadata = KVConnectorMetadata()
self._vllm_config = vllm_config
self._role = role
self._handshake_metadata: Optional[KVConnectorHandshakeMetadata] = None

@property
def role(self) -> KVConnectorRole:
Expand Down Expand Up @@ -104,7 +189,7 @@ def clear_connector_metadata(self) -> None:
"""
self._connector_metadata = KVConnectorMetadata()

def _get_connector_metadata(self) -> KVConnectorMetadata:
def get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata.

This function should only be called inside the connector.
Expand Down Expand Up @@ -185,21 +270,37 @@ def wait_for_save(self):
"""
pass

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self,
finished_req_ids: set[str]) -> KVTransferFinishedResult:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.

Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
KVTransferFinishedResult containing sets of finished sending,
finished receiving, and pending handshake request IDs.
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
return KVTransferFinishedResult.empty()

def get_pending_handshake_req_ids(self) -> Optional[set[str]]:
"""
Get request IDs that are currently pending handshake completion.

Returns:
Set of request IDs waiting for handshake, or None if not applicable.
"""
return None

def get_handshake_metadata(self) -> Optional[KVConnectorHandshakeMetadata]:
"""
Get the handshake metadata for the connector.

Returns:
KVConnectorHandshakeMetadata: the handshake metadata.
"""
return self._handshake_metadata

# ==============================
# Scheduler-side methods
Expand All @@ -225,8 +326,7 @@ def get_num_new_matched_tokens(
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
asynchronously (between scheduler steps).
"""
pass

Expand All @@ -236,18 +336,6 @@ def update_state_after_alloc(self, request: "Request",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.

If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.

Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
pass

Expand Down
36 changes: 27 additions & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole,
KVTransferFinishedResult)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -102,21 +103,27 @@ def wait_for_save(self):
for c in self._connectors:
c.wait_for_save()

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
def get_finished(self,
finished_req_ids: set[str]) -> KVTransferFinishedResult:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
pending_handshake: set[str] = set()

for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
if not recving and not sending:
result = c.get_finished(finished_req_ids)
if result.is_empty():
continue

# Aggregate finished recving request ids.
finished_recving.update(recving or ())
finished_recving.update(result.finished_recving)

# Aggregate pending handshake request ids.
pending_handshake.update(result.pending_handshake)

# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
for req_id in result.finished_sending:
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
Expand All @@ -127,7 +134,18 @@ def get_finished(
else:
self._extra_async_saves[req_id] = extra_pending - 1

return finished_sending or None, finished_recving or None
return KVTransferFinishedResult(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

def get_pending_handshake_req_ids(self) -> Optional[set[str]]:
"""Get request IDs that are currently pending handshake completion."""
pending_handshake: set[str] = set()
for c in self._connectors:
connector_pending = c.get_pending_handshake_req_ids()
if connector_pending:
pending_handshake.update(connector_pending)
return pending_handshake or None

# ==============================
# Scheduler-side methods
Expand Down
Loading