Skip to content

Optimize presigning for replay.json #2516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
133 changes: 117 additions & 16 deletions backend/btrixcloud/basecrawls.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
"""base crawl type"""

from datetime import datetime
from typing import Optional, List, Union, Dict, Any, Type, TYPE_CHECKING, cast, Tuple
from typing import (
Optional,
List,
Union,
Dict,
Any,
Type,
TYPE_CHECKING,
cast,
Tuple,
AsyncIterable,
)
from uuid import UUID
import os
import urllib.parse
Expand Down Expand Up @@ -76,6 +87,7 @@ def __init__(
background_job_ops: BackgroundJobOps,
):
self.crawls = mdb["crawls"]
self.presigned_urls = mdb["presigned_urls"]
self.crawl_configs = crawl_configs
self.user_manager = users
self.orgs = orgs
Expand Down Expand Up @@ -463,29 +475,118 @@ async def resolve_signed_urls(
) -> List[CrawlFileOut]:
"""Regenerate presigned URLs for files as necessary"""
if not files:
print("no files")
return []

out_files = []

for file_ in files:
presigned_url, expire_at = await self.storage_ops.get_presigned_url(
org, file_, force_update=force_update
cursor = self.presigned_urls.find(
{"_id": {"$in": [file.filename for file in files]}}
)

presigned = await cursor.to_list(10000)

files_dict = [file.dict() for file in files]

# need an async generator to call bulk_presigned_files
async def async_gen():
yield {"presigned": presigned, "files": files_dict, "_id": crawl_id}

out_files, _ = await self.bulk_presigned_files(async_gen(), org, force_update)

return out_files

async def get_presigned_files(
self, match: dict[str, Any], org: Organization
) -> tuple[list[CrawlFileOut], bool]:
"""return presigned crawl files queried as batch, merging presigns with files in one pass"""
cursor = self.crawls.aggregate(
[
{"$match": match},
{"$project": {"files": "$files", "version": 1}},
{
"$lookup": {
"from": "presigned_urls",
"localField": "files.filename",
"foreignField": "_id",
"as": "presigned",
}
},
]
)

return await self.bulk_presigned_files(cursor, org)

async def bulk_presigned_files(
self,
cursor: AsyncIterable[dict[str, Any]],
org: Organization,
force_update=False,
) -> tuple[list[CrawlFileOut], bool]:
"""process presigned files in batches"""
resources = []
pages_optimized = False

sign_files = []

async for result in cursor:
pages_optimized = result.get("version") == 2

mapping = {}
# create mapping of filename -> file data
for file in result["files"]:
file["crawl_id"] = result["_id"]
mapping[file["filename"]] = file

if not force_update:
# add already presigned resources
for presigned in result["presigned"]:
file = mapping.get(presigned["_id"])
if file:
file["signedAt"] = presigned["signedAt"]
file["path"] = presigned["url"]
resources.append(
CrawlFileOut(
name=os.path.basename(file["filename"]),
path=presigned["url"],
hash=file["hash"],
size=file["size"],
crawlId=file["crawl_id"],
numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(
presigned["signedAt"]
+ self.storage_ops.signed_duration_delta
),
)
)

del mapping[presigned["_id"]]

sign_files.extend(list(mapping.values()))

if sign_files:
names = [file["filename"] for file in sign_files]

first_file = CrawlFile(**sign_files[0])
s3storage = self.storage_ops.get_org_storage_by_ref(org, first_file.storage)

signed_urls, expire_at = await self.storage_ops.get_presigned_urls_bulk(
org, s3storage, names
)

out_files.append(
CrawlFileOut(
name=os.path.basename(file_.filename),
path=presigned_url or "",
hash=file_.hash,
size=file_.size,
crawlId=crawl_id,
numReplicas=len(file_.replicas) if file_.replicas else 0,
expireAt=date_to_str(expire_at),
for url, file in zip(signed_urls, sign_files):
resources.append(
CrawlFileOut(
name=os.path.basename(file["filename"]),
path=url,
hash=file["hash"],
size=file["size"],
crawlId=file["crawl_id"],
numReplicas=len(file.get("replicas") or []),
expireAt=date_to_str(expire_at),
)
)
)

return out_files
return resources, pages_optimized

async def add_to_collection(
self, crawl_ids: List[str], collection_id: UUID, org: Organization
Expand Down
60 changes: 22 additions & 38 deletions backend/btrixcloud/colls.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
UpdateColl,
AddRemoveCrawlList,
BaseCrawl,
CrawlOutWithResources,
CrawlFileOut,
Organization,
PaginatedCollOutResponse,
Expand All @@ -50,7 +49,12 @@
MIN_UPLOAD_PART_SIZE,
PublicCollOut,
)
from .utils import dt_now, slug_from_name, get_duplicate_key_error_field, get_origin
from .utils import (
dt_now,
slug_from_name,
get_duplicate_key_error_field,
get_origin,
)

if TYPE_CHECKING:
from .orgs import OrgOps
Expand Down Expand Up @@ -346,7 +350,7 @@ async def get_collection_out(
result["resources"],
crawl_ids,
pages_optimized,
) = await self.get_collection_crawl_resources(coll_id)
) = await self.get_collection_crawl_resources(coll_id, org)

initial_pages, _ = await self.page_ops.list_pages(
crawl_ids=crawl_ids,
Expand Down Expand Up @@ -400,7 +404,9 @@ async def get_public_collection_out(
if result.get("access") not in allowed_access:
raise HTTPException(status_code=404, detail="collection_not_found")

result["resources"], _, _ = await self.get_collection_crawl_resources(coll_id)
result["resources"], _, _ = await self.get_collection_crawl_resources(
coll_id, org
)

thumbnail = result.get("thumbnail")
if thumbnail:
Expand Down Expand Up @@ -554,32 +560,24 @@ async def list_collections(

return collections, total

# pylint: disable=too-many-locals
async def get_collection_crawl_resources(
self, coll_id: UUID
self, coll_id: Optional[UUID], org: Organization
) -> tuple[List[CrawlFileOut], List[str], bool]:
"""Return pre-signed resources for all collection crawl files."""
# Ensure collection exists
_ = await self.get_collection_raw(coll_id)
match: dict[str, Any]

resources = []
pages_optimized = True
if coll_id:
crawl_ids = await self.get_collection_crawl_ids(coll_id)
match = {"_id": {"$in": crawl_ids}}
else:
crawl_ids = []
match = {"oid": org.id}

crawls, _ = await self.crawl_ops.list_all_base_crawls(
collection_id=coll_id,
states=list(SUCCESSFUL_STATES),
page_size=10_000,
cls_type=CrawlOutWithResources,
resources, pages_optimized = await self.crawl_ops.get_presigned_files(
match, org
)

crawl_ids = []

for crawl in crawls:
crawl_ids.append(crawl.id)
if crawl.resources:
resources.extend(crawl.resources)
if crawl.version != 2:
pages_optimized = False

return resources, crawl_ids, pages_optimized

async def get_collection_names(self, uuids: List[UUID]):
Expand Down Expand Up @@ -1006,24 +1004,10 @@ async def list_collection_all(
@app.get(
"/orgs/{oid}/collections/$all",
tags=["collections"],
response_model=Dict[str, List[CrawlFileOut]],
)
async def get_collection_all(org: Organization = Depends(org_viewer_dep)):
results = {}
try:
all_collections, _ = await colls.list_collections(org, page_size=10_000)
for collection in all_collections:
(
results[collection.name],
_,
_,
) = await colls.get_collection_crawl_resources(collection.id)
except Exception as exc:
# pylint: disable=raise-missing-from
raise HTTPException(
status_code=400, detail="Error Listing All Crawled Files: " + str(exc)
)

results["resources"] = await colls.get_collection_crawl_resources(None, org)
return results

@app.get(
Expand Down
74 changes: 66 additions & 8 deletions backend/btrixcloud/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@


# ============================================================================
# pylint: disable=broad-except,raise-missing-from
# pylint: disable=broad-except,raise-missing-from,too-many-instance-attributes
class StorageOps:
"""All storage handling, download/upload operations"""

Expand Down Expand Up @@ -104,6 +104,8 @@ def __init__(self, org_ops, crawl_manager, mdb) -> None:
default_namespace = os.environ.get("DEFAULT_NAMESPACE", "default")
self.frontend_origin = f"{frontend_origin}.{default_namespace}"

self.presign_batch_size = int(os.environ.get("PRESIGN_BATCH_SIZE", 8))

with open(os.environ["STORAGES_JSON"], encoding="utf-8") as fh:
storage_list = json.loads(fh.read())

Expand Down Expand Up @@ -485,12 +487,9 @@ async def get_presigned_url(
s3storage,
for_presign=True,
) as (client, bucket, key):
orig_key = key
key += crawlfile.filename

presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key},
Params={"Bucket": bucket, "Key": key + crawlfile.filename},
ExpiresIn=PRESIGN_DURATION_SECONDS,
)

Expand All @@ -499,9 +498,7 @@ async def get_presigned_url(
and s3storage.access_endpoint_url != s3storage.endpoint_url
):
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = (
f"{parts.scheme}://{bucket}.{parts.netloc}/{orig_key}"
)
host_endpoint_url = f"{parts.scheme}://{bucket}.{parts.netloc}/{key}"
presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url
)
Expand All @@ -521,6 +518,67 @@ async def get_presigned_url(

return presigned_url, now + self.signed_duration_delta

async def get_presigned_urls_bulk(
self, org: Organization, s3storage: S3Storage, filenames: list[str]
) -> tuple[list[str], datetime]:
"""generate pre-signed url for crawl file"""

urls = []

futures = []
num_batch = self.presign_batch_size

now = dt_now()

async with self.get_s3_client(
s3storage,
for_presign=True,
) as (client, bucket, key):

if (
s3storage.access_endpoint_url
and s3storage.access_endpoint_url != s3storage.endpoint_url
):
parts = urlsplit(s3storage.endpoint_url)
host_endpoint_url = f"{parts.scheme}://{bucket}.{parts.netloc}/{key}"
else:
host_endpoint_url = None

for filename in filenames:
futures.append(
client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": key + filename},
ExpiresIn=PRESIGN_DURATION_SECONDS,
)
)

for i in range(0, len(futures), num_batch):
batch = futures[i : i + num_batch]
results = await asyncio.gather(*batch)

presigned_obj = []

for presigned_url, filename in zip(
results, filenames[i : i + num_batch]
):
if host_endpoint_url:
presigned_url = presigned_url.replace(
host_endpoint_url, s3storage.access_endpoint_url
)

urls.append(presigned_url)

presigned_obj.append(
PresignedUrl(
id=filename, url=presigned_url, signedAt=now, oid=org.id
).to_dict()
)

await self.presigned_urls.insert_many(presigned_obj, ordered=False)

return urls, now + self.signed_duration_delta

async def delete_file_object(self, org: Organization, crawlfile: BaseFile) -> bool:
"""delete crawl file from storage."""
return await self._delete_file(org, crawlfile.filename, crawlfile.storage)
Expand Down
2 changes: 2 additions & 0 deletions chart/templates/configmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ data:

REPLICA_DELETION_DELAY_DAYS: "{{ .Values.replica_deletion_delay_days | default 0 }}"

PRESIGN_BATCH_SIZE: "{{ .Values.presign_batch_size | default 8 }}"


---
apiVersion: v1
Expand Down
Loading