Skip to content

Commit 25b46a7

Browse files
kiddog99imClumsyPanda1729457433fxjhello
authored
标题增强 (chatchat-space#631)
* Add files via upload * Update local_doc_qa.py * Update model_config.py * Update zh_title_enhance.py * Add files via upload * Update README.md * fix bugs in MyFAISS.delete_doc * fix:前端知识库获取失败. * update zh_title_enhance.py * update zh_title_enhance.py * Update zh_title_enhance.py * add test/textsplitter * add test_zh_title_enhance.py --------- Co-authored-by: imClumsyPanda <[email protected]> Co-authored-by: JZF <[email protected]> Co-authored-by: fxjhello <[email protected]>
1 parent 3f7e815 commit 25b46a7

File tree

6 files changed

+136
-3
lines changed

6 files changed

+136
-3
lines changed

chains/local_doc_qa.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from agent import bing_search
1818
from langchain.docstore.document import Document
1919
from functools import lru_cache
20+
from textsplitter.zh_title_enhance import zh_title_enhance
2021

2122

2223
# patch HuggingFaceEmbeddings to make it hashable
@@ -56,7 +57,7 @@ def tree(filepath, ignore_dir_names=None, ignore_file_names=None):
5657
return ret_list, [os.path.basename(p) for p in ret_list]
5758

5859

59-
def load_file(filepath, sentence_size=SENTENCE_SIZE):
60+
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE):
6061
if filepath.lower().endswith(".md"):
6162
loader = UnstructuredFileLoader(filepath, mode="elements")
6263
docs = loader.load()
@@ -79,6 +80,8 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE):
7980
loader = UnstructuredFileLoader(filepath, mode="elements")
8081
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
8182
docs = loader.load_and_split(text_splitter=textsplitter)
83+
if using_zh_title_enhance:
84+
docs = zh_title_enhance(docs)
8285
write_check_file(filepath, docs)
8386
return docs
8487

configs/model_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,9 @@
173173

174174
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
175175
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
176-
BING_SUBSCRIPTION_KEY = ""
176+
BING_SUBSCRIPTION_KEY = ""
177+
178+
# 是否开启中文标题加强,以及标题增强的相关配置
179+
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
180+
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
181+
ZH_TITLE_ENHANCE = True
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from configs.model_config import *
2+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
3+
import nltk
4+
from vectorstores import MyFAISS
5+
from chains.local_doc_qa import load_file
6+
7+
8+
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
9+
10+
if __name__ == "__main__":
11+
filepath = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
12+
"knowledge_base", "samples", "content", "test.txt")
13+
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
14+
model_kwargs={'device': EMBEDDING_DEVICE})
15+
16+
docs = load_file(filepath, using_zh_title_enhance=True)
17+
vector_store = MyFAISS.from_documents(docs, embeddings)
18+
query = "指令提示技术有什么示例"
19+
search_result = vector_store.similarity_search(query)
20+
print(search_result)
21+
pass

textsplitter/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .chinese_text_splitter import ChineseTextSplitter
2-
from .ali_text_splitter import AliTextSplitter
2+
from .ali_text_splitter import AliTextSplitter
3+
from .zh_title_enhance import zh_title_enhance

textsplitter/zh_title_enhance.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from langchain.docstore.document import Document
2+
import re
3+
4+
5+
def under_non_alpha_ratio(text: str, threshold: float = 0.5):
6+
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
7+
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
8+
as a title or narrative text. The ratio does not count spaces.
9+
10+
Parameters
11+
----------
12+
text
13+
The input string to test
14+
threshold
15+
If the proportion of non-alpha characters exceeds this threshold, the function
16+
returns False
17+
"""
18+
if len(text) == 0:
19+
return False
20+
21+
alpha_count = len([char for char in text if char.strip() and char.isalpha()])
22+
total_count = len([char for char in text if char.strip()])
23+
try:
24+
ratio = alpha_count / total_count
25+
return ratio < threshold
26+
except:
27+
return False
28+
29+
30+
def is_possible_title(
31+
text: str,
32+
title_max_word_length: int = 20,
33+
non_alpha_threshold: float = 0.5,
34+
) -> bool:
35+
"""Checks to see if the text passes all of the checks for a valid title.
36+
37+
Parameters
38+
----------
39+
text
40+
The input text to check
41+
title_max_word_length
42+
The maximum number of words a title can contain
43+
non_alpha_threshold
44+
The minimum number of alpha characters the text needs to be considered a title
45+
"""
46+
47+
# 文本长度为0的话,肯定不是title
48+
if len(text) == 0:
49+
print("Not a title. Text is empty.")
50+
return False
51+
52+
# 文本中有标点符号,就不是title
53+
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
54+
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
55+
if ENDS_IN_PUNCT_RE.search(text) is not None:
56+
return False
57+
58+
# 文本长度不能超过设定值,默认20
59+
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
60+
# is less expensive and actual tokenization doesn't add much value for the length check
61+
if len(text) > title_max_word_length:
62+
return False
63+
64+
# 文本中数字的占比不能太高,否则不是title
65+
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
66+
return False
67+
68+
# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
69+
if text.endswith((",", ".", ",", "。")):
70+
return False
71+
72+
if text.isnumeric():
73+
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
74+
return False
75+
76+
# 开头的字符内应该有数字,默认5个字符内
77+
if len(text) < 5:
78+
text_5 = text
79+
else:
80+
text_5 = text[:5]
81+
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
82+
if not alpha_in_text_5:
83+
return False
84+
85+
return True
86+
87+
88+
def zh_title_enhance(docs: Document) -> Document:
89+
title = None
90+
if len(docs) > 0:
91+
for doc in docs:
92+
if is_possible_title(doc.page_content):
93+
doc.metadata['category'] = 'cn_Title'
94+
title = doc.page_content
95+
elif title:
96+
doc.page_content = f"下文与({title})有关。{doc.page_content}"
97+
return docs
98+
else:
99+
print("文件不存在")

vectorstores/MyFAISS.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import copy
99
import os
10+
from configs.model_config import *
1011

1112

1213
class MyFAISS(FAISS, VectorStore):
@@ -23,6 +24,9 @@ def __init__(
2324
docstore=docstore,
2425
index_to_docstore_id=index_to_docstore_id,
2526
normalize_L2=normalize_L2)
27+
self.score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD
28+
self.chunk_size = CHUNK_SIZE
29+
self.chunk_conent = False
2630

2731
def seperate_list(self, ls: List[int]) -> List[List[int]]:
2832
# TODO: 增加是否属于同一文档的判断

0 commit comments

Comments
 (0)