Skip to content

Commit c451405

Browse files
authored
Merge pull request #21 from codefuse-ai/release_20240409
Release 20240409
2 parents 5af9cf8 + 560c862 commit c451405

File tree

10 files changed

+248
-4
lines changed

10 files changed

+248
-4
lines changed

examples/flask/register.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
register index for redis
4+
"""
5+
import json
6+
import requests
7+
8+
9+
def run():
10+
url = 'http://127.0.0.1:5000/modelcache'
11+
type = 'register'
12+
scope = {"model": "CODEGPT-1117"}
13+
data = {'type': type, 'scope': scope}
14+
headers = {"Content-Type": "application/json"}
15+
res = requests.post(url, headers=headers, json=json.dumps(data))
16+
res_text = res.text
17+
print('res_text: {}'.format(res_text))
18+
19+
20+
if __name__ == '__main__':
21+
run()

flask4modelcache.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22
import time
3-
from datetime import datetime
43
from flask import Flask, request
54
import logging
65
import configparser
@@ -15,7 +14,6 @@
1514
from modelcache.utils.model_filter import model_blacklist_filter
1615
from modelcache.embedding import Data2VecAudio
1716

18-
1917
# 创建一个Flask实例
2018
app = Flask(__name__)
2119

@@ -36,11 +34,20 @@ def response_hitquery(cache_resp):
3634
data2vec = Data2VecAudio()
3735
mysql_config = configparser.ConfigParser()
3836
mysql_config.read('modelcache/config/mysql_config.ini')
37+
3938
milvus_config = configparser.ConfigParser()
4039
milvus_config.read('modelcache/config/milvus_config.ini')
40+
41+
# redis_config = configparser.ConfigParser()
42+
# redis_config.read('modelcache/config/redis_config.ini')
43+
44+
4145
data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
4246
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config))
4347

48+
# data_manager = get_data_manager(CacheBase("mysql", config=mysql_config),
49+
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config))
50+
4451

4552
cache.init(
4653
embedding_func=data2vec.to_embeddings,
@@ -85,9 +92,9 @@ def user_backend():
8592
model = model.replace('.', '_')
8693
query = param_dict.get("query")
8794
chat_info = param_dict.get("chat_info")
88-
if request_type is None or request_type not in ['query', 'insert', 'detox', 'remove']:
95+
if request_type is None or request_type not in ['query', 'insert', 'remove', 'register']:
8996
result = {"errorCode": 102,
90-
"errorDesc": "type exception, should one of ['query', 'insert', 'detox', 'remove']",
97+
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
9198
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
9299
cache.data_manager.save_query_resp(result, model=model, query='', delta_time=0)
93100
return json.dumps(result)
@@ -170,6 +177,17 @@ def user_backend():
170177
result = {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}
171178
return json.dumps(result)
172179

180+
if request_type == 'register':
181+
# iat_type = param_dict.get("iat_type")
182+
response = adapter.ChatCompletion.create_register(
183+
model=model
184+
)
185+
if response in ['create_success', 'already_exists']:
186+
result = {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
187+
else:
188+
result = {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}
189+
return json.dumps(result)
190+
173191

174192
if __name__ == '__main__':
175193
app.run(host='0.0.0.0', port=5000, debug=True)

modelcache/adapter/adapter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from modelcache.adapter.adapter_query import adapt_query
66
from modelcache.adapter.adapter_insert import adapt_insert
77
from modelcache.adapter.adapter_remove import adapt_remove
8+
from modelcache.adapter.adapter_register import adapt_register
89

910

1011
class ChatCompletion(openai.ChatCompletion):
@@ -44,6 +45,16 @@ def create_remove(cls, *args, **kwargs):
4445
logging.info('adapt_remove_e: {}'.format(e))
4546
return str(e)
4647

48+
@classmethod
49+
def create_register(cls, *args, **kwargs):
50+
try:
51+
return adapt_register(
52+
*args,
53+
**kwargs
54+
)
55+
except Exception as e:
56+
return str(e)
57+
4758

4859
def construct_resp_from_cache(return_message, return_query):
4960
return {
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# -*- coding: utf-8 -*-
2+
from modelcache import cache
3+
4+
5+
def adapt_register(*args, **kwargs):
6+
chat_cache = kwargs.pop("cache_obj", cache)
7+
model = kwargs.pop("model", None)
8+
if model is None or len(model) == 0:
9+
return ValueError('')
10+
11+
register_resp = chat_cache.data_manager.create_index(model)
12+
print('register_resp: {}'.format(register_resp))
13+
return register_resp

modelcache/manager/data_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ def delete(self, id_list, **kwargs):
256256
return {'status': 'success', 'milvus': 'delete_count: '+str(v_delete_count),
257257
'mysql': 'delete_count: '+str(s_delete_count)}
258258

259+
def create_index(self, model, **kwargs):
260+
return self.v.create(model)
261+
259262
def truncate(self, model_name):
260263
# model = kwargs.pop("model", None)
261264
# drop milvus data

modelcache/manager/vector_data/manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,28 @@ def get(name, **kwargs):
6868
local_mode=local_mode,
6969
local_data=local_data
7070
)
71+
elif name == "redis":
72+
from modelcache.manager.vector_data.redis import RedisVectorStore
73+
dimension = kwargs.get("dimension", DIMENSION)
74+
VectorBase.check_dimension(dimension)
75+
76+
redis_config = kwargs.get("redis_config")
77+
host = redis_config.get('redis', 'host')
78+
port = redis_config.get('redis', 'port')
79+
user = redis_config.get('redis', 'user')
80+
password = redis_config.get('redis', 'password')
81+
namespace = kwargs.get("namespace", "")
82+
# collection_name = kwargs.get("collection_name", COLLECTION_NAME)
83+
84+
vector_base = RedisVectorStore(
85+
host=host,
86+
port=port,
87+
username=user,
88+
password=password,
89+
namespace=namespace,
90+
top_k=top_k,
91+
dimension=dimension,
92+
)
7193
elif name == "faiss":
7294
from modelcache.manager.vector_data.faiss import Faiss
7395

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# -*- coding: utf-8 -*-
2+
from typing import List
3+
import numpy as np
4+
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
5+
from redis.commands.search.query import Query
6+
from redis.commands.search.field import TagField, VectorField, NumericField
7+
from redis.client import Redis
8+
9+
from modelcache.manager.vector_data.base import VectorBase, VectorData
10+
from modelcache.utils import import_redis
11+
from modelcache.utils.log import modelcache_log
12+
from modelcache.utils.index_util import get_index_name
13+
from modelcache.utils.index_util import get_index_prefix
14+
import_redis()
15+
16+
17+
class RedisVectorStore(VectorBase):
18+
def __init__(
19+
self,
20+
host: str = "localhost",
21+
port: str = "6379",
22+
username: str = "",
23+
password: str = "",
24+
dimension: int = 0,
25+
top_k: int = 1,
26+
namespace: str = "",
27+
):
28+
if dimension <= 0:
29+
raise ValueError(
30+
f"invalid `dim` param: {dimension} in the Milvus vector store."
31+
)
32+
self._client = Redis(
33+
host=host, port=int(port), username=username, password=password
34+
)
35+
self.top_k = top_k
36+
self.dimension = dimension
37+
self.namespace = namespace
38+
self.doc_prefix = f"{self.namespace}doc:"
39+
40+
def _check_index_exists(self, index_name: str) -> bool:
41+
"""Check if Redis index exists."""
42+
try:
43+
self._client.ft(index_name).info()
44+
except:
45+
modelcache_log.info("Index does not exist")
46+
return False
47+
modelcache_log.info("Index already exists")
48+
return True
49+
50+
def create_index(self, index_name, index_prefix):
51+
dimension = self.dimension
52+
print('dimension: {}'.format(dimension))
53+
if self._check_index_exists(index_name):
54+
modelcache_log.info(
55+
"The %s already exists, and it will be used directly", index_name
56+
)
57+
return 'already_exists'
58+
else:
59+
id_field_name = "data_id"
60+
embedding_field_name = "data_vector"
61+
62+
id = NumericField(name=id_field_name)
63+
embedding = VectorField(embedding_field_name,
64+
"HNSW", {
65+
"TYPE": "FLOAT32",
66+
"DIM": dimension,
67+
"DISTANCE_METRIC": "L2",
68+
"INITIAL_CAP": 1000,
69+
}
70+
)
71+
fields = [id, embedding]
72+
definition = IndexDefinition(prefix=[index_prefix], index_type=IndexType.HASH)
73+
74+
# create Index
75+
self._client.ft(index_name).create_index(
76+
fields=fields, definition=definition
77+
)
78+
return 'create_success'
79+
80+
def mul_add(self, datas: List[VectorData], model=None):
81+
# pipe = self._client.pipeline()
82+
for data in datas:
83+
id: int = data.id
84+
embedding = data.data.astype(np.float32).tobytes()
85+
id_field_name = "data_id"
86+
embedding_field_name = "data_vector"
87+
obj = {id_field_name: id, embedding_field_name: embedding}
88+
index_prefix = get_index_prefix(model)
89+
self._client.hset(f"{index_prefix}{id}", mapping=obj)
90+
91+
def search(self, data: np.ndarray, top_k: int = -1, model=None):
92+
index_name = get_index_name(model)
93+
id_field_name = "data_id"
94+
embedding_field_name = "data_vector"
95+
96+
base_query = f'*=>[KNN 2 @{embedding_field_name} $vector AS distance]'
97+
query = (
98+
Query(base_query)
99+
.sort_by("distance")
100+
.return_fields(id_field_name, "distance")
101+
.dialect(2)
102+
)
103+
104+
query_params = {"vector": data.astype(np.float32).tobytes()}
105+
results = (
106+
self._client.ft(index_name)
107+
.search(query, query_params=query_params)
108+
.docs
109+
)
110+
return [(float(result.distance), int(getattr(result, id_field_name))) for result in results]
111+
112+
def rebuild(self, ids=None) -> bool:
113+
pass
114+
115+
def rebuild_col(self, model):
116+
index_name_model = get_index_name(model)
117+
if self._check_index_exists(index_name_model):
118+
try:
119+
self._client.ft(index_name_model).dropindex(delete_documents=True)
120+
except Exception as e:
121+
raise ValueError(str(e))
122+
try:
123+
index_prefix = get_index_prefix(model)
124+
self.create_index(index_name_model, index_prefix)
125+
except Exception as e:
126+
raise ValueError(str(e))
127+
return 'rebuild success'
128+
129+
def delete(self, ids) -> None:
130+
pipe = self._client.pipeline()
131+
for data_id in ids:
132+
pipe.delete(f"{self.doc_prefix}{data_id}")
133+
pipe.execute()
134+
135+
def create(self, model=None):
136+
index_name = get_index_name(model)
137+
index_prefix = get_index_prefix(model)
138+
return self.create_index(index_name, index_prefix)
139+
140+
def get_index_by_name(self, index_name):
141+
pass

modelcache/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,7 @@ def import_timm():
6969

7070
def import_pillow():
7171
_check_library("PIL", package="pillow")
72+
73+
74+
def import_redis():
75+
_check_library("redis")

modelcache/utils/index_util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
def get_index_name(model):
5+
return 'modelcache' + '_' + model
6+
7+
8+
def get_index_prefix(model):
9+
return 'prefix' + '_' + model

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,5 @@ Requests==2.31.0
1010
torch==2.1.0
1111
transformers==4.34.1
1212
faiss-cpu==1.7.4
13+
redis==5.0.1
14+

0 commit comments

Comments
 (0)