Skip to content

WIP [P/D] Use ThreadPoolExecutor to do handshake for each P-D pair #19823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@
wait_for_save() - blocks until all saves are done

get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
`KVTransferResult` that contains ids of requests that have
completed async sending/recving, as well as requests that are
pending handshake.
"""

import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional

import torch
Expand All @@ -49,6 +52,63 @@
logger = init_logger(__name__)


@dataclass
class KVTransferResult:
"""Result of KV transfer get_finished() operation."""

finished_sending: set[str]
finished_recving: set[str]
pending_handshake: set[str]

def has_any_finished(self) -> bool:
"""Check if any requests finished sending or recving."""
return bool(self.finished_sending or self.finished_recving)

def is_empty(self) -> bool:
"""Check if all sets are empty."""
return not (self.has_any_finished() or bool(self.pending_handshake))

def get_all_finished_req_ids(self) -> set[str]:
"""Get all request IDs that have finished (sending or receiving)."""
return self.finished_sending.union(self.finished_recving)

def merge(self,
other: 'KVTransferResult') -> 'KVTransferResult':
"""Merge with another result, combining each field separately."""
return KVTransferResult(
finished_sending=self.finished_sending.union(
other.finished_sending),
finished_recving=self.finished_recving.union(
other.finished_recving),
pending_handshake=self.pending_handshake.union(
other.pending_handshake))

@classmethod
def empty(cls) -> 'KVTransferResult':
"""Create an empty result."""
return cls(finished_sending=set(),
finished_recving=set(),
pending_handshake=set())

@classmethod
def from_tuple(
cls, result_tuple: tuple[set[str], set[str], set[str]]
) -> 'KVTransferResult':
"""Create from the legacy tuple format for backward compatibility."""
finished_sending, finished_recving, pending_handshake = result_tuple
return cls(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

def to_tuple(self) -> tuple[set[str], set[str], set[str]]:
"""Convert to the legacy tuple format for backward compatibility."""
return (
self.finished_sending,
self.finished_recving,
self.pending_handshake,
)


class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
Expand Down Expand Up @@ -187,19 +247,18 @@ def wait_for_save(self):

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> KVTransferResult:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.

Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
KVTransferResult containing sets of finished sending,
finished receiving, and pending handshake request IDs.
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return None, None
return KVTransferResult.empty()

# ==============================
# Scheduler-side methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferResult)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput

Expand Down Expand Up @@ -89,7 +89,7 @@ def wait_for_save(self):

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> KVTransferResult:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Expand All @@ -101,7 +101,10 @@ def get_finished(
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
return self._lmcache_engine.get_finished(finished_req_ids)
finished_sending, finished_recving = self._lmcache_engine.get_finished(finished_req_ids)
return KVTransferResult(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=set())

# ==============================
# Scheduler-side methods
Expand Down
22 changes: 14 additions & 8 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, KVTransferResult)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -104,19 +104,23 @@ def wait_for_save(self):

def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
) -> KVTransferResult:
finished_sending: set[str] = set()
finished_recving: set[str] = set()
pending_handshake: set[str] = set()

for c in self._connectors:
sending, recving = c.get_finished(finished_req_ids)
if not recving and not sending:
result = c.get_finished(finished_req_ids)
if result.is_empty():
continue
# Aggregate finished recving request ids.
finished_recving.update(recving or ())

finished_recving.update(result.finished_recving)
# Aggregate pending handshake request ids.
pending_handshake.update(result.pending_handshake)
# Aggregate finished sending request ids - only include
# once we've drained the "extra" count (for cases where
# more than one connector is async-saving the same request).
for req_id in sending or ():
for req_id in result.finished_sending:
extra_pending = self._extra_async_saves.get(req_id)
if extra_pending is None:
finished_sending.add(req_id)
Expand All @@ -127,7 +131,9 @@ def get_finished(
else:
self._extra_async_saves[req_id] = extra_pending - 1

return finished_sending or None, finished_recving or None
return KVTransferResult(finished_sending=finished_sending,
finished_recving=finished_recving,
pending_handshake=pending_handshake)

# ==============================
# Scheduler-side methods
Expand Down
Loading
Loading