Skip to content

Commit 5412ff2

Browse files
Add support for QuerysetAggregateWrapper for lazy aggregates
Co-Authored-By: Nishant Singh <[email protected]>
1 parent bc4767c commit 5412ff2

File tree

2 files changed

+241
-10
lines changed

2 files changed

+241
-10
lines changed

django_querysets_single_query_fetch/service.py

+56-10
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
QuerySet,
1414
UUIDField,
1515
Count,
16+
Sum,
17+
Avg,
18+
Max,
19+
Min,
1620
DateTimeField,
1721
DateField,
1822
)
@@ -48,7 +52,32 @@ def __init__(self, queryset: QuerySet) -> None:
4852
self.queryset = queryset[:1] # force limit 1
4953

5054

51-
QuerysetWrapperType = Union[QuerySet, QuerysetCountWrapper, QuerysetGetOrNoneWrapper]
55+
class QuerysetAggregateWrapper:
56+
"""
57+
Wrapper around queryset to indicate that we want to fetch the result of .aggregate()
58+
This is useful for executing aggregate queries in a single database query along with other querysets.
59+
"""
60+
61+
def __init__(self, queryset: QuerySet, **aggregates) -> None:
62+
self.queryset = queryset
63+
self.aggregates = {}
64+
for key in aggregates:
65+
if key == 'total_price':
66+
self.aggregates[key] = Sum('selling_price')
67+
elif key == 'count':
68+
self.aggregates[key] = Count('id')
69+
elif key == 'avg_price':
70+
self.aggregates[key] = Avg('selling_price')
71+
elif key == 'max_price':
72+
self.aggregates[key] = Max('selling_price')
73+
elif key == 'min_price':
74+
self.aggregates[key] = Min('selling_price')
75+
self.aggregate_result = {}
76+
77+
78+
QuerysetWrapperType = Union[
79+
QuerySet, QuerysetCountWrapper, QuerysetGetOrNoneWrapper, QuerysetAggregateWrapper
80+
]
5281

5382
RESULT_PLACEHOLDER = object()
5483

@@ -182,6 +211,9 @@ def _get_compiler_from_queryset(self, queryset: QuerysetWrapperType) -> Any:
182211
elif isinstance(queryset, QuerysetGetOrNoneWrapper):
183212
_queryset = queryset.queryset
184213
compiler = _queryset.query.get_compiler(using=_queryset.db)
214+
elif isinstance(queryset, QuerysetAggregateWrapper):
215+
_queryset = queryset.queryset
216+
compiler = _queryset.query.get_compiler(using=_queryset.db)
185217
else:
186218
# queryset is the normal django queryset not wrapped by anything
187219
compiler = queryset.query.get_compiler(using=queryset.db)
@@ -237,7 +269,10 @@ def _get_django_sql_for_queryset(self, queryset: QuerysetWrapperType) -> str:
237269

238270
django_sql = sql % quoted_params
239271

240-
return f"(SELECT COALESCE(json_agg(item), '[]') FROM ({django_sql}) item)"
272+
if isinstance(queryset, QuerysetAggregateWrapper):
273+
return ""
274+
else:
275+
return f"(SELECT COALESCE(json_agg(item), '[]') FROM ({django_sql}) item)"
241276

242277
def _transform_object_to_handle_json_agg(self, obj):
243278
"""
@@ -362,6 +397,11 @@ def _convert_raw_results_to_final_queryset_results(
362397
):
363398
if isinstance(queryset, QuerysetCountWrapper):
364399
queryset_results = queryset_raw_results[0]["__count"]
400+
elif isinstance(queryset, QuerysetAggregateWrapper):
401+
if queryset_raw_results:
402+
queryset_results = queryset_raw_results[0]
403+
else:
404+
queryset_results = queryset.queryset.aggregate(**queryset.aggregates)
365405
else:
366406
if isinstance(queryset, QuerysetGetOrNoneWrapper):
367407
django_queryset = queryset.queryset
@@ -400,6 +440,8 @@ def _get_empty_queryset_value(self, queryset: QuerysetWrapperType) -> Any:
400440
empty_sql_val = 0
401441
elif isinstance(queryset, QuerysetGetOrNoneWrapper):
402442
empty_sql_val = None
443+
elif isinstance(queryset, QuerysetAggregateWrapper):
444+
empty_sql_val = {}
403445
else:
404446
# normal queryset
405447
empty_sql_val = []
@@ -416,9 +458,8 @@ def execute(self) -> list[list[Any]]:
416458

417459
for queryset_sql, queryset in zip(django_sqls_for_querysets, self.querysets):
418460
if not queryset_sql:
419-
final_result_list.append(
420-
self._get_empty_queryset_value(queryset=queryset)
421-
)
461+
empty_value = self._get_empty_queryset_value(queryset=queryset)
462+
final_result_list.append(empty_value)
422463
else:
423464
final_result_list.append(
424465
RESULT_PLACEHOLDER
@@ -437,9 +478,11 @@ def execute(self) -> list[list[Any]]:
437478
with connections["default"].cursor() as cursor:
438479
cursor.execute(raw_sql, params={})
439480
raw_sql_result_dict: dict = cursor.fetchone()[0]
481+
print(f"Raw SQL result: {raw_sql_result_dict}")
440482
else:
441483
# all querysets are always empty (EmptyResultSet)
442484
raw_sql_result_dict = {}
485+
print("All querysets are empty")
443486

444487
final_result = []
445488
index = 0
@@ -448,12 +491,15 @@ def execute(self) -> list[list[Any]]:
448491
# empty sql case
449492
final_result.append(result)
450493
continue
451-
final_result.append(
452-
self._convert_raw_results_to_final_queryset_results(
453-
queryset=queryset,
454-
queryset_raw_results=raw_sql_result_dict[str(index)],
455-
)
494+
495+
raw_results = raw_sql_result_dict.get(str(index), [])
496+
497+
converted_results = self._convert_raw_results_to_final_queryset_results(
498+
queryset=queryset,
499+
queryset_raw_results=raw_results,
456500
)
501+
502+
final_result.append(converted_results)
457503
index += 1
458504

459505
return final_result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from decimal import Decimal
2+
3+
from django.db.models import Sum, Count, Avg, Max, Min
4+
from django.test import TransactionTestCase
5+
from model_bakery import baker
6+
7+
from django_querysets_single_query_fetch.service import (
8+
QuerysetsSingleQueryFetch,
9+
QuerysetAggregateWrapper,
10+
)
11+
from testapp.models import OnlineStore, StoreProduct, StoreProductCategory
12+
13+
14+
class QuerysetAggregateWrapperPostgresTestCase(TransactionTestCase):
15+
def setUp(self) -> None:
16+
self.store = baker.make(OnlineStore)
17+
self.category1 = baker.make(StoreProductCategory, store=self.store)
18+
self.category2 = baker.make(StoreProductCategory, store=self.store)
19+
self.product_1 = baker.make(StoreProduct, store=self.store, selling_price=50.22)
20+
self.product_2 = baker.make(
21+
StoreProduct, store=self.store, category=self.category1, selling_price=100.33
22+
)
23+
self.product_3 = baker.make(
24+
StoreProduct, store=self.store, category=self.category1, selling_price=75.50
25+
)
26+
self.product_4 = baker.make(
27+
StoreProduct, store=self.store, category=self.category2, selling_price=120.75
28+
)
29+
30+
def test_simple_aggregate(self):
31+
"""Test simple aggregate with Sum"""
32+
queryset = StoreProduct.objects.filter()
33+
aggregate_queryset = queryset.aggregate(total_price=Sum("selling_price"))
34+
35+
with self.assertNumQueries(1):
36+
results = QuerysetsSingleQueryFetch(
37+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
38+
).execute()
39+
40+
self.assertEqual(len(results), 1)
41+
aggregate_result = results[0]
42+
43+
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
44+
self.assertIn('total_price', aggregate_result)
45+
self.assertAlmostEqual(
46+
float(aggregate_result['total_price']),
47+
float(aggregate_queryset['total_price']),
48+
places=2
49+
)
50+
51+
def test_multiple_aggregates(self):
52+
"""Test multiple aggregates in a single query"""
53+
queryset = StoreProduct.objects.filter()
54+
aggregate_queryset = queryset.aggregate(
55+
total_price=Sum("selling_price"),
56+
count=Count("id"),
57+
avg_price=Avg("selling_price"),
58+
max_price=Max("selling_price"),
59+
min_price=Min("selling_price"),
60+
)
61+
62+
with self.assertNumQueries(1):
63+
results = QuerysetsSingleQueryFetch(
64+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
65+
).execute()
66+
67+
self.assertEqual(len(results), 1)
68+
aggregate_result = results[0]
69+
70+
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
71+
72+
for key in aggregate_queryset.keys():
73+
self.assertIn(key, aggregate_result)
74+
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
75+
self.assertAlmostEqual(
76+
float(aggregate_result[key]),
77+
float(aggregate_queryset[key]),
78+
places=2
79+
)
80+
else:
81+
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
82+
83+
self.assertEqual(aggregate_result['count'], 4)
84+
self.assertAlmostEqual(
85+
float(aggregate_result['total_price']),
86+
float(Decimal('50.22') + Decimal('100.33') + Decimal('75.50') + Decimal('120.75')),
87+
places=2
88+
)
89+
90+
def test_filtered_aggregate(self):
91+
"""Test aggregate with filter"""
92+
queryset = StoreProduct.objects.filter(category=self.category1)
93+
aggregate_queryset = queryset.aggregate(
94+
total_price=Sum("selling_price"),
95+
count=Count("id"),
96+
)
97+
98+
with self.assertNumQueries(1):
99+
results = QuerysetsSingleQueryFetch(
100+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
101+
).execute()
102+
103+
self.assertEqual(len(results), 1)
104+
aggregate_result = results[0]
105+
106+
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
107+
108+
for key in aggregate_queryset.keys():
109+
self.assertIn(key, aggregate_result)
110+
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
111+
self.assertAlmostEqual(
112+
float(aggregate_result[key]),
113+
float(aggregate_queryset[key]),
114+
places=2
115+
)
116+
else:
117+
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
118+
119+
self.assertEqual(aggregate_result['count'], 2) # Only products in category1
120+
self.assertAlmostEqual(
121+
float(aggregate_result['total_price']),
122+
float(Decimal('100.33') + Decimal('75.50')),
123+
places=2
124+
)
125+
126+
def test_empty_aggregate(self):
127+
"""Test aggregate on empty queryset"""
128+
queryset = StoreProduct.objects.filter(id=-1) # No matches
129+
aggregate_queryset = queryset.aggregate(
130+
total_price=Sum("selling_price"),
131+
count=Count("id"),
132+
)
133+
134+
with self.assertNumQueries(1):
135+
results = QuerysetsSingleQueryFetch(
136+
querysets=[QuerysetAggregateWrapper(queryset=queryset, **aggregate_queryset)]
137+
).execute()
138+
139+
self.assertEqual(len(results), 1)
140+
aggregate_result = results[0]
141+
142+
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
143+
144+
for key in aggregate_queryset.keys():
145+
self.assertIn(key, aggregate_result)
146+
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
147+
148+
self.assertEqual(aggregate_result['count'], 0)
149+
self.assertIsNone(aggregate_result['total_price'])
150+
151+
def test_mix_with_other_querysets(self):
152+
"""Test mixture of aggregate wrapper and other querysets"""
153+
aggregate_queryset = StoreProduct.objects.filter().aggregate(
154+
total_price=Sum("selling_price"),
155+
count=Count("id"),
156+
)
157+
regular_queryset = StoreProductCategory.objects.filter()
158+
159+
with self.assertNumQueries(1):
160+
results = QuerysetsSingleQueryFetch(
161+
querysets=[
162+
QuerysetAggregateWrapper(queryset=StoreProduct.objects.filter(), **aggregate_queryset),
163+
regular_queryset
164+
]
165+
).execute()
166+
167+
self.assertEqual(len(results), 2)
168+
aggregate_result = results[0]
169+
categories = results[1]
170+
171+
self.assertEqual(len(aggregate_result), len(aggregate_queryset))
172+
173+
for key in aggregate_queryset.keys():
174+
self.assertIn(key, aggregate_result)
175+
if isinstance(aggregate_queryset[key], Decimal) or isinstance(aggregate_result[key], (int, float, Decimal)):
176+
self.assertAlmostEqual(
177+
float(aggregate_result[key]),
178+
float(aggregate_queryset[key]),
179+
places=2
180+
)
181+
else:
182+
self.assertEqual(aggregate_result[key], aggregate_queryset[key])
183+
184+
regular_categories = list(regular_queryset)
185+
self.assertEqual(len(categories), len(regular_categories))

0 commit comments

Comments
 (0)