Skip to content

Commit 7a0bb60

Browse files
hustyichilibingcodingl2k1
authored
ENH: add normalize to rerank model (#2509)
Co-authored-by: libing <[email protected]> Co-authored-by: codingl2k1 <[email protected]>
1 parent 042eb5b commit 7a0bb60

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

Diff for: xinference/api/restful_api.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class RerankRequest(BaseModel):
115115
return_documents: Optional[bool] = False
116116
return_len: Optional[bool] = False
117117
max_chunks_per_doc: Optional[int] = None
118+
kwargs: Optional[str] = None
118119

119120

120121
class TextToImageRequest(BaseModel):
@@ -1315,11 +1316,6 @@ async def rerank(self, request: Request) -> Response:
13151316
payload = await request.json()
13161317
body = RerankRequest.parse_obj(payload)
13171318
model_uid = body.model
1318-
kwargs = {
1319-
key: value
1320-
for key, value in payload.items()
1321-
if key not in RerankRequest.__annotations__.keys()
1322-
}
13231319

13241320
try:
13251321
model = await (await self._get_supervisor_ref()).get_model(model_uid)
@@ -1333,14 +1329,18 @@ async def rerank(self, request: Request) -> Response:
13331329
raise HTTPException(status_code=500, detail=str(e))
13341330

13351331
try:
1332+
if body.kwargs is not None:
1333+
parsed_kwargs = json.loads(body.kwargs)
1334+
else:
1335+
parsed_kwargs = {}
13361336
scores = await model.rerank(
13371337
body.documents,
13381338
body.query,
13391339
top_n=body.top_n,
13401340
max_chunks_per_doc=body.max_chunks_per_doc,
13411341
return_documents=body.return_documents,
13421342
return_len=body.return_len,
1343-
**kwargs,
1343+
**parsed_kwargs,
13441344
)
13451345
return Response(scores, media_type="application/json")
13461346
except RuntimeError as re:

Diff for: xinference/client/restful/restful_client.py

+1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def rerank(
174174
"max_chunks_per_doc": max_chunks_per_doc,
175175
"return_documents": return_documents,
176176
"return_len": return_len,
177+
"kwargs": json.dumps(kwargs),
177178
}
178179
request_body.update(kwargs)
179180
response = requests.post(url, json=request_body, headers=self.auth_headers)

Diff for: xinference/model/rerank/core.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _auto_detect_type(model_path):
179179
return rerank_type
180180

181181
def load(self):
182+
logger.info("Loading rerank model: %s", self._model_path)
182183
flash_attn_installed = importlib.util.find_spec("flash_attn") is not None
183184
if (
184185
self._auto_detect_type(self._model_path) != "normal"
@@ -189,6 +190,7 @@ def load(self):
189190
"will force set `use_fp16` to True"
190191
)
191192
self._use_fp16 = True
193+
192194
if self._model_spec.type == "normal":
193195
try:
194196
import sentence_transformers
@@ -250,22 +252,27 @@ def rerank(
250252
**kwargs,
251253
) -> Rerank:
252254
assert self._model is not None
253-
if kwargs:
254-
raise ValueError("rerank hasn't support extra parameter.")
255255
if max_chunks_per_doc is not None:
256256
raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.")
257+
logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model)
257258
sentence_combinations = [[query, doc] for doc in documents]
258259
# reset n tokens
259260
self._model.model.n_tokens = 0
260261
if self._model_spec.type == "normal":
261262
similarity_scores = self._model.predict(
262-
sentence_combinations, convert_to_numpy=False, convert_to_tensor=True
263+
sentence_combinations,
264+
convert_to_numpy=False,
265+
convert_to_tensor=True,
266+
**kwargs,
263267
).cpu()
264268
if similarity_scores.dtype == torch.bfloat16:
265269
similarity_scores = similarity_scores.float()
266270
else:
267271
# Related issue: https://github.com/xorbitsai/inference/issues/1775
268-
similarity_scores = self._model.compute_score(sentence_combinations)
272+
similarity_scores = self._model.compute_score(
273+
sentence_combinations, **kwargs
274+
)
275+
269276
if not isinstance(similarity_scores, Sequence):
270277
similarity_scores = [similarity_scores]
271278
elif (

Diff for: xinference/model/rerank/tests/test_rerank.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,8 @@ def test_restful_api(model_name, setup):
118118
kwargs = {
119119
"invalid": "invalid",
120120
}
121-
with pytest.raises(RuntimeError) as err:
122-
scores = model.rerank(corpus, query, **kwargs)
123-
assert "hasn't support" in str(err.value)
121+
with pytest.raises(RuntimeError):
122+
model.rerank(corpus, query, **kwargs)
124123

125124

126125
def test_from_local_uri():

0 commit comments

Comments
 (0)