Skip to content

[V1] Solve potential deadlock issue in v1 engine core client internally #19927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions tests/v1/engine/test_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger

Expand Down Expand Up @@ -108,8 +107,7 @@ async def test_load(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -156,8 +154,7 @@ async def test_abort(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -229,8 +226,7 @@ async def test_finished_flag(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -264,8 +260,7 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)

NUM_REQUESTS = 100
Expand Down Expand Up @@ -327,11 +322,10 @@ async def test_customize_loggers(monkeypatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)

await engine.do_log_stats()
Expand All @@ -346,8 +340,7 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)

sampling_params = SamplingParams(max_tokens=100,
Expand Down Expand Up @@ -383,8 +376,7 @@ async def test_check_health(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")

with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)

# Test 1: Healthy engine should not raise any exception
Expand Down
29 changes: 12 additions & 17 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor
Expand Down Expand Up @@ -58,10 +57,9 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""

# First request.
Expand Down Expand Up @@ -193,10 +191,9 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
Expand Down Expand Up @@ -290,10 +287,9 @@ def shutdown(self):
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None

# Add two requests in a row. Each request have 12 prompt tokens.
Expand Down Expand Up @@ -399,10 +395,9 @@ def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)

def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)
Expand Down
61 changes: 28 additions & 33 deletions tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
Expand Down Expand Up @@ -141,14 +140,13 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)

MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
Expand Down Expand Up @@ -228,14 +226,13 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

try:
MAX_TOKENS = 20
Expand Down Expand Up @@ -318,14 +315,13 @@ def test_kv_cache_events(
UsageContext.UNKNOWN_CONTEXT)

executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
Expand Down Expand Up @@ -401,14 +397,13 @@ async def test_kv_cache_events_dp(
UsageContext.UNKNOWN_CONTEXT)

executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)

# Build endpoints for all DP ranks
Expand Down
11 changes: 7 additions & 4 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,13 @@
@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
"""Sets the default number of threads for PyTorch to the given value."""
Comment on lines 195 to 196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The set_default_torch_num_threads context manager now has a special behavior when num_threads is -1. This should be documented in the function's docstring to clearly communicate its contract to users.

def set_default_torch_num_threads(num_threads: int):
    """Sets the default number of threads for PyTorch to the given value.
    If `num_threads` is -1, no change is made to PyTorch's thread settings.
    """

old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)
if num_threads == -1:
yield
else:
old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)


P = ParamSpec('P')
Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import msgspec
import zmq

import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
Expand All @@ -25,7 +26,8 @@
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
from vllm.utils import (make_zmq_socket, resolve_obj_by_qualname,
set_default_torch_num_threads)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
Expand Down Expand Up @@ -72,7 +74,9 @@ def __init__(self,
self.log_stats = log_stats

# Setup Model.
self.model_executor = executor_class(vllm_config)
disable_omp = envs.VLLM_WORKER_MULTIPROC_METHOD == "fork"
with set_default_torch_num_threads(1 if disable_omp else -1):
self.model_executor = executor_class(vllm_config)
Comment on lines +77 to +79
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a workaround for tests/v1/engine/test_engine_core.py because this test bypasses with set_default_torch_num_threads in core client.

Note that if we only add with set_default_torch_num_threads here without the ones in core_client.py, deadlock can still occur when initialize AsyncLLM.

if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
Expand Down
17 changes: 12 additions & 5 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import zmq
import zmq.asyncio

import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
zmq_socket_ctx)
set_default_torch_num_threads, zmq_socket_ctx)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
Expand Down Expand Up @@ -419,10 +420,16 @@ def __init__(
self.ctx, output_address, zmq.PULL)

if client_addresses is None:
self._init_engines_direct(vllm_config, local_only,
local_start_index, input_address,
output_address, executor_class,
log_stats)
# If we use fork for multiproc, we need to disable OpenMP for
# multimodal model during engine initialization, othwewise
# multimodal processor will call blocking ops like x.to(dtype)
# which will cause deadlock with OpenMP.
disable_omp = envs.VLLM_WORKER_MULTIPROC_METHOD == "fork"
with set_default_torch_num_threads(1 if disable_omp else -1):
self._init_engines_direct(vllm_config, local_only,
local_start_index, input_address,
output_address, executor_class,
log_stats)
Comment on lines +423 to +432
Copy link
Collaborator Author

@Isotr0py Isotr0py Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After investigation, I found the deadlock not only occured in compute_encoder_budget but also during the whole engine initialization process including model profiling etc.

Since it's tricky to add with set_default_torch_num_threads in executor and engine core everywhere that may cause deadlock, I decided to add it here at high level for better maintenance.

coordinator = self.resources.coordinator
if coordinator:
self.stats_update_address = (
Expand Down