Skip to content

[Model] Automatic conversion of score (CrossEncoding) models #19675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ Specified using `--task classify`.
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |

If your model is not in the above list, we will try to automatically convert the model using
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.

#### Sentence Pair Scoring

Expand Down
2 changes: 1 addition & 1 deletion docs/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>

Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).

We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.

Code example: <gh-file:examples/online_serving/openai_classification_client.py>

Expand Down
131 changes: 131 additions & 0 deletions examples/offline_inference/convert_model_to_seq_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501

import argparse
import json

import torch
import transformers

# Usage:
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
# for BAAI/bge-reranker-v2-gemma
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# Caution: "Yes" and "yes" are two different tokens


def from_2_way_softmax(
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
):
# for Qwen3-Reranker
# Adapted from https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
lm_head_weights = causal_lm.lm_head.weight

a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0])
b = tokenizer.convert_tokens_to_ids(classifier_from_tokens[1])

score_weight = lm_head_weights[b].to(torch.float32).to(device).to(
torch.float32
) - lm_head_weights[a].to(device)

with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()


def no_post_processing(
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
):
# for BAAI/bge-reranker-v2-gemma

lm_head_weights = causal_lm.lm_head.weight
tokens = [tokenizer.convert_tokens_to_ids(t) for t in classifier_from_tokens]
score_weight = lm_head_weights[tokens].to(device)

with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight)
if seq_cls_model.score.bias is not None:
seq_cls_model.score.bias.zero_()


method_map = {
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
}


def converting(
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
):
assert method in method_map

if method == "from_2_way_softmax":
assert len(classifier_from_tokens) == 2
num_labels = 1
else:
num_labels = len(classifier_from_tokens)

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
model_name, device_map=device
)

seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
ignore_mismatched_sizes=True,
device_map=device,
)

method_map[method](
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
)

seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
seq_cls_model.config.use_pad_token = use_pad_token

seq_cls_model.save_pretrained(path)
tokenizer.save_pretrained(path)


def parse_args():
parser = argparse.ArgumentParser(
description="Converting *ForCausalLM models to "
"*ForSequenceClassification models."
)
parser.add_argument(
"--model_name", type=str, default="Qwen/Qwen3-Reranker-0.6B", help="Model name"
)
parser.add_argument(
"--classifier_from_tokens",
type=str,
default='["no", "yes"]',
help="classifier from tokens",
)
parser.add_argument(
"--use-pad-token", action="store_true", help="Whether to use pad_token"
)
parser.add_argument(
"--method", type=str, default="from_2_way_softmax", help="Converting converting"
)
parser.add_argument(
"--path",
type=str,
default="./converted_model",
help="Path to save converted model",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

converting(
model_name=args.model_name,
classifier_from_tokens=json.loads(args.classifier_from_tokens),
method=args.method,
use_pad_token=args.use_pad_token,
path=args.path,
)
5 changes: 5 additions & 0 deletions examples/offline_inference/qwen3_reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
# concise, for example.
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")

# Offline conversion from official original version to sequence classification
# model code please refer to: convert_model_to_seq_cls.py
# The init parameters are as follows.
# model = LLM(model="path_to/converted_model", task="score")

# If you want to load the official original version, the init parameters are
# as follows.

Expand Down
25 changes: 14 additions & 11 deletions tests/entrypoints/openai/correctness/test_mteb_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@

# yapf conflicts with isort for this block
# yapf: disable
from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS,
MTEB_RERANK_TASKS,
MTEB_RERANK_TOL,
RerankClientMtebEncoder,
ScoreClientMtebEncoder,
run_mteb_rerank)
from tests.models.language.pooling.mteb_utils import (
MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL,
RerankClientMtebEncoder, ScoreClientMtebEncoder,
mteb_test_rerank_models_hf, run_mteb_rerank)
# yapf: enable
from tests.utils import RemoteOpenAIServer

os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"

MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
MAIN_SCORE = 0.33437


@pytest.fixture(scope="module")
Expand All @@ -31,12 +28,19 @@ def server():
yield remote_server


def test_mteb_score(server):
@pytest.fixture(scope="module")
def st_main_score(hf_runner):
# The main score related to the version of the dependency.
# So we need to recalculate every time.
main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME)
return main_score


def test_mteb_score(server, st_main_score):
url = server.url_for("score")
encoder = ScoreClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
Expand All @@ -45,12 +49,11 @@ def test_mteb_score(server):
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)


def test_mteb_rerank(server):
def test_mteb_rerank(server, st_main_score):
url = server.url_for("rerank")
encoder = RerankClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
Expand Down
2 changes: 2 additions & 0 deletions tests/entrypoints/openai/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def server():
"--enforce-eager",
"--max-model-len",
"512",
"--task",
"classify",
"--dtype",
DTYPE,
]
Expand Down
56 changes: 31 additions & 25 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
return main_score


def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
with hf_runner(model_name, is_cross_encoder=True,
dtype="float32") as hf_model:

original_predict = hf_model.predict

def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)

hf_model.predict = _predict
hf_model.original_predict = original_predict

if hf_model_callback is not None:
hf_model_callback(hf_model)

st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
return st_main_score, st_dtype


def mteb_test_rerank_models(hf_runner,
vllm_runner,
model_info: RerankModelInfo,
Expand Down Expand Up @@ -261,31 +290,8 @@ def mteb_test_rerank_models(hf_runner,
languages=MTEB_RERANK_LANGS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype

with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:

original_predict = hf_model.predict

def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)

hf_model.predict = _predict
hf_model.original_predict = original_predict

if hf_model_callback is not None:
hf_model_callback(hf_model)

st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)

print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/language/pooling/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_models(
# switch to use ROCm CK FA backend
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")

with vllm_runner(model, max_model_len=512, dtype=dtype) as vllm_model:
with vllm_runner(model, max_model_len=512, dtype=dtype,
task="classify") as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)

with hf_runner(model,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/language/pooling/test_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
enable_test=True),
EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5",
architecture="GteNewModel",
enable_test=True),
enable_test=False),
EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5",
architecture="GteNewModel",
enable_test=True),
enable_test=False),
########### Qwen2ForCausalLM
EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct",
architecture="Qwen2ForCausalLM",
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
Expand All @@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
assert is_text_generation_model(model_cls)

# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(as_seq_cls_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))

Expand Down
12 changes: 8 additions & 4 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,7 +1178,6 @@ def _cross_encoding_score(
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> list[ScoringRequestOutput]:

if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"Score API is only enabled for `--task embed or score`")
Expand All @@ -1189,6 +1188,8 @@ def _cross_encoding_score(
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

pooling_params = PoolingParams()
use_pad_token = getattr(self.llm_engine.model_config.hf_config,
"use_pad_token", True)

tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
Expand All @@ -1197,9 +1198,12 @@ def _cross_encoding_score(
parsed_prompts = []

for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
if use_pad_token:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
else:
prompt_inputs = tokenizer(text=q + t, **tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,12 @@ async def _cross_encoding_score(
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
for t1, t2 in input_pairs))

use_pad_token = getattr(self.model_config.hf_config, "use_pad_token",
True)

for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
sep_token = tokenizer.sep_token if tokenizer.sep_token else ''
sep_token = tokenizer.sep_token if (tokenizer.sep_token
and use_pad_token) else ''
request_prompt = f"{t1}{sep_token}{t2}"

input_ids = prompt_inputs["input_ids"]
Expand Down
Loading