Skip to content

Commit 8b3be42

Browse files
authored
revert batch query (langgenius#17707)
1 parent 1d5c07d commit 8b3be42

File tree

1 file changed

+74
-109
lines changed

1 file changed

+74
-109
lines changed

api/core/rag/datasource/retrieval_service.py

+74-109
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import concurrent.futures
2-
import logging
3-
import time
42
from concurrent.futures import ThreadPoolExecutor
53
from typing import Optional
64

75
from flask import Flask, current_app
8-
from sqlalchemy import and_, or_
96
from sqlalchemy.orm import load_only
10-
from sqlalchemy.sql.expression import false
117

128
from configs import dify_config
139
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@@ -182,7 +178,6 @@ def embedding_search(
182178
if not dataset:
183179
raise ValueError("dataset not found")
184180

185-
start = time.time()
186181
vector = Vector(dataset=dataset)
187182
documents = vector.search_by_vector(
188183
query,
@@ -192,7 +187,6 @@ def embedding_search(
192187
filter={"group_id": [dataset.id]},
193188
document_ids_filter=document_ids_filter,
194189
)
195-
logging.debug(f"embedding_search ends at {time.time() - start:.2f} seconds")
196190

197191
if documents:
198192
if (
@@ -276,8 +270,7 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
276270
return []
277271

278272
try:
279-
start_time = time.time()
280-
# Collect document IDs with existence check
273+
# Collect document IDs
281274
document_ids = {doc.metadata.get("document_id") for doc in documents if "document_id" in doc.metadata}
282275
if not document_ids:
283276
return []
@@ -295,138 +288,110 @@ def format_retrieval_documents(cls, documents: list[Document]) -> list[Retrieval
295288
include_segment_ids = set()
296289
segment_child_map = {}
297290

298-
# Precompute doc_forms to avoid redundant checks
299-
doc_forms = {}
300-
for doc in documents:
301-
document_id = doc.metadata.get("document_id")
302-
dataset_doc = dataset_documents.get(document_id)
303-
if dataset_doc:
304-
doc_forms[document_id] = dataset_doc.doc_form
305-
306-
# Batch collect index node IDs with type safety
307-
child_index_node_ids = []
308-
index_node_ids = []
309-
for doc in documents:
310-
document_id = doc.metadata.get("document_id")
311-
if doc_forms.get(document_id) == IndexType.PARENT_CHILD_INDEX:
312-
child_index_node_ids.append(doc.metadata.get("doc_id"))
313-
else:
314-
index_node_ids.append(doc.metadata.get("doc_id"))
315-
316-
# Batch query ChildChunk
317-
child_chunks = db.session.query(ChildChunk).filter(ChildChunk.index_node_id.in_(child_index_node_ids)).all()
318-
child_chunk_map = {chunk.index_node_id: chunk for chunk in child_chunks}
319-
320-
segment_ids_from_child = [chunk.segment_id for chunk in child_chunks]
321-
segment_conditions = []
322-
323-
if index_node_ids:
324-
segment_conditions.append(DocumentSegment.index_node_id.in_(index_node_ids))
325-
326-
if segment_ids_from_child:
327-
segment_conditions.append(DocumentSegment.id.in_(segment_ids_from_child))
328-
329-
if segment_conditions:
330-
filter_expr = or_(*segment_conditions)
331-
else:
332-
filter_expr = false()
333-
334-
segment_map = {
335-
segment.id: segment
336-
for segment in db.session.query(DocumentSegment)
337-
.filter(
338-
and_(
339-
filter_expr,
340-
DocumentSegment.enabled == True,
341-
DocumentSegment.status == "completed",
342-
)
343-
)
344-
.options(
345-
load_only(
346-
DocumentSegment.id,
347-
DocumentSegment.content,
348-
DocumentSegment.answer,
349-
)
350-
)
351-
.all()
352-
}
353-
291+
# Process documents
354292
for document in documents:
355293
document_id = document.metadata.get("document_id")
356-
dataset_document = dataset_documents.get(document_id)
294+
if document_id not in dataset_documents:
295+
continue
296+
297+
dataset_document = dataset_documents[document_id]
357298
if not dataset_document:
358299
continue
359300

360-
doc_form = doc_forms.get(document_id)
361-
if doc_form == IndexType.PARENT_CHILD_INDEX:
362-
# Handle parent-child documents using preloaded data
301+
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
302+
# Handle parent-child documents
363303
child_index_node_id = document.metadata.get("doc_id")
364-
if not child_index_node_id:
365-
continue
366304

367-
child_chunk = child_chunk_map.get(child_index_node_id)
305+
child_chunk = (
306+
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
307+
)
308+
368309
if not child_chunk:
369310
continue
370311

371-
segment = segment_map.get(child_chunk.segment_id)
312+
segment = (
313+
db.session.query(DocumentSegment)
314+
.filter(
315+
DocumentSegment.dataset_id == dataset_document.dataset_id,
316+
DocumentSegment.enabled == True,
317+
DocumentSegment.status == "completed",
318+
DocumentSegment.id == child_chunk.segment_id,
319+
)
320+
.options(
321+
load_only(
322+
DocumentSegment.id,
323+
DocumentSegment.content,
324+
DocumentSegment.answer,
325+
)
326+
)
327+
.first()
328+
)
329+
372330
if not segment:
373331
continue
374332

375333
if segment.id not in include_segment_ids:
376334
include_segment_ids.add(segment.id)
377-
map_detail = {"max_score": document.metadata.get("score", 0.0), "child_chunks": []}
335+
child_chunk_detail = {
336+
"id": child_chunk.id,
337+
"content": child_chunk.content,
338+
"position": child_chunk.position,
339+
"score": document.metadata.get("score", 0.0),
340+
}
341+
map_detail = {
342+
"max_score": document.metadata.get("score", 0.0),
343+
"child_chunks": [child_chunk_detail],
344+
}
378345
segment_child_map[segment.id] = map_detail
379-
records.append({"segment": segment})
380-
381-
# Append child chunk details
382-
child_chunk_detail = {
383-
"id": child_chunk.id,
384-
"content": child_chunk.content,
385-
"position": child_chunk.position,
386-
"score": document.metadata.get("score", 0.0),
387-
}
388-
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
389-
segment_child_map[segment.id]["max_score"] = max(
390-
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
391-
)
392-
346+
record = {
347+
"segment": segment,
348+
}
349+
records.append(record)
350+
else:
351+
child_chunk_detail = {
352+
"id": child_chunk.id,
353+
"content": child_chunk.content,
354+
"position": child_chunk.position,
355+
"score": document.metadata.get("score", 0.0),
356+
}
357+
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
358+
segment_child_map[segment.id]["max_score"] = max(
359+
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
360+
)
393361
else:
394362
# Handle normal documents
395363
index_node_id = document.metadata.get("doc_id")
396364
if not index_node_id:
397365
continue
398366

399-
segment = next(
400-
(
401-
s
402-
for s in segment_map.values()
403-
if s.index_node_id == index_node_id and s.dataset_id == dataset_document.dataset_id
404-
),
405-
None,
367+
segment = (
368+
db.session.query(DocumentSegment)
369+
.filter(
370+
DocumentSegment.dataset_id == dataset_document.dataset_id,
371+
DocumentSegment.enabled == True,
372+
DocumentSegment.status == "completed",
373+
DocumentSegment.index_node_id == index_node_id,
374+
)
375+
.first()
406376
)
407377

408378
if not segment:
409379
continue
410380

411-
if segment.id not in include_segment_ids:
412-
include_segment_ids.add(segment.id)
413-
records.append(
414-
{
415-
"segment": segment,
416-
"score": document.metadata.get("score", 0.0),
417-
}
418-
)
381+
include_segment_ids.add(segment.id)
382+
record = {
383+
"segment": segment,
384+
"score": document.metadata.get("score"), # type: ignore
385+
}
386+
records.append(record)
419387

420-
# Merge child chunks information
388+
# Add child chunks information to records
421389
for record in records:
422-
segment_id = record["segment"].id
423-
if segment_id in segment_child_map:
424-
record["child_chunks"] = segment_child_map[segment_id]["child_chunks"]
425-
record["score"] = segment_child_map[segment_id]["max_score"]
390+
if record["segment"].id in segment_child_map:
391+
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore
392+
record["score"] = segment_child_map[record["segment"].id]["max_score"]
426393

427-
logging.debug(f"Formatting retrieval documents took {time.time() - start_time:.2f} seconds")
428394
return [RetrievalSegments(**record) for record in records]
429395
except Exception as e:
430-
# Only rollback if there were write operations
431396
db.session.rollback()
432397
raise e

0 commit comments

Comments
 (0)