Skip to content

Commit 63f85bb

Browse files
authored
add long-term memory and support toolllama as brains; (#11)
* add long-term memory and support toolllama as brains; support partial match of tool names * add some docstrings
1 parent 0b8d6e1 commit 63f85bb

18 files changed

+768
-333
lines changed

eval/user_simulator.py

+296-149
Large diffs are not rendered by default.

llm4crs/agent_plan_first_openai.py

+140-80
Large diffs are not rendered by default.

llm4crs/corups/base.py

+53-39
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from pandasql import sqldf
1212
from pandas.api.types import is_integer_dtype, is_bool_dtype, is_float_dtype, is_datetime64_dtype, is_object_dtype, is_categorical_dtype
13+
from sentence_transformers import SentenceTransformer
14+
import torch
1315

1416
from llm4crs.utils import raise_error, SentBERTEngine
1517

@@ -34,16 +36,14 @@ def _pd_type_to_sql_type(col: pd.Series) -> str:
3436

3537

3638
class BaseGallery:
37-
3839
def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Information', columns: List[str]=None, sep: str=',', parquet_engine: str='pyarrow',
3940
fuzzy_cols: List[str]=['title'], categorical_cols: List[str]=['tags']) -> None:
4041
self.fpath = fpath
4142
self.name = name # name of the table
4243
self.corups = self._read_file(fpath, columns, sep, parquet_engine)
43-
# tags to be displayed to LLM: topk query-related tags + random selected tags
4444
self.disp_cate_topk: int = 6
4545
self.disp_cate_total: int = 10
46-
self._fuzzy_bert_base = "BAAI/bge-base-en-v1.5"
46+
self._fuzzy_bert_base = "thenlper/gte-base"
4747
self._required_columns_validate()
4848
self.column_meaning = self._load_col_desc_file(column_meaning_file)
4949

@@ -56,11 +56,33 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa
5656
else:
5757
self.corups[col] = self.corups[col].apply(lambda x: ', '.join(x))
5858

59-
self.fuzzy_engine: Dict[str:SentBERTEngine] = {
60-
col : SentBERTEngine(self.corups[col].to_numpy(), self.corups['id'].to_numpy(), case_sensitive=False, model_name=self._fuzzy_bert_base) if col not in categorical_cols
61-
else SentBERTEngine(self.categorical_col_values[col], np.arange(len(self.categorical_col_values[col])), case_sensitive=False, model_name=self._fuzzy_bert_base)
59+
if torch.cuda.is_available():
60+
device = 'cuda'
61+
else:
62+
device = 'cpu'
63+
_fuzzy_bert_engine = SentenceTransformer(self._fuzzy_bert_base, device=device)
64+
self.fuzzy_engine: Dict[str, SentBERTEngine] = {
65+
col: SentBERTEngine(
66+
self.corups[col].to_numpy(),
67+
self.corups["id"].to_numpy(),
68+
case_sensitive=False,
69+
model=_fuzzy_bert_engine
70+
)
71+
if col not in categorical_cols
72+
else SentBERTEngine(
73+
self.categorical_col_values[col],
74+
np.arange(len(self.categorical_col_values[col])),
75+
case_sensitive=False,
76+
model=_fuzzy_bert_engine
77+
)
6278
for col in fuzzy_cols
6379
}
80+
self.fuzzy_engine['sql_cols'] = SentBERTEngine(
81+
np.array(columns),
82+
np.arange(len(columns)),
83+
case_sensitive=False,
84+
model=_fuzzy_bert_engine
85+
) # fuzzy engine for column names
6486
# title as index
6587
self.corups_title = self.corups.set_index('title', drop=True)
6688
# id as index
@@ -76,25 +98,18 @@ def __call__(self, sql: str, corups: pd.DataFrame=None, return_id_only: bool=Tru
7698
Returns:
7799
list: the result represents by id
78100
"""
79-
try:
80-
if corups is None:
81-
res = sqldf(sql, {self.name: self.corups}) # all games
82-
else:
83-
res = sqldf(sql, {self.name: corups}) # games in buffer
101+
if corups is None:
102+
result = sqldf(sql, {self.name: self.corups}) # all games
103+
else:
104+
result = sqldf(sql, {self.name: corups}) # games in buffer
84105

85-
if return_id_only:
86-
res = res[self.corups.index.name].to_list()
87-
else:
88-
pass
89-
return res
90-
except Exception as e:
91-
print(e)
92-
return []
106+
if return_id_only:
107+
result = result[self.corups.index.name].to_list()
108+
return result
93109

94110

95111
def __len__(self) -> int:
96112
return len(self.corups)
97-
98113

99114
def info(self, remove_game_titles: bool=False, query: str=None):
100115
prefix = 'Table information:'
@@ -110,7 +125,7 @@ def info(self, remove_game_titles: bool=False, query: str=None):
110125
disp_values = self.sample_categoricol_values(col, total_n=self.disp_cate_total, query=query, topk=self.disp_cate_topk)
111126
_prefix = f" Related values: [{', '.join(disp_values)}]."
112127
cols_info += _prefix
113-
128+
114129
if dtype in {'float', 'datetime', 'integer'}:
115130
_min = self.corups[col].min()
116131
_max = self.corups[col].max()
@@ -120,17 +135,18 @@ def info(self, remove_game_titles: bool=False, query: str=None):
120135
cols_info += _prefix
121136

122137
primary_key = f"Primary Key: {self.corups.index.name}"
123-
foreign_key = f"Foreign Key: None"
138+
categorical_cols = list(self.categorical_col_values.keys())
139+
note = f"Note that [{','.join(categorical_cols)}] columns are categorical, must use related values to search otherwise no result would be returned."
124140
res = ''
125-
for i, s in enumerate([table_name, cols_info, primary_key, foreign_key]):
141+
for i, s in enumerate([table_name, cols_info, primary_key, note]):
126142
res += f"\n{i}. {s}"
127143
res = prefix + res
128144
return res
129145

130146
def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None, topk: int=None) -> List:
131147
# Select topk related tags according to query and sample (total_n-topk) tags
132148
if query is None:
133-
result = random.sample(self.categorical_col_values[col_name].tolist(), k=total_n)
149+
result = random.sample(self.categorical_col_values[col_name], k=total_n)
134150
else:
135151
if topk is None:
136152
topk = total_n
@@ -146,11 +162,11 @@ def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None
146162
return result
147163

148164

149-
def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]:
150-
"""Given game id, get game informations.
165+
def convert_id_2_info(self, item_id: Union[int, List[int], np.ndarray], col_names: Union[str, List[str]]=None) -> Union[Dict, List[Dict]]:
166+
"""Given game item_id, get game informations.
151167
152168
Args:
153-
- id: game ids.
169+
- item_id: game ids.
154170
- col_names: column names to be returned
155171
156172
Returns:
@@ -167,13 +183,13 @@ def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Un
167183
else:
168184
raise_error(TypeError, "Not supported type for `col_names`.")
169185

170-
if isinstance(id, int):
171-
items = self.corups.loc[id][col_names].to_dict()
172-
elif isinstance(id, list) or isinstance(id, np.ndarray):
173-
items = self.corups.loc[id][col_names].to_dict(orient='list')
186+
if isinstance(item_id, int):
187+
items = self.corups.loc[item_id][col_names].to_dict()
188+
elif isinstance(item_id, list) or isinstance(item_id, np.ndarray):
189+
items = self.corups.loc[item_id][col_names].to_dict(orient='list')
174190
else:
175-
raise_error(TypeError, "Not supported type for `id`.")
176-
191+
raise_error(TypeError, "Not supported type for `item_id`.")
192+
177193
return items
178194

179195

@@ -204,7 +220,7 @@ def convert_title_2_info(self, titles: Union[int, List[int], np.ndarray], col_na
204220
items = self.corups_title.loc[titles][col_names].to_dict(orient='list')
205221
else:
206222
raise_error(TypeError, "Not supported type for `titles`.")
207-
223+
208224
return items
209225

210226

@@ -225,15 +241,14 @@ def _read_file(self, fpath: str, columns: List[str]=None, sep: str=',', parquet_
225241

226242
def _load_col_desc_file(self, fpath: str) -> Dict:
227243
assert fpath.endswith('.json'), "Only support json file now."
228-
with open(fpath, 'r') as f:
244+
with open(fpath, 'r', encoding='utf-8') as f:
229245
return json.load(f)
230246

231-
247+
232248
def _required_columns_validate(self) -> None:
233249
for col in _REQUIRED_COLUMNS:
234250
if col not in self.corups.columns:
235251
raise_error(ValueError, f"`id` and `name` are required in item corups table but {col} not found, please check the table file `{self.fpath}`.")
236-
237252

238253
def fuzzy_match(self, value: Union[str, List[str]], col: str) -> Union[str, List[str]]:
239254
if col not in self.fuzzy_engine:
@@ -244,9 +259,8 @@ def fuzzy_match(self, value: Union[str, List[str]], col: str) -> Union[str, List
244259
return res
245260

246261

247-
248262
if __name__ == '__main__':
249-
from llm4crs.environ_variables import *
263+
from llm4crs.environ_variables import GAME_INFO_FILE, TABLE_COL_DESC_FILE
250264
gallery = BaseGallery(GAME_INFO_FILE, column_meaning_file=TABLE_COL_DESC_FILE)
251265
print(gallery.info())
252266

llm4crs/critic/base.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,10 @@ def _call(self, request: str, answer: str, history: str, tracks: str):
8686
answer=answer,
8787
**TOOL_NAMES,
8888
)
89-
if self.bot_type == "chat":
90-
prompt = [
91-
{"role": "system", "content": sys_msg},
92-
{"role": "user", "content": usr_msg},
93-
]
94-
else:
95-
prompt = f"{sys_msg}\n{usr_msg}"
9689

97-
reply = self.bot.call(prompt, max_tokens=128)
90+
reply = self.bot.call(
91+
sys_prompt=sys_msg,
92+
user_prompt=usr_msg,
93+
max_tokens=128
94+
)
9895
return reply

llm4crs/memory/__init__.py

Whitespace-only changes.

llm4crs/memory/memory.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
2+
import json
3+
4+
from llm4crs.utils.open_ai import OpenAICall
5+
6+
_FEW_SHOT_EXAMPLES = \
7+
"""
8+
> Conversations
9+
User: My history is ITEM-1, ITEM-2, ITEM-3. Now I want something new.
10+
Assistent: Based on your preference, I recommend you ITEM-17, ITEM-19, ITEM-30.
11+
User: I don't like those items, give me more options.
12+
Assistent: Based on your feedbacks, I recommend you ITEM-5, ITEM-100.
13+
User: I think ITEM-100 may be very interesting. I may like it.
14+
> Profiles
15+
{"history": ["ITEM-1", "ITEM-2", "ITEM-3"], "like": ["ITEM-100"], "unwanted": ["ITEM-17", "ITEM-19", "ITEM-30"]}
16+
17+
> Conversations
18+
User: I used to enjoy ITEM-89, ITEM-11, ITEM-78, ITEM-67. Now I want something new.
19+
Assistent: Based on your preference, I recommend you ITEM-53, ITEM-10.
20+
User: I think ITEM-10 may be very interesting, but I don't like it.
21+
Assistent: Based on your feedbacks, I recommend you ITEM-88, ITEM-70.
22+
User: I don't like those items, give me more options.
23+
> Profiles
24+
{"history": ["ITEM-89", "ITEM-11", "ITEM-78", "ITEM-67"], "like": [], "unwanted": ["ITEM-10", "ITEM-88", "ITEM-70"]}
25+
26+
"""
27+
28+
class UserProfileMemory:
29+
"""
30+
The memory is used to store long-term user profile. It can be updated by the conversation and used as the input for recommendation tool.
31+
32+
The memory consists of three parts: history, like and unwanted. Each part is a set. The history is a set of items that the user has interacted with. The like is a set of items that the user likes. The unwanted is a set of items that the user dislikes.
33+
"""
34+
def __init__(self, llm_engine=None, **kwargs) -> None:
35+
if llm_engine:
36+
self.llm_engine = llm_engine
37+
else:
38+
self.llm_engine = OpenAICall(**kwargs)
39+
self.profile = {
40+
"history": set([]),
41+
"like": set([]),
42+
"unwanted": set([]),
43+
}
44+
45+
def conclude_user_profile(self, conversation: str) -> str:
46+
prompt = "Your task is to extract user profile from the conversation."
47+
prompt += f"The profile consists of three parts: history, like and unwanted.Each part is a list. You should return a json-format string.\nHere are some examples.\n{_FEW_SHOT_EXAMPLES}\nNow extract user profiles from below conversation: \n> Conversation\n{conversation}\n> Profiles\n"
48+
return self.llm_engine.call(
49+
user_prompt=prompt,
50+
temperature=0.0
51+
)
52+
53+
54+
def correct_format(self, err_resp: str) -> str:
55+
prompt = "Your task is to correct the string to json format. Here are two examples of the format:\n{\"history\": [\"ITEM-1\", \"ITEM-2\", \"ITEM-3\"], \"like\": [\"ITEM-100\"], \"unwanted\": [\"ITEM-17\", \"ITEM-19\", \"ITEM-30\"]}\nThe string to be corrected is {err_resp}. It can not be parsed by Python json.loads(). Now give the corrected json format string.".replace("{err_resp}", err_resp)
56+
return self.llm_engine.call(
57+
user_prompt=prompt,
58+
sys_prompt="You are an assistent and good at writing json string.",
59+
temperature=0.0
60+
)
61+
62+
63+
def update(self, conversation: str):
64+
cur_profile: str = self.conclude_user_profile(conversation)
65+
parse_success = False
66+
limit = 3
67+
tries = 0
68+
while not parse_success and tries < limit:
69+
try:
70+
cur_profile_dict = json.loads(cur_profile)
71+
parse_success = True
72+
except json.decoder.JSONDecodeError as e:
73+
cur_profile = self.correct_format(cur_profile)
74+
tries += 1
75+
if parse_success:
76+
# update profile
77+
self.profile['like'] -= set(cur_profile_dict.get('unwanted', []))
78+
self.profile['like'].update(cur_profile_dict.get('like', []))
79+
self.profile['unwanted'] -= set(cur_profile_dict.get('like', []))
80+
self.profile['unwanted'].update(cur_profile_dict.get('unwanted', []))
81+
self.profile['history'].update(cur_profile_dict.get('history', []))
82+
83+
def get(self) -> dict:
84+
return {k: list(v) for k, v in self.profile.items()}
85+
86+
87+
def clear(self):
88+
self.profile = {
89+
"history": set([]),
90+
"like": set([]),
91+
"unwanted": set([]),
92+
}
93+

llm4crs/prompt/critic.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
5+
46
CRITIC_PROMPT = \
57
"""
68
{{#system~}}

0 commit comments

Comments
 (0)