diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 803d2938d2b..2c343e3d6c4 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -447,7 +447,7 @@ Specified using `--task classify`. | `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | If your model is not in the above list, we will try to automatically convert the model using -[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. +[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. #### Sentence Pair Scoring diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 3002b2f92e4..91c02d76914 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -379,7 +379,7 @@ Code example: Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach). -We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. +We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. Code example: diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/convert_model_to_seq_cls.py new file mode 100644 index 00000000000..77959b77f37 --- /dev/null +++ b/examples/offline_inference/convert_model_to_seq_cls.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import argparse +import json + +import torch +import transformers + +# Usage: +# for Qwen3-Reranker +# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls +# for BAAI/bge-reranker-v2-gemma +# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# Caution: "Yes" and "yes" are two different tokens + + +def from_2_way_softmax( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device +): + # for Qwen3-Reranker + # Adapted from https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 + lm_head_weights = causal_lm.lm_head.weight + + a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0]) + b = tokenizer.convert_tokens_to_ids(classifier_from_tokens[1]) + + score_weight = lm_head_weights[b].to(torch.float32).to(device).to( + torch.float32 + ) - lm_head_weights[a].to(device) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +def no_post_processing( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device +): + # for BAAI/bge-reranker-v2-gemma + + lm_head_weights = causal_lm.lm_head.weight + tokens = [tokenizer.convert_tokens_to_ids(t) for t in classifier_from_tokens] + score_weight = lm_head_weights[tokens].to(device) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +method_map = { + function.__name__: function for function in [from_2_way_softmax, no_post_processing] +} + + +def converting( + model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" +): + assert method in method_map + + if method == "from_2_way_softmax": + assert len(classifier_from_tokens) == 2 + num_labels = 1 + else: + num_labels = len(classifier_from_tokens) + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + causal_lm = transformers.AutoModelForCausalLM.from_pretrained( + model_name, device_map=device + ) + + seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=num_labels, + ignore_mismatched_sizes=True, + device_map=device, + ) + + method_map[method]( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device + ) + + seq_cls_model.config.pad_token_id = tokenizer.pad_token_id + seq_cls_model.config.use_pad_token = use_pad_token + + seq_cls_model.save_pretrained(path) + tokenizer.save_pretrained(path) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Converting *ForCausalLM models to " + "*ForSequenceClassification models." + ) + parser.add_argument( + "--model_name", type=str, default="Qwen/Qwen3-Reranker-0.6B", help="Model name" + ) + parser.add_argument( + "--classifier_from_tokens", + type=str, + default='["no", "yes"]', + help="classifier from tokens", + ) + parser.add_argument( + "--use-pad-token", action="store_true", help="Whether to use pad_token" + ) + parser.add_argument( + "--method", type=str, default="from_2_way_softmax", help="Converting converting" + ) + parser.add_argument( + "--path", + type=str, + default="./converted_model", + help="Path to save converted model", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + converting( + model_name=args.model_name, + classifier_from_tokens=json.loads(args.classifier_from_tokens), + method=args.method, + use_pad_token=args.use_pad_token, + path=args.path, + ) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index 27c4071bf09..4230368c1c0 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -19,6 +19,11 @@ # concise, for example. # model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") +# Offline conversion from official original version to sequence classification +# model code please refer to: convert_model_to_seq_cls.py +# The init parameters are as follows. +# model = LLM(model="path_to/converted_model", task="score") + # If you want to load the official original version, the init parameters are # as follows. diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/openai/correctness/test_mteb_score.py index f90fc0b9be0..05e953de4a0 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/openai/correctness/test_mteb_score.py @@ -6,19 +6,16 @@ # yapf conflicts with isort for this block # yapf: disable -from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS, - MTEB_RERANK_TASKS, - MTEB_RERANK_TOL, - RerankClientMtebEncoder, - ScoreClientMtebEncoder, - run_mteb_rerank) +from tests.models.language.pooling.mteb_utils import ( + MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, + RerankClientMtebEncoder, ScoreClientMtebEncoder, + mteb_test_rerank_models_hf, run_mteb_rerank) # yapf: enable from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" -MAIN_SCORE = 0.33437 @pytest.fixture(scope="module") @@ -31,12 +28,19 @@ def server(): yield remote_server -def test_mteb_score(server): +@pytest.fixture(scope="module") +def st_main_score(hf_runner): + # The main score related to the version of the dependency. + # So we need to recalculate every time. + main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME) + return main_score + + +def test_mteb_score(server, st_main_score): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) - st_main_score = MAIN_SCORE print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) @@ -45,12 +49,11 @@ def test_mteb_score(server): assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) -def test_mteb_rerank(server): +def test_mteb_rerank(server, st_main_score): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) - st_main_score = MAIN_SCORE print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 6d5f925152c..15d72519c98 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -18,6 +18,8 @@ def server(): "--enforce-eager", "--max-model-len", "512", + "--task", + "classify", "--dtype", DTYPE, ] diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 21d55c418c3..f3c92926643 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages): return main_score +def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): + with hf_runner(model_name, is_cross_encoder=True, + dtype="float32") as hf_model: + + original_predict = hf_model.predict + + def _predict( + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ): + # vllm and st both remove the prompt, fair comparison. + prompts = [(s[0], s[1]) for s in sentences] + return original_predict(prompts, *args, **kwargs, batch_size=8) + + hf_model.predict = _predict + hf_model.original_predict = original_predict + + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_rerank(hf_model, + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS) + st_dtype = next(hf_model.model.model.parameters()).dtype + return st_main_score, st_dtype + + def mteb_test_rerank_models(hf_runner, vllm_runner, model_info: RerankModelInfo, @@ -261,31 +290,8 @@ def mteb_test_rerank_models(hf_runner, languages=MTEB_RERANK_LANGS) vllm_dtype = vllm_model.model.llm_engine.model_config.dtype - with hf_runner(model_info.name, is_cross_encoder=True, - dtype="float32") as hf_model: - - original_predict = hf_model.predict - - def _predict( - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt - *args, - **kwargs, - ): - # vllm and st both remove the prompt, fair comparison. - prompts = [(s[0], s[1]) for s in sentences] - return original_predict(prompts, *args, **kwargs, batch_size=8) - - hf_model.predict = _predict - hf_model.original_predict = original_predict - - if hf_model_callback is not None: - hf_model_callback(hf_model) - - st_main_score = run_mteb_rerank(hf_model, - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) - st_dtype = next(hf_model.model.model.parameters()).dtype + st_main_score, st_dtype = mteb_test_rerank_models_hf( + hf_runner, model_info.name, hf_model_callback) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) diff --git a/tests/models/language/pooling/test_classification.py b/tests/models/language/pooling/test_classification.py index 77df6d16a36..d62ba17891b 100644 --- a/tests/models/language/pooling/test_classification.py +++ b/tests/models/language/pooling/test_classification.py @@ -37,7 +37,8 @@ def test_models( # switch to use ROCm CK FA backend monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") - with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model: + with vllm_runner(model, max_model_len=512, dtype=dtype, + task="classify") as vllm_model: vllm_outputs = vllm_model.classify(example_prompts) with hf_runner(model, diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 6a3a0f150b6..a1c1d96a52f 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -33,10 +33,10 @@ enable_test=True), EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", architecture="GteNewModel", - enable_test=True), + enable_test=False), EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", architecture="GteNewModel", - enable_test=True), + enable_test=False), ########### Qwen2ForCausalLM EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", architecture="Qwen2ForCausalLM", diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b7527ca2706..01b2260abe8 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -9,9 +9,9 @@ from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.adapters import (as_classification_model, - as_embedding_model, - as_reward_model) +from vllm.model_executor.models.adapters import (as_embedding_model, + as_reward_model, + as_seq_cls_model) from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, @@ -38,7 +38,7 @@ def test_registry_imports(model_arch): assert is_text_generation_model(model_cls) # All vLLM models should be convertible to a pooling model - assert is_pooling_model(as_classification_model(model_cls)) + assert is_pooling_model(as_seq_cls_model(model_cls)) assert is_pooling_model(as_embedding_model(model_cls)) assert is_pooling_model(as_reward_model(model_cls)) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 05e0be61ada..ebd5ee519e0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1178,7 +1178,6 @@ def _cross_encoding_score( lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: - if isinstance(tokenizer, MistralTokenizer): raise ValueError( "Score API is only enabled for `--task embed or score`") @@ -1189,6 +1188,8 @@ def _cross_encoding_score( input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] pooling_params = PoolingParams() + use_pad_token = getattr(self.llm_engine.model_config.hf_config, + "use_pad_token", True) tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, @@ -1197,9 +1198,12 @@ def _cross_encoding_score( parsed_prompts = [] for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + if use_pad_token: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + else: + prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 9f333c02ab5..59d2b6622b4 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -173,8 +173,12 @@ async def _cross_encoding_score( *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) for t1, t2 in input_pairs)) + use_pad_token = getattr(self.model_config.hf_config, "use_pad_token", + True) + for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if tokenizer.sep_token else '' + sep_token = tokenizer.sep_token if (tokenizer.sep_token + and use_pad_token) else '' request_prompt = f"{t1}{sep_token}{t2}" input_ids = prompt_inputs["input_ids"] diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 79e6fa7b16d..b0f5e5df198 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -21,9 +21,9 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.adapters import (as_classification_model, - as_embedding_model, - as_reward_model) +from vllm.model_executor.models.adapters import (as_embedding_model, + as_reward_model, + as_seq_cls_model) from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -244,8 +244,8 @@ def get_model_architecture( model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": model_cls = as_embedding_model(model_cls) - elif model_config.task == "classify": - model_cls = as_classification_model(model_cls) + elif model_config.task in ["classify", "score"]: + model_cls = as_seq_cls_model(model_cls) elif model_config.task == "reward": model_cls = as_reward_model(model_cls) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1651e3e429e..70de304ae86 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import torch import torch.nn as nn @@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T: return ModelForEmbedding # type: ignore -def as_classification_model(cls: _T) -> _T: +def as_seq_cls_model(cls: _T) -> _T: """ - Subclass an existing vLLM model to support classification. + Subclass an existing vLLM model to support classify and score tasks. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. @@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T: # Lazy import from vllm.config import VllmConfig from vllm.model_executor.layers.linear import RowParallelLinear - from vllm.model_executor.layers.pooler import PoolingType + from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType + from vllm.model_executor.models.interfaces import SupportsCrossEncoding + from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T: default_softmax=True, ) - class ModelForClassification(ModelForPooling): + class ModelForSequenceClassification(ModelForPooling, + SupportsCrossEncoding): def __init__( self, @@ -186,10 +189,18 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.verify_and_update_config(vllm_config) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + self.task = vllm_config.model_config.task + self.pooling_type = ( + vllm_config.model_config.pooler_config.pooling_type) + + if self.task == "score": + assert config.num_labels == 1 + self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config, @@ -198,6 +209,11 @@ def __init__( prefix=maybe_prefix( prefix, "score")) + def verify_and_update_config(self, vllm_config): + # Leave an interface for validating and modifying model_config + # for slightly different models + pass + def forward( self, input_ids: torch.Tensor, @@ -205,17 +221,44 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = super().forward(input_ids, positions, - intermediate_tensors, - inputs_embeds) - logits, _ = self.score(hidden_states) - return logits + return super().forward(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def pooler( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + + def get_logits(hidden_states): + if isinstance(hidden_states, list): + logits = [self.score(state)[0] for state in hidden_states] + else: + logits, _ = self.score(hidden_states) + return logits + + if self.pooling_type == PoolingType.ALL: + logits = get_logits(hidden_states) + return self._pooler(logits, pooling_metadata) + else: + hidden_states = self._pooler.extract_states( + hidden_states, pooling_metadata) + logits = get_logits(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) + + if self.task == "score": + pooled_data = [data.squeeze(-1) for data in pooled_data] + + pooled_outputs = [ + self._pooler.build_output(data) for data in pooled_data + ] + return PoolerOutput(outputs=pooled_outputs) - ModelForClassification.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForClassification") + ModelForSequenceClassification.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForSequenceClassification") - return ModelForClassification # type: ignore + return ModelForSequenceClassification # type: ignore def as_reward_model(cls: _T) -> _T: diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 59c3102add4..bc8179f886f 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -43,6 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -425,3 +426,6 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +GemmaForSequenceClassification = as_seq_cls_model(GemmaForCausalLM) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 23f65b99c22..7ef9d248da4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,6 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -495,3 +496,6 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 216c1f1c7ff..2c87fbceed2 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -38,15 +38,14 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors -from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP +from .adapters import as_seq_cls_model +from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix @@ -323,69 +322,31 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) -class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA, - SupportsCrossEncoding): +class Qwen3ForSequenceClassification(as_seq_cls_model(Qwen3ForCausalLM)): def __init__( self, vllm_config: "VllmConfig", prefix: str = "", ) -> None: - super().__init__() + super().__init__(vllm_config=vllm_config, prefix=prefix) + def verify_and_update_config(self, vllm_config: "VllmConfig"): config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config - - self.vllm_config = vllm_config - self.config = config - self.quant_config = quant_config - self.prefix = prefix - self.model = Qwen3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.score = RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix(prefix, "score")) - - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.model(input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - hidden_states = self._pooler.extract_states(hidden_states, - pooling_metadata) + is_original_qwen3_reranker = getattr(config, + "is_original_qwen3_reranker", + False) - if isinstance(hidden_states, list): - logits = [self.score(state)[0] for state in hidden_states] - else: - logits, _ = self.score(hidden_states) + if not is_original_qwen3_reranker: + return - pooled_data = self._pooler.head(logits, pooling_metadata) - pooled_outputs = [ - self._pooler.build_output(data.squeeze(-1)) for data in pooled_data - ] - return PoolerOutput(outputs=pooled_outputs) + tokens = getattr(config, "classifier_from_token", None) + assert tokens is not None and len(tokens) == 2, \ + ("Try loading the original Qwen3 Reranker?, see: " + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") + config.num_labels = 1 + self.vllm_config = vllm_config def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): is_original_qwen3_reranker = getattr(self.config, @@ -400,22 +361,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights_from_original_qwen3_reranker( self, weights: Iterable[tuple[str, torch.Tensor]]): - tokens = getattr(self.config, "classifier_from_token", None) - assert tokens is not None and len(tokens) == 2, \ - ("Try loading the original Qwen3 Reranker?, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py") - self.config.num_labels = 1 model_config = self.vllm_config.model_config - + tokens = getattr(self.config, "classifier_from_token", None) device = self.score.weight.device - self.score = RowParallelLinear(self.config.hidden_size, - self.config.num_labels, - quant_config=self.quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix( - self.prefix, "score")).to(device) if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -443,5 +392,6 @@ def load_weights_from_original_qwen3_reranker( self.score.weight.data.copy_(weight) del self.lm_head - loaded_weights.add("classifier.weight") + loaded_weights.add("score.weight") loaded_weights.discard("lm_head.weight") + return loaded_weights diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index faeaf6ef68c..5c1d08a28bd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -158,8 +158,6 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # [Auto-converted (see adapters.py)] - "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), # Technically PrithviGeoSpatialMAE is a model that works on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. @@ -174,6 +172,9 @@ "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), + # [Auto-converted (see adapters.py)] + "GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501 + "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 }