diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 04d9923f921..55491db0fd8 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -463,7 +463,7 @@ Specified using `--task classify`. | `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | If your model is not in the above list, we will try to automatically convert the model using -[as_classification_model][vllm.model_executor.models.adapters.as_classification_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. +[as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. #### Sentence Pair Scoring diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 00756e71999..b18db1264e8 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -431,7 +431,7 @@ Code example: Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach). -We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. +We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. Code example: diff --git a/tests/entrypoints/openai/correctness/test_mteb_score.py b/tests/entrypoints/openai/correctness/test_mteb_score.py index f90fc0b9be0..05e953de4a0 100644 --- a/tests/entrypoints/openai/correctness/test_mteb_score.py +++ b/tests/entrypoints/openai/correctness/test_mteb_score.py @@ -6,19 +6,16 @@ # yapf conflicts with isort for this block # yapf: disable -from tests.models.language.pooling.mteb_utils import (MTEB_RERANK_LANGS, - MTEB_RERANK_TASKS, - MTEB_RERANK_TOL, - RerankClientMtebEncoder, - ScoreClientMtebEncoder, - run_mteb_rerank) +from tests.models.language.pooling.mteb_utils import ( + MTEB_RERANK_LANGS, MTEB_RERANK_TASKS, MTEB_RERANK_TOL, + RerankClientMtebEncoder, ScoreClientMtebEncoder, + mteb_test_rerank_models_hf, run_mteb_rerank) # yapf: enable from tests.utils import RemoteOpenAIServer os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2" -MAIN_SCORE = 0.33437 @pytest.fixture(scope="module") @@ -31,12 +28,19 @@ def server(): yield remote_server -def test_mteb_score(server): +@pytest.fixture(scope="module") +def st_main_score(hf_runner): + # The main score related to the version of the dependency. + # So we need to recalculate every time. + main_score, st_dtype = mteb_test_rerank_models_hf(hf_runner, MODEL_NAME) + return main_score + + +def test_mteb_score(server, st_main_score): url = server.url_for("score") encoder = ScoreClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) - st_main_score = MAIN_SCORE print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) @@ -45,12 +49,11 @@ def test_mteb_score(server): assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) -def test_mteb_rerank(server): +def test_mteb_rerank(server, st_main_score): url = server.url_for("rerank") encoder = RerankClientMtebEncoder(MODEL_NAME, url) vllm_main_score = run_mteb_rerank(encoder, MTEB_RERANK_TASKS, MTEB_RERANK_LANGS) - st_main_score = MAIN_SCORE print("VLLM main score: ", vllm_main_score) print("SentenceTransformer main score: ", st_main_score) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index 0284e69f3f0..a83d2581858 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -234,6 +234,35 @@ def run_mteb_rerank(cross_encoder, tasks, languages): return main_score +def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None): + with hf_runner(model_name, is_cross_encoder=True, + dtype="float32") as hf_model: + + original_predict = hf_model.predict + + def _predict( + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ): + # vllm and st both remove the prompt, fair comparison. + prompts = [(s[0], s[1]) for s in sentences] + return original_predict(prompts, *args, **kwargs, batch_size=8) + + hf_model.predict = _predict + hf_model.original_predict = original_predict + + if hf_model_callback is not None: + hf_model_callback(hf_model) + + st_main_score = run_mteb_rerank(hf_model, + tasks=MTEB_RERANK_TASKS, + languages=MTEB_RERANK_LANGS) + st_dtype = next(hf_model.model.model.parameters()).dtype + return st_main_score, st_dtype + + def mteb_test_rerank_models(hf_runner, vllm_runner, model_info: RerankModelInfo, @@ -264,31 +293,8 @@ def mteb_test_rerank_models(hf_runner, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype - with hf_runner(model_info.name, is_cross_encoder=True, - dtype="float32") as hf_model: - - original_predict = hf_model.predict - - def _predict( - sentences: list[tuple[str, str, - Optional[str]]], # query, corpus, prompt - *args, - **kwargs, - ): - # vllm and st both remove the prompt, fair comparison. - prompts = [(s[0], s[1]) for s in sentences] - return original_predict(prompts, *args, **kwargs, batch_size=8) - - hf_model.predict = _predict - hf_model.original_predict = original_predict - - if hf_model_callback is not None: - hf_model_callback(hf_model) - - st_main_score = run_mteb_rerank(hf_model, - tasks=MTEB_RERANK_TASKS, - languages=MTEB_RERANK_LANGS) - st_dtype = next(hf_model.model.model.parameters()).dtype + st_main_score, st_dtype = mteb_test_rerank_models_hf( + hf_runner, model_info.name, hf_model_callback) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformers:", st_dtype, st_main_score) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b7527ca2706..01b2260abe8 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -9,9 +9,9 @@ from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.adapters import (as_classification_model, - as_embedding_model, - as_reward_model) +from vllm.model_executor.models.adapters import (as_embedding_model, + as_reward_model, + as_seq_cls_model) from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, @@ -38,7 +38,7 @@ def test_registry_imports(model_arch): assert is_text_generation_model(model_cls) # All vLLM models should be convertible to a pooling model - assert is_pooling_model(as_classification_model(model_cls)) + assert is_pooling_model(as_seq_cls_model(model_cls)) assert is_pooling_model(as_embedding_model(model_cls)) assert is_pooling_model(as_reward_model(model_cls)) diff --git a/tests/test_config.py b/tests/test_config.py index 5d5c4453d30..1e06cde9cbd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -85,7 +85,7 @@ def test_get_field(): ("distilbert/distilgpt2", "generate", "generate"), ("intfloat/multilingual-e5-small", "pooling", "embed"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), - ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), ("openai/whisper-small", "transcription", "transcription"), ], @@ -105,6 +105,32 @@ def test_auto_task(model_id, expected_runner_type, expected_task): assert config.task == expected_task +@pytest.mark.parametrize( + ("model_id", "expected_runner_type", "expected_task"), + [ + ("distilbert/distilgpt2", "pooling", "embed"), + ("intfloat/multilingual-e5-small", "pooling", "embed"), + ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), + ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), + ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"), + ("openai/whisper-small", "pooling", "embed"), + ], +) +def test_score_task(model_id, expected_runner_type, expected_task): + config = ModelConfig( + model_id, + task="score", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + assert config.runner_type == expected_runner_type + assert config.task == expected_task + + @pytest.mark.parametrize(("model_id", "bad_task"), [ ("Qwen/Qwen2.5-Math-RM-72B", "generate"), ]) diff --git a/vllm/config.py b/vllm/config.py index 7a3329aea5f..8ee0ce04a7a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -93,14 +93,14 @@ TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", "score", "reward", "transcription"] -_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward", - "draft", "transcription"] +_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft", + "transcription"] RunnerType = Literal["generate", "pooling", "draft", "transcription"] _RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { "generate": ["generate"], - "pooling": ["embed", "classify", "score", "reward"], + "pooling": ["embed", "classify", "reward"], "draft": ["draft"], "transcription": ["transcription"], } @@ -790,7 +790,7 @@ def _get_preferred_task( if get_pooling_config(model_id, self.revision): return "embed" if self.registry.is_cross_encoder_model(architectures): - return "score" + return "classify" if self.registry.is_transcription_model(architectures): return "transcription" @@ -854,14 +854,24 @@ def _resolve_task( "This model supports multiple tasks: %s. " "Defaulting to '%s'.", supported_tasks, selected_task) else: - # Aliases - if task_option == "embedding": - msg = ("The 'embedding' task has been renamed to " - "'embed', please use the new name. The old name " - "will be removed in v1.0.") - warnings.warn(msg, DeprecationWarning, stacklevel=2) - - task_option = "embed" + if task_option == "score": + if not runner_support["pooling"]: + msg = (f"This model does not support the '{task_option}' " + f"task. Supported tasks: {supported_tasks}") + raise ValueError(msg) + if self.registry.is_cross_encoder_model(architectures): + task_option = "classify" + else: + task_option = "embed" + else: + # Aliases + if task_option == "embedding": + msg = ("The 'embedding' task has been renamed to " + "'embed', please use the new name. The old name " + "will be removed in v1.0.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + task_option = "embed" if task_option not in supported_tasks: msg = ( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63967e4d2d4..0572d60b9d0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1273,9 +1273,13 @@ def score( raise ValueError(" ".join(messages)) - if self.llm_engine.model_config.task not in ("embed", "score"): - raise ValueError( - "Score API is only enabled for `--task embed or --task score`") + if self.llm_engine.model_config.task not in ("embed", "classify"): + raise ValueError("Score API is only enabled for " + "`--task embed or --task classify`.") + + if (self.llm_engine.model_config.task == "classify" + and self.llm_engine.model_config.hf_config.num_labels != 1): + raise ValueError("Score API is only enabled for num_labels == 1.") # the tokenizer for models such as # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 681633a2aff..84a24828771 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1263,24 +1263,27 @@ async def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embed" else None - state.openai_serving_scores = ServingScores( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger) if model_config.task in ( - "score", "embed", "pooling") else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, ) if model_config.task == "classify" else None + + enable_serving_reranking = (model_config.task == "classify" and getattr( + model_config.hf_config, "num_labels", 0) == 1) state.jinaai_serving_reranking = ServingScores( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger - ) if model_config.task == "score" else None + request_logger=request_logger) if enable_serving_reranking else None + state.openai_serving_scores = ServingScores( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger) if ( + model_config.task == "embed" or enable_serving_reranking) else None + state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 29740fc7e60..e112e2f893a 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -357,12 +357,16 @@ async def main(args): chat_template=None, chat_template_content_format="auto", ) if model_config.task == "embed" else None + + enable_serving_reranking = (model_config.task == "classify" and getattr( + model_config.hf_config, "num_labels", 0) == 1) + openai_serving_scores = (ServingScores( engine, model_config, openai_serving_models, request_logger=request_logger, - ) if model_config.task == "score" else None) + ) if (model_config.task == "embed" or enable_serving_reranking) else None) tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 8a33cd6be40..48adcc5fef8 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -285,6 +285,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = pooled_data.to(torch.float32) + # for matryoshka representation if isinstance(pooling_metadata, V0PoolingMetadata): dimensions_list = [ pooling_param.dimensions @@ -299,10 +300,16 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], if any(d is not None for d in dimensions_list): # change the output dimension assert len(pooled_data) == len(dimensions_list) - pooled_data = [ - vecs if d is None else vecs[..., :d] - for vecs, d in zip(pooled_data, dimensions_list) - ] + if len(set(dimensions_list)) == 1 and not isinstance( + pooled_data, list): + # if all dimensions are the same + d = dimensions_list[0] + pooled_data = pooled_data[..., :d] + else: + pooled_data = [ + vecs if d is None else vecs[..., :d] + for vecs, d in zip(pooled_data, dimensions_list) + ] if self.normalize: if isinstance(pooled_data, list): @@ -325,6 +332,10 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], else: pooled_data = F.sigmoid(pooled_data) + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL return pooled_data @@ -419,7 +430,6 @@ def forward( offset += prompt_len pooled_data.append(pooled_data_i) - offset = 0 pooled_data_lst = [] for pooled_data_i in pooled_data: @@ -436,7 +446,8 @@ def forward( # apply classifier once on the full batch if possible pooled_output = self.classifier(pooled_output) - scores = self.default_activation_function(pooled_output).squeeze(-1) + # shape: (batch_size, num_labels) + scores = self.default_activation_function(pooled_output) pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 79e6fa7b16d..6f8f7316ca2 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -21,8 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.models.adapters import (as_classification_model, - as_embedding_model, +from vllm.model_executor.models.adapters import (as_embedding_model, as_reward_model) from vllm.utils import is_pin_memory_available @@ -245,7 +244,9 @@ def get_model_architecture( if model_config.task == "embed": model_cls = as_embedding_model(model_cls) elif model_config.task == "classify": - model_cls = as_classification_model(model_cls) + # Cannot automatically run as_seq_cls_model, + # otherwise it will cause a circular reference on is_cross_encoder_model + pass elif model_config.task == "reward": model_cls = as_reward_model(model_cls) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1651e3e429e..4611f6704e1 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union import torch import torch.nn as nn @@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T: return ModelForEmbedding # type: ignore -def as_classification_model(cls: _T) -> _T: +def as_seq_cls_model(cls: _T) -> _T: """ - Subclass an existing vLLM model to support classification. + Subclass an existing vLLM model to support classify and score tasks. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. @@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T: # Lazy import from vllm.config import VllmConfig from vllm.model_executor.layers.linear import RowParallelLinear - from vllm.model_executor.layers.pooler import PoolingType + from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType + from vllm.model_executor.models.interfaces import SupportsCrossEncoding + from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix @@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T: default_softmax=True, ) - class ModelForClassification(ModelForPooling): + class ModelForSequenceClassification(ModelForPooling, + SupportsCrossEncoding): def __init__( self, @@ -190,6 +193,10 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + self.task = vllm_config.model_config.task + self.pooling_type = ( + vllm_config.model_config.pooler_config.pooling_type) + self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config, @@ -205,17 +212,41 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - hidden_states = super().forward(input_ids, positions, - intermediate_tensors, - inputs_embeds) - logits, _ = self.score(hidden_states) - return logits + return super().forward(input_ids, positions, intermediate_tensors, + inputs_embeds) + + def pooler( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + + def get_logits(hidden_states): + if isinstance(hidden_states, list): + logits = [self.score(state)[0] for state in hidden_states] + else: + logits, _ = self.score(hidden_states) + return logits + + if self.pooling_type == PoolingType.ALL: + logits = get_logits(hidden_states) + return self._pooler(logits, pooling_metadata) + else: + hidden_states = self._pooler.extract_states( + hidden_states, pooling_metadata) + logits = get_logits(hidden_states) + pooled_data = self._pooler.head(logits, pooling_metadata) + + pooled_outputs = [ + self._pooler.build_output(data) for data in pooled_data + ] + return PoolerOutput(outputs=pooled_outputs) - ModelForClassification.__name__ = \ - _get_pooling_model_name(cls.__name__, "ForClassification") + ModelForSequenceClassification.__name__ = \ + _get_pooling_model_name(cls.__name__, "ForSequenceClassification") - return ModelForClassification # type: ignore + return ModelForSequenceClassification # type: ignore def as_reward_model(cls: _T) -> _T: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 23f65b99c22..7ef9d248da4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,6 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -495,3 +496,6 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) + + +Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index faeaf6ef68c..a04fbccfc55 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -158,8 +158,6 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # [Auto-converted (see adapters.py)] - "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"), # Technically PrithviGeoSpatialMAE is a model that works on images, both in # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. @@ -174,7 +172,9 @@ "RobertaForSequenceClassification"), "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 + # [Auto-converted (see adapters.py)] + "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 + "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 } _MULTIMODAL_MODELS = { diff --git a/vllm/outputs.py b/vllm/outputs.py index 891305eb793..9784a889447 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -453,6 +453,7 @@ class ClassificationOutput: @staticmethod def from_base(pooling_output: PoolingOutput): + # pooling_output shape: (num_classes) pooled_data = pooling_output.data if pooled_data.ndim != 1: raise ValueError("pooled_data should be a 1-D probability vector") @@ -490,7 +491,10 @@ class ScoringOutput: @staticmethod def from_base(pooling_output: PoolingOutput): - pooled_data = pooling_output.data + # pooling_output shape: + # classify task: (num_classes) num_classes == 1 + # embed task: a scalar value + pooled_data = pooling_output.data.squeeze() if pooled_data.ndim != 0: raise ValueError("pooled_data should be a scalar score")