diff --git a/django_cte/cte.py b/django_cte/cte.py index 96d03e4..1141ace 100644 --- a/django_cte/cte.py +++ b/django_cte/cte.py @@ -153,6 +153,17 @@ def as_manager(cls): as_manager.queryset_only = True as_manager = classmethod(as_manager) + def _combinator_query(self, combinator, *other_qs, all=False): + clone = super(CTEQuerySet, self)._combinator_query(combinator, *other_qs, all=False) + if clone.query.combinator: + # Move CTE onto parent query, we can modify as they are cloned by super + sub_cte = [] + for query in clone.query.combined_queries: + sub_cte.extend(query._with_ctes) + query._with_ctes=[] + clone.query._with_ctes = sub_cte + return clone + class CTEManager(Manager.from_queryset(CTEQuerySet)): """Manager for models that perform CTE queries""" diff --git a/tests/test_cte.py b/tests/test_cte.py index ed7a112..7d10383 100644 --- a/tests/test_cte.py +++ b/tests/test_cte.py @@ -597,6 +597,25 @@ def test_left_outer_join_on_empty_result_set_cte(self): self.assertEqual(len(orders), 22) + def test_union_query_with_no_ctes(self): + query_partA = Order.objects.filter(region__name="earth") + query_partB = Order.objects.filter(region__name="mars") + query_combined = ( + query_partA.union(query_partB) + .order_by('region__name', 'amount') + .values_list('region__name', 'amount') + ) + print(query_combined.query) + self.assertEqual(list(query_combined), [ + ('earth', 30), + ('earth', 31), + ('earth', 32), + ('earth', 33), + ('mars', 40), + ('mars', 41), + ('mars', 42), + ]) + def test_union_query_with_cte(self): orders = ( Order.objects @@ -627,3 +646,54 @@ def test_union_query_with_cte(self): ('mars', 41), ('mars', 42), ]) + + def test_recursive_union_query_with_cte(self): + origin_node_pk = Region.objects.get(name='sun').pk + + def make_root_mapping(leaf_cte): + return Region.objects.filter( + parent_id=origin_node_pk + ).values( + rid=F('name'), + ).union( + leaf_cte.join( + Region, parent=leaf_cte.col.rid + ).values( + rid=F('name'), + ).distinct(), + all=False + ) + + self_leaf_node = ( # Find leaf origin nodes + Region.objects.filter(pk=origin_node_pk) + .annotate( + linked=Exists( + Region.objects.filter( + parent_id = OuterRef('pk')) + ) + ) + .filter(linked=False) + ) + leaf_cte = With.recursive(make_root_mapping, name="leaf_cte") + non_self_leaf_nodes = ( + leaf_cte.join(Region, pk=leaf_cte.col.rid) + .with_cte(leaf_cte) + .annotate( + linked=Exists( + Region.objects.filter(parent_id=OuterRef('pk')) + ) + ) + .filter(linked=False) + ) + all_leaf_nodes = non_self_leaf_nodes.union(self_leaf_node).order_by('name') + print(all_leaf_nodes.query) + self.assertEqual( + list(all_leaf_nodes.values_list('name')), + [ + ('deimos',), + ('mercury',), + ('moon',), + ('phobos',), + ('venus',) + ] + )