diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 0c30c9785766..0d425c57f55b 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -200,21 +200,26 @@ pub fn check_subquery_expr( } }?; match outer_plan { - LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { + LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - ) + "Correlated scalar subquery can only be used in Projection, \ + Filter, Aggregate plan nodes" + ), }?; } check_correlations_in_subquery(inner_plan) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index bbf0b0dd810e..499cfeebe421 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1089,7 +1089,13 @@ impl OptimizerRule for PushDownFilter { let (volatile_filters, non_volatile_filters): (Vec<&Expr>, Vec<&Expr>) = filter_predicates .into_iter() - .partition(|pred| pred.is_volatile()); + // TODO: subquery decorrelation sometimes cannot decorrelated all the expr + // (i.e in the case of recursive subquery) + // this function may accidentally pushdown the subquery expr as well + // until then, we have to exclude these exprs here + .partition(|pred| { + pred.is_volatile() || has_scalar_subquery(pred) + }); // Check which non-volatile filters are supported by source let supported_filters = scan @@ -1382,6 +1388,14 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { is_contain } +fn has_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::ScalarSubquery(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + #[cfg(test)] mod tests { use std::any::Any; diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 7c276ce53e35..9ee7b22e6dde 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -69,7 +69,7 @@ impl SqlToRel<'_, S> { } // Check the outer query schema - if let Some(outer) = planner_context.outer_query_schema() { + for outer in planner_context.outer_queries_schemas() { if let Ok((qualifier, field)) = outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) { @@ -165,35 +165,43 @@ impl SqlToRel<'_, S> { not_impl_err!("compound identifier: {ids:?}") } else { // Check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { - let search_result = search_dfschema(&ids, outer); - match search_result { - // Found matching field with spare identifier(s) for nested field(s) in structure - Some((field, qualifier, nested_names)) - if !nested_names.is_empty() => - { - // TODO: remove when can support nested identifiers for OuterReferenceColumn - not_impl_err!( + let outer_schemas = planner_context.outer_queries_schemas(); + let mut maybe_result = None; + if !outer_schemas.is_empty() { + for outer in planner_context.outer_queries_schemas() { + let search_result = search_dfschema(&ids, &outer); + let result = match search_result { + // Found matching field with spare identifier(s) for nested field(s) in structure + Some((field, qualifier, nested_names)) + if !nested_names.is_empty() => + { + // TODO: remove when can support nested identifiers for OuterReferenceColumn + not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)).quoted_flat_name() ) - } - // Found matching field with no spare identifier(s) - Some((field, qualifier, _nested_names)) => { - // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), - Column::from((qualifier, field)), - )) - } - // Found no matching field, will return a default - None => { - let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds - let (relation, column_name) = - form_identifier(s).unwrap(); - Ok(Expr::Column(Column::new(relation, column_name))) - } + } + // Found matching field with no spare identifier(s) + Some((field, qualifier, _nested_names)) => { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + Column::from((qualifier, field)), + )) + } + // Found no matching field, will return a default + None => continue, + }; + maybe_result = Some(result); + break; + } + if let Some(result) = maybe_result { + result + } else { + let s = &ids[0..ids.len()]; + // safe unwrap as s can never be empty or exceed the bounds + let (relation, column_name) = form_identifier(s).unwrap(); + Ok(Expr::Column(Column::new(relation, column_name))) } } else { let s = &ids[0..ids.len()]; diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 602d39233d58..6e10607d8533 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -31,11 +31,10 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); Ok(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(sub_plan), @@ -54,8 +53,7 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { @@ -70,7 +68,7 @@ impl SqlToRel<'_, S> { let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, @@ -98,8 +96,8 @@ impl SqlToRel<'_, S> { input_schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + planner_context.append_outer_query_schema(input_schema.clone().into()); + let mut spans = Spans::new(); if let SetExpr::Select(select) = subquery.body.as_ref() { for item in &select.projection { @@ -112,7 +110,7 @@ impl SqlToRel<'_, S> { } let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_outer_query_schema); + planner_context.pop_outer_query_schema(); self.validate_single_column( &sub_plan, diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 73d136d7d1cc..c7bf248e1b1a 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -198,7 +198,15 @@ pub struct PlannerContext { /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, + + /// The queries schemas of outer query relations, used to resolve the outer referenced + /// columns in subquery (recursive aware) + outer_queries_schemas_stack: Vec, + /// The query schema of the outer query plan, used to resolve the columns in subquery + /// This field is maintained to support deprecated functions + /// `outer_query_schema` and `set_outer_query_schema` + /// which is only aware of the adjacent outer relation outer_query_schema: Option, /// The joined schemas of all FROM clauses planned so far. When planning LATERAL /// FROM clauses, this should become a suffix of the `outer_query_schema`. @@ -220,6 +228,7 @@ impl PlannerContext { prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), outer_query_schema: None, + outer_queries_schemas_stack: vec![], outer_from_schema: None, create_table_schema: None, } @@ -234,13 +243,22 @@ impl PlannerContext { self } - // Return a reference to the outer query's schema + /// Return a reference to the outer query's schema + /// This function should not be used together with + /// `outer_queries_schemas`, `append_outer_query_schema` + /// `latest_outer_query_schema` and `pop_outer_query_schema` + #[deprecated(note = "Use outer_queries_schemas instead")] pub fn outer_query_schema(&self) -> Option<&DFSchema> { self.outer_query_schema.as_ref().map(|s| s.as_ref()) } /// Sets the outer query schema, returning the existing one, if - /// any + /// any, this function should not be used together with + /// `outer_queries_schemas`, `append_outer_query_schema` + /// `latest_outer_query_schema` and `pop_outer_query_schema` + #[deprecated( + note = "This struct is now aware of a stack of schemas, check pop_outer_query_schema" + )] pub fn set_outer_query_schema( &mut self, mut schema: Option, @@ -249,6 +267,28 @@ impl PlannerContext { schema } + /// Return the stack of outer relations' schemas, the outer most + /// relation are at the first entry + pub fn outer_queries_schemas(&self) -> Vec { + self.outer_queries_schemas_stack.to_vec() + } + + /// Sets the outer query schema, returning the existing one, if + /// any + pub fn append_outer_query_schema(&mut self, schema: DFSchemaRef) { + self.outer_queries_schemas_stack.push(schema); + } + + /// The schema of the adjacent outer relation + pub fn latest_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.last().cloned() + } + + /// Remove the schema of the adjacent outer relation + pub fn pop_outer_query_schema(&mut self) -> Option { + self.outer_queries_schemas_stack.pop() + } + pub fn set_table_schema( &mut self, mut schema: Option, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 88a32a218341..e494404a50a7 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -184,20 +184,24 @@ impl SqlToRel<'_, S> { let old_from_schema = planner_context .set_outer_from_schema(None) .unwrap_or_else(|| Arc::new(DFSchema::empty())); - let new_query_schema = match planner_context.outer_query_schema() { - Some(old_query_schema) => { + let outer_query_schema = planner_context.pop_outer_query_schema(); + let new_query_schema = match outer_query_schema { + Some(ref old_query_schema) => { let mut new_query_schema = old_from_schema.as_ref().clone(); - new_query_schema.merge(old_query_schema); - Some(Arc::new(new_query_schema)) + new_query_schema.merge(old_query_schema.as_ref()); + Arc::new(new_query_schema) } - None => Some(Arc::clone(&old_from_schema)), + None => Arc::clone(&old_from_schema), }; - let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + planner_context.append_outer_query_schema(new_query_schema); let plan = self.create_relation(subquery, planner_context)?; let outer_ref_columns = plan.all_out_ref_exprs(); - planner_context.set_outer_query_schema(old_query_schema); + planner_context.pop_outer_query_schema(); + if let Some(schema) = outer_query_schema { + planner_context.append_outer_query_schema(schema); + } planner_context.set_outer_from_schema(Some(old_from_schema)); // We can omit the subquery wrapper if there are no columns diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 9fad274b51c0..ac2fea310933 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -29,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::{not_impl_err, plan_err, DFSchema, Result}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -506,12 +506,8 @@ impl SqlToRel<'_, S> { match selection { Some(predicate_expr) => { let fallback_schemas = plan.fallback_normalize_schemas(); - let outer_query_schema = planner_context.outer_query_schema().cloned(); - let outer_query_schema_vec = outer_query_schema - .as_ref() - .map(|schema| vec![schema]) - .unwrap_or_else(Vec::new); + let outer_query_schema_vec = planner_context.outer_queries_schemas(); let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; @@ -526,9 +522,19 @@ impl SqlToRel<'_, S> { let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; + let mut schema_stack: Vec> = + vec![vec![plan.schema()], fallback_schemas]; + for sc in outer_query_schema_vec.iter().rev() { + schema_stack.push(vec![sc.as_ref()]); + } + let filter_expr = normalize_col_with_schemas_and_ambiguity_check( filter_expr, - &[&[plan.schema()], &fallback_schemas, &outer_query_schema_vec], + schema_stack + .iter() + .map(|sc| sc.as_slice()) + .collect::>() + .as_slice(), &[using_columns], )?; diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index b5189d16ec24..aae1a4fae758 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4689,6 +4689,30 @@ logical_plan 08)----------TableScan: j3 projection=[j3_string, j3_id] physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1" }), name: "j1_id" }) +# 2 nested lateral join with the deepest join referencing the outer most relation +query TT +explain SELECT * FROM j1 j1_outer, LATERAL ( + SELECT * FROM j1 j1_inner, LATERAL ( + SELECT * FROM j2 WHERE j1_inner.j1_id = j2_id and j1_outer.j1_id=j2_id + ) as j2 +) as j2; +---- +logical_plan +01)Cross Join: +02)--SubqueryAlias: j1_outer +03)----TableScan: j1 projection=[j1_string, j1_id] +04)--SubqueryAlias: j2 +05)----Subquery: +06)------Cross Join: +07)--------SubqueryAlias: j1_inner +08)----------TableScan: j1 projection=[j1_string, j1_id] +09)--------SubqueryAlias: j2 +10)----------Subquery: +11)------------Filter: outer_ref(j1_inner.j1_id) = j2.j2_id AND outer_ref(j1_outer.j1_id) = j2.j2_id +12)--------------TableScan: j2 projection=[j2_string, j2_id] +physical_plan_error This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn(Int32, Column { relation: Some(Bare { table: "j1_inner" }), name: "j1_id" }) + + query TT explain SELECT * FROM j1, LATERAL (SELECT 1) AS j2; ---- diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 796570633f67..03372e32baca 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1482,3 +1482,85 @@ logical_plan statement count 0 drop table person; + +# correlated_recursive_scalar_subquery_with_level_3_scalar_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice < ( + select sum(l_extendedprice) as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_1 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < () +10)----------------Subquery: +11)------------------Projection: sum(lineitem.l_extendedprice) AS price +12)--------------------Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] +13)----------------------Filter: lineitem.l_orderkey = outer_ref(orders.o_orderkey) AND lineitem.l_extendedprice < outer_ref(customer.c_acctbal) +14)------------------------TableScan: lineitem, partial_filters=[lineitem.l_orderkey = outer_ref(orders.o_orderkey), lineitem.l_extendedprice < outer_ref(customer.c_acctbal)] +15)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] + +# correlated_recursive_scalar_subquery_with_level_3_exists_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and exists ( + select * from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] + +# correlated_recursive_scalar_subquery_with_level_3_in_subquery_referencing_level1_relation +query TT +explain select c_custkey from customer +where c_acctbal < ( + select sum(o_totalprice) from orders + where o_custkey = c_custkey + and o_totalprice in ( + select l_extendedprice as price from lineitem where l_orderkey = o_orderkey + and l_extendedprice < c_acctbal + ) +) order by c_custkey; +---- +logical_plan +01)Sort: customer.c_custkey ASC NULLS LAST +02)--Projection: customer.c_custkey +03)----Inner Join: customer.c_custkey = __scalar_sq_2.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_2.sum(orders.o_totalprice) +04)------TableScan: customer projection=[c_custkey, c_acctbal] +05)------SubqueryAlias: __scalar_sq_2 +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] +08)------------Projection: orders.o_custkey, orders.o_totalprice +09)--------------LeftSemi Join: orders.o_totalprice = __correlated_sq_1.price, orders.o_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_extendedprice < customer.c_acctbal +10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] +11)----------------SubqueryAlias: __correlated_sq_1 +12)------------------Projection: lineitem.l_extendedprice AS price, lineitem.l_extendedprice, lineitem.l_orderkey +13)--------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] \ No newline at end of file