12
12
import json
13
13
import os
14
14
import time
15
+ import traceback
15
16
from typing import Generator , Optional , Union , Dict , List , Any
16
17
17
18
import aiohttp
@@ -105,8 +106,47 @@ class AppSettings(BaseSettings):
105
106
106
107
107
108
app_settings = AppSettings ()
108
- app = fastapi .FastAPI (docs_url = "/" )
109
- headers = {"User-Agent" : "FastChat API Server" }
109
+ from contextlib import asynccontextmanager
110
+
111
+ model_address_map = {}
112
+
113
+
114
+ async def timing_tasks ():
115
+ """定时任务"""
116
+ global model_address_map
117
+ logger .info ("定时任务已启动!" )
118
+ controller_address = app_settings .controller_address
119
+
120
+ while True :
121
+ try :
122
+ models = await fetch_remote (
123
+ controller_address + "/list_models" , None , "models"
124
+ )
125
+ worker_addr_coro_list = []
126
+ for model in models :
127
+ worker_addr_coro = fetch_remote (
128
+ controller_address + "/get_worker_address" ,
129
+ {"model" : model },
130
+ "address" ,
131
+ )
132
+ worker_addr_coro_list .append (worker_addr_coro )
133
+ worker_address_list = await asyncio .gather (* worker_addr_coro_list )
134
+ for model , worker_addr in zip (models , worker_address_list ):
135
+ model_address_map [model ] = worker_addr
136
+ await asyncio .sleep (6 )
137
+ except Exception :
138
+ traceback .print_exc ()
139
+ await asyncio .sleep (6 )
140
+
141
+
142
+ @asynccontextmanager
143
+ async def lifespan (app : fastapi .FastAPI ):
144
+ asyncio .create_task (timing_tasks ())
145
+ yield
146
+
147
+
148
+ app = fastapi .FastAPI (docs_url = "/" , lifespan = lifespan )
149
+ headers = {"User-Agent" : "gpt_server API Server" }
110
150
get_bearer_token = HTTPBearer (auto_error = False )
111
151
112
152
@@ -143,11 +183,10 @@ async def validation_exception_handler(request, exc):
143
183
return create_error_response (ErrorCode .VALIDATION_TYPE_ERROR , str (exc ))
144
184
145
185
146
- async def check_model (request ) -> Optional [JSONResponse ]:
147
- controller_address = app_settings . controller_address
186
+ def check_model (request ) -> Optional [JSONResponse ]:
187
+ global model_address_map
148
188
ret = None
149
-
150
- models = await fetch_remote (controller_address + "/list_models" , None , "models" )
189
+ models = list (model_address_map .keys ())
151
190
if request .model not in models :
152
191
ret = create_error_response (
153
192
ErrorCode .INVALID_MODEL ,
@@ -156,54 +195,6 @@ async def check_model(request) -> Optional[JSONResponse]:
156
195
return ret
157
196
158
197
159
- def check_requests (request ) -> Optional [JSONResponse ]:
160
- # Check all params
161
- if request .max_tokens is not None and request .max_tokens <= 0 :
162
- return create_error_response (
163
- ErrorCode .PARAM_OUT_OF_RANGE ,
164
- f"{ request .max_tokens } is less than the minimum of 1 - 'max_tokens'" ,
165
- )
166
- if request .n is not None and request .n <= 0 :
167
- return create_error_response (
168
- ErrorCode .PARAM_OUT_OF_RANGE ,
169
- f"{ request .n } is less than the minimum of 1 - 'n'" ,
170
- )
171
- if request .temperature is not None and request .temperature < 0 :
172
- return create_error_response (
173
- ErrorCode .PARAM_OUT_OF_RANGE ,
174
- f"{ request .temperature } is less than the minimum of 0 - 'temperature'" ,
175
- )
176
- if request .temperature is not None and request .temperature > 2 :
177
- return create_error_response (
178
- ErrorCode .PARAM_OUT_OF_RANGE ,
179
- f"{ request .temperature } is greater than the maximum of 2 - 'temperature'" ,
180
- )
181
- if request .top_p is not None and request .top_p < 0 :
182
- return create_error_response (
183
- ErrorCode .PARAM_OUT_OF_RANGE ,
184
- f"{ request .top_p } is less than the minimum of 0 - 'top_p'" ,
185
- )
186
- if request .top_p is not None and request .top_p > 1 :
187
- return create_error_response (
188
- ErrorCode .PARAM_OUT_OF_RANGE ,
189
- f"{ request .top_p } is greater than the maximum of 1 - 'top_p'" ,
190
- )
191
- if request .top_k is not None and (request .top_k > - 1 and request .top_k < 1 ):
192
- return create_error_response (
193
- ErrorCode .PARAM_OUT_OF_RANGE ,
194
- f"{ request .top_k } is out of Range. Either set top_k to -1 or >=1." ,
195
- )
196
- if request .stop is not None and (
197
- not isinstance (request .stop , str ) and not isinstance (request .stop , list )
198
- ):
199
- return create_error_response (
200
- ErrorCode .PARAM_OUT_OF_RANGE ,
201
- f"{ request .stop } is not valid under any of the given schemas - 'stop'" ,
202
- )
203
-
204
- return None
205
-
206
-
207
198
def process_input (model_name , inp ):
208
199
if isinstance (inp , str ):
209
200
inp = [inp ]
@@ -242,7 +233,7 @@ def _add_to_set(s, new_stop):
242
233
new_stop .update (s )
243
234
244
235
245
- async def get_gen_params (
236
+ def get_gen_params (
246
237
model_name : str ,
247
238
worker_addr : str ,
248
239
messages : Union [str , List [Dict [str , str ]]],
@@ -305,18 +296,16 @@ async def get_gen_params(
305
296
return gen_params
306
297
307
298
308
- async def get_worker_address (model_name : str ) -> str :
299
+ def get_worker_address (model_name : str ) -> str :
309
300
"""
310
301
Get worker address based on the requested model
311
302
312
303
:param model_name: The worker's model name
313
304
:return: Worker address from the controller
314
305
:raises: :class:`ValueError`: No available worker for requested model
315
306
"""
316
- controller_address = app_settings .controller_address
317
- worker_addr = await fetch_remote (
318
- controller_address + "/get_worker_address" , {"model" : model_name }, "address"
319
- )
307
+ global model_address_map
308
+ worker_addr = model_address_map [model_name ]
320
309
321
310
# No available worker
322
311
if worker_addr == "" :
@@ -366,25 +355,30 @@ async def show_available_models():
366
355
)
367
356
368
357
358
+ @app .get (
359
+ "/get_model_address_map" ,
360
+ dependencies = [Depends (check_api_key )],
361
+ response_class = responses .ORJSONResponse ,
362
+ )
363
+ def get_model_address_map ():
364
+ global model_address_map
365
+ return model_address_map
366
+
367
+
369
368
@app .post (
370
369
"/v1/chat/completions" ,
371
370
dependencies = [Depends (check_api_key )],
372
371
response_class = responses .ORJSONResponse ,
373
372
)
374
373
async def create_chat_completion (request : CustomChatCompletionRequest ):
375
374
"""Creates a completion for the chat message"""
376
- error_check_ret = await check_model (request )
377
- if error_check_ret is not None :
378
- return error_check_ret
379
- error_check_ret = check_requests (request )
375
+ error_check_ret = check_model (request )
380
376
if error_check_ret is not None :
381
377
return error_check_ret
382
-
383
- worker_addr = await get_worker_address (request .model )
384
-
385
- gen_params = await get_gen_params (
378
+ worker_addr = get_worker_address (request .model )
379
+ gen_params = get_gen_params (
386
380
request .model ,
387
- worker_addr ,
381
+ "" ,
388
382
request .messages ,
389
383
temperature = request .temperature ,
390
384
top_p = request .top_p ,
@@ -507,16 +501,13 @@ async def chat_completion_stream_generator(
507
501
response_class = responses .ORJSONResponse ,
508
502
)
509
503
async def create_completion (request : CompletionRequest ):
510
- error_check_ret = await check_model (request )
511
- if error_check_ret is not None :
512
- return error_check_ret
513
- error_check_ret = check_requests (request )
504
+ error_check_ret = check_model (request )
514
505
if error_check_ret is not None :
515
506
return error_check_ret
516
507
517
508
request .prompt = process_input (request .model , request .prompt )
518
509
519
- worker_addr = await get_worker_address (request .model )
510
+ worker_addr = get_worker_address (request .model )
520
511
max_tokens = request .max_tokens
521
512
for text in request .prompt :
522
513
if isinstance (max_tokens , int ) and max_tokens < request .max_tokens :
@@ -529,7 +520,7 @@ async def create_completion(request: CompletionRequest):
529
520
else :
530
521
text_completions = []
531
522
for text in request .prompt :
532
- gen_params = await get_gen_params (
523
+ gen_params = get_gen_params (
533
524
request .model ,
534
525
worker_addr ,
535
526
text ,
@@ -587,7 +578,7 @@ async def generate_completion_stream_generator(
587
578
for text in request .prompt :
588
579
for i in range (n ):
589
580
previous_text = ""
590
- gen_params = await get_gen_params (
581
+ gen_params = get_gen_params (
591
582
request .model ,
592
583
worker_addr ,
593
584
text ,
@@ -705,7 +696,7 @@ async def speech(request: OpenAISpeechRequest):
705
696
if error_check_ret is not None :
706
697
return error_check_ret
707
698
708
- worker_addr = await get_worker_address (request .model )
699
+ worker_addr = get_worker_address (request .model )
709
700
response_format = request .response_format
710
701
payload = {
711
702
"model" : request .model ,
@@ -766,7 +757,7 @@ async def speech(request: SpeechRequest):
766
757
async def get_transcriptions (payload : Dict [str , Any ]):
767
758
controller_address = app_settings .controller_address
768
759
model_name = payload ["model" ]
769
- worker_addr = await get_worker_address (model_name )
760
+ worker_addr = get_worker_address (model_name )
770
761
771
762
transcription = await fetch_remote (
772
763
worker_addr + "/worker_get_transcription" , payload
@@ -812,7 +803,7 @@ async def transcriptions(file: UploadFile, model: str = Form()):
812
803
response_class = responses .ORJSONResponse ,
813
804
)
814
805
async def classify (request : ModerationsRequest ):
815
- error_check_ret = await check_model (request )
806
+ error_check_ret = check_model (request )
816
807
if error_check_ret is not None :
817
808
return error_check_ret
818
809
request .input = process_input (request .model , request .input )
@@ -855,7 +846,7 @@ async def classify(request: ModerationsRequest):
855
846
response_class = responses .ORJSONResponse ,
856
847
)
857
848
async def rerank (request : RerankRequest ):
858
- error_check_ret = await check_model (request )
849
+ error_check_ret = check_model (request )
859
850
if error_check_ret is not None :
860
851
return error_check_ret
861
852
request .documents = process_input (request .model , request .documents )
@@ -906,7 +897,7 @@ async def create_embeddings(request: CustomEmbeddingsRequest, model_name: str =
906
897
"""Creates embeddings for the text"""
907
898
if request .model is None :
908
899
request .model = model_name
909
- error_check_ret = await check_model (request )
900
+ error_check_ret = check_model (request )
910
901
if error_check_ret is not None :
911
902
return error_check_ret
912
903
@@ -952,7 +943,7 @@ async def create_embeddings(request: CustomEmbeddingsRequest, model_name: str =
952
943
async def get_classify (payload : Dict [str , Any ]):
953
944
controller_address = app_settings .controller_address
954
945
model_name = payload ["model" ]
955
- worker_addr = await get_worker_address (model_name )
946
+ worker_addr = get_worker_address (model_name )
956
947
957
948
classify = await fetch_remote (worker_addr + "/worker_get_classify" , payload )
958
949
return json .loads (classify )
@@ -961,7 +952,7 @@ async def get_classify(payload: Dict[str, Any]):
961
952
async def get_embedding (payload : Dict [str , Any ]):
962
953
controller_address = app_settings .controller_address
963
954
model_name = payload ["model" ]
964
- worker_addr = await get_worker_address (model_name )
955
+ worker_addr = get_worker_address (model_name )
965
956
966
957
embedding = await fetch_remote (worker_addr + "/worker_get_embeddings" , payload )
967
958
return json .loads (embedding )
@@ -978,7 +969,7 @@ async def count_tokens(request: APITokenCheckRequest):
978
969
"""
979
970
checkedList = []
980
971
for item in request .prompts :
981
- worker_addr = await get_worker_address (item .model )
972
+ worker_addr = get_worker_address (item .model )
982
973
983
974
context_len = await fetch_remote (
984
975
worker_addr + "/model_details" ,
@@ -1008,16 +999,13 @@ async def count_tokens(request: APITokenCheckRequest):
1008
999
@app .post ("/api/v1/chat/completions" )
1009
1000
async def create_chat_completion (request : APIChatCompletionRequest ):
1010
1001
"""Creates a completion for the chat message"""
1011
- error_check_ret = await check_model (request )
1012
- if error_check_ret is not None :
1013
- return error_check_ret
1014
- error_check_ret = check_requests (request )
1002
+ error_check_ret = check_model (request )
1015
1003
if error_check_ret is not None :
1016
1004
return error_check_ret
1017
1005
1018
- worker_addr = await get_worker_address (request .model )
1006
+ worker_addr = get_worker_address (request .model )
1019
1007
1020
- gen_params = await get_gen_params (
1008
+ gen_params = get_gen_params (
1021
1009
request .model ,
1022
1010
worker_addr ,
1023
1011
request .messages ,
0 commit comments