Skip to content

Commit 72018ce

Browse files
committed
fastapi for modelcache_demo
1 parent 642e1ca commit 72018ce

File tree

2 files changed

+163
-2
lines changed

2 files changed

+163
-2
lines changed

fastapi4modelcache_demo.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# -*- coding: utf-8 -*-
2+
import time
3+
import uvicorn
4+
import asyncio
5+
import logging
6+
# import configparser
7+
import json
8+
from fastapi import FastAPI, Request, HTTPException
9+
from pydantic import BaseModel
10+
from concurrent.futures import ThreadPoolExecutor
11+
from starlette.responses import PlainTextResponse
12+
import functools
13+
14+
from modelcache import cache
15+
from modelcache.adapter import adapter
16+
from modelcache.manager import CacheBase, VectorBase, get_data_manager
17+
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
18+
from modelcache.processor.pre import query_multi_splicing
19+
from modelcache.processor.pre import insert_multi_splicing
20+
from modelcache.utils.model_filter import model_blacklist_filter
21+
from modelcache.embedding import Data2VecAudio
22+
23+
# 创建一个FastAPI实例
24+
app = FastAPI()
25+
26+
class RequestData(BaseModel):
27+
type: str
28+
scope: dict = None
29+
query: str = None
30+
chat_info: list = None
31+
remove_type: str = None
32+
id_list: list = []
33+
34+
data2vec = Data2VecAudio()
35+
36+
data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))
37+
38+
cache.init(
39+
embedding_func=data2vec.to_embeddings,
40+
data_manager=data_manager,
41+
similarity_evaluation=SearchDistanceEvaluation(),
42+
query_pre_embedding_func=query_multi_splicing,
43+
insert_pre_embedding_func=insert_multi_splicing,
44+
)
45+
46+
executor = ThreadPoolExecutor(max_workers=6)
47+
48+
# 异步保存查询信息
49+
async def save_query_info_fastapi(result, model, query, delta_time_log):
50+
loop = asyncio.get_running_loop()
51+
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
52+
await loop.run_in_executor(None, func)
53+
54+
55+
56+
@app.get("/welcome", response_class=PlainTextResponse)
57+
async def first_fastapi():
58+
return "hello, modelcache!"
59+
60+
@app.post("/modelcache")
61+
async def user_backend(request: Request):
62+
try:
63+
raw_body = await request.body()
64+
# 解析字符串为JSON对象
65+
if isinstance(raw_body, bytes):
66+
raw_body = raw_body.decode("utf-8")
67+
if isinstance(raw_body, str):
68+
try:
69+
# 尝试将字符串解析为JSON对象
70+
request_data = json.loads(raw_body)
71+
except json.JSONDecodeError:
72+
# 如果无法解析,返回格式错误
73+
raise HTTPException(status_code=400, detail="Invalid JSON format")
74+
else:
75+
request_data = raw_body
76+
77+
# 确保request_data是字典对象
78+
if isinstance(request_data, str):
79+
try:
80+
request_data = json.loads(request_data)
81+
except json.JSONDecodeError:
82+
raise HTTPException(status_code=400, detail="Invalid JSON format")
83+
84+
request_type = request_data.get('type')
85+
model = None
86+
if 'scope' in request_data:
87+
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
88+
query = request_data.get('query')
89+
chat_info = request_data.get('chat_info')
90+
91+
if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
92+
raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")
93+
94+
except Exception as e:
95+
request_data = raw_body if 'raw_body' in locals() else None
96+
result = {
97+
"errorCode": 103,
98+
"errorDesc": str(e),
99+
"cacheHit": False,
100+
"delta_time": 0,
101+
"hit_query": '',
102+
"answer": '',
103+
"para_dict": request_data
104+
}
105+
return result
106+
107+
108+
# model filter
109+
filter_resp = model_blacklist_filter(model, request_type)
110+
if isinstance(filter_resp, dict):
111+
return filter_resp
112+
113+
if request_type == 'query':
114+
try:
115+
start_time = time.time()
116+
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
117+
delta_time = f"{round(time.time() - start_time, 2)}s"
118+
119+
if response is None:
120+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
121+
elif response in ['adapt_query_exception']:
122+
# elif isinstance(response, str):
123+
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
124+
"hit_query": '', "answer": ''}
125+
else:
126+
answer = response['data']
127+
hit_query = response['hitQuery']
128+
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}
129+
130+
delta_time_log = round(time.time() - start_time, 2)
131+
asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
132+
return result
133+
except Exception as e:
134+
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
135+
"hit_query": '', "answer": ''}
136+
logging.info(f'result: {str(result)}')
137+
return result
138+
139+
if request_type == 'insert':
140+
try:
141+
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
142+
if response == 'success':
143+
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
144+
else:
145+
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
146+
except Exception as e:
147+
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}
148+
149+
if request_type == 'remove':
150+
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
151+
if not isinstance(response, dict):
152+
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}
153+
154+
state = response.get('status')
155+
if state == 'success':
156+
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
157+
else:
158+
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
159+
160+
# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
161+
if __name__ == '__main__':
162+
uvicorn.run(app, host='0.0.0.0', port=5000)

modelcache/manager/scalar_data/sql_storage_sqlite.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ def insert_query_resp(self, query_resp, **kwargs):
100100
hit_query = json.dumps(hit_query, ensure_ascii=False)
101101

102102
table_name = "modelcache_query_log"
103-
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)
104-
103+
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?)".format(table_name)
105104
conn = sqlite3.connect(self._url)
106105
try:
107106
cursor = conn.cursor()

0 commit comments

Comments
 (0)