-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
[V1] Solve potential deadlock issue in v1 engine core client internally #19927
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
import msgspec | ||
import zmq | ||
|
||
import vllm.envs as envs | ||
from vllm.config import ParallelConfig, VllmConfig | ||
from vllm.distributed import stateless_destroy_torch_distributed_process_group | ||
from vllm.executor.multiproc_worker_utils import _add_prefix | ||
|
@@ -25,7 +26,8 @@ | |
from vllm.lora.request import LoRARequest | ||
from vllm.transformers_utils.config import ( | ||
maybe_register_config_serialize_by_value) | ||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname | ||
from vllm.utils import (make_zmq_socket, resolve_obj_by_qualname, | ||
set_default_torch_num_threads) | ||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, | ||
unify_kv_cache_configs) | ||
from vllm.v1.core.sched.interface import SchedulerInterface | ||
|
@@ -72,7 +74,9 @@ def __init__(self, | |
self.log_stats = log_stats | ||
|
||
# Setup Model. | ||
self.model_executor = executor_class(vllm_config) | ||
disable_omp = envs.VLLM_WORKER_MULTIPROC_METHOD == "fork" | ||
with set_default_torch_num_threads(1 if disable_omp else -1): | ||
self.model_executor = executor_class(vllm_config) | ||
Comment on lines
+77
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a workaround for Note that if we only add |
||
if executor_fail_callback is not None: | ||
self.model_executor.register_failure_callback( | ||
executor_fail_callback) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,12 @@ | |
import zmq | ||
import zmq.asyncio | ||
|
||
import vllm.envs as envs | ||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket, | ||
zmq_socket_ctx) | ||
set_default_torch_num_threads, zmq_socket_ctx) | ||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, | ||
EngineCoreRequestType, UtilityOutput) | ||
from vllm.v1.engine.coordinator import DPCoordinator | ||
|
@@ -419,10 +420,16 @@ def __init__( | |
self.ctx, output_address, zmq.PULL) | ||
|
||
if client_addresses is None: | ||
self._init_engines_direct(vllm_config, local_only, | ||
local_start_index, input_address, | ||
output_address, executor_class, | ||
log_stats) | ||
# If we use fork for multiproc, we need to disable OpenMP for | ||
# multimodal model during engine initialization, othwewise | ||
# multimodal processor will call blocking ops like x.to(dtype) | ||
# which will cause deadlock with OpenMP. | ||
disable_omp = envs.VLLM_WORKER_MULTIPROC_METHOD == "fork" | ||
with set_default_torch_num_threads(1 if disable_omp else -1): | ||
self._init_engines_direct(vllm_config, local_only, | ||
local_start_index, input_address, | ||
output_address, executor_class, | ||
log_stats) | ||
Comment on lines
+423
to
+432
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After investigation, I found the deadlock not only occured in Since it's tricky to add |
||
coordinator = self.resources.coordinator | ||
if coordinator: | ||
self.stats_update_address = ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
set_default_torch_num_threads
context manager now has a special behavior whennum_threads
is -1. This should be documented in the function's docstring to clearly communicate its contract to users.