Skip to content

Commit 6e7797f

Browse files
committed
+ as_seq_cls_model
Signed-off-by: wang.yuqi <[email protected]>
1 parent c6703d1 commit 6e7797f

File tree

6 files changed

+78
-61
lines changed

6 files changed

+78
-61
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ Specified using `--task classify`.
446446
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
447447

448448
If your model is not in the above list, we will try to automatically convert the model using
449-
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
449+
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
450450

451451
#### Sentence Pair Scoring
452452

docs/serving/openai_compatible_server.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>
379379

380380
Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).
381381

382-
We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
382+
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
383383

384384
Code example: <gh-file:examples/online_serving/openai_classification_client.py>
385385

tests/models/test_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from vllm.model_executor.models import (is_pooling_model,
1010
is_text_generation_model,
1111
supports_multimodal)
12-
from vllm.model_executor.models.adapters import (as_classification_model,
13-
as_embedding_model,
14-
as_reward_model)
12+
from vllm.model_executor.models.adapters import (as_embedding_model,
13+
as_reward_model,
14+
as_seq_cls_model)
1515
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
1616
_SPECULATIVE_DECODING_MODELS,
1717
_TEXT_GENERATION_MODELS,
@@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
3838
assert is_text_generation_model(model_cls)
3939

4040
# All vLLM models should be convertible to a pooling model
41-
assert is_pooling_model(as_classification_model(model_cls))
41+
assert is_pooling_model(as_seq_cls_model(model_cls))
4242
assert is_pooling_model(as_embedding_model(model_cls))
4343
assert is_pooling_model(as_reward_model(model_cls))
4444

vllm/model_executor/model_loader/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from vllm.model_executor.layers.quantization.base_config import (
2222
QuantizationConfig, QuantizeMethodBase)
2323
from vllm.model_executor.models import ModelRegistry
24-
from vllm.model_executor.models.adapters import (as_classification_model,
25-
as_embedding_model,
26-
as_reward_model)
24+
from vllm.model_executor.models.adapters import (as_embedding_model,
25+
as_reward_model,
26+
as_seq_cls_model)
2727
from vllm.utils import is_pin_memory_available
2828

2929
logger = init_logger(__name__)
@@ -244,8 +244,8 @@ def get_model_architecture(
244244
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
245245
if model_config.task == "embed":
246246
model_cls = as_embedding_model(model_cls)
247-
elif model_config.task == "classify":
248-
model_cls = as_classification_model(model_cls)
247+
elif model_config.task in ["classify", "score"]:
248+
model_cls = as_seq_cls_model(model_cls)
249249
elif model_config.task == "reward":
250250
model_cls = as_reward_model(model_cls)
251251

vllm/model_executor/models/adapters.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
145145
return ModelForEmbedding # type: ignore
146146

147147

148-
def as_classification_model(cls: _T) -> _T:
148+
def as_seq_cls_model(cls: _T) -> _T:
149149
"""
150-
Subclass an existing vLLM model to support classification.
150+
Subclass an existing vLLM model to support classify and score tasks.
151151
152152
By default, the class probabilities are extracted from the softmaxed
153153
hidden state corresponding to the last token.
@@ -164,19 +164,23 @@ def as_classification_model(cls: _T) -> _T:
164164
# Lazy import
165165
from vllm.config import VllmConfig
166166
from vllm.model_executor.layers.linear import RowParallelLinear
167-
from vllm.model_executor.layers.pooler import PoolingType
167+
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
168+
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
169+
from vllm.model_executor.pooling_metadata import PoolingMetadata
168170
from vllm.sequence import IntermediateTensors
169171

170172
from .utils import maybe_prefix
171173

172174
ModelForPooling = _create_pooling_model_cls(
173175
cls,
174-
default_pooling_type=PoolingType.LAST,
176+
default_pooling_type=getattr(cls, "default_pooling_type",
177+
PoolingType.LAST),
175178
default_normalize=False,
176179
default_softmax=True,
177180
)
178181

179-
class ModelForClassification(ModelForPooling):
182+
class ModelForSequenceClassification(ModelForPooling,
183+
SupportsCrossEncoding):
180184

181185
def __init__(
182186
self,
@@ -186,10 +190,18 @@ def __init__(
186190
**kwargs: Any,
187191
) -> None:
188192
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
193+
self.config_verify(vllm_config)
189194

190195
config = vllm_config.model_config.hf_config
191196
quant_config = vllm_config.quant_config
192197

198+
self.task = vllm_config.model_config.task
199+
self.pooling_type = (
200+
vllm_config.model_config.pooler_config.pooling_type)
201+
202+
if self.task == "score":
203+
assert config.num_labels == 1
204+
193205
self.score = RowParallelLinear(config.hidden_size,
194206
config.num_labels,
195207
quant_config=quant_config,
@@ -198,24 +210,48 @@ def __init__(
198210
prefix=maybe_prefix(
199211
prefix, "score"))
200212

213+
def config_verify(self, vllm_config):
214+
# Leave an interface for validating and modifying model_config
215+
# for slightly different models
216+
pass
217+
201218
def forward(
202219
self,
203220
input_ids: torch.Tensor,
204221
positions: torch.Tensor,
205222
intermediate_tensors: Optional[IntermediateTensors] = None,
206223
inputs_embeds: Optional[torch.Tensor] = None,
207224
) -> torch.Tensor:
208-
hidden_states = super().forward(input_ids, positions,
209-
intermediate_tensors,
210-
inputs_embeds)
211-
logits, _ = self.score(hidden_states)
212-
return logits
225+
return super().forward(input_ids, positions, intermediate_tensors,
226+
inputs_embeds)
227+
228+
def pooler(
229+
self,
230+
hidden_states: torch.Tensor,
231+
pooling_metadata: PoolingMetadata,
232+
) -> PoolerOutput:
233+
if self.pooling_type == PoolingType.ALL:
234+
logits, _ = self.score(hidden_states)
235+
return self._pooler(hidden_states, pooling_metadata)
236+
else:
237+
hidden_states = self._pooler.extract_states(
238+
hidden_states, pooling_metadata)
239+
logits, _ = self.score(hidden_states)
240+
pooled_data = self._pooler.head(logits, pooling_metadata)
241+
242+
if self.task == "score":
243+
pooled_data = [data.squeeze(-1) for data in pooled_data]
244+
245+
pooled_outputs = [
246+
self._pooler.build_output(data) for data in pooled_data
247+
]
248+
return PoolerOutput(outputs=pooled_outputs)
213249

214250

215-
ModelForClassification.__name__ = \
216-
_get_pooling_model_name(cls.__name__, "ForClassification")
251+
ModelForSequenceClassification.__name__ = \
252+
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
217253

218-
return ModelForClassification # type: ignore
254+
return ModelForSequenceClassification # type: ignore
219255

220256

221257
def as_reward_model(cls: _T) -> _T:

vllm/model_executor/models/qwen3.py

Lines changed: 18 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
from vllm.model_executor.layers.linear import (QKVParallelLinear,
3939
RowParallelLinear)
4040
from vllm.model_executor.layers.logits_processor import LogitsProcessor
41-
from vllm.model_executor.layers.pooler import Pooler, PoolingType
4241
from vllm.model_executor.layers.quantization import QuantizationConfig
4342
from vllm.model_executor.layers.rotary_embedding import get_rope
4443
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
4544
from vllm.model_executor.pooling_metadata import PoolingMetadata
4645
from vllm.model_executor.sampling_metadata import SamplingMetadata
4746
from vllm.sequence import IntermediateTensors, PoolerOutput
4847

49-
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
48+
from .adapters import as_seq_cls_model
49+
from .interfaces import SupportsLoRA, SupportsPP
5050
from .qwen2 import Qwen2MLP as Qwen3MLP
5151
from .qwen2 import Qwen2Model
5252
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
@@ -323,38 +323,31 @@ def load_weights(self, weights: Iterable[tuple[str,
323323
return loader.load_weights(weights)
324324

325325

326-
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
327-
SupportsCrossEncoding):
326+
class Qwen3ForSequenceClassification(as_seq_cls_model(Qwen3ForCausalLM)):
328327

329328
def __init__(
330329
self,
331330
vllm_config: "VllmConfig",
332331
prefix: str = "",
333332
) -> None:
334-
super().__init__()
333+
super().__init__(vllm_config=vllm_config, prefix=prefix)
335334

335+
def config_verify(self, vllm_config: "VllmConfig"):
336336
config = vllm_config.model_config.hf_config
337-
quant_config = vllm_config.quant_config
338-
pooler_config = vllm_config.model_config.pooler_config
339337

338+
is_original_qwen3_reranker = getattr(config,
339+
"is_original_qwen3_reranker",
340+
False)
341+
342+
if not is_original_qwen3_reranker:
343+
return
344+
345+
tokens = getattr(config, "classifier_from_token", None)
346+
assert tokens is not None and len(tokens) == 2, \
347+
("Try loading the original Qwen3 Reranker?, see: "
348+
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
349+
config.num_labels = 1
340350
self.vllm_config = vllm_config
341-
self.config = config
342-
self.quant_config = quant_config
343-
self.prefix = prefix
344-
self.model = Qwen3Model(vllm_config=vllm_config,
345-
prefix=maybe_prefix(prefix, "model"))
346-
self.score = RowParallelLinear(config.hidden_size,
347-
config.num_labels,
348-
quant_config=quant_config,
349-
input_is_parallel=False,
350-
bias=False,
351-
prefix=maybe_prefix(prefix, "score"))
352-
353-
self._pooler = Pooler.from_config_with_defaults(
354-
pooler_config,
355-
pooling_type=PoolingType.LAST,
356-
normalize=False,
357-
softmax=True)
358351

359352
def forward(
360353
self,
@@ -395,22 +388,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
395388

396389
def load_weights_from_original_qwen3_reranker(
397390
self, weights: Iterable[tuple[str, torch.Tensor]]):
398-
tokens = getattr(self.config, "classifier_from_token", None)
399-
assert tokens is not None and len(tokens) == 2, \
400-
("Try loading the original Qwen3 Reranker?, see: "
401-
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
402391

403-
self.config.num_labels = 1
404392
model_config = self.vllm_config.model_config
405-
393+
tokens = getattr(self.config, "classifier_from_token", None)
406394
device = self.score.weight.device
407-
self.score = RowParallelLinear(self.config.hidden_size,
408-
self.config.num_labels,
409-
quant_config=self.quant_config,
410-
input_is_parallel=False,
411-
bias=False,
412-
prefix=maybe_prefix(
413-
self.prefix, "score")).to(device)
414395

415396
if self.config.tie_word_embeddings:
416397
self.lm_head = self.model.embed_tokens

0 commit comments

Comments
 (0)