Skip to content

Commit 6dc55ba

Browse files
committed
+ ping_pong_test_score_models for test_rerank_models_correctness
1 parent adf3d36 commit 6dc55ba

File tree

6 files changed

+143
-33
lines changed

6 files changed

+143
-33
lines changed

examples/offline_inference/converting2seq_cls_models.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import transformers
1010

1111

12-
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer,
13-
classifier_from_tokens, device):
12+
def from_2_way_softmax(
13+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
14+
):
1415
# for Qwen3-Reranker
1516
assert len(classifier_from_tokens) == 2
1617

@@ -20,32 +21,53 @@ def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer,
2021
b = tokenizer.convert_tokens_to_ids(classifier_from_tokens[1])
2122

2223
score_weight = lm_head_weights[b].to(torch.float32).to(device).to(
23-
torch.float32) - lm_head_weights[a].to(device)
24+
torch.float32
25+
) - lm_head_weights[a].to(device)
2426

2527
with torch.no_grad():
2628
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
2729
if seq_cls_model.score.bias is not None:
2830
seq_cls_model.score.bias.zero_()
2931

3032

31-
method_map = {function.__name__: function for function in [from_2_way_softmax]}
33+
def from_1_way_sigmoid(
34+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
35+
):
36+
# for BAAI/bge-reranker-v2-gemma
37+
assert len(classifier_from_tokens) == 1
38+
39+
lm_head_weights = causal_lm.lm_head.weight
40+
41+
a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0])
42+
43+
score_weight = lm_head_weights[a].to(device)
44+
45+
with torch.no_grad():
46+
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
47+
if seq_cls_model.score.bias is not None:
48+
seq_cls_model.score.bias.zero_()
49+
50+
51+
method_map = {
52+
function.__name__: function for function in [from_2_way_softmax, from_1_way_sigmoid]
53+
}
3254

3355

3456
def converting(model_name, classifier_from_tokens, path, method, device="cpu"):
3557
assert method in method_map
3658

3759
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
3860
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
39-
model_name, device_map=device)
61+
model_name, device_map=device
62+
)
4063

4164
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
42-
model_name,
43-
num_labels=1,
44-
ignore_mismatched_sizes=True,
45-
device_map=device)
65+
model_name, num_labels=1, ignore_mismatched_sizes=True, device_map=device
66+
)
4667

47-
method_map[method](causal_lm, seq_cls_model, tokenizer,
48-
classifier_from_tokens, device)
68+
method_map[method](
69+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
70+
)
4971

5072
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
5173

@@ -55,31 +77,35 @@ def converting(model_name, classifier_from_tokens, path, method, device="cpu"):
5577

5678
def parse_args():
5779
parser = argparse.ArgumentParser(
58-
description=
59-
"Converting *ForCausalLM models to *ForSequenceClassification models.")
60-
parser.add_argument("--model_name",
61-
type=str,
62-
default="Qwen/Qwen3-Reranker-0.6B",
63-
help="Model name")
64-
parser.add_argument("--classifier_from_tokens",
65-
type=str,
66-
default='["no", "yes"]',
67-
help="classifier from tokens")
68-
parser.add_argument("--method",
69-
type=str,
70-
default='from_2_way_softmax',
71-
help="Converting converting")
72-
parser.add_argument("--path",
73-
type=str,
74-
default="./converted_model",
75-
help="Path to save converted model")
80+
description="Converting *ForCausalLM models to *ForSequenceClassification models."
81+
)
82+
parser.add_argument(
83+
"--model_name", type=str, default="Qwen/Qwen3-Reranker-0.6B", help="Model name"
84+
)
85+
parser.add_argument(
86+
"--classifier_from_tokens",
87+
type=str,
88+
default='["no", "yes"]',
89+
help="classifier from tokens",
90+
)
91+
parser.add_argument(
92+
"--method", type=str, default="from_2_way_softmax", help="Converting converting"
93+
)
94+
parser.add_argument(
95+
"--path",
96+
type=str,
97+
default="./converted_model",
98+
help="Path to save converted model",
99+
)
76100
return parser.parse_args()
77101

78102

79103
if __name__ == "__main__":
80104
args = parse_args()
81105

82-
converting(model_name=args.model_name,
83-
classifier_from_tokens=json.loads(args.classifier_from_tokens),
84-
method=args.method,
85-
path=args.path)
106+
converting(
107+
model_name=args.model_name,
108+
classifier_from_tokens=json.loads(args.classifier_from_tokens),
109+
method=args.method,
110+
path=args.path,
111+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import math
4+
5+
import pytest
6+
7+
from tests.models.utils import RerankModelInfo
8+
9+
10+
def ping_pong_test_score_models(hf_runner,
11+
vllm_runner,
12+
model_info: RerankModelInfo,
13+
vllm_extra_kwargs=None,
14+
hf_model_callback=None):
15+
if not model_info.enable_test:
16+
# A model family has many models with the same architecture,
17+
# and we don't need to test each one.
18+
pytest.skip("Skipping test.")
19+
20+
sentences = []
21+
22+
vllm_extra_kwargs = vllm_extra_kwargs or {}
23+
# This test must use float32 to pass.
24+
vllm_extra_kwargs["dtype"] = "float32"
25+
26+
with vllm_runner(model_info.name,
27+
task="score",
28+
max_model_len=None,
29+
**vllm_extra_kwargs) as vllm_model:
30+
31+
max_model_len = vllm_model.model.llm_engine.model_config.max_model_len
32+
33+
for i in range(0, int(math.log2(max_model_len - 1))):
34+
sentences.append(("ping", "pong" * 2**i))
35+
36+
text_1 = [x[0] for x in sentences]
37+
text_2 = [x[1] for x in sentences]
38+
vllm_outputs = vllm_model.score(text_1=text_1, text_2=text_2)
39+
40+
with hf_runner(
41+
model_info.name,
42+
dtype="float32",
43+
is_cross_encoder=True,
44+
) as hf_model:
45+
46+
if hf_model_callback is not None:
47+
hf_model_callback(hf_model)
48+
49+
# use batchsize = 1 to avoid oom
50+
hf_outputs = [
51+
hf_model.predict([sentences[i]])[0] for i in range(len(sentences))
52+
]
53+
54+
for i in range(len(sentences)):
55+
assert float(hf_outputs[i]) == pytest.approx(float(vllm_outputs[i]), rel=0.01), \
56+
f"Test failed at #{i}, vllm: {vllm_outputs[i]}, st: {hf_outputs[i]}"

tests/models/language/pooling/test_baai.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ...utils import EmbedModelInfo, RerankModelInfo
66
from .embed_utils import correctness_test_embed_models
77
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
8+
from .score_utils import ping_pong_test_score_models
89

910
MODELS = [
1011
########## BertModel
@@ -91,3 +92,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
9192
def test_rerank_models_mteb(hf_runner, vllm_runner,
9293
model_info: RerankModelInfo) -> None:
9394
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
95+
96+
97+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
98+
def test_rerank_models_correctness(hf_runner, vllm_runner,
99+
model_info: RerankModelInfo) -> None:
100+
ping_pong_test_score_models(hf_runner, vllm_runner, model_info)

tests/models/language/pooling/test_cross_encoder.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
6+
from .score_utils import ping_pong_test_score_models
67

78
RERANK_MODELS = [
89
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
@@ -16,3 +17,9 @@
1617
def test_rerank_models_mteb(hf_runner, vllm_runner,
1718
model_info: RerankModelInfo) -> None:
1819
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
20+
21+
22+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
23+
def test_rerank_models_correctness(hf_runner, vllm_runner,
24+
model_info: RerankModelInfo) -> None:
25+
ping_pong_test_score_models(hf_runner, vllm_runner, model_info)

tests/models/language/pooling/test_gte.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...utils import RerankModelInfo
88
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
99
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
10+
from .score_utils import ping_pong_test_score_models
1011

1112
MODELS = [
1213
########## BertModel
@@ -93,3 +94,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
9394
def test_rerank_models_mteb(hf_runner, vllm_runner,
9495
model_info: RerankModelInfo) -> None:
9596
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
97+
98+
99+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
100+
def test_rerank_models_correctness(hf_runner, vllm_runner,
101+
model_info: RerankModelInfo) -> None:
102+
ping_pong_test_score_models(hf_runner, vllm_runner, model_info)

tests/models/language/pooling/test_jina.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .embed_utils import (check_embeddings_close,
1111
correctness_test_embed_models, matryoshka_fy)
1212
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
13+
from .score_utils import ping_pong_test_score_models
1314

1415
EMBEDDING_MODELS = [
1516
EmbedModelInfo("jinaai/jina-embeddings-v3",
@@ -60,6 +61,12 @@ def test_rerank_models_mteb(hf_runner, vllm_runner,
6061
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
6162

6263

64+
@pytest.mark.parametrize("model_info", RERANK_MODELS)
65+
def test_rerank_models_correctness(hf_runner, vllm_runner,
66+
model_info: RerankModelInfo) -> None:
67+
ping_pong_test_score_models(hf_runner, vllm_runner, model_info)
68+
69+
6370
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
6471
@pytest.mark.parametrize("dtype", ["half"])
6572
@pytest.mark.parametrize("dimensions", [16, 32])

0 commit comments

Comments
 (0)