10
10
11
11
from pandasql import sqldf
12
12
from pandas .api .types import is_integer_dtype , is_bool_dtype , is_float_dtype , is_datetime64_dtype , is_object_dtype , is_categorical_dtype
13
+ from sentence_transformers import SentenceTransformer
14
+ import torch
13
15
14
16
from llm4crs .utils import raise_error , SentBERTEngine
15
17
@@ -34,16 +36,14 @@ def _pd_type_to_sql_type(col: pd.Series) -> str:
34
36
35
37
36
38
class BaseGallery :
37
-
38
39
def __init__ (self , fpath : str , column_meaning_file : str , name : str = 'Item_Information' , columns : List [str ]= None , sep : str = ',' , parquet_engine : str = 'pyarrow' ,
39
40
fuzzy_cols : List [str ]= ['title' ], categorical_cols : List [str ]= ['tags' ]) -> None :
40
41
self .fpath = fpath
41
42
self .name = name # name of the table
42
43
self .corups = self ._read_file (fpath , columns , sep , parquet_engine )
43
- # tags to be displayed to LLM: topk query-related tags + random selected tags
44
44
self .disp_cate_topk : int = 6
45
45
self .disp_cate_total : int = 10
46
- self ._fuzzy_bert_base = "BAAI/bge -base-en-v1.5 "
46
+ self ._fuzzy_bert_base = "thenlper/gte -base"
47
47
self ._required_columns_validate ()
48
48
self .column_meaning = self ._load_col_desc_file (column_meaning_file )
49
49
@@ -56,11 +56,33 @@ def __init__(self, fpath: str, column_meaning_file: str, name: str='Item_Informa
56
56
else :
57
57
self .corups [col ] = self .corups [col ].apply (lambda x : ', ' .join (x ))
58
58
59
- self .fuzzy_engine : Dict [str :SentBERTEngine ] = {
60
- col : SentBERTEngine (self .corups [col ].to_numpy (), self .corups ['id' ].to_numpy (), case_sensitive = False , model_name = self ._fuzzy_bert_base ) if col not in categorical_cols
61
- else SentBERTEngine (self .categorical_col_values [col ], np .arange (len (self .categorical_col_values [col ])), case_sensitive = False , model_name = self ._fuzzy_bert_base )
59
+ if torch .cuda .is_available ():
60
+ device = 'cuda'
61
+ else :
62
+ device = 'cpu'
63
+ _fuzzy_bert_engine = SentenceTransformer (self ._fuzzy_bert_base , device = device )
64
+ self .fuzzy_engine : Dict [str , SentBERTEngine ] = {
65
+ col : SentBERTEngine (
66
+ self .corups [col ].to_numpy (),
67
+ self .corups ["id" ].to_numpy (),
68
+ case_sensitive = False ,
69
+ model = _fuzzy_bert_engine
70
+ )
71
+ if col not in categorical_cols
72
+ else SentBERTEngine (
73
+ self .categorical_col_values [col ],
74
+ np .arange (len (self .categorical_col_values [col ])),
75
+ case_sensitive = False ,
76
+ model = _fuzzy_bert_engine
77
+ )
62
78
for col in fuzzy_cols
63
79
}
80
+ self .fuzzy_engine ['sql_cols' ] = SentBERTEngine (
81
+ np .array (columns ),
82
+ np .arange (len (columns )),
83
+ case_sensitive = False ,
84
+ model = _fuzzy_bert_engine
85
+ ) # fuzzy engine for column names
64
86
# title as index
65
87
self .corups_title = self .corups .set_index ('title' , drop = True )
66
88
# id as index
@@ -76,25 +98,18 @@ def __call__(self, sql: str, corups: pd.DataFrame=None, return_id_only: bool=Tru
76
98
Returns:
77
99
list: the result represents by id
78
100
"""
79
- try :
80
- if corups is None :
81
- res = sqldf (sql , {self .name : self .corups }) # all games
82
- else :
83
- res = sqldf (sql , {self .name : corups }) # games in buffer
101
+ if corups is None :
102
+ result = sqldf (sql , {self .name : self .corups }) # all games
103
+ else :
104
+ result = sqldf (sql , {self .name : corups }) # games in buffer
84
105
85
- if return_id_only :
86
- res = res [self .corups .index .name ].to_list ()
87
- else :
88
- pass
89
- return res
90
- except Exception as e :
91
- print (e )
92
- return []
106
+ if return_id_only :
107
+ result = result [self .corups .index .name ].to_list ()
108
+ return result
93
109
94
110
95
111
def __len__ (self ) -> int :
96
112
return len (self .corups )
97
-
98
113
99
114
def info (self , remove_game_titles : bool = False , query : str = None ):
100
115
prefix = 'Table information:'
@@ -110,7 +125,7 @@ def info(self, remove_game_titles: bool=False, query: str=None):
110
125
disp_values = self .sample_categoricol_values (col , total_n = self .disp_cate_total , query = query , topk = self .disp_cate_topk )
111
126
_prefix = f" Related values: [{ ', ' .join (disp_values )} ]."
112
127
cols_info += _prefix
113
-
128
+
114
129
if dtype in {'float' , 'datetime' , 'integer' }:
115
130
_min = self .corups [col ].min ()
116
131
_max = self .corups [col ].max ()
@@ -120,17 +135,18 @@ def info(self, remove_game_titles: bool=False, query: str=None):
120
135
cols_info += _prefix
121
136
122
137
primary_key = f"Primary Key: { self .corups .index .name } "
123
- foreign_key = f"Foreign Key: None"
138
+ categorical_cols = list (self .categorical_col_values .keys ())
139
+ note = f"Note that [{ ',' .join (categorical_cols )} ] columns are categorical, must use related values to search otherwise no result would be returned."
124
140
res = ''
125
- for i , s in enumerate ([table_name , cols_info , primary_key , foreign_key ]):
141
+ for i , s in enumerate ([table_name , cols_info , primary_key , note ]):
126
142
res += f"\n { i } . { s } "
127
143
res = prefix + res
128
144
return res
129
145
130
146
def sample_categoricol_values (self , col_name : str , total_n : int , query : str = None , topk : int = None ) -> List :
131
147
# Select topk related tags according to query and sample (total_n-topk) tags
132
148
if query is None :
133
- result = random .sample (self .categorical_col_values [col_name ]. tolist () , k = total_n )
149
+ result = random .sample (self .categorical_col_values [col_name ], k = total_n )
134
150
else :
135
151
if topk is None :
136
152
topk = total_n
@@ -146,11 +162,11 @@ def sample_categoricol_values(self, col_name: str, total_n: int, query: str=None
146
162
return result
147
163
148
164
149
- def convert_id_2_info (self , id : Union [int , List [int ], np .ndarray ], col_names : Union [str , List [str ]]= None ) -> Union [Dict , List [Dict ]]:
150
- """Given game id , get game informations.
165
+ def convert_id_2_info (self , item_id : Union [int , List [int ], np .ndarray ], col_names : Union [str , List [str ]]= None ) -> Union [Dict , List [Dict ]]:
166
+ """Given game item_id , get game informations.
151
167
152
168
Args:
153
- - id : game ids.
169
+ - item_id : game ids.
154
170
- col_names: column names to be returned
155
171
156
172
Returns:
@@ -167,13 +183,13 @@ def convert_id_2_info(self, id: Union[int, List[int], np.ndarray], col_names: Un
167
183
else :
168
184
raise_error (TypeError , "Not supported type for `col_names`." )
169
185
170
- if isinstance (id , int ):
171
- items = self .corups .loc [id ][col_names ].to_dict ()
172
- elif isinstance (id , list ) or isinstance (id , np .ndarray ):
173
- items = self .corups .loc [id ][col_names ].to_dict (orient = 'list' )
186
+ if isinstance (item_id , int ):
187
+ items = self .corups .loc [item_id ][col_names ].to_dict ()
188
+ elif isinstance (item_id , list ) or isinstance (item_id , np .ndarray ):
189
+ items = self .corups .loc [item_id ][col_names ].to_dict (orient = 'list' )
174
190
else :
175
- raise_error (TypeError , "Not supported type for `id `." )
176
-
191
+ raise_error (TypeError , "Not supported type for `item_id `." )
192
+
177
193
return items
178
194
179
195
@@ -204,7 +220,7 @@ def convert_title_2_info(self, titles: Union[int, List[int], np.ndarray], col_na
204
220
items = self .corups_title .loc [titles ][col_names ].to_dict (orient = 'list' )
205
221
else :
206
222
raise_error (TypeError , "Not supported type for `titles`." )
207
-
223
+
208
224
return items
209
225
210
226
@@ -225,15 +241,14 @@ def _read_file(self, fpath: str, columns: List[str]=None, sep: str=',', parquet_
225
241
226
242
def _load_col_desc_file (self , fpath : str ) -> Dict :
227
243
assert fpath .endswith ('.json' ), "Only support json file now."
228
- with open (fpath , 'r' ) as f :
244
+ with open (fpath , 'r' , encoding = 'utf-8' ) as f :
229
245
return json .load (f )
230
246
231
-
247
+
232
248
def _required_columns_validate (self ) -> None :
233
249
for col in _REQUIRED_COLUMNS :
234
250
if col not in self .corups .columns :
235
251
raise_error (ValueError , f"`id` and `name` are required in item corups table but { col } not found, please check the table file `{ self .fpath } `." )
236
-
237
252
238
253
def fuzzy_match (self , value : Union [str , List [str ]], col : str ) -> Union [str , List [str ]]:
239
254
if col not in self .fuzzy_engine :
@@ -244,9 +259,8 @@ def fuzzy_match(self, value: Union[str, List[str]], col: str) -> Union[str, List
244
259
return res
245
260
246
261
247
-
248
262
if __name__ == '__main__' :
249
- from llm4crs .environ_variables import *
263
+ from llm4crs .environ_variables import GAME_INFO_FILE , TABLE_COL_DESC_FILE
250
264
gallery = BaseGallery (GAME_INFO_FILE , column_meaning_file = TABLE_COL_DESC_FILE )
251
265
print (gallery .info ())
252
266
0 commit comments