Skip to content

Commit 4cf6b40

Browse files
committed
rm score task
Signed-off-by: wang.yuqi <[email protected]>
1 parent 1bcd15e commit 4cf6b40

File tree

9 files changed

+117
-37
lines changed

9 files changed

+117
-37
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ Specified using `--task classify`.
462462
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
463463
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | |
464464
If your model is not in the above list, we will try to automatically convert the model using
465-
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
465+
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
466466

467467
#### Sentence Pair Scoring
468468

docs/serving/openai_compatible_server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
401401

402402
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
403403

404-
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
404+
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
405405

406406
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
407407

tests/models/test_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from vllm.model_executor.models import (is_pooling_model,
1010
is_text_generation_model,
1111
supports_multimodal)
12-
from vllm.model_executor.models.adapters import (as_classification_model,
13-
as_embedding_model,
14-
as_reward_model)
12+
from vllm.model_executor.models.adapters import (as_embedding_model,
13+
as_reward_model,
14+
as_seq_cls_model)
1515
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
1616
_SPECULATIVE_DECODING_MODELS,
1717
_TEXT_GENERATION_MODELS,
@@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
3838
assert is_text_generation_model(model_cls)
3939

4040
# All vLLM models should be convertible to a pooling model
41-
assert is_pooling_model(as_classification_model(model_cls))
41+
assert is_pooling_model(as_seq_cls_model(model_cls))
4242
assert is_pooling_model(as_embedding_model(model_cls))
4343
assert is_pooling_model(as_reward_model(model_cls))
4444

tests/test_config.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_get_field():
8585
("distilbert/distilgpt2", "generate", "generate"),
8686
("intfloat/multilingual-e5-small", "pooling", "embed"),
8787
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
88-
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
88+
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
8989
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
9090
("openai/whisper-small", "transcription", "transcription"),
9191
],
@@ -105,6 +105,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
105105
assert config.task == expected_task
106106

107107

108+
@pytest.mark.parametrize(
109+
("model_id", "expected_runner_type", "expected_task"),
110+
[
111+
("distilbert/distilgpt2", "pooling", "embed"),
112+
("intfloat/multilingual-e5-small", "pooling", "embed"),
113+
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
114+
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
115+
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
116+
("openai/whisper-small", "pooling", "embed"),
117+
],
118+
)
119+
def test_score_task(model_id, expected_runner_type, expected_task):
120+
config = ModelConfig(
121+
model_id,
122+
task="score",
123+
tokenizer=model_id,
124+
tokenizer_mode="auto",
125+
trust_remote_code=False,
126+
seed=0,
127+
dtype="float16",
128+
)
129+
130+
assert config.runner_type == expected_runner_type
131+
assert config.task == expected_task
132+
133+
108134
@pytest.mark.parametrize(("model_id", "bad_task"), [
109135
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
110136
])

vllm/config.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@
8282
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
8383
"score", "reward", "transcription"]
8484

85-
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
86-
"draft", "transcription"]
85+
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
86+
"transcription"]
8787

8888
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
8989

9090
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
9191
"generate": ["generate"],
92-
"pooling": ["embed", "classify", "score", "reward"],
92+
"pooling": ["embed", "classify", "reward"],
9393
"draft": ["draft"],
9494
"transcription": ["transcription"],
9595
}
@@ -768,7 +768,7 @@ def _get_preferred_task(
768768
if get_pooling_config(model_id, self.revision):
769769
return "embed"
770770
if self.registry.is_cross_encoder_model(architectures):
771-
return "score"
771+
return "classify"
772772
if self.registry.is_transcription_model(architectures):
773773
return "transcription"
774774

@@ -832,14 +832,24 @@ def _resolve_task(
832832
"This model supports multiple tasks: %s. "
833833
"Defaulting to '%s'.", supported_tasks, selected_task)
834834
else:
835-
# Aliases
836-
if task_option == "embedding":
837-
msg = ("The 'embedding' task has been renamed to "
838-
"'embed', please use the new name. The old name "
839-
"will be removed in v1.0.")
840-
warnings.warn(msg, DeprecationWarning, stacklevel=2)
841-
842-
task_option = "embed"
835+
if task_option == "score":
836+
if not runner_support["pooling"]:
837+
msg = (f"This model does not support the '{task_option}' "
838+
f"task. Supported tasks: {supported_tasks}")
839+
raise ValueError(msg)
840+
if self.registry.is_cross_encoder_model(architectures):
841+
task_option = "classify"
842+
else:
843+
task_option = "embed"
844+
else:
845+
# Aliases
846+
if task_option == "embedding":
847+
msg = ("The 'embedding' task has been renamed to "
848+
"'embed', please use the new name. The old name "
849+
"will be removed in v1.0.")
850+
warnings.warn(msg, DeprecationWarning, stacklevel=2)
851+
852+
task_option = "embed"
843853

844854
if task_option not in supported_tasks:
845855
msg = (

vllm/model_executor/model_loader/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.model_executor.layers.quantization.base_config import (
2222
QuantizationConfig, QuantizeMethodBase)
2323
from vllm.model_executor.models import ModelRegistry
24-
from vllm.model_executor.models.adapters import (as_classification_model,
24+
from vllm.model_executor.models.adapters import (as_seq_cls_model,
2525
as_embedding_model,
2626
as_reward_model)
2727
from vllm.utils import is_pin_memory_available
@@ -245,7 +245,10 @@ def get_model_architecture(
245245
if model_config.task == "embed":
246246
model_cls = as_embedding_model(model_cls)
247247
elif model_config.task == "classify":
248-
model_cls = as_classification_model(model_cls)
248+
# Cannot automatically run as_seq_cls_model,
249+
# otherwise it will cause a circular reference on is_cross_encoder_model
250+
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
251+
assert isinstance(model_cls, SupportsCrossEncoding)
249252
elif model_config.task == "reward":
250253
model_cls = as_reward_model(model_cls)
251254

vllm/model_executor/models/adapters.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5-
from typing import TYPE_CHECKING, Any, Optional, TypeVar
5+
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
66

77
import torch
88
import torch.nn as nn
@@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
145145
return ModelForEmbedding # type: ignore
146146

147147

148-
def as_classification_model(cls: _T) -> _T:
148+
def as_seq_cls_model(cls: _T) -> _T:
149149
"""
150-
Subclass an existing vLLM model to support classification.
150+
Subclass an existing vLLM model to support classify and score tasks.
151151
152152
By default, the class probabilities are extracted from the softmaxed
153153
hidden state corresponding to the last token.
@@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T:
164164
# Lazy import
165165
from vllm.config import VllmConfig
166166
from vllm.model_executor.layers.linear import RowParallelLinear
167-
from vllm.model_executor.layers.pooler import PoolingType
167+
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
168+
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
169+
from vllm.model_executor.pooling_metadata import PoolingMetadata
168170
from vllm.sequence import IntermediateTensors
169171

170172
from .utils import maybe_prefix
@@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
176178
default_softmax=True,
177179
)
178180

179-
class ModelForClassification(ModelForPooling):
181+
class ModelForSequenceClassification(ModelForPooling,
182+
SupportsCrossEncoding):
180183

181184
def __init__(
182185
self,
@@ -186,10 +189,15 @@ def __init__(
186189
**kwargs: Any,
187190
) -> None:
188191
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
192+
self.verify_and_update_config(vllm_config)
189193

190194
config = vllm_config.model_config.hf_config
191195
quant_config = vllm_config.quant_config
192196

197+
self.task = vllm_config.model_config.task
198+
self.pooling_type = (
199+
vllm_config.model_config.pooler_config.pooling_type)
200+
193201
self.score = RowParallelLinear(config.hidden_size,
194202
config.num_labels,
195203
quant_config=quant_config,
@@ -198,24 +206,53 @@ def __init__(
198206
prefix=maybe_prefix(
199207
prefix, "score"))
200208

209+
def verify_and_update_config(self, vllm_config):
210+
# Leave an interface for validating and modifying model_config
211+
# for slightly different models
212+
pass
213+
201214
def forward(
202215
self,
203216
input_ids: torch.Tensor,
204217
positions: torch.Tensor,
205218
intermediate_tensors: Optional[IntermediateTensors] = None,
206219
inputs_embeds: Optional[torch.Tensor] = None,
207220
) -> torch.Tensor:
208-
hidden_states = super().forward(input_ids, positions,
209-
intermediate_tensors,
210-
inputs_embeds)
211-
logits, _ = self.score(hidden_states)
212-
return logits
221+
return super().forward(input_ids, positions, intermediate_tensors,
222+
inputs_embeds)
223+
224+
def pooler(
225+
self,
226+
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
227+
pooling_metadata: PoolingMetadata,
228+
) -> PoolerOutput:
229+
230+
def get_logits(hidden_states):
231+
if isinstance(hidden_states, list):
232+
logits = [self.score(state)[0] for state in hidden_states]
233+
else:
234+
logits, _ = self.score(hidden_states)
235+
return logits
236+
237+
if self.pooling_type == PoolingType.ALL:
238+
logits = get_logits(hidden_states)
239+
return self._pooler(logits, pooling_metadata)
240+
else:
241+
hidden_states = self._pooler.extract_states(
242+
hidden_states, pooling_metadata)
243+
logits = get_logits(hidden_states)
244+
pooled_data = self._pooler.head(logits, pooling_metadata)
245+
246+
pooled_outputs = [
247+
self._pooler.build_output(data) for data in pooled_data
248+
]
249+
return PoolerOutput(outputs=pooled_outputs)
213250

214251

215-
ModelForClassification.__name__ = \
216-
_get_pooling_model_name(cls.__name__, "ForClassification")
252+
ModelForSequenceClassification.__name__ = \
253+
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
217254

218-
return ModelForClassification # type: ignore
255+
return ModelForSequenceClassification # type: ignore
219256

220257

221258
def as_reward_model(cls: _T) -> _T:

vllm/model_executor/models/qwen2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from vllm.model_executor.sampling_metadata import SamplingMetadata
5151
from vllm.sequence import IntermediateTensors
5252

53+
from .adapters import as_seq_cls_model
5354
from .interfaces import SupportsLoRA, SupportsPP
5455
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5556
is_pp_missing_parameter,
@@ -495,3 +496,6 @@ def load_weights(self, weights: Iterable[tuple[str,
495496
if self.config.tie_word_embeddings else None),
496497
)
497498
return loader.load_weights(weights)
499+
500+
501+
Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM)

vllm/model_executor/models/registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,6 @@
158158
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
159159
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
160160
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
161-
# [Auto-converted (see adapters.py)]
162-
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
163161
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
164162
# input and output. I am adding it here because it piggy-backs on embedding
165163
# models for the time being.
@@ -174,7 +172,9 @@
174172
"RobertaForSequenceClassification"),
175173
"ModernBertForSequenceClassification": ("modernbert",
176174
"ModernBertForSequenceClassification"),
177-
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
175+
# [Auto-converted (see adapters.py)]
176+
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
177+
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
178178
}
179179

180180
_MULTIMODAL_MODELS = {

0 commit comments

Comments
 (0)