Skip to content

Commit 99738a7

Browse files
committed
+ converting2seq_cls_models.py
1 parent 02de010 commit 99738a7

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
import argparse
5+
import json
6+
7+
import torch
8+
import transformers
9+
10+
# Usage:
11+
# for Qwen3-Reranker
12+
# python converting.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
13+
# for BAAI/bge-reranker-v2-gemma
14+
# python converting.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
15+
16+
17+
def from_2_way_softmax(
18+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
19+
):
20+
# for Qwen3-Reranker
21+
# Adapted from https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
22+
assert len(classifier_from_tokens) == 2
23+
24+
lm_head_weights = causal_lm.lm_head.weight
25+
26+
a = tokenizer.convert_tokens_to_ids(classifier_from_tokens[0])
27+
b = tokenizer.convert_tokens_to_ids(classifier_from_tokens[1])
28+
29+
score_weight = lm_head_weights[b].to(torch.float32).to(device).to(
30+
torch.float32
31+
) - lm_head_weights[a].to(device)
32+
33+
with torch.no_grad():
34+
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
35+
if seq_cls_model.score.bias is not None:
36+
seq_cls_model.score.bias.zero_()
37+
38+
39+
def no_post_processing(
40+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
41+
):
42+
# for BAAI/bge-reranker-v2-gemma
43+
44+
lm_head_weights = causal_lm.lm_head.weight
45+
tokens = [tokenizer.convert_tokens_to_ids(t) for t in classifier_from_tokens]
46+
score_weight = lm_head_weights[tokens].to(device)
47+
48+
with torch.no_grad():
49+
seq_cls_model.score.weight.copy_(score_weight)
50+
if seq_cls_model.score.bias is not None:
51+
seq_cls_model.score.bias.zero_()
52+
53+
54+
method_map = {
55+
function.__name__: function for function in [from_2_way_softmax, no_post_processing]
56+
}
57+
58+
59+
def converting(model_name, classifier_from_tokens, path, method, device="cpu"):
60+
assert method in method_map
61+
62+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
63+
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
64+
model_name, device_map=device
65+
)
66+
67+
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
68+
model_name, num_labels=1, ignore_mismatched_sizes=True, device_map=device
69+
)
70+
71+
method_map[method](
72+
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
73+
)
74+
75+
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
76+
77+
seq_cls_model.save_pretrained(path)
78+
tokenizer.save_pretrained(path)
79+
80+
81+
def parse_args():
82+
parser = argparse.ArgumentParser(
83+
description="Converting *ForCausalLM models to "
84+
"*ForSequenceClassification models."
85+
)
86+
parser.add_argument(
87+
"--model_name", type=str, default="Qwen/Qwen3-Reranker-0.6B", help="Model name"
88+
)
89+
parser.add_argument(
90+
"--classifier_from_tokens",
91+
type=str,
92+
default='["no", "yes"]',
93+
help="classifier from tokens",
94+
)
95+
parser.add_argument(
96+
"--method", type=str, default="from_2_way_softmax", help="Converting converting"
97+
)
98+
parser.add_argument(
99+
"--path",
100+
type=str,
101+
default="./converted_model",
102+
help="Path to save converted model",
103+
)
104+
return parser.parse_args()
105+
106+
107+
if __name__ == "__main__":
108+
args = parse_args()
109+
110+
converting(
111+
model_name=args.model_name,
112+
classifier_from_tokens=json.loads(args.classifier_from_tokens),
113+
method=args.method,
114+
path=args.path,
115+
)

examples/offline_inference/qwen3_reranker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
# concise, for example.
2020
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
2121

22+
# Offline conversion from official original version to sequence classification
23+
# model code please refer to: converting2seq_cls_models.py
24+
# The init parameters are as follows.
25+
# model = LLM(model="path_to/converted_model", task="score")
26+
2227
# If you want to load the official original version, the init parameters are
2328
# as follows.
2429

0 commit comments

Comments
 (0)