-
-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Add support for encoder embedding models #19988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -716,6 +716,11 @@ def _init_pooler_config(self) -> Optional["PoolerConfig"]: | |
self.override_pooler_config = PoolerConfig( | ||
**self.override_pooler_config) | ||
|
||
# WIP: currently cuda graphs are not working for encoder models. | ||
logger.warning("CUDA graph is not supported for pooling yet, " | ||
"fallback to the eager mode.") | ||
self.enforce_eager = True | ||
Comment on lines
+720
to
+722
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This warning message and the subsequent setting of if self.runner_type == "pooling":
# WIP: currently cuda graphs are not working for encoder models.
logger.warning("CUDA graph is not supported for pooling yet, "
"fallback to the eager mode.")
self.enforce_eager = True |
||
|
||
pooler_config = self.override_pooler_config or PoolerConfig() | ||
|
||
base_config = get_pooling_config(self.model, self.revision) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -386,11 +386,13 @@ def __init__( | |
f"Supported head sizes are: {support_head_sizes}. " | ||
"Set VLLM_USE_V1=0 to use another attention backend.") | ||
|
||
if attn_type != AttentionType.DECODER: | ||
raise NotImplementedError("Encoder self-attention and " | ||
"encoder/decoder cross-attention " | ||
"are not implemented for " | ||
if attn_type not in [ | ||
AttentionType.DECODER, AttentionType.ENCODER_ONLY | ||
]: | ||
raise NotImplementedError("Encoder/decoder cross-attention " | ||
"is not implemented for " | ||
"FlashAttentionImpl") | ||
Comment on lines
+389
to
394
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("FlashAttentionImpl only supports DECODER and ENCODER_ONLY attention types.") |
||
self.attn_type = attn_type | ||
self.use_irope = use_irope | ||
self.vllm_flash_attn_version = get_flash_attn_version() | ||
if is_quantized_kv_cache(self.kv_cache_dtype) \ | ||
|
@@ -509,7 +511,7 @@ def forward( | |
seqused_k=seqused_k, | ||
max_seqlen_k=max_seqlen_k, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
causal=_get_causal_option(self.attn_type), | ||
alibi_slopes=self.alibi_slopes, | ||
window_size=self.sliding_window, | ||
block_table=block_table, | ||
|
@@ -711,3 +713,21 @@ def cascade_attention( | |
# Merge prefix and suffix outputs, and store the result in output. | ||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, | ||
suffix_lse) | ||
|
||
|
||
def _get_causal_option(attn_type: str) -> bool: | ||
""" | ||
Determine whether the given attention type is suitable for causal | ||
attention mechanisms. | ||
|
||
Args: | ||
attn_type (AttentionType): The type of attention being evaluated | ||
|
||
Returns: | ||
bool: Returns `True` if the attention type is suitable for causal | ||
attention (i.e., not encoder, encoder-only, or encoder-decoder), | ||
otherwise returns `False`. | ||
""" | ||
return not (attn_type == AttentionType.ENCODER | ||
or attn_type == AttentionType.ENCODER_ONLY | ||
or attn_type == AttentionType.ENCODER_DECODER) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -35,7 +35,7 @@ | |||||||||||||||||||||||||||||||||||||||||||
EngineCoreRequestType, UtilityOutput) | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.executor.abstract import Executor | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.kv_cache_interface import KVCacheConfig | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.metrics.stats import SchedulerStats | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.outputs import ModelRunnerOutput | ||||||||||||||||||||||||||||||||||||||||||||
from vllm.v1.request import Request, RequestStatus | ||||||||||||||||||||||||||||||||||||||||||||
|
@@ -150,6 +150,23 @@ def _initialize_kv_caches( | |||||||||||||||||||||||||||||||||||||||||||
zip(kv_cache_specs, available_gpu_memory) | ||||||||||||||||||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
for kv_cache_spec_one_worker in kv_cache_specs: | ||||||||||||||||||||||||||||||||||||||||||||
for _, spec in kv_cache_spec_one_worker.items(): | ||||||||||||||||||||||||||||||||||||||||||||
if isinstance(spec, AttentionSpec) and \ | ||||||||||||||||||||||||||||||||||||||||||||
spec.attn_type != "decoder": | ||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
logger.info("Found non-decoder layer. Disabling " | ||||||||||||||||||||||||||||||||||||||||||||
"prefix cache and chunked prefill") | ||||||||||||||||||||||||||||||||||||||||||||
self.vllm_config.cache_config.\ | ||||||||||||||||||||||||||||||||||||||||||||
enable_prefix_caching = False | ||||||||||||||||||||||||||||||||||||||||||||
self.vllm_config.scheduler_config.\ | ||||||||||||||||||||||||||||||||||||||||||||
enable_chunked_prefill = False | ||||||||||||||||||||||||||||||||||||||||||||
self.vllm_config.scheduler_config.\ | ||||||||||||||||||||||||||||||||||||||||||||
chunked_prefill_enabled = False | ||||||||||||||||||||||||||||||||||||||||||||
self.vllm_config.scheduler_config.\ | ||||||||||||||||||||||||||||||||||||||||||||
long_prefill_token_threshold = 0 | ||||||||||||||||||||||||||||||||||||||||||||
break | ||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+155
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding a check to ensure that prefix caching and chunked prefill are not enabled when a non-decoder layer is found. This would provide a more explicit error message to the user, rather than silently disabling the features.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||
# Since we use a shared centralized controller, we need the | ||||||||||||||||||||||||||||||||||||||||||||
# `kv_cache_config` to be consistent across all workers to make sure | ||||||||||||||||||||||||||||||||||||||||||||
# all the memory operators can be applied to all workers. | ||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider removing the skip mark for v0 tests, as it seems the models are now supported in both engines.