Skip to content

[Model][2/N] Automatic conversion of CrossEncoding model #19978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ Specified using `--task classify`.
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
If your model is not in the above list, we will try to automatically convert the model using
[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.

#### Sentence Pair Scoring

Expand Down
2 changes: 1 addition & 1 deletion docs/serving/openai_compatible_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ Code example: <gh-file:examples/online_serving/openai_pooling_client.py>

Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach).

We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.

Code example: <gh-file:examples/online_serving/openai_classification_client.py>

Expand Down
25 changes: 14 additions & 11 deletions tests/entrypoints/openai/correctness/test_mteb_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@

# yapf conflicts with isort for this block
# yapf: disable
from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS,
MTEB_RERANK_TASKS,
MTEB_RERANK_TOL,
RerankClientMtebEncoder,
ScoreClientMtebEncoder,
run_mteb_rerank)
from tests.models.language.pooling.mteb_utils import (
MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL,
RerankClientMtebEncoder, ScoreClientMtebEncoder,
mteb_test_rerank_models_hf, run_mteb_rerank)
# yapf: enable
from tests.utils import RemoteOpenAIServer

os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"

MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
MAIN_SCORE = 0.33437


@pytest.fixture(scope="module")
Expand All @@ -31,12 +28,19 @@ def server():
yield remote_server


def test_mteb_score(server):
@pytest.fixture(scope="module")
def st_main_score(hf_runner):
# The main score related to the version of the dependency.
# So we need to recalculate every time.
main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME)
return main_score


def test_mteb_score(server, st_main_score):
url = server.url_for("score")
encoder = ScoreClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
Expand All @@ -45,12 +49,11 @@ def test_mteb_score(server):
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)


def test_mteb_rerank(server):
def test_mteb_rerank(server, st_main_score):
url = server.url_for("rerank")
encoder = RerankClientMtebEncoder(MODEL_NAME, url)
vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS,
MTEB_RERANK_LANGS)
st_main_score = MAIN_SCORE

print("VLLM main score: ", vllm_main_score)
print("SentenceTransformer main score: ", st_main_score)
Expand Down
56 changes: 31 additions & 25 deletions tests/models/language/pooling/mteb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
return main_score


def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
with hf_runner(model_name, is_cross_encoder=True,
dtype="float32") as hf_model:

original_predict = hf_model.predict

def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)

hf_model.predict = _predict
hf_model.original_predict = original_predict

if hf_model_callback is not None:
hf_model_callback(hf_model)

st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
return st_main_score, st_dtype


def mteb_test_rerank_models(hf_runner,
vllm_runner,
model_info: RerankModelInfo,
Expand Down Expand Up @@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner,
languages=MTEB_RERANK_LANGS)
vllm_dtype = model_config.dtype

with hf_runner(model_info.name, is_cross_encoder=True,
dtype="float32") as hf_model:

original_predict = hf_model.predict

def _predict(
sentences: list[tuple[str, str,
Optional[str]]], # query, corpus, prompt
*args,
**kwargs,
):
# vllm and st both remove the prompt, fair comparison.
prompts = [(s[0], s[1]) for s in sentences]
return original_predict(prompts, *args, **kwargs, batch_size=8)

hf_model.predict = _predict
hf_model.original_predict = original_predict

if hf_model_callback is not None:
hf_model_callback(hf_model)

st_main_score = run_mteb_rerank(hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS)
st_dtype = next(hf_model.model.model.parameters()).dtype
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, hf_model_callback)

print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformers:", st_dtype, st_main_score)
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
Expand All @@ -38,7 +38,7 @@ def test_registry_imports(model_arch):
assert is_text_generation_model(model_cls)

# All vLLM models should be convertible to a pooling model
assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(as_seq_cls_model(model_cls))
assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))

Expand Down
28 changes: 27 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_get_field():
("distilbert/distilgpt2", "generate", "generate"),
("intfloat/multilingual-e5-small", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
],
Expand All @@ -105,6 +105,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task


@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("distilbert/distilgpt2", "pooling", "embed"),
("intfloat/multilingual-e5-small", "pooling", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
("openai/whisper-small", "pooling", "embed"),
],
)
def test_score_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="score",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)

assert config.runner_type == expected_runner_type
assert config.task == expected_task


@pytest.mark.parametrize(("model_id", "bad_task"), [
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
])
Expand Down
34 changes: 22 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]

_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
"transcription"]

RunnerType = Literal["generate", "pooling", "draft", "transcription"]

_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"pooling": ["embed", "classify", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
Expand Down Expand Up @@ -790,7 +790,7 @@ def _get_preferred_task(
if get_pooling_config(model_id, self.revision):
return "embed"
if self.registry.is_cross_encoder_model(architectures):
return "score"
return "classify"
if self.registry.is_transcription_model(architectures):
return "transcription"

Expand Down Expand Up @@ -854,14 +854,24 @@ def _resolve_task(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)

task_option = "embed"
if task_option == "score":
if not runner_support["pooling"]:
msg = (f"This model does not support the '{task_option}' "
f"task. Supported tasks: {supported_tasks}")
raise ValueError(msg)
if self.registry.is_cross_encoder_model(architectures):
task_option = "classify"
else:
task_option = "embed"
else:
# Aliases
if task_option == "embedding":
msg = ("The 'embedding' task has been renamed to "
"'embed', please use the new name. The old name "
"will be removed in v1.0.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)

task_option = "embed"

if task_option not in supported_tasks:
msg = (
Expand Down
10 changes: 7 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,9 +1273,13 @@ def score(

raise ValueError(" ".join(messages))

if self.llm_engine.model_config.task not in ("embed", "score"):
raise ValueError(
"Score API is only enabled for `--task embed or --task score`")
if self.llm_engine.model_config.task not in ("embed", "classify"):
raise ValueError("Score API is only enabled for "
"`--task embed or --task classify`.")

if (self.llm_engine.model_config.task == "classify"
and self.llm_engine.model_config.hf_config.num_labels != 1):
raise ValueError("Score API is only enabled for num_labels == 1.")

# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
Expand Down
19 changes: 11 additions & 8 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,24 +1263,27 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.task == "classify" else None

enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
request_logger=request_logger) if enable_serving_reranking else None
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if (
model_config.task == "embed" or enable_serving_reranking) else None

state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/openai/run_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,16 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embed" else None

enable_serving_reranking = (model_config.task == "classify" and getattr(
model_config.hf_config, "num_labels", 0) == 1)

openai_serving_scores = (ServingScores(
engine,
model_config,
openai_serving_models,
request_logger=request_logger,
) if model_config.task == "score" else None)
) if (model_config.task == "embed" or enable_serving_reranking) else None)

tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
Expand Down
Loading