Skip to content

Commit 2d11141

Browse files
committed
Temporary fix for is_cross_encoder
1 parent cd0da36 commit 2d11141

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

examples/offline_inference/convert_model_to_seq_cls.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
1414
# for BAAI/bge-reranker-v2-gemma
1515
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
16+
# Caution: "Yes" and "yes" are two different tokens
1617

1718

1819
def from_2_way_softmax(
1920
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
2021
):
2122
# for Qwen3-Reranker
2223
# Adapted from https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
23-
assert len(classifier_from_tokens) == 2
24-
2524
lm_head_weights = causal_lm.lm_head.weight
2625

2726
a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0])
@@ -62,13 +61,22 @@ def converting(
6261
):
6362
assert method in method_map
6463

64+
if method == "from_2_way_softmax":
65+
assert len(classifier_from_tokens) == 2
66+
num_labels = 1
67+
else:
68+
num_labels = len(classifier_from_tokens)
69+
6570
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
6671
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
6772
model_name, device_map=device
6873
)
6974

7075
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
71-
model_name, num_labels=1, ignore_mismatched_sizes=True, device_map=device
76+
model_name,
77+
num_labels=num_labels,
78+
ignore_mismatched_sizes=True,
79+
device_map=device,
7280
)
7381

7482
method_map[method](

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,9 @@ def is_multimodal_model(self) -> bool:
14001400

14011401
@property
14021402
def is_cross_encoder(self) -> bool:
1403-
return self.registry.is_cross_encoder_model(self.architectures)
1403+
# Temporary solution, See #19675
1404+
return (self.registry.is_cross_encoder_model(self.architectures) or
1405+
"forsequenceclassification" in self.architectures[0].lower())
14041406

14051407
@property
14061408
def use_mla(self) -> bool:

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,7 @@ def is_cross_encoder_model(
510510
architectures: Union[str, list[str]],
511511
) -> bool:
512512
model_cls, _ = self.inspect_model_cls(architectures)
513-
#return model_cls.supports_cross_encoding
514-
return True
513+
return model_cls.supports_cross_encoding
515514

516515
def is_multimodal_model(
517516
self,

0 commit comments

Comments
 (0)