Skip to content

Commit 69387f1

Browse files
committed
feat: batch processing
1 parent a9ff1f7 commit 69387f1

File tree

2 files changed

+128
-52
lines changed

2 files changed

+128
-52
lines changed

Diff for: backend/app/routers/feed.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
from pydantic.networks import HttpUrl
3+
from datetime import datetime, timedelta, timezone
24
import re
35
from fastapi import Request
46
import json
@@ -8,8 +10,10 @@
810
from app.models.feed import Feed, generate_feed, parse_feed, UpstreamError
911
from app.models.user import User
1012
from app.recommend import filter_articles
13+
from app.tasks import fetch_feed_batch
1114
from .common import get_engine
1215
from fastapi import HTTPException
16+
from fastapi import BackgroundTasks
1317

1418
# from fastapi_cache.coder import PickleCoder
1519
# from fastapi_cache.decorator import cache
@@ -21,35 +25,33 @@
2125
)
2226

2327

28+
FEED_REFRESH_INTERVAL = timedelta(days=1) # Adjust as needed
29+
30+
2431
@router.get("/{user_id}/{feed_url:path}")
25-
# @cache(expire=300, coder=PickleCoder)
2632
async def get_feed(
27-
request: Request, user_id: str, feed_url: HttpUrl, engine=Depends(get_engine)
33+
request: Request,
34+
user_id: str,
35+
feed_url: HttpUrl,
36+
background_tasks: BackgroundTasks,
37+
engine=Depends(get_engine),
2838
) -> str:
29-
"""Get filtered feed."""
3039
if feed_url.host and re.match(
3140
r"^(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b", feed_url.host
3241
):
3342
raise HTTPException(status_code=422, detail="Invalid URL")
34-
if request.query_params:
35-
feed_url = f"{feed_url}?" # type: ignore
36-
for key, value in request.query_params.items():
37-
feed_url = f"{feed_url}&{key}={value}" # type: ignore
43+
3844
with Session(engine, autoflush=False) as session:
45+
# User handling (same as before)
3946
try:
4047
user: User = session.exec(select(User).where(User.id == user_id)).one()
4148
except NoResultFound:
4249
logger.info(f"User {user_id} not found in database, creating new user")
4350
user = User(id=user_id)
4451
session.add(user)
45-
try:
46-
session.commit()
47-
except Exception as e:
48-
# might happen if the user was created before by another thread
49-
logger.warning(f"Failed to add user {user_id} to database: {e}")
50-
session.rollback()
51-
user = session.exec(select(User).where(User.id == user_id)).one()
52+
session.commit()
5253

54+
# Feed handling
5355
try:
5456
feed: Feed = session.exec(
5557
select(Feed).where(Feed.url == str(feed_url))
@@ -70,9 +72,23 @@ async def get_feed(
7072
logger.warning(f"Failed to add feed {feed_url} to database: {e}")
7173
session.rollback()
7274
feed = session.exec(select(Feed).where(Feed.url == str(feed_url))).one()
75+
session.add(feed)
76+
session.commit()
77+
7378
if feed not in user.feeds:
7479
user.feeds.append(feed)
75-
session.commit()
80+
session.commit()
81+
82+
# Check if feed needs refreshing
83+
now = datetime.now(timezone.utc)
84+
if (
85+
feed.updated_at is None
86+
or (now - feed.updated_at.replace(tzinfo=timezone.utc))
87+
> FEED_REFRESH_INTERVAL
88+
):
89+
logger.info(f"Feed {feed_url} needs refreshing")
90+
await asyncio.run(fetch_feed_batch([feed.id]))
91+
session.refresh(feed)
7692

7793
articles = feed.articles[-30:]
7894

Diff for: backend/app/tasks.py

+97-37
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from sqlmodel import create_engine, Session, select
55
from datetime import datetime, timezone, timedelta
66
from loguru import logger
7-
from redis import Redis
8-
from rq import Queue
7+
from redis import Redis # type: ignore
8+
from rq import Queue, Retry
99

1010
from app.models.article import Article
1111
from app.models.feed import Feed, parse_feed, generate_feed
@@ -25,32 +25,6 @@
2525
gpu_queue = Queue("gpu", connection=redis_conn)
2626

2727

28-
def fetch_feed(feed_id):
29-
with Session(ENGINE) as session:
30-
feed = session.get(Feed, feed_id)
31-
if not feed:
32-
logger.error(f"Feed {feed_id} not found")
33-
return
34-
35-
try:
36-
parsed_feed = asyncio.run(parse_feed(feed.url))
37-
for article in parsed_feed.articles:
38-
existing_article = session.exec(
39-
select(Article).where(Article.url == article.url)
40-
).first()
41-
if not existing_article:
42-
article.feed = feed
43-
session.add(article)
44-
session.commit()
45-
gpu_queue.enqueue(compute_article_embedding, article.id)
46-
47-
logger.info(
48-
f"Fetched {len(parsed_feed.articles)} articles for feed {feed_id}"
49-
)
50-
except Exception as e:
51-
logger.error(f"Error fetching feed {feed_id}: {e}")
52-
53-
5428
def compute_article_embedding(article_id):
5529
with Session(ENGINE) as session:
5630
article = session.get(Article, article_id)
@@ -87,7 +61,7 @@ def remove_old_embeddings():
8761
old_articles = session.exec(
8862
select(Article)
8963
.where(Article.updated < one_month_ago)
90-
.where(Article.embedding is not None)
64+
.where(Article.embedding != None) # noqa: E711
9165
).all()
9266

9367
for article in old_articles:
@@ -97,12 +71,86 @@ def remove_old_embeddings():
9771
logger.info(f"Removed embeddings from {len(old_articles)} old articles")
9872

9973

74+
BATCH_SIZE = int(os.getenv("FEED_FETCH_BATCH_SIZE", "50"))
75+
76+
77+
async def fetch_feed_batch(feed_ids):
78+
async def fetch_single_feed(feed):
79+
try:
80+
return await parse_feed(feed.url)
81+
except Exception as e:
82+
logger.error(f"Error fetching feed {feed.id}: {e}")
83+
return None
84+
85+
with Session(ENGINE) as session:
86+
feeds = session.exec(select(Feed).where(Feed.id.in_(feed_ids))).all()
87+
tasks = [asyncio.create_task(fetch_single_feed(feed)) for feed in feeds]
88+
results = await asyncio.gather(*tasks)
89+
90+
new_articles = []
91+
for feed, parsed_feed in zip(feeds, results):
92+
if parsed_feed is None:
93+
continue
94+
95+
for article in parsed_feed.articles:
96+
existing_article = session.exec(
97+
select(Article).where(Article.url == article.url)
98+
).first()
99+
if not existing_article:
100+
article.feed = feed
101+
session.add(article)
102+
new_articles.append(article)
103+
104+
feed.updated_at = datetime.now(timezone.utc)
105+
106+
session.commit()
107+
108+
if new_articles:
109+
new_article_ids = [article.id for article in new_articles]
110+
enqueue_gpu_task(compute_embeddings_batch, new_article_ids)
111+
112+
logger.info(
113+
f"Fetched {len(feeds)} feeds, added {len(new_articles)} new articles"
114+
)
115+
116+
117+
def compute_embeddings_batch(article_ids):
118+
with Session(ENGINE) as session:
119+
articles = session.exec(
120+
select(Article).where(Article.id.in_(article_ids))
121+
).all()
122+
articles_to_embed = [
123+
article for article in articles if article.embedding is None
124+
]
125+
126+
if not articles_to_embed:
127+
return
128+
129+
try:
130+
compute_embeddings(articles_to_embed)
131+
session.commit()
132+
logger.info(f"Computed embeddings for {len(articles_to_embed)} articles")
133+
except Exception as e:
134+
logger.error(f"Error computing embeddings for articles: {e}")
135+
136+
100137
def fetch_all_feeds():
101138
with Session(ENGINE) as session:
102-
feeds = session.exec(select(Feed)).all()
103-
for feed in feeds:
104-
low_queue.enqueue(fetch_feed, feed.id)
105-
logger.info(f"Enqueued fetch tasks for {len(feeds)} feeds")
139+
one_month_ago = datetime.now(timezone.utc) - timedelta(days=30)
140+
active_feeds = session.exec(
141+
select(Feed)
142+
.join(User.feeds)
143+
.where(User.last_request > one_month_ago)
144+
.distinct()
145+
).all()
146+
147+
for i in range(0, len(active_feeds), BATCH_SIZE):
148+
batch = active_feeds[i : i + BATCH_SIZE]
149+
asyncio.run(fetch_feed_batch([feed.id for feed in batch]))
150+
151+
logger.info(
152+
f"Processed {len(active_feeds)} active feeds in batches of {BATCH_SIZE}"
153+
)
106154

107155

108156
def log_user_action(user_id: str, article_id: int, link_url: str):
@@ -123,7 +171,7 @@ def log_user_action(user_id: str, article_id: int, link_url: str):
123171
if article not in user.articles:
124172
user.articles.append(article)
125173
session.add(user)
126-
medium_queue.enqueue(recompute_user_clusters, user.id)
174+
enqueue_medium_priority(recompute_user_clusters, user.id)
127175

128176
session.commit()
129177
logger.info(f"Logged action for user {user_id}, article {article_id}")
@@ -151,12 +199,24 @@ def generate_filtered_feed(feed_id: int, user_id: str):
151199

152200
# Helper functions to enqueue tasks
153201
def enqueue_low_priority(func, *args, **kwargs):
154-
return low_queue.enqueue(func, *args, **kwargs)
202+
return low_queue.enqueue(
203+
func, args=args, kwargs=kwargs, retry=Retry(max=3), timeout=300
204+
)
155205

156206

157207
def enqueue_medium_priority(func, *args, **kwargs):
158-
return medium_queue.enqueue(func, *args, **kwargs)
208+
return medium_queue.enqueue(
209+
func, args=args, kwargs=kwargs, retry=Retry(max=3), timeout=300
210+
)
159211

160212

161213
def enqueue_high_priority(func, *args, **kwargs):
162-
return high_queue.enqueue(func, *args, **kwargs)
214+
return high_queue.enqueue(
215+
func, args=args, kwargs=kwargs, retry=Retry(max=2), timeout=10
216+
)
217+
218+
219+
def enqueue_gpu_task(func, *args, **kwargs):
220+
return gpu_queue.enqueue(
221+
func, args=args, kwargs=kwargs, retry=Retry(max=3), timeout=600
222+
)

0 commit comments

Comments
 (0)