@@ -56,8 +56,8 @@ class QuerysetAggregateWrapper:
56
56
"""
57
57
Wrapper around queryset to indicate that we want to fetch the result of .aggregate()
58
58
This is useful for executing aggregate queries in a single database query along with other querysets.
59
-
60
- Since aggregates don't support lazy evaluation, we need to store the queryset and
59
+
60
+ Since aggregates don't support lazy evaluation, we need to store the queryset and
61
61
the aggregate expressions separately.
62
62
"""
63
63
@@ -218,15 +218,17 @@ def _get_sanitized_sql_param(self, param) -> str:
218
218
return str (param )
219
219
if isinstance (param , bool ):
220
220
return "TRUE" if param else "FALSE"
221
-
221
+
222
222
param_str = str (param )
223
-
223
+
224
224
try :
225
225
from psycopg import sql
226
+
226
227
return sql .quote (param_str )
227
228
except ImportError :
228
229
try :
229
230
from psycopg2 .extensions import QuotedString
231
+
230
232
return QuotedString (param_str ).getquoted ().decode ("utf-8" )
231
233
except ImportError :
232
234
raise ImportError ("psycopg or psycopg2 not installed" )
@@ -268,10 +270,9 @@ def _get_django_sql_for_queryset(self, queryset: QuerysetWrapperType) -> str:
268
270
django_sql = sql % quoted_params
269
271
270
272
if isinstance (queryset , QuerysetAggregateWrapper ):
271
-
272
273
compiler = self ._get_compiler_from_queryset (queryset .queryset )
273
274
sql , params = compiler .as_sql ()
274
-
275
+
275
276
if isinstance (params , dict ):
276
277
quoted_params = {}
277
278
for key , value in params .items ():
@@ -282,15 +283,15 @@ def _get_django_sql_for_queryset(self, queryset: QuerysetWrapperType) -> str:
282
283
for value in params :
283
284
quoted_params .append (self ._get_sanitized_sql_param (value ))
284
285
base_sql = sql % tuple (quoted_params )
285
-
286
+
286
287
aggregate_sql_parts = []
287
288
for key , value in queryset .aggregate_expressions .items ():
288
289
if isinstance (value , Sum ):
289
290
field = value .source_expressions [0 ].name
290
291
aggregate_sql_parts .append (f"'{ key } ', SUM(subquery.{ field } )" )
291
292
elif isinstance (value , Count ):
292
293
field = value .source_expressions [0 ].name
293
- if field == '*' :
294
+ if field == "*" :
294
295
aggregate_sql_parts .append (f"'{ key } ', COUNT(*)" )
295
296
else :
296
297
aggregate_sql_parts .append (f"'{ key } ', COUNT(subquery.{ field } )" )
@@ -303,7 +304,7 @@ def _get_django_sql_for_queryset(self, queryset: QuerysetWrapperType) -> str:
303
304
elif isinstance (value , Min ):
304
305
field = value .source_expressions [0 ].name
305
306
aggregate_sql_parts .append (f"'{ key } ', MIN(subquery.{ field } )" )
306
-
307
+
307
308
if aggregate_sql_parts :
308
309
return f"(SELECT array_to_json(array[row(json_build_object({ ', ' .join (aggregate_sql_parts )} ))]) FROM ({ base_sql } ) AS subquery)"
309
310
else :
@@ -436,10 +437,12 @@ def _convert_raw_results_to_final_queryset_results(
436
437
queryset_results = queryset_raw_results [0 ]["__count" ]
437
438
elif isinstance (queryset , QuerysetAggregateWrapper ):
438
439
if queryset_raw_results and len (queryset_raw_results ) > 0 :
439
- nested_result = queryset_raw_results [0 ].get ('f1' , {})
440
+ nested_result = queryset_raw_results [0 ].get ("f1" , {})
440
441
queryset_results = nested_result
441
442
else :
442
- queryset_results = {key : None for key in queryset .aggregate_expressions .keys ()}
443
+ queryset_results = {
444
+ key : None for key in queryset .aggregate_expressions .keys ()
445
+ }
443
446
for key , value in queryset .aggregate_expressions .items ():
444
447
if isinstance (value , Count ):
445
448
queryset_results [key ] = 0
@@ -519,11 +522,9 @@ def execute(self) -> list[list[Any]]:
519
522
with connections ["default" ].cursor () as cursor :
520
523
cursor .execute (raw_sql , params = {})
521
524
raw_sql_result_dict : dict = cursor .fetchone ()[0 ]
522
- print (f"Raw SQL result: { raw_sql_result_dict } " )
523
525
else :
524
526
# all querysets are always empty (EmptyResultSet)
525
527
raw_sql_result_dict = {}
526
- print ("All querysets are empty" )
527
528
528
529
final_result = []
529
530
index = 0
@@ -532,14 +533,14 @@ def execute(self) -> list[list[Any]]:
532
533
# empty sql case
533
534
final_result .append (result )
534
535
continue
535
-
536
+
536
537
raw_results = raw_sql_result_dict .get (str (index ), [])
537
-
538
+
538
539
converted_results = self ._convert_raw_results_to_final_queryset_results (
539
540
queryset = queryset ,
540
541
queryset_raw_results = raw_results ,
541
542
)
542
-
543
+
543
544
final_result .append (converted_results )
544
545
index += 1
545
546
0 commit comments