Skip to content

Commit 3b64bda

Browse files
committed
优化代码,提高并发TTFT 性能
1 parent 578caa7 commit 3b64bda

File tree

1 file changed

+78
-90
lines changed

1 file changed

+78
-90
lines changed

gpt_server/serving/openai_api_server.py

Lines changed: 78 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import json
1313
import os
1414
import time
15+
import traceback
1516
from typing import Generator, Optional, Union, Dict, List, Any
1617

1718
import aiohttp
@@ -105,8 +106,47 @@ class AppSettings(BaseSettings):
105106

106107

107108
app_settings = AppSettings()
108-
app = fastapi.FastAPI(docs_url="/")
109-
headers = {"User-Agent": "FastChat API Server"}
109+
from contextlib import asynccontextmanager
110+
111+
model_address_map = {}
112+
113+
114+
async def timing_tasks():
115+
"""定时任务"""
116+
global model_address_map
117+
logger.info("定时任务已启动!")
118+
controller_address = app_settings.controller_address
119+
120+
while True:
121+
try:
122+
models = await fetch_remote(
123+
controller_address + "/list_models", None, "models"
124+
)
125+
worker_addr_coro_list = []
126+
for model in models:
127+
worker_addr_coro = fetch_remote(
128+
controller_address + "/get_worker_address",
129+
{"model": model},
130+
"address",
131+
)
132+
worker_addr_coro_list.append(worker_addr_coro)
133+
worker_address_list = await asyncio.gather(*worker_addr_coro_list)
134+
for model, worker_addr in zip(models, worker_address_list):
135+
model_address_map[model] = worker_addr
136+
await asyncio.sleep(6)
137+
except Exception:
138+
traceback.print_exc()
139+
await asyncio.sleep(6)
140+
141+
142+
@asynccontextmanager
143+
async def lifespan(app: fastapi.FastAPI):
144+
asyncio.create_task(timing_tasks())
145+
yield
146+
147+
148+
app = fastapi.FastAPI(docs_url="/", lifespan=lifespan)
149+
headers = {"User-Agent": "gpt_server API Server"}
110150
get_bearer_token = HTTPBearer(auto_error=False)
111151

112152

@@ -143,11 +183,10 @@ async def validation_exception_handler(request, exc):
143183
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
144184

145185

146-
async def check_model(request) -> Optional[JSONResponse]:
147-
controller_address = app_settings.controller_address
186+
def check_model(request) -> Optional[JSONResponse]:
187+
global model_address_map
148188
ret = None
149-
150-
models = await fetch_remote(controller_address + "/list_models", None, "models")
189+
models = list(model_address_map.keys())
151190
if request.model not in models:
152191
ret = create_error_response(
153192
ErrorCode.INVALID_MODEL,
@@ -156,54 +195,6 @@ async def check_model(request) -> Optional[JSONResponse]:
156195
return ret
157196

158197

159-
def check_requests(request) -> Optional[JSONResponse]:
160-
# Check all params
161-
if request.max_tokens is not None and request.max_tokens <= 0:
162-
return create_error_response(
163-
ErrorCode.PARAM_OUT_OF_RANGE,
164-
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
165-
)
166-
if request.n is not None and request.n <= 0:
167-
return create_error_response(
168-
ErrorCode.PARAM_OUT_OF_RANGE,
169-
f"{request.n} is less than the minimum of 1 - 'n'",
170-
)
171-
if request.temperature is not None and request.temperature < 0:
172-
return create_error_response(
173-
ErrorCode.PARAM_OUT_OF_RANGE,
174-
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
175-
)
176-
if request.temperature is not None and request.temperature > 2:
177-
return create_error_response(
178-
ErrorCode.PARAM_OUT_OF_RANGE,
179-
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
180-
)
181-
if request.top_p is not None and request.top_p < 0:
182-
return create_error_response(
183-
ErrorCode.PARAM_OUT_OF_RANGE,
184-
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
185-
)
186-
if request.top_p is not None and request.top_p > 1:
187-
return create_error_response(
188-
ErrorCode.PARAM_OUT_OF_RANGE,
189-
f"{request.top_p} is greater than the maximum of 1 - 'top_p'",
190-
)
191-
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
192-
return create_error_response(
193-
ErrorCode.PARAM_OUT_OF_RANGE,
194-
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
195-
)
196-
if request.stop is not None and (
197-
not isinstance(request.stop, str) and not isinstance(request.stop, list)
198-
):
199-
return create_error_response(
200-
ErrorCode.PARAM_OUT_OF_RANGE,
201-
f"{request.stop} is not valid under any of the given schemas - 'stop'",
202-
)
203-
204-
return None
205-
206-
207198
def process_input(model_name, inp):
208199
if isinstance(inp, str):
209200
inp = [inp]
@@ -242,7 +233,7 @@ def _add_to_set(s, new_stop):
242233
new_stop.update(s)
243234

244235

245-
async def get_gen_params(
236+
def get_gen_params(
246237
model_name: str,
247238
worker_addr: str,
248239
messages: Union[str, List[Dict[str, str]]],
@@ -305,18 +296,16 @@ async def get_gen_params(
305296
return gen_params
306297

307298

308-
async def get_worker_address(model_name: str) -> str:
299+
def get_worker_address(model_name: str) -> str:
309300
"""
310301
Get worker address based on the requested model
311302
312303
:param model_name: The worker's model name
313304
:return: Worker address from the controller
314305
:raises: :class:`ValueError`: No available worker for requested model
315306
"""
316-
controller_address = app_settings.controller_address
317-
worker_addr = await fetch_remote(
318-
controller_address + "/get_worker_address", {"model": model_name}, "address"
319-
)
307+
global model_address_map
308+
worker_addr = model_address_map[model_name]
320309

321310
# No available worker
322311
if worker_addr == "":
@@ -366,25 +355,30 @@ async def show_available_models():
366355
)
367356

368357

358+
@app.get(
359+
"/get_model_address_map",
360+
dependencies=[Depends(check_api_key)],
361+
response_class=responses.ORJSONResponse,
362+
)
363+
def get_model_address_map():
364+
global model_address_map
365+
return model_address_map
366+
367+
369368
@app.post(
370369
"/v1/chat/completions",
371370
dependencies=[Depends(check_api_key)],
372371
response_class=responses.ORJSONResponse,
373372
)
374373
async def create_chat_completion(request: CustomChatCompletionRequest):
375374
"""Creates a completion for the chat message"""
376-
error_check_ret = await check_model(request)
377-
if error_check_ret is not None:
378-
return error_check_ret
379-
error_check_ret = check_requests(request)
375+
error_check_ret = check_model(request)
380376
if error_check_ret is not None:
381377
return error_check_ret
382-
383-
worker_addr = await get_worker_address(request.model)
384-
385-
gen_params = await get_gen_params(
378+
worker_addr = get_worker_address(request.model)
379+
gen_params = get_gen_params(
386380
request.model,
387-
worker_addr,
381+
"",
388382
request.messages,
389383
temperature=request.temperature,
390384
top_p=request.top_p,
@@ -507,16 +501,13 @@ async def chat_completion_stream_generator(
507501
response_class=responses.ORJSONResponse,
508502
)
509503
async def create_completion(request: CompletionRequest):
510-
error_check_ret = await check_model(request)
511-
if error_check_ret is not None:
512-
return error_check_ret
513-
error_check_ret = check_requests(request)
504+
error_check_ret = check_model(request)
514505
if error_check_ret is not None:
515506
return error_check_ret
516507

517508
request.prompt = process_input(request.model, request.prompt)
518509

519-
worker_addr = await get_worker_address(request.model)
510+
worker_addr = get_worker_address(request.model)
520511
max_tokens = request.max_tokens
521512
for text in request.prompt:
522513
if isinstance(max_tokens, int) and max_tokens < request.max_tokens:
@@ -529,7 +520,7 @@ async def create_completion(request: CompletionRequest):
529520
else:
530521
text_completions = []
531522
for text in request.prompt:
532-
gen_params = await get_gen_params(
523+
gen_params = get_gen_params(
533524
request.model,
534525
worker_addr,
535526
text,
@@ -587,7 +578,7 @@ async def generate_completion_stream_generator(
587578
for text in request.prompt:
588579
for i in range(n):
589580
previous_text = ""
590-
gen_params = await get_gen_params(
581+
gen_params = get_gen_params(
591582
request.model,
592583
worker_addr,
593584
text,
@@ -705,7 +696,7 @@ async def speech(request: OpenAISpeechRequest):
705696
if error_check_ret is not None:
706697
return error_check_ret
707698

708-
worker_addr = await get_worker_address(request.model)
699+
worker_addr = get_worker_address(request.model)
709700
response_format = request.response_format
710701
payload = {
711702
"model": request.model,
@@ -766,7 +757,7 @@ async def speech(request: SpeechRequest):
766757
async def get_transcriptions(payload: Dict[str, Any]):
767758
controller_address = app_settings.controller_address
768759
model_name = payload["model"]
769-
worker_addr = await get_worker_address(model_name)
760+
worker_addr = get_worker_address(model_name)
770761

771762
transcription = await fetch_remote(
772763
worker_addr + "/worker_get_transcription", payload
@@ -812,7 +803,7 @@ async def transcriptions(file: UploadFile, model: str = Form()):
812803
response_class=responses.ORJSONResponse,
813804
)
814805
async def classify(request: ModerationsRequest):
815-
error_check_ret = await check_model(request)
806+
error_check_ret = check_model(request)
816807
if error_check_ret is not None:
817808
return error_check_ret
818809
request.input = process_input(request.model, request.input)
@@ -855,7 +846,7 @@ async def classify(request: ModerationsRequest):
855846
response_class=responses.ORJSONResponse,
856847
)
857848
async def rerank(request: RerankRequest):
858-
error_check_ret = await check_model(request)
849+
error_check_ret = check_model(request)
859850
if error_check_ret is not None:
860851
return error_check_ret
861852
request.documents = process_input(request.model, request.documents)
@@ -906,7 +897,7 @@ async def create_embeddings(request: CustomEmbeddingsRequest, model_name: str =
906897
"""Creates embeddings for the text"""
907898
if request.model is None:
908899
request.model = model_name
909-
error_check_ret = await check_model(request)
900+
error_check_ret = check_model(request)
910901
if error_check_ret is not None:
911902
return error_check_ret
912903

@@ -952,7 +943,7 @@ async def create_embeddings(request: CustomEmbeddingsRequest, model_name: str =
952943
async def get_classify(payload: Dict[str, Any]):
953944
controller_address = app_settings.controller_address
954945
model_name = payload["model"]
955-
worker_addr = await get_worker_address(model_name)
946+
worker_addr = get_worker_address(model_name)
956947

957948
classify = await fetch_remote(worker_addr + "/worker_get_classify", payload)
958949
return json.loads(classify)
@@ -961,7 +952,7 @@ async def get_classify(payload: Dict[str, Any]):
961952
async def get_embedding(payload: Dict[str, Any]):
962953
controller_address = app_settings.controller_address
963954
model_name = payload["model"]
964-
worker_addr = await get_worker_address(model_name)
955+
worker_addr = get_worker_address(model_name)
965956

966957
embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload)
967958
return json.loads(embedding)
@@ -978,7 +969,7 @@ async def count_tokens(request: APITokenCheckRequest):
978969
"""
979970
checkedList = []
980971
for item in request.prompts:
981-
worker_addr = await get_worker_address(item.model)
972+
worker_addr = get_worker_address(item.model)
982973

983974
context_len = await fetch_remote(
984975
worker_addr + "/model_details",
@@ -1008,16 +999,13 @@ async def count_tokens(request: APITokenCheckRequest):
1008999
@app.post("/api/v1/chat/completions")
10091000
async def create_chat_completion(request: APIChatCompletionRequest):
10101001
"""Creates a completion for the chat message"""
1011-
error_check_ret = await check_model(request)
1012-
if error_check_ret is not None:
1013-
return error_check_ret
1014-
error_check_ret = check_requests(request)
1002+
error_check_ret = check_model(request)
10151003
if error_check_ret is not None:
10161004
return error_check_ret
10171005

1018-
worker_addr = await get_worker_address(request.model)
1006+
worker_addr = get_worker_address(request.model)
10191007

1020-
gen_params = await get_gen_params(
1008+
gen_params = get_gen_params(
10211009
request.model,
10221010
worker_addr,
10231011
request.messages,

0 commit comments

Comments
 (0)