From 4ba36c0f259398939074574e9b12ff9c9ae8a80e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 3 Feb 2025 04:44:04 +0100 Subject: [PATCH 01/70] chore: add test --- .../sqllogictest/test_files/unsupported.slt | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 datafusion/sqllogictest/test_files/unsupported.slt diff --git a/datafusion/sqllogictest/test_files/unsupported.slt b/datafusion/sqllogictest/test_files/unsupported.slt new file mode 100644 index 000000000000..742d8f529e3d --- /dev/null +++ b/datafusion/sqllogictest/test_files/unsupported.slt @@ -0,0 +1,69 @@ +statement ok +CREATE TABLE students( + id int, + name varchar, + major varchar, + year int +) +AS VALUES + (1,'toai','math',2014), + (2,'manh','math',2015), + (3,'bao','math',2025) +; + +statement ok +CREATE TABLE exams( + sid int, + curriculum varchar, + grade int, + date int +) +AS VALUES + (1, 'math', 10, 2014), + (2, 'math', 9, 2015), + (3, 'math', 4, 2025) +; + +query TTR +select s.name, e.curriculum, pulled.m as standard_grade from students s, exams e, ( + select avg(e2.grade) as m, id ,d.year ,d.major from ( + select distinct id, year, major from students + ) as d join exams e2 where d.id=e2.sid or ( + d.year > e2.date and d.major = e2.curriculum + ) group by id,year,major +) as pulled where +s.id=e.sid +and e.grade < pulled.m +and ( + pulled.id=s.id and pulled.year=s.year and pulled.major=s.major -- join with the domain columns +) +---- +manh math 9.5 +bao math 7.666666666667 + +query TT +explain select s.name, e.curriculum from students s, exams e where s.id=e.sid +and (s.major='math') and e.grade < ( + select avg(e2.grade) from exams e2 where s.id=e2.sid or ( + s.year) +10)----------Subquery: +11)------------Projection: avg(e2.grade) +12)--------------Aggregate: groupBy=[[]], aggr=[[avg(CAST(e2.grade AS Float64))]] +13)----------------SubqueryAlias: e2 +14)------------------Filter: outer_ref(s.id) = exams.sid OR outer_ref(s.year) < exams.date AND exams.curriculum = outer_ref(s.major) +15)--------------------TableScan: exams +16)----------TableScan: exams projection=[sid, curriculum, grade] +physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() \ No newline at end of file From 79eaca3a2b5bf84ffe8b89d971632b7ea3e32348 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 10 Feb 2025 03:13:43 +0100 Subject: [PATCH 02/70] chore: more progress --- datafusion/expr/src/utils.rs | 25 + .../optimizer/src/decorrelate_general.rs | 662 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + .../optimizer/src/scalar_subquery_to_join.rs | 15 +- datafusion/sqllogictest/test_files/debug.slt | 67 ++ .../sqllogictest/test_files/subquery.slt | 7 + .../sqllogictest/test_files/unsupported.slt | 7 + 7 files changed, 783 insertions(+), 1 deletion(-) create mode 100644 datafusion/optimizer/src/decorrelate_general.rs create mode 100644 datafusion/sqllogictest/test_files/debug.slt diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 049926fb0bcd..e616b511d3af 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1093,6 +1093,31 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { split_conjunction_impl(expr, vec![]) } +/// Splits a conjunctive [`Expr`] such as `A OR B OR C` => `[A, B, C]` +/// +/// See [`split_disjunction`] for more details and an example. +pub fn split_disjunction(expr: &Expr) -> Vec<&Expr> { + split_disjunction_impl(expr, vec![]) +} + +fn split_disjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::Or, + left, + }) => { + let exprs = split_disjunction_impl(left, exprs); + split_disjunction_impl(right, exprs) + } + Expr::Alias(Alias { expr, .. }) => split_disjunction_impl(expr, exprs), + other => { + exprs.push(other); + exprs + } + } +} + fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs new file mode 100644 index 000000000000..c8b6ff4f832c --- /dev/null +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -0,0 +1,662 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` + +use std::collections::BTreeSet; +use std::ops::Deref; +use std::sync::Arc; + +use crate::simplify_expressions::ExprSimplifier; + +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue}; +use datafusion_expr::expr::Alias; +use datafusion_expr::simplify::SimplifyContext; +use datafusion_expr::utils::{ + collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, + split_disjunction, +}; +use datafusion_expr::{ + expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, +}; +use datafusion_physical_expr::execution_props::ExecutionProps; + +/// This struct rewrite the sub query plan by pull up the correlated +/// expressions(contains outer reference columns) from the inner subquery's +/// 'Filter'. It adds the inner reference columns to the 'Projection' or +/// 'Aggregate' of the subquery if they are missing, so that they can be +/// evaluated by the parent operator as the join condition. +#[derive(Debug)] +pub struct GeneralPullUpCorrelatedExpr { + pub join_filters: Vec, + /// mapping from the plan to its holding correlated columns + pub correlated_subquery_cols_map: HashMap>, + pub in_predicate_opt: Option, + /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE** + pub exists_sub_query: bool, + /// Can the correlated expressions be pulled up. Defaults to **TRUE** + pub can_pull_up: bool, + /// Indicates if we encounter any correlated expression that can not be pulled up + /// above a aggregation without changing the meaning of the query. + can_pull_over_aggregation: bool, + /// Do we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 + pub need_handle_count_bug: bool, + /// mapping from the plan to its expressions' evaluation result on empty batch + pub collected_count_expr_map: HashMap, + /// pull up having expr, which must be evaluated after the Join + pub pull_up_having_expr: Option, +} + +impl Default for GeneralPullUpCorrelatedExpr { + fn default() -> Self { + Self::new() + } +} + +impl GeneralPullUpCorrelatedExpr { + pub fn new() -> Self { + Self { + join_filters: vec![], + correlated_subquery_cols_map: HashMap::new(), + in_predicate_opt: None, + exists_sub_query: false, + can_pull_up: true, + can_pull_over_aggregation: true, + need_handle_count_bug: false, + collected_count_expr_map: HashMap::new(), + pull_up_having_expr: None, + } + } + + /// Set if we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 + pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self { + self.need_handle_count_bug = need_handle_count_bug; + self + } + + /// Set the in_predicate_opt + pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option) -> Self { + self.in_predicate_opt = in_predicate_opt; + self + } + + /// Set if this is an Exists(Not Exists) SubQuery + pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self { + self.exists_sub_query = exists_sub_query; + self + } +} + +/// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join +/// This is used to handle [the Count bug] +/// +/// [the Count bug]: https://github.com/apache/datafusion/pull/10500 +pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; + +/// Mapping from expr display name to its evaluation result on empty record +/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is +/// 'ScalarValue(2)') +pub type ExprResultMap = HashMap; + +impl TreeNodeRewriter for GeneralPullUpCorrelatedExpr { + type Node = LogicalPlan; + + fn f_down(&mut self, plan: LogicalPlan) -> Result> { + match plan { + LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), + LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { + let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); + println!("plan hold outer and contains union"); + if plan_hold_outer { + // the unsupported case + self.can_pull_up = false; + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } else { + Ok(Transformed::no(plan)) + } + } + LogicalPlan::Limit(_) => { + let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); + match (self.exists_sub_query, plan_hold_outer) { + (false, true) => { + // the unsupported case + println!("plan has limit and no subquery found and plan hold outer ref"); + self.can_pull_up = false; + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } + _ => Ok(Transformed::no(plan)), + } + } + _ if plan.contains_outer_reference() => { + println!("plan contains outer reference, cannot pull up"); + // the unsupported cases, the plan expressions contain out reference columns(like window expressions) + self.can_pull_up = false; + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) + } + _ => Ok(Transformed::no(plan)), + } + } + + fn f_up(&mut self, plan: LogicalPlan) -> Result> { + let subquery_schema = plan.schema(); + println!("XXXXXXXXXXXXX Plan type {}", plan.display()); + match &plan { + // TODO: what if this happen recursively? + // select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid + // and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) + LogicalPlan::Filter(plan_filter) => { + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + let or_filters = split_disjunction(&plan_filter.predicate); + or_filters.iter().for_each(|f| { + println!("or filter {}", f); + }); + self.can_pull_over_aggregation = self.can_pull_over_aggregation + && subquery_filter_exprs + .iter() + .filter(|e| e.contains_outer()) + .all(|&e| { + let ret = can_pullup_over_aggregation(e); + if !ret { + println!("can NOT pull up over aggregation {:?}", e); + } + ret + }); + + let (mut join_filters, subquery_filters) = + find_join_exprs(subquery_filter_exprs)?; + + if let Some(in_predicate) = &self.in_predicate_opt { + // in_predicate may be already included in the join filters, remove it from the join filters first. + join_filters = remove_duplicated_filter(join_filters, in_predicate); + } + println!("JOIN FILTERS"); + for expr in join_filters.iter() { + println!("{}", expr); + } + + // TODO: these cols only include the inner's table columns which is not sufficient + // in the case of complex unnest + // + // We need to collect all the columns in the outer table to construct the domain + // and join the domain with the inner table to prepare for the aggregation + let correlated_subquery_cols = + collect_subquery_cols(&join_filters, subquery_schema)?; + println!("CORRELATED COLUMS"); + for col in correlated_subquery_cols.iter() { + println!("{}", col); + } + // TODO: these join filters may need to be transformed, because now the join + // happen between the outer table columns and the newly built relation + for expr in join_filters { + if !self.join_filters.contains(&expr) { + self.join_filters.push(expr) + } + } + + let mut expr_result_map_for_count_bug = HashMap::new(); + let pull_up_expr_opt = if let Some(expr_result_map) = + self.collected_count_expr_map.get(plan_filter.input.deref()) + { + if let Some(expr) = conjunction(subquery_filters.clone()) { + filter_exprs_evaluation_result_on_empty_batch( + &expr, + Arc::clone(plan_filter.input.schema()), + expr_result_map, + &mut expr_result_map_for_count_bug, + )? + } else { + None + } + } else { + None + }; + + match (&pull_up_expr_opt, &self.pull_up_having_expr) { + (Some(_), Some(_)) => { + // Error path + plan_err!("Unsupported Subquery plan") + } + (Some(_), None) => { + self.pull_up_having_expr = pull_up_expr_opt; + let new_plan = + LogicalPlanBuilder::from((*plan_filter.input).clone()) + .build()?; + self.correlated_subquery_cols_map + .insert(new_plan.clone(), correlated_subquery_cols); + Ok(Transformed::yes(new_plan)) + } + (None, _) => { + // if the subquery still has filter expressions, restore them. + let mut plan = + LogicalPlanBuilder::from((*plan_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + plan = plan.filter(expr)? + } + let new_plan = plan.build()?; + self.correlated_subquery_cols_map + .insert(new_plan.clone(), correlated_subquery_cols); + Ok(Transformed::yes(new_plan)) + } + } + } + LogicalPlan::Projection(projection) + if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => + { + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + // add missing columns to Projection + let mut missing_exprs = + self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?; + + let mut expr_result_map_for_count_bug = HashMap::new(); + if let Some(expr_result_map) = + self.collected_count_expr_map.get(projection.input.deref()) + { + proj_exprs_evaluation_result_on_empty_batch( + &projection.expr, + projection.input.schema(), + expr_result_map, + &mut expr_result_map_for_count_bug, + )?; + if !expr_result_map_for_count_bug.is_empty() { + // has count bug + let un_matched_row = Expr::Column(Column::new_unqualified( + UN_MATCHED_ROW_INDICATOR.to_string(), + )); + // add the unmatched rows indicator to the Projection expressions + missing_exprs.push(un_matched_row); + } + } + + let new_plan = LogicalPlanBuilder::from((*projection.input).clone()) + .project(missing_exprs)? + .build()?; + if !expr_result_map_for_count_bug.is_empty() { + self.collected_count_expr_map + .insert(new_plan.clone(), expr_result_map_for_count_bug); + } + Ok(Transformed::yes(new_plan)) + } + LogicalPlan::Aggregate(aggregate) + if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => + { + // If the aggregation is from a distinct it will not change the result for + // exists/in subqueries so we can still pull up all predicates. + let is_distinct = aggregate.aggr_expr.is_empty(); + if !is_distinct { + println!( + "can pull up {:?} and can pull over aggregation {:?}", + self.can_pull_up, self.can_pull_over_aggregation + ); + self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; + } + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + // add missing columns to Aggregation's group expressions + let mut missing_exprs = self.collect_missing_exprs( + &aggregate.group_expr, + &local_correlated_cols, + )?; + + // if the original group expressions are empty, need to handle the Count bug + let mut expr_result_map_for_count_bug = HashMap::new(); + if self.need_handle_count_bug + && aggregate.group_expr.is_empty() + && !missing_exprs.is_empty() + { + agg_exprs_evaluation_result_on_empty_batch( + &aggregate.aggr_expr, + aggregate.input.schema(), + &mut expr_result_map_for_count_bug, + )?; + if !expr_result_map_for_count_bug.is_empty() { + // has count bug + let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); + // add the unmatched rows indicator to the Aggregation's group expressions + missing_exprs.push(un_matched_row); + } + } + let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) + .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? + .build()?; + if !expr_result_map_for_count_bug.is_empty() { + self.collected_count_expr_map + .insert(new_plan.clone(), expr_result_map_for_count_bug); + } + Ok(Transformed::yes(new_plan)) + } + LogicalPlan::SubqueryAlias(alias) => { + let mut local_correlated_cols = BTreeSet::new(); + collect_local_correlated_cols( + &plan, + &self.correlated_subquery_cols_map, + &mut local_correlated_cols, + ); + let mut new_correlated_cols = BTreeSet::new(); + for col in local_correlated_cols.iter() { + new_correlated_cols + .insert(Column::new(Some(alias.alias.clone()), col.name.clone())); + } + self.correlated_subquery_cols_map + .insert(plan.clone(), new_correlated_cols); + if let Some(input_map) = + self.collected_count_expr_map.get(alias.input.deref()) + { + self.collected_count_expr_map + .insert(plan.clone(), input_map.clone()); + } + Ok(Transformed::no(plan)) + } + LogicalPlan::Limit(limit) => { + let input_expr_map = self + .collected_count_expr_map + .get(limit.input.deref()) + .cloned(); + // handling the limit clause in the subquery + let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) + { + // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(limit.input.schema()), + }) + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), + _ => Transformed::no(plan), + }; + if let Some(input_map) = input_expr_map { + self.collected_count_expr_map + .insert(new_plan.data.clone(), input_map); + } + Ok(new_plan) + } + _ => Ok(Transformed::no(plan)), + } + } +} + +impl GeneralPullUpCorrelatedExpr { + fn collect_missing_exprs( + &self, + exprs: &[Expr], + correlated_subquery_cols: &BTreeSet, + ) -> Result> { + let mut missing_exprs = vec![]; + for expr in exprs { + if !missing_exprs.contains(expr) { + missing_exprs.push(expr.clone()) + } + } + for col in correlated_subquery_cols.iter() { + let col_expr = Expr::Column(col.clone()); + if !missing_exprs.contains(&col_expr) { + missing_exprs.push(col_expr) + } + } + if let Some(pull_up_having) = &self.pull_up_having_expr { + let filter_apply_columns = pull_up_having.column_refs(); + for col in filter_apply_columns { + // add to missing_exprs if not already there + let contains = missing_exprs + .iter() + .any(|expr| matches!(expr, Expr::Column(c) if c == col)); + if !contains { + missing_exprs.push(Expr::Column(col.clone())) + } + } + } + Ok(missing_exprs) + } +} + +/// for now only simple exprs can be pulled up over aggregation +/// such as binaryExpr between a outer column ref vs non column expr +/// In the general unnesting framework, the complex expr is pulled up, but being decomposed in some way +/// for example: +/// select * from exams e1 where score > (select avg(score) from exams e2 where e1.student_id = e2.student_id +/// or (e2.year > e1.year and e2.subject=e1.subject)) +/// In this case, the complex expr to be pulled up is +/// ``` +/// e1.student_id=e1.student_id or (e2.year > e1.year and e2.subject=e1.subject) +/// ``` +/// The complex expr is decomposed during the pull up over aggregation avg(score) +/// into a new relation +fn can_pullup_over_aggregation(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + +fn collect_local_correlated_cols( + plan: &LogicalPlan, + all_cols_map: &HashMap>, + local_cols: &mut BTreeSet, +) { + for child in plan.inputs() { + if let Some(cols) = all_cols_map.get(child) { + local_cols.extend(cols.clone()); + } + // SubqueryAlias is treated as the leaf node + if !matches!(child, LogicalPlan::SubqueryAlias(_)) { + collect_local_correlated_cols(child, all_cols_map, local_cols); + } + } +} + +fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { + filters + .into_iter() + .filter(|filter| { + if filter == in_predicate { + return false; + } + + // ignore the binary order + !match (filter, in_predicate) { + (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { + (a_expr.op == b_expr.op) + && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) + || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) + } + _ => false, + } + }) + .collect::>() +} + +fn agg_exprs_evaluation_result_on_empty_batch( + agg_expr: &[Expr], + schema: &DFSchemaRef, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result<()> { + for e in agg_expr.iter() { + let result_expr = e + .clone() + .transform_up(|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + if func.name() == "count" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + } else { + Transformed::yes(Expr::Literal(ScalarValue::Null)) + } + } + _ => Transformed::no(expr), + }; + Ok(new_expr) + }) + .data()?; + + let result_expr = result_expr.unalias(); + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { + expr_result_map_for_count_bug + .insert(e.schema_name().to_string(), result_expr); + } + } + Ok(()) +} + +fn proj_exprs_evaluation_result_on_empty_batch( + proj_expr: &[Expr], + schema: &DFSchemaRef, + input_expr_result_map_for_count_bug: &ExprResultMap, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result<()> { + for expr in proj_expr.iter() { + let result_expr = expr + .clone() + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(name) + { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + + if result_expr.ne(expr) { + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + let expr_name = match expr { + Expr::Alias(Alias { name, .. }) => name.to_string(), + Expr::Column(Column { + relation: _, + name, + spans: _, + }) => name.to_string(), + _ => expr.schema_name().to_string(), + }; + expr_result_map_for_count_bug.insert(expr_name, result_expr); + } + } + Ok(()) +} + +fn filter_exprs_evaluation_result_on_empty_batch( + filter_expr: &Expr, + schema: DFSchemaRef, + input_expr_result_map_for_count_bug: &ExprResultMap, + expr_result_map_for_count_bug: &mut ExprResultMap, +) -> Result> { + let result_expr = filter_expr + .clone() + .transform_up(|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + } else { + Ok(Transformed::no(expr)) + } + }) + .data()?; + + let pull_up_expr = if result_expr.ne(filter_expr) { + let props = ExecutionProps::new(); + let info = SimplifyContext::new(&props).with_schema(schema); + let simplifier = ExprSimplifier::new(info); + let result_expr = simplifier.simplify(result_expr)?; + match &result_expr { + // evaluate to false or null on empty batch, no need to pull up + Expr::Literal(ScalarValue::Null) + | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, + // evaluate to true on empty batch, need to pull up the expr + Expr::Literal(ScalarValue::Boolean(Some(true))) => { + for (name, exprs) in input_expr_result_map_for_count_bug { + expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); + } + Some(filter_expr.clone()) + } + // can not evaluate statically + _ => { + for input_expr in input_expr_result_map_for_count_bug.values() { + let new_expr = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(result_expr.clone()), + Box::new(input_expr.clone()), + )], + else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), + }); + let expr_key = new_expr.schema_name().to_string(); + expr_result_map_for_count_bug.insert(expr_key, new_expr); + } + None + } + } + } else { + for (name, exprs) in input_expr_result_map_for_count_bug { + expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); + } + None + }; + Ok(pull_up_expr) +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 614284e1b477..5a8e51ccb4a7 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -34,6 +34,7 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; +pub mod decorrelate_general; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 3a8aef267be5..12c9257a4b0c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -21,6 +21,7 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; +use crate::decorrelate_general::GeneralPullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; @@ -297,12 +298,16 @@ fn build_join( subquery_alias: &str, ) -> Result)>> { let subquery_plan = subquery.subquery.as_ref(); - let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); + let mut pull_up = GeneralPullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; + if !pull_up.can_pull_up { return Ok(None); } + println!("before rewrite: {}", subquery_plan); + println!("ater rewrite: {}", new_plan); + let collected_count_expr_map = pull_up.collected_count_expr_map.get(&new_plan).cloned(); let sub_query_alias = LogicalPlanBuilder::from(new_plan) @@ -314,12 +319,19 @@ fn build_join( .correlated_subquery_cols_map .values() .for_each(|cols| all_correlated_cols.extend(cols.clone())); + println!("========\ncorrelated cols"); + for col in &all_correlated_cols { + println!("{}", col); + } + println!("===================="); // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; + // TODO: build domain from filter input + // select distinct columns from filter input // join our sub query into the main plan let new_plan = if join_filter_opt.is_none() { @@ -336,6 +348,7 @@ fn build_join( } } } else { + println!("++++++++++++++++filter input: {}", filter_input); // left join if correlated, grouping by the join keys so we don't change row count LogicalPlanBuilder::from(filter_input.clone()) .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt new file mode 100644 index 000000000000..55affb07fb6b --- /dev/null +++ b/datafusion/sqllogictest/test_files/debug.slt @@ -0,0 +1,67 @@ +statement ok +CREATE TABLE students( + id int, + name varchar, + major varchar, + year timestamp +) +AS VALUES + (1,'A','math','2014-01-01T00:00:00'::timestamp), + (2,'B','math','2015-01-01T00:00:00'::timestamp), + (3,'C','math','2016-01-01T00:00:00'::timestamp) +; + +statement ok +CREATE TABLE exams( + sid int, + curriculum varchar, + grade int, + date timestamp +) +AS VALUES + (1, 'math', 10, '2014-01-01T00:00:00'::timestamp), + (2, 'math', 9, '2015-01-01T00:00:00'::timestamp), + (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) +; + +## Multi-level correlated subquery +##query TT +##explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid +##and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) +##---- + +query TT +explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid) +---- + + +## select * from exams e1, ( +## select avg(score) as avg_score, e2.sid, e2.year,e2.subject from exams e2 group by e2.sid,e2.year,e2.subject +## ) as pulled_up where e1.score > pulled_up.avg_score + + +## query TT +## explain select s.name, e.curriculum from students s, exams e where s.id=e.sid +## and (s.major='math') and e.grade < ( +## select avg(e2.grade) from exams e2 where s.id=e2.sid or ( +## s.year) +## 10)----------Subquery: +## 11)------------Projection: avg(e2.grade) +## 12)--------------Aggregate: groupBy=[[]], aggr=[[avg(CAST(e2.grade AS Float64))]] +## 13)----------------SubqueryAlias: e2 +## 14)------------------Filter: outer_ref(s.id) = exams.sid OR outer_ref(s.year) < exams.date AND exams.curriculum = outer_ref(s.major) +## 15)--------------------TableScan: exams +## 16)----------TableScan: exams projection=[sid, curriculum, grade] \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 8895a2986103..ce6ebfc6f4f3 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -870,6 +870,13 @@ SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) #correlated_scalar_subquery_count_agg_where_clause query TT explain select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = t2.t2_id) < t1.t1_int +select t1.t1_int from t1, +( + select count(*) as count_all from t2, ( + select distinct t1_id + ) as domain where t2.t2_id = domain.t1_id +) as pulled_up +where t1.t1_id=pulled_up.t1_id and pulled_up.count_all < t1.t1_int ---- logical_plan 01)Projection: t1.t1_int diff --git a/datafusion/sqllogictest/test_files/unsupported.slt b/datafusion/sqllogictest/test_files/unsupported.slt index 742d8f529e3d..101a3ecd4442 100644 --- a/datafusion/sqllogictest/test_files/unsupported.slt +++ b/datafusion/sqllogictest/test_files/unsupported.slt @@ -24,6 +24,13 @@ AS VALUES (3, 'math', 4, 2025) ; +-- explain select s.name, e.curriculum from students s, exams e where s.id=e.sid +-- and (s.major='math') and e.grade < ( +-- select avg(e2.grade) from exams e2 where s.id=e2.sid or ( +-- s.year Date: Tue, 18 Mar 2025 20:56:52 +0100 Subject: [PATCH 03/70] temp --- datafusion/sqllogictest/test_files/debug.slt | 41 -------------------- 1 file changed, 41 deletions(-) diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt index 55affb07fb6b..36bf75072759 100644 --- a/datafusion/sqllogictest/test_files/debug.slt +++ b/datafusion/sqllogictest/test_files/debug.slt @@ -24,44 +24,3 @@ AS VALUES (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) ; -## Multi-level correlated subquery -##query TT -##explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid -##and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) -##---- - -query TT -explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid) ----- - - -## select * from exams e1, ( -## select avg(score) as avg_score, e2.sid, e2.year,e2.subject from exams e2 group by e2.sid,e2.year,e2.subject -## ) as pulled_up where e1.score > pulled_up.avg_score - - -## query TT -## explain select s.name, e.curriculum from students s, exams e where s.id=e.sid -## and (s.major='math') and e.grade < ( -## select avg(e2.grade) from exams e2 where s.id=e2.sid or ( -## s.year) -## 10)----------Subquery: -## 11)------------Projection: avg(e2.grade) -## 12)--------------Aggregate: groupBy=[[]], aggr=[[avg(CAST(e2.grade AS Float64))]] -## 13)----------------SubqueryAlias: e2 -## 14)------------------Filter: outer_ref(s.id) = exams.sid OR outer_ref(s.year) < exams.date AND exams.curriculum = outer_ref(s.major) -## 15)--------------------TableScan: exams -## 16)----------TableScan: exams projection=[sid, curriculum, grade] \ No newline at end of file From 68fd9cad269a84e82edd1f167ea8b97bee43c9d6 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Wed, 16 Apr 2025 22:12:29 +0200 Subject: [PATCH 04/70] chore: some work --- .../optimizer/src/decorrelate_general.rs | 768 ++++-------------- .../optimizer/src/scalar_subquery_to_join.rs | 24 +- datafusion/sqllogictest/test_files/debug.slt | 35 + datafusion/sqllogictest/test_files/debug2.slt | 114 +++ .../sqllogictest/test_files/unsupported.slt | 16 +- 5 files changed, 330 insertions(+), 627 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/debug2.slt diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index c8b6ff4f832c..42f7f09aae0d 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -17,646 +17,206 @@ //! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` -use std::collections::BTreeSet; +use std::cell::RefCell; +use std::collections::{BTreeSet, HashSet}; use std::ops::Deref; +use std::rc::{Rc, Weak}; use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{plan_err, Column, DFSchemaRef, HashMap, Result, ScalarValue}; -use datafusion_expr::expr::Alias; -use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::utils::{ - collect_subquery_cols, conjunction, find_join_exprs, split_conjunction, - split_disjunction, -}; -use datafusion_expr::{ - expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, - LogicalPlanBuilder, Operator, -}; -use datafusion_physical_expr::execution_props::ExecutionProps; +use datafusion_common::{internal_err, Column, Result}; +use datafusion_expr::{Expr, LogicalPlan}; +use indexmap::map::Entry; +use indexmap::IndexMap; -/// This struct rewrite the sub query plan by pull up the correlated -/// expressions(contains outer reference columns) from the inner subquery's -/// 'Filter'. It adds the inner reference columns to the 'Projection' or -/// 'Aggregate' of the subquery if they are missing, so that they can be -/// evaluated by the parent operator as the join condition. #[derive(Debug)] -pub struct GeneralPullUpCorrelatedExpr { - pub join_filters: Vec, - /// mapping from the plan to its holding correlated columns - pub correlated_subquery_cols_map: HashMap>, - pub in_predicate_opt: Option, - /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE** - pub exists_sub_query: bool, - /// Can the correlated expressions be pulled up. Defaults to **TRUE** - pub can_pull_up: bool, - /// Indicates if we encounter any correlated expression that can not be pulled up - /// above a aggregation without changing the meaning of the query. - can_pull_over_aggregation: bool, - /// Do we need to handle [the Count bug] during the pull up process - /// - /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 - pub need_handle_count_bug: bool, - /// mapping from the plan to its expressions' evaluation result on empty batch - pub collected_count_expr_map: HashMap, - /// pull up having expr, which must be evaluated after the Join - pub pull_up_having_expr: Option, +pub struct GeneralDecorrelation { + root: Option, + current_id: usize, + nodes: IndexMap, // column_ + stack: Vec, } -impl Default for GeneralPullUpCorrelatedExpr { +impl Default for GeneralDecorrelation { fn default() -> Self { - Self::new() + return GeneralDecorrelation { + root: None, + current_id: 0, + nodes: IndexMap::new(), + stack: vec![], + }; } } -impl GeneralPullUpCorrelatedExpr { - pub fn new() -> Self { - Self { - join_filters: vec![], - correlated_subquery_cols_map: HashMap::new(), - in_predicate_opt: None, - exists_sub_query: false, - can_pull_up: true, - can_pull_over_aggregation: true, - need_handle_count_bug: false, - collected_count_expr_map: HashMap::new(), - pull_up_having_expr: None, - } - } - - /// Set if we need to handle [the Count bug] during the pull up process - /// - /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 - pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self { - self.need_handle_count_bug = need_handle_count_bug; - self - } - - /// Set the in_predicate_opt - pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option) -> Self { - self.in_predicate_opt = in_predicate_opt; - self - } +#[derive(Debug)] +struct Operator { + id: usize, + plan: LogicalPlan, + parent: Option, + // children: Vec>>, + accesses: HashSet, + provides: HashSet, +} - /// Set if this is an Exists(Not Exists) SubQuery - pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self { - self.exists_sub_query = exists_sub_query; - self +impl GeneralDecorrelation { + fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { + plan.visit(self)?; + Ok(()) } } -/// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join -/// This is used to handle [the Count bug] -/// -/// [the Count bug]: https://github.com/apache/datafusion/pull/10500 -pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; - -/// Mapping from expr display name to its evaluation result on empty record -/// batch (for example: 'count(*)' is 'ScalarValue(0)', 'count(*) + 2' is -/// 'ScalarValue(2)') -pub type ExprResultMap = HashMap; - -impl TreeNodeRewriter for GeneralPullUpCorrelatedExpr { +impl TreeNodeVisitor<'_> for GeneralDecorrelation { type Node = LogicalPlan; - - fn f_down(&mut self, plan: LogicalPlan) -> Result> { - match plan { - LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), - LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { - let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); - println!("plan hold outer and contains union"); - if plan_hold_outer { - // the unsupported case - self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) - } else { - Ok(Transformed::no(plan)) - } - } - LogicalPlan::Limit(_) => { - let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); - match (self.exists_sub_query, plan_hold_outer) { - (false, true) => { - // the unsupported case - println!("plan has limit and no subquery found and plan hold outer ref"); - self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) - } - _ => Ok(Transformed::no(plan)), - } - } - _ if plan.contains_outer_reference() => { - println!("plan contains outer reference, cannot pull up"); - // the unsupported cases, the plan expressions contain out reference columns(like window expressions) - self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) - } - _ => Ok(Transformed::no(plan)), - } + fn f_down(&mut self, node: &LogicalPlan) -> Result { + self.stack.push(node.clone()); + println!("+++node {:?}", node); + // for each node, find which column it is accessing, which column it is providing + // Set of columns current node access + let (accesses, provides): (HashSet, HashSet) = match node { + LogicalPlan::Filter(f) => ( + HashSet::new(), + f.predicate + .column_refs() + .into_iter() + .map(|r| r.to_owned()) + .collect(), + ), + LogicalPlan::TableScan(tbl_scan) => { + let provided_columns: HashSet = + tbl_scan.projected_schema.columns().into_iter().collect(); + (provided_columns, HashSet::new()) + } + LogicalPlan::Aggregate(_) => (HashSet::new(), HashSet::new()), + LogicalPlan::EmptyRelation(_) => (HashSet::new(), HashSet::new()), + LogicalPlan::Limit(_) => (HashSet::new(), HashSet::new()), + LogicalPlan::Subquery(_) => (HashSet::new(), HashSet::new()), + _ => { + return internal_err!("impl scan for node type {:?}", node); + } + }; + + let parent = if self.stack.is_empty() { + None + } else { + Some(self.stack.last().unwrap().to_owned()) + }; + self.nodes.insert( + node.clone(), + Operator { + id: self.current_id, + parent, + plan: node.clone(), + accesses, + provides, + }, + ); + // let operator = match self.nodes.entry(node.clone()) { + // Entry::Occupied(entry) => entry.into_mut(), + // Entry::Vacant(entry) => { + // let parent = if self.stack.len() == 0 { + // None + // } else { + // Some(self.stack.last().unwrap().to_owned()) + // }; + // entry.insert(Operator { + // id: self.current_id, + // parent, + // plan: node.clone(), + // accesses, + // provides, + // }) + // } + // }; + + Ok(TreeNodeRecursion::Continue) } - fn f_up(&mut self, plan: LogicalPlan) -> Result> { - let subquery_schema = plan.schema(); - println!("XXXXXXXXXXXXX Plan type {}", plan.display()); - match &plan { - // TODO: what if this happen recursively? - // select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid - // and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) - LogicalPlan::Filter(plan_filter) => { - let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); - let or_filters = split_disjunction(&plan_filter.predicate); - or_filters.iter().for_each(|f| { - println!("or filter {}", f); - }); - self.can_pull_over_aggregation = self.can_pull_over_aggregation - && subquery_filter_exprs - .iter() - .filter(|e| e.contains_outer()) - .all(|&e| { - let ret = can_pullup_over_aggregation(e); - if !ret { - println!("can NOT pull up over aggregation {:?}", e); - } - ret - }); - - let (mut join_filters, subquery_filters) = - find_join_exprs(subquery_filter_exprs)?; - - if let Some(in_predicate) = &self.in_predicate_opt { - // in_predicate may be already included in the join filters, remove it from the join filters first. - join_filters = remove_duplicated_filter(join_filters, in_predicate); - } - println!("JOIN FILTERS"); - for expr in join_filters.iter() { - println!("{}", expr); - } - - // TODO: these cols only include the inner's table columns which is not sufficient - // in the case of complex unnest - // - // We need to collect all the columns in the outer table to construct the domain - // and join the domain with the inner table to prepare for the aggregation - let correlated_subquery_cols = - collect_subquery_cols(&join_filters, subquery_schema)?; - println!("CORRELATED COLUMS"); - for col in correlated_subquery_cols.iter() { - println!("{}", col); - } - // TODO: these join filters may need to be transformed, because now the join - // happen between the outer table columns and the newly built relation - for expr in join_filters { - if !self.join_filters.contains(&expr) { - self.join_filters.push(expr) - } - } - - let mut expr_result_map_for_count_bug = HashMap::new(); - let pull_up_expr_opt = if let Some(expr_result_map) = - self.collected_count_expr_map.get(plan_filter.input.deref()) - { - if let Some(expr) = conjunction(subquery_filters.clone()) { - filter_exprs_evaluation_result_on_empty_batch( - &expr, - Arc::clone(plan_filter.input.schema()), - expr_result_map, - &mut expr_result_map_for_count_bug, - )? - } else { - None - } - } else { - None - }; - - match (&pull_up_expr_opt, &self.pull_up_having_expr) { - (Some(_), Some(_)) => { - // Error path - plan_err!("Unsupported Subquery plan") - } - (Some(_), None) => { - self.pull_up_having_expr = pull_up_expr_opt; - let new_plan = - LogicalPlanBuilder::from((*plan_filter.input).clone()) - .build()?; - self.correlated_subquery_cols_map - .insert(new_plan.clone(), correlated_subquery_cols); - Ok(Transformed::yes(new_plan)) - } - (None, _) => { - // if the subquery still has filter expressions, restore them. - let mut plan = - LogicalPlanBuilder::from((*plan_filter.input).clone()); - if let Some(expr) = conjunction(subquery_filters) { - plan = plan.filter(expr)? - } - let new_plan = plan.build()?; - self.correlated_subquery_cols_map - .insert(new_plan.clone(), correlated_subquery_cols); - Ok(Transformed::yes(new_plan)) - } - } - } - LogicalPlan::Projection(projection) - if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => - { - let mut local_correlated_cols = BTreeSet::new(); - collect_local_correlated_cols( - &plan, - &self.correlated_subquery_cols_map, - &mut local_correlated_cols, - ); - // add missing columns to Projection - let mut missing_exprs = - self.collect_missing_exprs(&projection.expr, &local_correlated_cols)?; - - let mut expr_result_map_for_count_bug = HashMap::new(); - if let Some(expr_result_map) = - self.collected_count_expr_map.get(projection.input.deref()) - { - proj_exprs_evaluation_result_on_empty_batch( - &projection.expr, - projection.input.schema(), - expr_result_map, - &mut expr_result_map_for_count_bug, - )?; - if !expr_result_map_for_count_bug.is_empty() { - // has count bug - let un_matched_row = Expr::Column(Column::new_unqualified( - UN_MATCHED_ROW_INDICATOR.to_string(), - )); - // add the unmatched rows indicator to the Projection expressions - missing_exprs.push(un_matched_row); - } - } - - let new_plan = LogicalPlanBuilder::from((*projection.input).clone()) - .project(missing_exprs)? - .build()?; - if !expr_result_map_for_count_bug.is_empty() { - self.collected_count_expr_map - .insert(new_plan.clone(), expr_result_map_for_count_bug); - } - Ok(Transformed::yes(new_plan)) - } - LogicalPlan::Aggregate(aggregate) - if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => - { - // If the aggregation is from a distinct it will not change the result for - // exists/in subqueries so we can still pull up all predicates. - let is_distinct = aggregate.aggr_expr.is_empty(); - if !is_distinct { - println!( - "can pull up {:?} and can pull over aggregation {:?}", - self.can_pull_up, self.can_pull_over_aggregation - ); - self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; - } - let mut local_correlated_cols = BTreeSet::new(); - collect_local_correlated_cols( - &plan, - &self.correlated_subquery_cols_map, - &mut local_correlated_cols, - ); - // add missing columns to Aggregation's group expressions - let mut missing_exprs = self.collect_missing_exprs( - &aggregate.group_expr, - &local_correlated_cols, - )?; - - // if the original group expressions are empty, need to handle the Count bug - let mut expr_result_map_for_count_bug = HashMap::new(); - if self.need_handle_count_bug - && aggregate.group_expr.is_empty() - && !missing_exprs.is_empty() - { - agg_exprs_evaluation_result_on_empty_batch( - &aggregate.aggr_expr, - aggregate.input.schema(), - &mut expr_result_map_for_count_bug, - )?; - if !expr_result_map_for_count_bug.is_empty() { - // has count bug - let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); - // add the unmatched rows indicator to the Aggregation's group expressions - missing_exprs.push(un_matched_row); - } - } - let new_plan = LogicalPlanBuilder::from((*aggregate.input).clone()) - .aggregate(missing_exprs, aggregate.aggr_expr.to_vec())? - .build()?; - if !expr_result_map_for_count_bug.is_empty() { - self.collected_count_expr_map - .insert(new_plan.clone(), expr_result_map_for_count_bug); - } - Ok(Transformed::yes(new_plan)) - } - LogicalPlan::SubqueryAlias(alias) => { - let mut local_correlated_cols = BTreeSet::new(); - collect_local_correlated_cols( - &plan, - &self.correlated_subquery_cols_map, - &mut local_correlated_cols, - ); - let mut new_correlated_cols = BTreeSet::new(); - for col in local_correlated_cols.iter() { - new_correlated_cols - .insert(Column::new(Some(alias.alias.clone()), col.name.clone())); - } - self.correlated_subquery_cols_map - .insert(plan.clone(), new_correlated_cols); - if let Some(input_map) = - self.collected_count_expr_map.get(alias.input.deref()) - { - self.collected_count_expr_map - .insert(plan.clone(), input_map.clone()); - } - Ok(Transformed::no(plan)) - } - LogicalPlan::Limit(limit) => { - let input_expr_map = self - .collected_count_expr_map - .get(limit.input.deref()) - .cloned(); - // handling the limit clause in the subquery - let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) - { - // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes(match limit.get_fetch_type()? { - FetchType::Literal(Some(0)) => { - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: Arc::clone(limit.input.schema()), - }) - } - _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, - }), - _ => Transformed::no(plan), - }; - if let Some(input_map) = input_expr_map { - self.collected_count_expr_map - .insert(new_plan.data.clone(), input_map); - } - Ok(new_plan) - } - _ => Ok(Transformed::no(plan)), - } + /// Invoked while traversing up the tree after children are visited. Default + /// implementation continues the recursion. + fn f_up(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) } } -impl GeneralPullUpCorrelatedExpr { - fn collect_missing_exprs( +impl OptimizerRule for GeneralDecorrelation { + fn supports_rewrite(&self) -> bool { + true + } + fn rewrite( &self, - exprs: &[Expr], - correlated_subquery_cols: &BTreeSet, - ) -> Result> { - let mut missing_exprs = vec![]; - for expr in exprs { - if !missing_exprs.contains(expr) { - missing_exprs.push(expr.clone()) - } - } - for col in correlated_subquery_cols.iter() { - let col_expr = Expr::Column(col.clone()); - if !missing_exprs.contains(&col_expr) { - missing_exprs.push(col_expr) - } - } - if let Some(pull_up_having) = &self.pull_up_having_expr { - let filter_apply_columns = pull_up_having.column_refs(); - for col in filter_apply_columns { - // add to missing_exprs if not already there - let contains = missing_exprs - .iter() - .any(|expr| matches!(expr, Expr::Column(c) if c == col)); - if !contains { - missing_exprs.push(Expr::Column(col.clone())) - } - } - } - Ok(missing_exprs) + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("todo") } -} -/// for now only simple exprs can be pulled up over aggregation -/// such as binaryExpr between a outer column ref vs non column expr -/// In the general unnesting framework, the complex expr is pulled up, but being decomposed in some way -/// for example: -/// select * from exams e1 where score > (select avg(score) from exams e2 where e1.student_id = e2.student_id -/// or (e2.year > e1.year and e2.subject=e1.subject)) -/// In this case, the complex expr to be pulled up is -/// ``` -/// e1.student_id=e1.student_id or (e2.year > e1.year and e2.subject=e1.subject) -/// ``` -/// The complex expr is decomposed during the pull up over aggregation avg(score) -/// into a new relation -fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr - { - match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => - { - !right.any_column_refs() - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => - { - !left.any_column_refs() - } - (_, _) => false, - } - } else { - false + fn name(&self) -> &str { + "decorrelate_subquery" } -} -fn collect_local_correlated_cols( - plan: &LogicalPlan, - all_cols_map: &HashMap>, - local_cols: &mut BTreeSet, -) { - for child in plan.inputs() { - if let Some(cols) = all_cols_map.get(child) { - local_cols.extend(cols.clone()); - } - // SubqueryAlias is treated as the leaf node - if !matches!(child, LogicalPlan::SubqueryAlias(_)) { - collect_local_correlated_cols(child, all_cols_map, local_cols); - } + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) } } -fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec { - filters - .into_iter() - .filter(|filter| { - if filter == in_predicate { - return false; - } - - // ignore the binary order - !match (filter, in_predicate) { - (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { - (a_expr.op == b_expr.op) - && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) - || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) - } - _ => false, - } - }) - .collect::>() -} +#[cfg(test)] +mod tests { + use std::sync::Arc; -fn agg_exprs_evaluation_result_on_empty_batch( - agg_expr: &[Expr], - schema: &DFSchemaRef, - expr_result_map_for_count_bug: &mut ExprResultMap, -) -> Result<()> { - for e in agg_expr.iter() { - let result_expr = e - .clone() - .transform_up(|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { - if func.name() == "count" { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } - } - _ => Transformed::no(expr), - }; - Ok(new_expr) - }) - .data()?; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{ + expr_fn::{self, col}, + lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, EmptyRelation, + Expr, LogicalPlan, LogicalPlanBuilder, + }; + use datafusion_functions_aggregate::sum::sum; + use regex_syntax::ast::LiteralKind; - let result_expr = result_expr.unalias(); - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); - let simplifier = ExprSimplifier::new(info); - let result_expr = simplifier.simplify(result_expr)?; - if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { - expr_result_map_for_count_bug - .insert(e.schema_name().to_string(), result_expr); - } - } - Ok(()) -} + use crate::test::{test_table_scan, test_table_scan_with_name}; -fn proj_exprs_evaluation_result_on_empty_batch( - proj_expr: &[Expr], - schema: &DFSchemaRef, - input_expr_result_map_for_count_bug: &ExprResultMap, - expr_result_map_for_count_bug: &mut ExprResultMap, -) -> Result<()> { - for expr in proj_expr.iter() { - let result_expr = expr - .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(name) - { - Ok(Transformed::yes(result_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; + use super::GeneralDecorrelation; + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, + }; - if result_expr.ne(expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); - let simplifier = ExprSimplifier::new(info); - let result_expr = simplifier.simplify(result_expr)?; - let expr_name = match expr { - Expr::Alias(Alias { name, .. }) => name.to_string(), - Expr::Column(Column { - relation: _, - name, - spans: _, - }) => name.to_string(), - _ => expr.schema_name().to_string(), - }; - expr_result_map_for_count_bug.insert(expr_name, result_expr); - } + #[test] + fn todo() -> Result<()> { + let mut a = GeneralDecorrelation::default(); + + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table = test_table_scan_with_name("inner_table")?; + let sq = Arc::new( + LogicalPlanBuilder::from(inner_table) + .filter( + col("inner_table.a") + .eq(out_ref_col(ArrowDataType::UInt64, "outer_table.a")), + )? + .aggregate(Vec::::new(), vec![sum(col("inner_table.b"))])? + .project(vec![sum(col("inner_table.b"))])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)))? + .filter(col("inner_table.b").gt(scalar_subquery(sq)))? + .build()?; + a.build_algebra_index(input1.clone())?; + + // let input2 = LogicalPlanBuilder::from(input.clone()) + // .filter(col("int_col").gt(lit(1)))? + // .project(vec![col("string_col")])? + // .build()?; + + // let mut b = GeneralDecorrelation::default(); + // b.build_algebra_index(input2)?; + + Ok(()) } - Ok(()) -} - -fn filter_exprs_evaluation_result_on_empty_batch( - filter_expr: &Expr, - schema: DFSchemaRef, - input_expr_result_map_for_count_bug: &ExprResultMap, - expr_result_map_for_count_bug: &mut ExprResultMap, -) -> Result> { - let result_expr = filter_expr - .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::yes(result_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - } else { - Ok(Transformed::no(expr)) - } - }) - .data()?; - - let pull_up_expr = if result_expr.ne(filter_expr) { - let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(schema); - let simplifier = ExprSimplifier::new(info); - let result_expr = simplifier.simplify(result_expr)?; - match &result_expr { - // evaluate to false or null on empty batch, no need to pull up - Expr::Literal(ScalarValue::Null) - | Expr::Literal(ScalarValue::Boolean(Some(false))) => None, - // evaluate to true on empty batch, need to pull up the expr - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - for (name, exprs) in input_expr_result_map_for_count_bug { - expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); - } - Some(filter_expr.clone()) - } - // can not evaluate statically - _ => { - for input_expr in input_expr_result_map_for_count_bug.values() { - let new_expr = Expr::Case(expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(result_expr.clone()), - Box::new(input_expr.clone()), - )], - else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), - }); - let expr_key = new_expr.schema_name().to_string(); - expr_result_map_for_count_bug.insert(expr_key, new_expr); - } - None - } - } - } else { - for (name, exprs) in input_expr_result_map_for_count_bug { - expr_result_map_for_count_bug.insert(name.clone(), exprs.clone()); - } - None - }; - Ok(pull_up_expr) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index c19fa0364585..e3a1cf93b653 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -21,7 +21,6 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; -use crate::decorrelate_general::GeneralPullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; use crate::utils::{evaluates_to_null, replace_qualified_name}; use crate::{OptimizerConfig, OptimizerRule}; @@ -29,7 +28,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::analyzer::type_coercion::TypeCoercionRewriter; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, + TreeNodeRewriter, }; use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -74,7 +74,6 @@ impl OptimizerRule for ScalarSubqueryToJoin { fn supports_rewrite(&self) -> bool { true } - fn rewrite( &self, plan: LogicalPlan, @@ -88,6 +87,8 @@ impl OptimizerRule for ScalarSubqueryToJoin { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } + // reWriteExpr is all the filter in the subquery that is irrelevant to the subquery execution + // i.e where outer=some col, or outer + binary operator with some aggregated value let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), @@ -289,26 +290,25 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// /// # Arguments /// -/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) +/// * `subquery` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) -/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `subquery_alias` - Subquery aliases +/// # Returns +/// * an optimize subquery if any +/// * a map of original count expr to a transformed expr (a hacky way to handle count bug) fn build_join( subquery: &Subquery, filter_input: &LogicalPlan, subquery_alias: &str, ) -> Result)>> { let subquery_plan = subquery.subquery.as_ref(); - let mut pull_up = GeneralPullUpCorrelatedExpr::new().with_need_handle_count_bug(true); + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } - println!("before rewrite: {}", subquery_plan); - println!("ater rewrite: {}", new_plan); - let collected_count_expr_map = pull_up.collected_count_expr_map.get(&new_plan).cloned(); let sub_query_alias = LogicalPlanBuilder::from(new_plan) @@ -320,11 +320,6 @@ fn build_join( .correlated_subquery_cols_map .values() .for_each(|cols| all_correlated_cols.extend(cols.clone())); - println!("========\ncorrelated cols"); - for col in &all_correlated_cols { - println!("{}", col); - } - println!("===================="); // alias the join filter let join_filter_opt = @@ -353,7 +348,6 @@ fn build_join( } } } else { - println!("++++++++++++++++filter input: {}", filter_input); // left join if correlated, grouping by the join keys so we don't change row count LogicalPlanBuilder::from(filter_input.clone()) .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt index 36bf75072759..d56f2a210d64 100644 --- a/datafusion/sqllogictest/test_files/debug.slt +++ b/datafusion/sqllogictest/test_files/debug.slt @@ -24,3 +24,38 @@ AS VALUES (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) ; +## Multi-level correlated subquery +##query TT +##explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid +##and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) +##---- + +# query TT +#explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid +# and e2.sid='some fixed value 1' +# or e2.sid='some fixed value 2' +#) +# ---- + + +## select * from exams e1, ( +## select avg(score) as avg_score, e2.sid, e2.year,e2.subject from exams e2 group by e2.sid,e2.year,e2.subject +## ) as pulled_up where e1.score > pulled_up.avg_score + +query TT +explain select s.name, ( + select count(e2.grade) as c from exams e2 + having c > 10 +) from students s +---- + +## query TT +## explain select s.name, e.curriculum from students s, exams e where s.id=e.sid +## and s.major='math' and 0 < ( +## select count(e2.grade) from exams e2 where s.id=e2.sid and e2.grade>0 +## having count(e2.grade) < 10 +## -- or (s.year1) from t1 +---- +logical_plan +01)Projection: t1.t1_id, __scalar_sq_1.cnt_plus_2 AS cnt_plus_2 +02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----SubqueryAlias: __scalar_sq_1 +05)------Projection: count(Int64(1)) AS count(*) + Int64(2) AS cnt_plus_2, t2.t2_int +06)--------Filter: count(Int64(1)) > Int64(1) +07)----------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] +08)------------TableScan: t2 projection=[t2_int] + + +query TT +explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 +---- +logical_plan +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----SubqueryAlias: __scalar_sq_1 +05)------Projection: count(Int64(1)) + Int64(2) AS cnt_plus_2, t2.t2_int, count(Int64(1)), Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] +07)----------TableScan: t2 projection=[t2_int] + +query TT +explain select t1.t1_int from t1 where (select cnt from (select count(*) as cnt, sum(t2_int) from t2 where t1.t1_int = t2.t2_int)) = 0 +---- +logical_plan +01)Projection: t1.t1_int +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.cnt END = Int64(0) +03)----Projection: t1.t1_int, __scalar_sq_1.cnt, __scalar_sq_1.__always_true +04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int +05)--------TableScan: t1 projection=[t1_int] +06)--------SubqueryAlias: __scalar_sq_1 +07)----------Projection: count(Int64(1)) AS cnt, t2.t2_int, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] +09)--------------TableScan: t2 projection=[t2_int] + diff --git a/datafusion/sqllogictest/test_files/unsupported.slt b/datafusion/sqllogictest/test_files/unsupported.slt index 101a3ecd4442..b4c581d332e0 100644 --- a/datafusion/sqllogictest/test_files/unsupported.slt +++ b/datafusion/sqllogictest/test_files/unsupported.slt @@ -3,12 +3,12 @@ CREATE TABLE students( id int, name varchar, major varchar, - year int + year timestamp ) AS VALUES - (1,'toai','math',2014), - (2,'manh','math',2015), - (3,'bao','math',2025) + (1,'A','math','2014-01-01T00:00:00'::timestamp), + (2,'B','math','2015-01-01T00:00:00'::timestamp), + (3,'C','math','2016-01-01T00:00:00'::timestamp) ; statement ok @@ -16,12 +16,12 @@ CREATE TABLE exams( sid int, curriculum varchar, grade int, - date int + date timestamp ) AS VALUES - (1, 'math', 10, 2014), - (2, 'math', 9, 2015), - (3, 'math', 4, 2025) + (1, 'math', 10, '2014-01-01T00:00:00'::timestamp), + (2, 'math', 9, '2015-01-01T00:00:00'::timestamp), + (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) ; -- explain select s.name, e.curriculum from students s, exams e where s.id=e.sid From ace332e16604ef400be3487020e98647666575ac Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 27 Apr 2025 15:24:18 +0200 Subject: [PATCH 05/70] chore: some work on indexed algebra --- datafusion/expr/src/expr.rs | 19 + .../optimizer/src/decorrelate_general.rs | 342 +++++++++++++++--- 2 files changed, 319 insertions(+), 42 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9f6855b69824..f11fea405b00 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1655,6 +1655,25 @@ impl Expr { using_columns } + pub fn outer_column_refs(&self) -> HashSet<&Column> { + let mut using_columns = HashSet::new(); + self.add_outer_column_refs(&mut using_columns); + using_columns + } + + /// Adds references to all outer columns in this expression to the set + /// + /// See [`Self::column_refs`] for details + pub fn add_outer_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { + self.apply(|expr| { + if let Expr::OuterReferenceColumn(_, col) = expr { + set.insert(col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + /// Adds references to all columns in this expression to the set /// /// See [`Self::column_refs`] for details diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 42f7f09aae0d..3d001008d9b8 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -19,6 +19,7 @@ use std::cell::RefCell; use std::collections::{BTreeSet, HashSet}; +use std::fmt; use std::ops::Deref; use std::rc::{Rc, Weak}; use std::sync::Arc; @@ -32,16 +33,134 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, Column, Result}; use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; use indexmap::IndexMap; +use log::Log; -#[derive(Debug)] pub struct GeneralDecorrelation { - root: Option, + root: Option, current_id: usize, nodes: IndexMap, // column_ + // TODO: use a different identifier for a node, instead of the whole logical plan obj stack: Vec, } +impl fmt::Debug for GeneralDecorrelation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "GeneralDecorrelation Tree:")?; + if let Some(root_op) = &self.root { + self.fmt_operator(f, root_op, 0, false)?; + } else { + writeln!(f, " ")?; + } + Ok(()) + } +} + +impl GeneralDecorrelation { + fn fmt_operator( + &self, + f: &mut fmt::Formatter<'_>, + lp: &LogicalPlan, + indent: usize, + is_last: bool, + ) -> fmt::Result { + // Find the LogicalPlan corresponding to this Operator + let op = self.nodes.get(lp).unwrap(); + + for i in 0..indent { + if i + 1 == indent { + if is_last { + write!(f, " ")?; // if last child, no vertical line + } else { + write!(f, "| ")?; // vertical line continues + } + } else { + write!(f, "| ")?; + } + } + if indent > 0 { + write!(f, "|--- ")?; // branch + } + + let unparsed_sql = match Unparser::default().plan_to_sql(lp) { + Ok(str) => str.to_string(), + Err(_) => "".to_string(), + }; + writeln!(f, "\x1b[33m{}\x1b[0m", lp.display())?; + if !unparsed_sql.is_empty() { + for i in 0..=indent { + if i < indent { + write!(f, "| ")?; + } else if indent > 0 { + write!(f, "| ")?; // Align with LogicalPlan text + } + } + + writeln!(f, "{}", unparsed_sql)?; + } + + for i in 0..=indent { + if i < indent { + write!(f, "| ")?; + } else if indent > 0 { + write!(f, "| ")?; // Align with LogicalPlan text + } + } + + let access_string = op + .accesses + .iter() + .map(|c| c.debug()) + .collect::>() + .join(", "); + let provide_string = op + .provides + .iter() + .map(|c| c.debug()) + .collect::>() + .join(", "); + // Now print the Operator details + writeln!( + f, + "accesses: {}, provides: {}", + access_string, provide_string, + )?; + let len = op.children.len(); + + // Recursively print children if Operator has children + for (i, child) in op.children.iter().enumerate() { + let last = i + 1 == len; + + self.fmt_operator(f, child, indent + 1, last)?; + } + + Ok(()) + } + + fn update_ancestor_node_accesses(&mut self, col: &Column) { + // iter from bottom to top, the goal is to find the LCA only + for node in self.stack.iter().rev() { + let operator = self.nodes.get_mut(node).unwrap(); + let to_insert = ColumnUsage::Outer(col.clone()); + // This is the LCA between the current node and the outer column provider + if operator.accesses.contains(&to_insert) { + return; + } + operator.accesses.insert(to_insert); + } + } + fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { + println!("======================================begin"); + plan.visit_with_subqueries(self)?; + println!("======================================end"); + Ok(()) + } + fn update_children(&mut self, parent: &LogicalPlan, child: &LogicalPlan) { + let operator = self.nodes.get_mut(parent).unwrap(); + operator.children.push(child.clone()); + } +} impl Default for GeneralDecorrelation { fn default() -> Self { @@ -54,58 +173,182 @@ impl Default for GeneralDecorrelation { } } +#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] +enum ColumnUsage { + Own(Column), + Outer(Column), +} +impl ColumnUsage { + fn debug(&self) -> String { + match self { + ColumnUsage::Own(col) => format!("\x1b[34m{}\x1b[0m", col.flat_name()), + ColumnUsage::Outer(col) => format!("\x1b[31m{}\x1b[0m", col.flat_name()), + } + } +} #[derive(Debug)] struct Operator { id: usize, plan: LogicalPlan, parent: Option, // children: Vec>>, - accesses: HashSet, - provides: HashSet, + // Note if the current node is a Subquery + // at the first time this node is visited, + // the set of accesses columns are not sufficient + // (i.e) some where deep down the ast another recursive subquery + // exists and also referencing some columns belongs to the outer part + // of the subquery + // Thus, on discovery of new subquery, we must + // add the accesses columns to the ancestor nodes which are Subquery + accesses: HashSet, + provides: HashSet, + + // for now only care about filter/projection with one of the expr is subquery + is_dependent_join_node: bool, + children: Vec, } -impl GeneralDecorrelation { - fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { - plan.visit(self)?; - Ok(()) - } +fn contains_subquery(expr: &Expr) -> bool { + expr.exists(|expr| { + Ok(matches!( + expr, + Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists(_) + )) + }) + .expect("Inner is always Ok") +} + +// struct ExtractScalarSubQuery<'a> { +// sub_query_info: Vec<(Subquery, String)>, +// in_sub_query_info: Vec<(InSubquery, String)>, +// alias_gen: &'a Arc, +// } + +// impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { +// type Node = Expr; + +// fn f_down(&mut self, expr: Expr) -> Result> { +// match expr { +// Expr::InSubquery(in_subquery) => {} +// Expr::ScalarSubquery(subquery) => { +// let subqry_alias = self.alias_gen.next("__scalar_sq"); +// self.sub_query_info +// .push((subquery.clone(), subqry_alias.clone())); +// let scalar_expr = subquery +// .subquery +// .head_output_expr()? +// .map_or(plan_err!("single expression required."), Ok)?; +// Ok(Transformed::new( +// Expr::Column(create_col_from_scalar_expr( +// &scalar_expr, +// subqry_alias, +// )?), +// true, +// TreeNodeRecursion::Jump, +// )) +// } +// _ => Ok(Transformed::no(expr)), +// } +// } +// } + +fn print(a: &Expr) -> Result<()> { + let unparser = Unparser::default(); + let round_trip_sql = unparser.expr_to_sql(a)?.to_string(); + println!("{}", round_trip_sql); + Ok(()) } impl TreeNodeVisitor<'_> for GeneralDecorrelation { type Node = LogicalPlan; fn f_down(&mut self, node: &LogicalPlan) -> Result { - self.stack.push(node.clone()); - println!("+++node {:?}", node); + if self.root.is_none() { + self.root = Some(node.clone()); + } + let mut is_dependent_join_node = false; + println!("{}\nnode {}", "----".repeat(self.stack.len()), node); // for each node, find which column it is accessing, which column it is providing // Set of columns current node access - let (accesses, provides): (HashSet, HashSet) = match node { - LogicalPlan::Filter(f) => ( - HashSet::new(), - f.predicate - .column_refs() - .into_iter() - .map(|r| r.to_owned()) - .collect(), - ), - LogicalPlan::TableScan(tbl_scan) => { - let provided_columns: HashSet = - tbl_scan.projected_schema.columns().into_iter().collect(); - (provided_columns, HashSet::new()) - } - LogicalPlan::Aggregate(_) => (HashSet::new(), HashSet::new()), - LogicalPlan::EmptyRelation(_) => (HashSet::new(), HashSet::new()), - LogicalPlan::Limit(_) => (HashSet::new(), HashSet::new()), - LogicalPlan::Subquery(_) => (HashSet::new(), HashSet::new()), - _ => { - return internal_err!("impl scan for node type {:?}", node); - } - }; + let (accesses, provides): (HashSet, HashSet) = + match node { + LogicalPlan::Filter(f) => { + if contains_subquery(&f.predicate) { + is_dependent_join_node = true; + print(&f.predicate); + } + let mut outer_col_refs: HashSet = f + .predicate + .outer_column_refs() + .into_iter() + .map(|f| { + self.update_ancestor_node_accesses(f); + ColumnUsage::Outer(f.clone()) + }) + .collect(); + + outer_col_refs.extend( + f.predicate + .column_refs() + .into_iter() + .map(|f| ColumnUsage::Own(f.clone())), + ); + (outer_col_refs, HashSet::new()) + } + LogicalPlan::TableScan(tbl_scan) => { + let provided_columns: HashSet = tbl_scan + .projected_schema + .columns() + .into_iter() + .map(|col| ColumnUsage::Own(col)) + .collect(); + (HashSet::new(), provided_columns) + } + LogicalPlan::Aggregate(_) => (HashSet::new(), HashSet::new()), + LogicalPlan::EmptyRelation(_) => (HashSet::new(), HashSet::new()), + LogicalPlan::Limit(_) => (HashSet::new(), HashSet::new()), + // TODO + // 1.handle subquery inside projection + // 2.projection also provide some new columns + // 3.if within projection exists multiple subquery, how does this work + LogicalPlan::Projection(proj) => { + for expr in &proj.expr { + if contains_subquery(expr) { + is_dependent_join_node = true; + } + } + // proj.expr + // TODO: fix me + (HashSet::new(), HashSet::new()) + } + LogicalPlan::Subquery(subquery) => { + // TODO: once we detect the subquery + let accessed = subquery + .outer_ref_columns + .iter() + .filter_map(|f| match f { + Expr::Column(col) => Some(ColumnUsage::Outer(col.clone())), + Expr::OuterReferenceColumn(_, col) => { + Some(ColumnUsage::Outer(col.clone())) + } + _ => None, + }) + .collect(); + (accessed, HashSet::new()) + } + _ => { + return internal_err!("impl scan for node type {:?}", node); + } + }; let parent = if self.stack.is_empty() { None } else { + let previous_node = self.stack.last().unwrap().to_owned(); + self.update_children(&previous_node, node); Some(self.stack.last().unwrap().to_owned()) }; + + self.stack.push(node.clone()); self.nodes.insert( node.clone(), Operator { @@ -114,6 +357,8 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { plan: node.clone(), accesses, provides, + is_dependent_join_node, + children: vec![], }, ); // let operator = match self.nodes.entry(node.clone()) { @@ -140,6 +385,7 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { /// Invoked while traversing up the tree after children are visited. Default /// implementation continues the recursion. fn f_up(&mut self, _node: &Self::Node) -> Result { + self.stack.pop(); Ok(TreeNodeRecursion::Continue) } } @@ -175,7 +421,7 @@ mod tests { lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, }; - use datafusion_functions_aggregate::sum::sum; + use datafusion_functions_aggregate::{count::count, sum::sum}; use regex_syntax::ast::LiteralKind; use crate::test::{test_table_scan, test_table_scan_with_name}; @@ -191,23 +437,35 @@ mod tests { let mut a = GeneralDecorrelation::default(); let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table = test_table_scan_with_name("inner_table")?; - let sq = Arc::new( - LogicalPlanBuilder::from(inner_table) + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") + .eq(col("inner_table_lv2.b")), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) .filter( - col("inner_table.a") - .eq(out_ref_col(ArrowDataType::UInt64, "outer_table.a")), + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), )? - .aggregate(Vec::::new(), vec![sum(col("inner_table.b"))])? - .project(vec![sum(col("inner_table.b"))])? + .filter(scalar_subquery(sq_level2).gt(lit(5)))? + .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? + .project(vec![sum(col("inner_table_lv1.b"))])? .build()?, ); let input1 = LogicalPlanBuilder::from(outer_table.clone()) .filter(col("outer_table.a").gt(lit(1)))? - .filter(col("inner_table.b").gt(scalar_subquery(sq)))? + .filter(col("outer_table.b").gt(scalar_subquery(sq_level1)))? .build()?; a.build_algebra_index(input1.clone())?; + println!("{:?}", a); // let input2 = LogicalPlanBuilder::from(input.clone()) // .filter(col("int_col").gt(lit(1)))? From da8980c445ed6bbe23d7a593815cfffc154993fa Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 4 May 2025 09:11:25 +0200 Subject: [PATCH 06/70] chore: more progress --- .../optimizer/src/decorrelate_general.rs | 794 +++++++++++++++--- 1 file changed, 672 insertions(+), 122 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 3d001008d9b8..c1a0050c702d 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -31,25 +31,407 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{internal_err, Column, Result}; -use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_common::{internal_err, not_impl_err, Column, Result}; +use datafusion_expr::{binary_expr, Expr, JoinType, LogicalPlan}; use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; use indexmap::IndexMap; +use itertools::Itertools; use log::Log; -pub struct GeneralDecorrelation { - root: Option, +pub struct AlgebraIndex { + root: Option, current_id: usize, - nodes: IndexMap, // column_ + nodes: IndexMap, // column_ // TODO: use a different identifier for a node, instead of the whole logical plan obj - stack: Vec, + stack: Vec, + accessed_columns: IndexMap>, } -impl fmt::Debug for GeneralDecorrelation { + +#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] +struct ColumnAccess { + stack: Vec, + node_id: usize, + col: Column, +} +// pub struct GeneralDecorrelation { +// index: AlgebraIndex, +// } + +// data structure to store equivalent columns +// Expr is used to represent either own column or outer referencing columns +#[derive(Clone)] +pub struct UnionFind { + parent: IndexMap, + rank: IndexMap, +} + +impl UnionFind { + pub fn new() -> Self { + Self { + parent: IndexMap::new(), + rank: IndexMap::new(), + } + } + + pub fn find(&mut self, x: Expr) -> Expr { + let p = self.parent.get(&x).cloned(); + match p { + None => { + self.parent.insert(x.clone(), x.clone()); + self.rank.insert(x.clone(), 0); + x + } + Some(parent) => { + if parent == x { + x + } else { + let root = self.find(parent.clone()); + self.parent.insert(x, root.clone()); + root + } + } + } + } + + pub fn union(&mut self, x: Expr, y: Expr) -> bool { + let root_x = self.find(x.clone()); + let root_y = self.find(y.clone()); + if root_x == root_y { + return false; + } + + let rank_x = *self.rank.get(&root_x).unwrap_or(&0); + let rank_y = *self.rank.get(&root_y).unwrap_or(&0); + + if rank_x < rank_y { + self.parent.insert(root_x, root_y); + } else if rank_x > rank_y { + self.parent.insert(root_y, root_x); + } else { + // asign y as children of x + self.parent.insert(root_y.clone(), root_x.clone()); + *self.rank.entry(root_x).or_insert(0) += 1; + } + + true + } +} +// TODO: impl me +#[derive(Clone)] +struct DependentJoin { + // + original_expr: LogicalPlan, + left: Operator, + right: Operator, + // TODO: combine into one Expr + join_conditions: Vec, + // join_type: +} +impl DependentJoin { + fn replace_right( + &mut self, + plan: LogicalPlan, + unnesting: &UnnestingInfo, + replacements: &IndexMap, + ) { + self.right.plan = plan; + for col in unnesting.outer_refs.iter() { + let replacement = replacements.get(col).unwrap(); + self.join_conditions.push(binary_expr( + Expr::Column(col.clone()), + datafusion_expr::Operator::IsNotDistinctFrom, + Expr::Column(replacement.clone()), + )); + } + } + fn replace_left( + &mut self, + plan: LogicalPlan, + column_replacements: &IndexMap, + ) { + self.left.plan = plan + // TODO: + // - update join condition + // - check if the relation with children should be removed + } +} + +#[derive(Clone)] +struct UnnestingInfo { + join: DependentJoin, + outer_refs: Vec, + domain: Vec, + parent: Option, +} +#[derive(Clone)] +struct Unnesting { + info: Arc, // cclasses: union find data structure of equivalent columns + equivalences: UnionFind, + replaces: IndexMap, + // mapping from outer ref column to new column, if any + // i.e in some subquery ( + // ... where outer.column_c=inner.column_a + // ) + // and through union find we have outer.column_c = some_other_expr + // we can substitute the inner query with inner.column_a=some_other_expr +} + +// impl Default for GeneralDecorrelation { +// fn default() -> Self { +// return GeneralDecorrelation { +// index: AlgebraIndex::default(), +// }; +// } +// } +impl AlgebraIndex { + fn is_linear_operator(&self, plan: &LogicalPlan) -> bool { + match plan { + LogicalPlan::Limit(_) => true, + LogicalPlan::TableScan(_) => true, + LogicalPlan::Projection(_) => true, + LogicalPlan::Filter(_) => true, + LogicalPlan::Repartition(_) => true, + _ => false, + } + } + fn is_linear_path(&self, parent: &usize, child: &usize) -> bool { + let mut current_node = *child; + + loop { + let child_node = self.nodes.get(¤t_node).unwrap(); + if !self.is_linear_operator(&child_node.plan) { + return false; + } + if current_node == *parent { + return true; + } + match child_node.parent { + None => return true, + Some(new_parent) => { + if new_parent == *parent { + return true; + } + current_node = new_parent; + } + }; + } + } + // decorrelate all children with simple unnesting + // returns true if all children were eliminated + // TODO(impl me) + fn try_decorrelate_child(&self, root: &usize, child: &usize) -> Result { + if !self.is_linear_path(root, child) { + return Ok(false); + } + let child_node = self.nodes.get(child).unwrap(); + let root_node = self.nodes.get(root).unwrap(); + match &child_node.plan { + LogicalPlan::Projection(proj) => {} + LogicalPlan::Filter(filter) => { + let accessed_from_child = &child_node.access_tracker; + for col_access in accessed_from_child { + println!( + "checking if col {} can be merged into parent's join filter {}", + col_access.debug(), + root_node.plan + ) + } + } + _ => {} + } + Ok(false) + } + + fn unnest( + &mut self, + node_id: usize, + unnesting: &mut Unnesting, + outer_refs_from_parent: HashSet, + ) -> Result { + unimplemented!() + // if unnesting.info.parent.is_some() { + // not_impl_err!("impl me") + // // TODO + // } + // // info = Un + // let node = self.nodes.get(node_id).unwrap(); + // match node.plan { + // LogicalPlan::Aggregate(aggr) => {} + // _ => {} + // } + // Ok(()) + } + fn right(&self, node: &Operator) -> &Operator { + assert_eq!(2, node.children.len()); + // during the building of the tree, the subquery (right node) is always traversed first + let node_id = node.children.get(0).unwrap(); + return self.nodes.get(node_id).unwrap(); + } + fn left(&self, node: &Operator) -> &Operator { + assert_eq!(2, node.children.len()); + // during the building of the tree, the subquery (right node) is always traversed first + let node_id = node.children.get(1).unwrap(); + return self.nodes.get(node_id).unwrap(); + } + fn root_dependent_join_elimination(&mut self) -> Result { + let root = self.root.unwrap(); + let node = self.nodes.get(&root).unwrap(); + // TODO: need to store the first dependent join node + assert!( + node.is_dependent_join_node, + "need to handle the case root node is not dependent join node" + ); + let unnesting_info = UnnestingInfo { + parent: None, + join: DependentJoin { + original_expr: node.plan.clone(), + left: self.left(node).clone(), + right: self.right(node).clone(), + join_conditions: vec![], + }, + domain: vec![], + outer_refs: vec![], + }; + // let unnesting = Unnesting { + // info: Arc::new(unnesting), + // equivalences: UnionFind::new(), + // replaces: IndexMap::new(), + // }; + + self.dependent_join_elimination(node.id, &unnesting_info, HashSet::new()) + } + + fn column_accesses(&self, node_id: usize) -> Vec<&ColumnAccess> { + let node = self.nodes.get(&node_id).unwrap(); + node.access_tracker.iter().collect() + } + fn new_dependent_join(&self, node: &Operator) -> DependentJoin { + DependentJoin { + original_expr: node.plan.clone(), + left: self.left(node).clone(), + right: self.left(node).clone(), + join_conditions: vec![], + } + } + + fn dependent_join_elimination( + &mut self, + node: usize, + unnesting: &UnnestingInfo, + outer_refs_from_parent: HashSet, + ) -> Result { + let parent = unnesting.parent.clone(); + let operator = self.nodes.get(&node).unwrap(); + let plan = &operator.plan; + let mut join = self.new_dependent_join(operator); + // we have to do the reversed iter, because we know the subquery (right side of + // the dependent join) is always the first child of the node, and we want to visit + // the left side first + + let (dependent_join, finished) = self.simple_decorrelation(node)?; + if finished { + if parent.is_some() { + // for each projection of outer column moved up by simple_decorrelation + // replace them with the expr store inside parent.replaces + unimplemented!(""); + return self.unnest(node, &mut parent.unwrap(), outer_refs_from_parent); + } + return Ok(dependent_join); + } + if parent.is_some() { + // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) + + let mut outer_ref_from_left = HashSet::new(); + let left = join.left.clone(); + for col_from_parent in outer_refs_from_parent.iter() { + if left + .plan + .all_out_ref_exprs() + .contains(&Expr::Column(col_from_parent.clone())) + { + outer_ref_from_left.insert(col_from_parent.clone()); + } + } + let mut parent_unnesting = parent.clone().unwrap(); + let new_left = + self.unnest(left.id, &mut parent_unnesting, outer_ref_from_left)?; + join.replace_left(new_left, &parent_unnesting.replaces); + + // TODO: after imple simple_decorrelation, rewrite the projection pushed up column as well + } + let new_unnesting_info = UnnestingInfo { + parent: parent.clone(), + join: join.clone(), + domain: vec![], // TODO: populate me + outer_refs: vec![], // TODO: populate me + }; + let mut unnesting = Unnesting { + info: Arc::new(new_unnesting_info.clone()), + equivalences: UnionFind { + parent: IndexMap::new(), + rank: IndexMap::new(), + }, + replaces: IndexMap::new(), + }; + let mut accesses: HashSet = self + .column_accesses(node) + .iter() + .map(|a| a.col.clone()) + .collect(); + if parent.is_some() { + for col_access in outer_refs_from_parent { + if join + .right + .plan + .all_out_ref_exprs() + .contains(&Expr::Column(col_access.clone())) + { + accesses.insert(col_access.clone()); + } + } + // add equivalences from join.condition to unnest.cclasses + } + + let new_right = self.unnest(join.right.id, &mut unnesting, accesses)?; + join.replace_right(new_right, &new_unnesting_info, &unnesting.replaces); + // for acc in new_unnesting_info.outer_refs{ + // join.join_conditions.append(other); + // } + + unimplemented!() + } + fn rewrite_columns(expr: Expr, unnesting: Unnesting) { + unimplemented!() + // expr.apply(|expr| { + // if let Expr::OuterReferenceColumn(_, col) = expr { + // set.insert(col); + // } + // Ok(TreeNodeRecursion::Continue) + // }) + // .expect("traversal is infallible"); + } + + fn simple_decorrelation(&mut self, node: usize) -> Result<(LogicalPlan, bool)> { + let node = self.nodes.get(&node).unwrap(); + let mut all_eliminated = false; + for child in node.children.iter() { + let branch_all_eliminated = self.try_decorrelate_child(child, child)?; + all_eliminated = all_eliminated || branch_all_eliminated; + } + Ok((node.plan.clone(), false)) + } + fn build(&mut self, root: &LogicalPlan) -> Result<()> { + self.build_algebra_index(root.clone())?; + println!("{:?}", self); + Ok(()) + } +} +impl fmt::Debug for AlgebraIndex { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "GeneralDecorrelation Tree:")?; if let Some(root_op) = &self.root { - self.fmt_operator(f, root_op, 0, false)?; + self.fmt_operator(f, *root_op, 0, false)?; } else { writeln!(f, " ")?; } @@ -57,16 +439,17 @@ impl fmt::Debug for GeneralDecorrelation { } } -impl GeneralDecorrelation { +impl AlgebraIndex { fn fmt_operator( &self, f: &mut fmt::Formatter<'_>, - lp: &LogicalPlan, + node_id: usize, indent: usize, is_last: bool, ) -> fmt::Result { // Find the LogicalPlan corresponding to this Operator - let op = self.nodes.get(lp).unwrap(); + let op = self.nodes.get(&node_id).unwrap(); + let lp = &op.plan; for i in 0..indent { if i + 1 == indent { @@ -87,7 +470,7 @@ impl GeneralDecorrelation { Ok(str) => str.to_string(), Err(_) => "".to_string(), }; - writeln!(f, "\x1b[33m{}\x1b[0m", lp.display())?; + writeln!(f, "\x1b[33m [{}] {}\x1b[0m", node_id, lp.display())?; if !unparsed_sql.is_empty() { for i in 0..=indent { if i < indent { @@ -108,14 +491,14 @@ impl GeneralDecorrelation { } } - let access_string = op - .accesses + let accessing_string = op + .potential_accesses .iter() .map(|c| c.debug()) .collect::>() .join(", "); - let provide_string = op - .provides + let accessed_by_string = op + .access_tracker .iter() .map(|c| c.debug()) .collect::>() @@ -123,8 +506,8 @@ impl GeneralDecorrelation { // Now print the Operator details writeln!( f, - "accesses: {}, provides: {}", - access_string, provide_string, + "acccessing: {}, accessed_by: {}", + accessing_string, accessed_by_string, )?; let len = op.children.len(); @@ -132,43 +515,86 @@ impl GeneralDecorrelation { for (i, child) in op.children.iter().enumerate() { let last = i + 1 == len; - self.fmt_operator(f, child, indent + 1, last)?; + self.fmt_operator(f, *child, indent + 1, last)?; } Ok(()) } - fn update_ancestor_node_accesses(&mut self, col: &Column) { - // iter from bottom to top, the goal is to find the LCA only - for node in self.stack.iter().rev() { - let operator = self.nodes.get_mut(node).unwrap(); - let to_insert = ColumnUsage::Outer(col.clone()); - // This is the LCA between the current node and the outer column provider - if operator.accesses.contains(&to_insert) { - return; + fn lca_from_stack(a: &[usize], b: &[usize]) -> usize { + let mut lca = None; + + let min_len = a.len().min(b.len()); + + for i in 0..min_len { + let ai = a[i]; + let bi = b[i]; + + if ai == bi { + lca = Some(ai); + } else { + break; + } + } + + lca.unwrap() + } + + // because the column providers are visited after column-accessor + // function visit_with_subqueries always visit the subquery before visiting the other child + // we can always infer the LCA inside this function, by getting the deepest common parent + fn conclude_lca_for_column(&mut self, child_id: usize, col: &Column) { + if let Some(accesses) = self.accessed_columns.get(col) { + for access in accesses.iter() { + let mut cur_stack = self.stack.clone(); + cur_stack.push(child_id); + // this is a dependen join node + let lca_node = Self::lca_from_stack(&cur_stack, &access.stack); + let node = self.nodes.get_mut(&lca_node).unwrap(); + node.access_tracker.insert(ColumnAccess { + col: col.clone(), + node_id: access.node_id, + stack: access.stack.clone(), + }); } - operator.accesses.insert(to_insert); } } + + fn mark_column_access(&mut self, child_id: usize, col: &Column) { + // iter from bottom to top, the goal is to mark the independen_join node + // the current child's access + let mut stack = self.stack.clone(); + stack.push(child_id); + self.accessed_columns + .entry(col.clone()) + .or_default() + .push(ColumnAccess { + stack, + node_id: child_id, + col: col.clone(), + }); + } fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { println!("======================================begin"); + // let mut index = AlgebraIndex::default(); plan.visit_with_subqueries(self)?; println!("======================================end"); Ok(()) } - fn update_children(&mut self, parent: &LogicalPlan, child: &LogicalPlan) { - let operator = self.nodes.get_mut(parent).unwrap(); - operator.children.push(child.clone()); + fn create_child_relationship(&mut self, parent: usize, child: usize) { + let operator = self.nodes.get_mut(&parent).unwrap(); + operator.children.push(child); } } -impl Default for GeneralDecorrelation { +impl Default for AlgebraIndex { fn default() -> Self { - return GeneralDecorrelation { + return AlgebraIndex { root: None, current_id: 0, nodes: IndexMap::new(), stack: vec![], + accessed_columns: IndexMap::new(), }; } } @@ -186,12 +612,16 @@ impl ColumnUsage { } } } -#[derive(Debug)] +impl ColumnAccess { + fn debug(&self) -> String { + format!("\x1b[31m{} ({})\x1b[0m", self.node_id, self.col) + } +} +#[derive(Debug, Clone)] struct Operator { id: usize, plan: LogicalPlan, - parent: Option, - // children: Vec>>, + parent: Option, // Note if the current node is a Subquery // at the first time this node is visited, // the set of accesses columns are not sufficient @@ -200,12 +630,26 @@ struct Operator { // of the subquery // Thus, on discovery of new subquery, we must // add the accesses columns to the ancestor nodes which are Subquery - accesses: HashSet, + potential_accesses: HashSet, provides: HashSet, - // for now only care about filter/projection with one of the expr is subquery + // This field is only set if the node is dependent join node + // it track which child still accessing which column of + access_tracker: HashSet, + is_dependent_join_node: bool, - children: Vec, + is_subquery_node: bool, + children: Vec, +} +impl Operator { + // fn to_dependent_join(&self) -> DependentJoin { + // DependentJoin { + // original_expr: self.plan.clone(), + // left: self.left(), + // right: self.right(), + // join_conditions: vec![], + // } + // } } fn contains_subquery(expr: &Expr) -> bool { @@ -218,40 +662,6 @@ fn contains_subquery(expr: &Expr) -> bool { .expect("Inner is always Ok") } -// struct ExtractScalarSubQuery<'a> { -// sub_query_info: Vec<(Subquery, String)>, -// in_sub_query_info: Vec<(InSubquery, String)>, -// alias_gen: &'a Arc, -// } - -// impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { -// type Node = Expr; - -// fn f_down(&mut self, expr: Expr) -> Result> { -// match expr { -// Expr::InSubquery(in_subquery) => {} -// Expr::ScalarSubquery(subquery) => { -// let subqry_alias = self.alias_gen.next("__scalar_sq"); -// self.sub_query_info -// .push((subquery.clone(), subqry_alias.clone())); -// let scalar_expr = subquery -// .subquery -// .head_output_expr()? -// .map_or(plan_err!("single expression required."), Ok)?; -// Ok(Transformed::new( -// Expr::Column(create_col_from_scalar_expr( -// &scalar_expr, -// subqry_alias, -// )?), -// true, -// TreeNodeRecursion::Jump, -// )) -// } -// _ => Ok(Transformed::no(expr)), -// } -// } -// } - fn print(a: &Expr) -> Result<()> { let unparser = Unparser::default(); let round_trip_sql = unparser.expr_to_sql(a)?.to_string(); @@ -259,12 +669,14 @@ fn print(a: &Expr) -> Result<()> { Ok(()) } -impl TreeNodeVisitor<'_> for GeneralDecorrelation { +impl TreeNodeVisitor<'_> for AlgebraIndex { type Node = LogicalPlan; fn f_down(&mut self, node: &LogicalPlan) -> Result { + self.current_id += 1; if self.root.is_none() { - self.root = Some(node.clone()); + self.root = Some(self.current_id); } + let mut is_subquery_node = false; let mut is_dependent_join_node = false; println!("{}\nnode {}", "----".repeat(self.stack.len()), node); // for each node, find which column it is accessing, which column it is providing @@ -274,24 +686,17 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { LogicalPlan::Filter(f) => { if contains_subquery(&f.predicate) { is_dependent_join_node = true; - print(&f.predicate); } let mut outer_col_refs: HashSet = f .predicate .outer_column_refs() .into_iter() .map(|f| { - self.update_ancestor_node_accesses(f); + self.mark_column_access(self.current_id, f); ColumnUsage::Outer(f.clone()) }) .collect(); - outer_col_refs.extend( - f.predicate - .column_refs() - .into_iter() - .map(|f| ColumnUsage::Own(f.clone())), - ); (outer_col_refs, HashSet::new()) } LogicalPlan::TableScan(tbl_scan) => { @@ -299,7 +704,10 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { .projected_schema .columns() .into_iter() - .map(|col| ColumnUsage::Own(col)) + .map(|col| { + self.conclude_lca_for_column(self.current_id, &col); + ColumnUsage::Own(col) + }) .collect(); (HashSet::new(), provided_columns) } @@ -314,6 +722,7 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { for expr in &proj.expr { if contains_subquery(expr) { is_dependent_join_node = true; + break; } } // proj.expr @@ -321,6 +730,7 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { (HashSet::new(), HashSet::new()) } LogicalPlan::Subquery(subquery) => { + is_subquery_node = true; // TODO: once we detect the subquery let accessed = subquery .outer_ref_columns @@ -344,40 +754,25 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { None } else { let previous_node = self.stack.last().unwrap().to_owned(); - self.update_children(&previous_node, node); + self.create_child_relationship(previous_node, self.current_id); Some(self.stack.last().unwrap().to_owned()) }; - self.stack.push(node.clone()); + self.stack.push(self.current_id); self.nodes.insert( - node.clone(), + self.current_id, Operator { id: self.current_id, parent, plan: node.clone(), - accesses, + potential_accesses: accesses, provides, + is_subquery_node, is_dependent_join_node, children: vec![], + access_tracker: HashSet::new(), }, ); - // let operator = match self.nodes.entry(node.clone()) { - // Entry::Occupied(entry) => entry.into_mut(), - // Entry::Vacant(entry) => { - // let parent = if self.stack.len() == 0 { - // None - // } else { - // Some(self.stack.last().unwrap().to_owned()) - // }; - // entry.insert(Operator { - // id: self.current_id, - // parent, - // plan: node.clone(), - // accesses, - // provides, - // }) - // } - // }; Ok(TreeNodeRecursion::Continue) } @@ -390,7 +785,7 @@ impl TreeNodeVisitor<'_> for GeneralDecorrelation { } } -impl OptimizerRule for GeneralDecorrelation { +impl OptimizerRule for AlgebraIndex { fn supports_rewrite(&self) -> bool { true } @@ -418,54 +813,157 @@ mod tests { use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ expr_fn::{self, col}, - lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, EmptyRelation, - Expr, LogicalPlan, LogicalPlanBuilder, + in_subquery, lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, + EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, }; use datafusion_functions_aggregate::{count::count, sum::sum}; use regex_syntax::ast::LiteralKind; use crate::test::{test_table_scan, test_table_scan_with_name}; - use super::GeneralDecorrelation; + use super::AlgebraIndex; use arrow::{ array::{Int32Array, StringArray}, datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; #[test] - fn todo() -> Result<()> { - let mut a = GeneralDecorrelation::default(); + fn play_unnest_simple_projection_pull_up() -> Result<()> { + // let mut framework = GeneralDecorrelation::default(); let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - let sq_level2 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv2) + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) .filter( - out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") - .eq(col("inner_table_lv2.b")), + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b")])? .build()?, ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + let mut index = AlgebraIndex::default(); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + println!("{}", new_plan); + + // let input2 = LogicalPlanBuilder::from(input.clone()) + // .filter(col("int_col").gt(lit(1)))? + // .project(vec![col("string_col")])? + // .build()?; + + // let mut b = GeneralDecorrelation::default(); + // b.build_algebra_index(input2)?; + + Ok(()) + } + #[test] + fn play_unnest_simple_predicate_pull_up() -> Result<()> { + // let mut framework = GeneralDecorrelation::default(); + + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + // let sq_level2 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv2) + // .filter( + // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") + // .eq(col("inner_table_lv2.b")) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.c") + // .eq(col("inner_table_lv2.c")), + // ), + // )? + // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + // .build()?, + // ); + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .eq(lit(1)), + ), + )? + .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? + .project(vec![sum(col("inner_table_lv1.b"))])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), + )? + .build()?; + let mut index = AlgebraIndex::default(); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + println!("{}", new_plan); + + // let input2 = LogicalPlanBuilder::from(input.clone()) + // .filter(col("int_col").gt(lit(1)))? + // .project(vec![col("string_col")])? + // .build()?; + + // let mut b = GeneralDecorrelation::default(); + // b.build_algebra_index(input2)?; + + Ok(()) + } + #[test] + fn play_unnest() -> Result<()> { + // let mut framework = GeneralDecorrelation::default(); + + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + // let sq_level2 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv2) + // .filter( + // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") + // .eq(col("inner_table_lv2.b")) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.c") + // .eq(col("inner_table_lv2.c")), + // ), + // )? + // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + // .build()?, + // ); let sq_level1 = Arc::new( LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), )? - .filter(scalar_subquery(sq_level2).gt(lit(5)))? .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? .project(vec![sum(col("inner_table_lv1.b"))])? .build()?, ); let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter(col("outer_table.a").gt(lit(1)))? - .filter(col("outer_table.b").gt(scalar_subquery(sq_level1)))? + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), + )? .build()?; - a.build_algebra_index(input1.clone())?; - println!("{:?}", a); + let mut index = AlgebraIndex::default(); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + println!("{}", new_plan); // let input2 = LogicalPlanBuilder::from(input.clone()) // .filter(col("int_col").gt(lit(1)))? @@ -477,4 +975,56 @@ mod tests { Ok(()) } + + // #[test] + // fn todo() -> Result<()> { + // let mut framework = GeneralDecorrelation::default(); + + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + // let sq_level2 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv2) + // .filter( + // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") + // .eq(col("inner_table_lv2.b")) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.c") + // .eq(col("inner_table_lv2.c")), + // ), + // )? + // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + // .build()?, + // ); + // let sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter( + // col("inner_table_lv1.a") + // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + // .and(scalar_subquery(sq_level2).gt(lit(5))), + // )? + // .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? + // .project(vec![sum(col("inner_table_lv1.b"))])? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter( + // col("outer_table.a") + // .gt(lit(1)) + // .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), + // )? + // .build()?; + // framework.build(&input1)?; + + // // let input2 = LogicalPlanBuilder::from(input.clone()) + // // .filter(col("int_col").gt(lit(1)))? + // // .project(vec![col("string_col")])? + // // .build()?; + + // // let mut b = GeneralDecorrelation::default(); + // // b.build_algebra_index(input2)?; + + // Ok(()) + // } } From 483e3ac81440b60cc85792f4cd5636c0e16fe046 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 4 May 2025 20:11:34 +0200 Subject: [PATCH 07/70] chore: impl projection pull up --- .../optimizer/src/decorrelate_general.rs | 190 +++++++++++++++--- 1 file changed, 163 insertions(+), 27 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index c1a0050c702d..a19f442961e7 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -24,7 +24,8 @@ use std::ops::Deref; use std::rc::{Rc, Weak}; use std::sync::Arc; -use crate::simplify_expressions::ExprSimplifier; +use crate::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; +use crate::utils::has_all_column_refs; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{ @@ -35,7 +36,7 @@ use datafusion_common::{internal_err, not_impl_err, Column, Result}; use datafusion_expr::{binary_expr, Expr, JoinType, LogicalPlan}; use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; -use indexmap::IndexMap; +use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use log::Log; @@ -177,6 +178,17 @@ struct Unnesting { // we can substitute the inner query with inner.column_a=some_other_expr } +struct SimpleDecorrelationResult { + // new: Option, + // if projectoin pull up happened, each will be tracked, so that later on general decorrelation + // can rewrite them (a.k.a outer ref column maybe renamed/substituted some where in the parent already + // because the decorrelation is top-down) + pulled_up_projections: IndexSet, + pulled_up_predicates: Vec, + // simple decorrelation has eliminated all dependent joins + finished: bool, +} + // impl Default for GeneralDecorrelation { // fn default() -> Self { // return GeneralDecorrelation { @@ -192,6 +204,7 @@ impl AlgebraIndex { LogicalPlan::Projection(_) => true, LogicalPlan::Filter(_) => true, LogicalPlan::Repartition(_) => true, + LogicalPlan::Subquery(_) => true, // TODO: is this true??? _ => false, } } @@ -217,17 +230,80 @@ impl AlgebraIndex { }; } } - // decorrelate all children with simple unnesting + fn remove_node(&mut self, parent: &mut Operator, node: &mut Operator) { + let next_children = node.children.get(0).unwrap(); + let next_children_node = self.nodes.swap_remove(next_children).unwrap(); + // let next_children_node = self.nodes.get_mut(next_children).unwrap(); + *node = next_children_node; + node.parent = Some(parent.id); + } + // decorrelate all descendant(recursively) with simple unnesting // returns true if all children were eliminated // TODO(impl me) - fn try_decorrelate_child(&self, root: &usize, child: &usize) -> Result { - if !self.is_linear_path(root, child) { + fn try_simple_unnest_descendent( + &mut self, + root_node: &mut Operator, + child_node: &mut Operator, + col_access: &ColumnAccess, + result: &mut SimpleDecorrelationResult, + ) -> Result { + // unnest children first + // println!("decorrelating {} from {}", child, root); + + if !self.is_linear_path(&root_node.id, &child_node.id) { + // TODO: return Ok(false); } - let child_node = self.nodes.get(child).unwrap(); - let root_node = self.nodes.get(root).unwrap(); - match &child_node.plan { - LogicalPlan::Projection(proj) => {} + + // TODO: inplace update + // let mut child_node = self.nodes.swap_remove(child).unwrap().clone(); + // let mut root_node = self.nodes.swap_remove(root).unwrap(); + println!("child node is {}", child_node.plan); + + match &mut child_node.plan { + LogicalPlan::Projection(proj) => { + // TODO: handle the case outer_ref_a + outer_ref_b??? + // if we only see outer_ref_a and decide to move the whole expr + // outer_ref_b is accidentally pulled up + let pulled_up_expr: IndexSet<_> = proj + .expr + .iter() + .filter(|proj_expr| { + proj_expr + .exists(|expr| { + // TODO: what if parent has already rewritten outer_ref_col + if let Expr::OuterReferenceColumn(_, col) = expr { + root_node.access_tracker.remove(col_access); + return Ok(*col == col_access.col); + } + Ok(false) + }) + .unwrap() + }) + .cloned() + .collect(); + println!("{:?}", pulled_up_expr); + + if !pulled_up_expr.is_empty() { + for expr in pulled_up_expr.iter() { + result.pulled_up_projections.insert(expr.clone()); + } + // all expr of this node is pulled up, fully remove this node from the tree + if proj.expr.len() == pulled_up_expr.len() { + self.remove_node(root_node, child_node); + return Ok(true); + } + + let new_proj = proj + .expr + .iter() + .filter(|expr| !pulled_up_expr.contains(*expr)) + .cloned() + .collect(); + proj.expr = new_proj; + } + // TODO: try_decorrelate for each of the child + } LogicalPlan::Filter(filter) => { let accessed_from_child = &child_node.access_tracker; for col_access in accessed_from_child { @@ -238,8 +314,26 @@ impl AlgebraIndex { ) } } - _ => {} - } + + // LogicalPlan::Subquery(sq) => { + // let descendent_id = child_node.children.get(0).unwrap(); + // let mut descendent_node = self.nodes.get(descendent_id).unwrap().clone(); + // self.try_simple_unnest_descendent( + // root_node, + // &mut descendent_node, + // result, + // )?; + // self.nodes.insert(*descendent_id, descendent_node); + // } + _ => { + // unimplemented!( + // "simple unnest is missing for this operator {}", + // child_node.plan + // ) + } + }; + // self.nodes.insert(*root, root_node); + // self.nodes.insert(*child, child_node); Ok(false) } @@ -329,16 +423,21 @@ impl AlgebraIndex { // the dependent join) is always the first child of the node, and we want to visit // the left side first - let (dependent_join, finished) = self.simple_decorrelation(node)?; - if finished { + let simple_unnest_result = self.simple_decorrelation(node)?; + let new_root = self.nodes.get(&node).unwrap(); + if new_root.access_tracker.len() == 0 { + unimplemented!("reached"); if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation // replace them with the expr store inside parent.replaces unimplemented!(""); return self.unnest(node, &mut parent.unwrap(), outer_refs_from_parent); } - return Ok(dependent_join); + unimplemented!() + // return Ok(dependent_join); } + println!("after rewriting================================"); + println!("{:?}", self); if parent.is_some() { // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) @@ -412,14 +511,40 @@ impl AlgebraIndex { // .expect("traversal is infallible"); } - fn simple_decorrelation(&mut self, node: usize) -> Result<(LogicalPlan, bool)> { - let node = self.nodes.get(&node).unwrap(); + fn simple_decorrelation( + &mut self, + node_id: usize, + ) -> Result { + let mut node = self.nodes.get(&node_id).unwrap().clone(); let mut all_eliminated = false; - for child in node.children.iter() { - let branch_all_eliminated = self.try_decorrelate_child(child, child)?; + let mut result = SimpleDecorrelationResult { + // new: None, + pulled_up_projections: IndexSet::new(), + pulled_up_predicates: vec![], + finished: false, + }; + // only iter with direct child + // TODO: confirm if this needs to happen also with descendant + // most likely no, because if this is recursive, it is already non-linear anyway + // and simple decorrleation will stop + for col_access in node.clone().access_tracker.iter() { + println!("here"); + let mut parent_node = self.nodes.get(&node_id).unwrap().clone(); + let mut cloned_child_node = + self.nodes.get(&col_access.node_id).unwrap().clone(); + let branch_all_eliminated = self.try_simple_unnest_descendent( + &mut parent_node, + &mut cloned_child_node, + col_access, + &mut result, + )?; + self.nodes.insert(node_id, parent_node.clone()); + self.nodes.insert(col_access.node_id, cloned_child_node); all_eliminated = all_eliminated || branch_all_eliminated; } - Ok((node.plan.clone(), false)) + + result.finished = all_eliminated; + Ok(result) } fn build(&mut self, root: &LogicalPlan) -> Result<()> { self.build_algebra_index(root.clone())?; @@ -470,7 +595,12 @@ impl AlgebraIndex { Ok(str) => str.to_string(), Err(_) => "".to_string(), }; - writeln!(f, "\x1b[33m [{}] {}\x1b[0m", node_id, lp.display())?; + let (node_color, display_str) = match lp { + LogicalPlan::Subquery(_) => ("\x1b[32m", format!("\x1b[1m{}", lp.display())), + _ => ("\x1b[33m", lp.display().to_string()), + }; + + writeln!(f, "{} [{}] {}\x1b[0m", node_color, node_id, display_str)?; if !unparsed_sql.is_empty() { for i in 0..=indent { if i < indent { @@ -575,10 +705,8 @@ impl AlgebraIndex { }); } fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { - println!("======================================begin"); // let mut index = AlgebraIndex::default(); plan.visit_with_subqueries(self)?; - println!("======================================end"); Ok(()) } fn create_child_relationship(&mut self, parent: usize, child: usize) { @@ -678,7 +806,6 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { } let mut is_subquery_node = false; let mut is_dependent_join_node = false; - println!("{}\nnode {}", "----".repeat(self.stack.len()), node); // for each node, find which column it is accessing, which column it is providing // Set of columns current node access let (accesses, provides): (HashSet, HashSet) = @@ -687,7 +814,7 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { if contains_subquery(&f.predicate) { is_dependent_join_node = true; } - let mut outer_col_refs: HashSet = f + let outer_col_refs: HashSet = f .predicate .outer_column_refs() .into_iter() @@ -719,15 +846,24 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { // 2.projection also provide some new columns // 3.if within projection exists multiple subquery, how does this work LogicalPlan::Projection(proj) => { + let mut outer_cols = HashSet::new(); for expr in &proj.expr { if contains_subquery(expr) { is_dependent_join_node = true; break; } + expr.add_outer_column_refs(&mut outer_cols); } - // proj.expr - // TODO: fix me - (HashSet::new(), HashSet::new()) + ( + outer_cols + .into_iter() + .map(|c| { + self.mark_column_access(self.current_id, c); + ColumnUsage::Outer(c.clone()) + }) + .collect(), + HashSet::new(), + ) } LogicalPlan::Subquery(subquery) => { is_subquery_node = true; From f14b14512c85a15e1730dd2bb2aeb942bc0764d5 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 6 May 2025 21:26:20 +0200 Subject: [PATCH 08/70] chore: complete unnesting simple subquery --- .../optimizer/src/decorrelate_general.rs | 194 ++++++++++++++++-- 1 file changed, 172 insertions(+), 22 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index a19f442961e7..2800102b79ee 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -33,7 +33,13 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, not_impl_err, Column, Result}; -use datafusion_expr::{binary_expr, Expr, JoinType, LogicalPlan}; +use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::{conjunction, split_conjunction}; +use datafusion_expr::{ + binary_expr, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, + Operator as ExprOperator, Subquery, +}; use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; @@ -141,7 +147,7 @@ impl DependentJoin { let replacement = replacements.get(col).unwrap(); self.join_conditions.push(binary_expr( Expr::Column(col.clone()), - datafusion_expr::Operator::IsNotDistinctFrom, + ExprOperator::IsNotDistinctFrom, Expr::Column(replacement.clone()), )); } @@ -178,9 +184,37 @@ struct Unnesting { // we can substitute the inner query with inner.column_a=some_other_expr } +// TODO: looks like this function can be improved to allow more expr pull up +fn can_pull_up(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: ExprOperator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + struct SimpleDecorrelationResult { // new: Option, - // if projectoin pull up happened, each will be tracked, so that later on general decorrelation + // if projection pull up happened, each will be tracked, so that later on general decorrelation // can rewrite them (a.k.a outer ref column maybe renamed/substituted some where in the parent already // because the decorrelation is top-down) pulled_up_projections: IndexSet, @@ -188,6 +222,19 @@ struct SimpleDecorrelationResult { // simple decorrelation has eliminated all dependent joins finished: bool, } +fn expr_contains_sq(expr: &Expr, sq: &Subquery) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(isq) => Ok(isq.subquery == *sq), + Expr::ScalarSubquery(ssq) => { + if let LogicalPlan::Subquery(inner_sq) = ssq.subquery.as_ref() { + return Ok(inner_sq.clone() == *sq); + } + Ok(false) + } + _ => Ok(false), + }) + .unwrap() +} // impl Default for GeneralDecorrelation { // fn default() -> Self { @@ -204,7 +251,6 @@ impl AlgebraIndex { LogicalPlan::Projection(_) => true, LogicalPlan::Filter(_) => true, LogicalPlan::Repartition(_) => true, - LogicalPlan::Subquery(_) => true, // TODO: is this true??? _ => false, } } @@ -214,7 +260,17 @@ impl AlgebraIndex { loop { let child_node = self.nodes.get(¤t_node).unwrap(); if !self.is_linear_operator(&child_node.plan) { - return false; + match child_node.parent { + None => { + unimplemented!("traversing from descedent to top does not meet expected root") + } + Some(new_parent) => { + if new_parent == *parent { + return true; + } + return false; + } + } } if current_node == *parent { return true; @@ -222,9 +278,6 @@ impl AlgebraIndex { match child_node.parent { None => return true, Some(new_parent) => { - if new_parent == *parent { - return true; - } current_node = new_parent; } }; @@ -305,14 +358,41 @@ impl AlgebraIndex { // TODO: try_decorrelate for each of the child } LogicalPlan::Filter(filter) => { - let accessed_from_child = &child_node.access_tracker; - for col_access in accessed_from_child { + // let accessed_from_child = &child_node.access_tracker; + let subquery_filter_exprs: Vec = + split_conjunction(&filter.predicate) + .into_iter() + .cloned() + .collect(); + + let (pulled_up, kept): (Vec<_>, Vec<_>) = subquery_filter_exprs + .iter() + .cloned() + .partition(|e| e.contains_outer() && can_pull_up(e)); + // only remove the access tracker if non of the kept expr contains reference to the column + // i.e some of the remaining expr still reference to the column and not pullable + let removable = kept.iter().all(|e| { + !e.exists(|e| { + if let Expr::Column(col) = e { + return Ok(*col == col_access.col); + } + Ok(false) + }) + .unwrap() + }); + if removable { + root_node.access_tracker.remove(col_access); println!( - "checking if col {} can be merged into parent's join filter {}", - col_access.debug(), - root_node.plan - ) + "remove {} access from node {:?}", + col_access.col, root_node.id + ); + } + result.pulled_up_predicates.extend(pulled_up); + if kept.is_empty() { + self.remove_node(root_node, child_node); + return Ok(true); } + filter.predicate = conjunction(kept).unwrap(); } // LogicalPlan::Subquery(sq) => { @@ -408,6 +488,69 @@ impl AlgebraIndex { join_conditions: vec![], } } + fn get_subquery_children( + &self, + parent: &Operator, + ) -> Result<(LogicalPlan, Subquery)> { + let subquery = parent.children.get(0).unwrap(); + let sq_node = self.nodes.get(subquery).unwrap(); + assert!(sq_node.is_subquery_node); + let query = sq_node.children.get(0).unwrap(); + let target_node = self.nodes.get(query).unwrap(); + // let op = .clone(); + if let LogicalPlan::Subquery(subquery) = sq_node.plan.clone() { + return Ok((target_node.plan.clone(), subquery)); + } else { + internal_err!("") + } + } + + fn build_join_from_simple_unnest( + &self, + dependent_join_node: &mut Operator, + ret: SimpleDecorrelationResult, + ) -> Result { + let (subquery_children, subquery) = + self.get_subquery_children(dependent_join_node)?; + match dependent_join_node.plan { + LogicalPlan::Filter(ref mut filter) => { + let exprs = split_conjunction(&filter.predicate); + let mut kept_predicates: Vec = exprs + .into_iter() + .filter(|e| !expr_contains_sq(e, &subquery)) + .cloned() + .collect(); + let new_predicates = ret + .pulled_up_predicates + .iter() + .map(|e| strip_outer_reference(e.clone())); + // TODO: some predicate is join predicate, some is just filter + // kept_predicates.extend(new_predicates); + // filter.predicate = conjunction(kept_predicates).unwrap(); + // left + let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); + + builder = + builder.join_on(subquery_children, JoinType::Left, new_predicates)?; + if !ret.pulled_up_projections.is_empty() { + // TODO: do we need to pull up projection? + // when most of the case they will be eliminated anyway + // builder = builder.project( + // ret.pulled_up_projections + // .iter() + // .map(|e| SelectExpr::Expression(e.clone())), + // )?; + } + if kept_predicates.len() > 0 { + builder = builder.filter(conjunction(kept_predicates).unwrap())? + } + builder.build() + } + _ => { + unimplemented!() + } + } + } fn dependent_join_elimination( &mut self, @@ -424,9 +567,12 @@ impl AlgebraIndex { // the left side first let simple_unnest_result = self.simple_decorrelation(node)?; - let new_root = self.nodes.get(&node).unwrap(); + let mut new_root = self.nodes.get(&node).unwrap().clone(); if new_root.access_tracker.len() == 0 { - unimplemented!("reached"); + println!("after rewriting================================"); + println!("{:?}", self); + return self + .build_join_from_simple_unnest(&mut new_root, simple_unnest_result); if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation // replace them with the expr store inside parent.replaces @@ -515,7 +661,7 @@ impl AlgebraIndex { &mut self, node_id: usize, ) -> Result { - let mut node = self.nodes.get(&node_id).unwrap().clone(); + let node = self.nodes.get(&node_id).unwrap().clone(); let mut all_eliminated = false; let mut result = SimpleDecorrelationResult { // new: None, @@ -528,8 +674,8 @@ impl AlgebraIndex { // most likely no, because if this is recursive, it is already non-linear anyway // and simple decorrleation will stop for col_access in node.clone().access_tracker.iter() { - println!("here"); let mut parent_node = self.nodes.get(&node_id).unwrap().clone(); + println!("{}", col_access.node_id); let mut cloned_child_node = self.nodes.get(&col_access.node_id).unwrap().clone(); let branch_all_eliminated = self.try_simple_unnest_descendent( @@ -596,7 +742,10 @@ impl AlgebraIndex { Err(_) => "".to_string(), }; let (node_color, display_str) = match lp { - LogicalPlan::Subquery(_) => ("\x1b[32m", format!("\x1b[1m{}", lp.display())), + LogicalPlan::Subquery(sq) => ( + "\x1b[32m", + format!("\x1b[1m{}{}", lp.display(), sq.subquery), + ), _ => ("\x1b[33m", lp.display().to_string()), }; @@ -691,7 +840,7 @@ impl AlgebraIndex { } fn mark_column_access(&mut self, child_id: usize, col: &Column) { - // iter from bottom to top, the goal is to mark the independen_join node + // iter from bottom to top, the goal is to mark the dependent node // the current child's access let mut stack = self.stack.clone(); stack.push(child_id); @@ -763,7 +912,8 @@ struct Operator { // This field is only set if the node is dependent join node // it track which child still accessing which column of - access_tracker: HashSet, + // the insertion order is top down + access_tracker: IndexSet, is_dependent_join_node: bool, is_subquery_node: bool, @@ -906,7 +1056,7 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { is_subquery_node, is_dependent_join_node, children: vec![], - access_tracker: HashSet::new(), + access_tracker: IndexSet::new(), }, ); From 0cd814357f1813ea242a99d6e4b82c14b87b1aa8 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 8 May 2025 12:21:53 +0200 Subject: [PATCH 09/70] chore: correct join condition --- .../optimizer/src/decorrelate_general.rs | 411 ++++++++++-------- 1 file changed, 228 insertions(+), 183 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 2800102b79ee..ba8ee189feff 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -18,6 +18,7 @@ //! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` use std::cell::RefCell; +use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; use std::fmt; use std::ops::Deref; @@ -28,6 +29,7 @@ use crate::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use crate::utils::has_all_column_refs; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use arrow::compute::kernels::cmp::eq; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -37,8 +39,8 @@ use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ - binary_expr, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, - Operator as ExprOperator, Subquery, + binary_expr, col, expr_fn, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; @@ -46,12 +48,16 @@ use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use log::Log; -pub struct AlgebraIndex { +pub struct DependentJoinTracker { root: Option, + // each logical plan traversal will assign it a integer id current_id: usize, - nodes: IndexMap, // column_ - // TODO: use a different identifier for a node, instead of the whole logical plan obj + // each newly visted operator is inserted inside this map for tracking + nodes: IndexMap, + // all the node ids from root to the current node + // this is used during traversal only stack: Vec, + // track for each column, the nodes/logical plan that reference to its within the tree accessed_columns: IndexMap>, } @@ -186,12 +192,15 @@ struct Unnesting { // TODO: looks like this function can be improved to allow more expr pull up fn can_pull_up(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: ExprOperator::Eq, - right, - }) = expr - { + if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { + match op { + ExprOperator::Eq + | ExprOperator::Gt + | ExprOperator::Lt + | ExprOperator::GtEq + | ExprOperator::LtEq => {} + _ => return false, + } match (left.deref(), right.deref()) { (Expr::Column(_), right) => !right.any_column_refs(), (left, Expr::Column(_)) => !left.any_column_refs(), @@ -219,21 +228,88 @@ struct SimpleDecorrelationResult { // because the decorrelation is top-down) pulled_up_projections: IndexSet, pulled_up_predicates: Vec, - // simple decorrelation has eliminated all dependent joins - finished: bool, } -fn expr_contains_sq(expr: &Expr, sq: &Subquery) -> bool { - expr.exists(|e| match e { - Expr::InSubquery(isq) => Ok(isq.subquery == *sq), + +fn transform_subquery_to_join_expr( + expr: &Expr, + sq: &Subquery, + replace_columns: &[Expr], +) -> Result<(bool, Option)> { + let mut transformed_expr = None; + if replace_columns.len() != 1 { + for expr in replace_columns { + println!("{}", expr) + } + return internal_err!("result of in subquery should only involve one column"); + } + let found_sq = expr.exists(|e| match e { + Expr::InSubquery(isq) => { + if replace_columns.len() != 1 { + println!("{:?}", replace_columns); + return internal_err!( + "result of in subquery should only involve one column" + ); + } + if isq.subquery == *sq { + if isq.negated { + transformed_expr = Some(binary_expr( + *isq.expr.clone(), + ExprOperator::NotEq, + replace_columns[0].clone(), + )); + return Ok(true); + } + + transformed_expr = Some(binary_expr( + *isq.expr.clone(), + ExprOperator::NotEq, + replace_columns[0].clone(), + )); + return Ok(true); + } + return Ok(false); + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + let (exist, transformed) = + transform_subquery_to_join_expr(left.as_ref(), sq, replace_columns)?; + if !exist { + let (right_exist, transformed_right) = + transform_subquery_to_join_expr(right.as_ref(), sq, replace_columns)?; + if !right_exist { + return Ok(false); + } + // TODO: exist query won't have any transformed expr, + // meaning this query is not supported `where bool_col = exists(subquery)` + transformed_expr = Some(binary_expr( + *left.clone(), + op.clone(), + transformed_right.unwrap(), + )); + return Ok(true); + } + // TODO: exist query won't have any transformed expr, + // meaning this query is not supported `where bool_col = exists(subquery)` + transformed_expr = Some(binary_expr( + transformed.unwrap(), + op.clone(), + *right.clone(), + )); + return Ok(true); + } Expr::ScalarSubquery(ssq) => { + unimplemented!( + "we need to store map between scalarsubquery and replaced_expr later on" + ); if let LogicalPlan::Subquery(inner_sq) = ssq.subquery.as_ref() { - return Ok(inner_sq.clone() == *sq); + if inner_sq.clone() == *sq { + return Ok(true); + } } - Ok(false) + return Ok(false); } _ => Ok(false), - }) - .unwrap() + })?; + return Ok((found_sq, transformed_expr)); } // impl Default for GeneralDecorrelation { @@ -243,7 +319,7 @@ fn expr_contains_sq(expr: &Expr, sq: &Subquery) -> bool { // }; // } // } -impl AlgebraIndex { +impl DependentJoinTracker { fn is_linear_operator(&self, plan: &LogicalPlan) -> bool { match plan { LogicalPlan::Limit(_) => true, @@ -284,7 +360,7 @@ impl AlgebraIndex { } } fn remove_node(&mut self, parent: &mut Operator, node: &mut Operator) { - let next_children = node.children.get(0).unwrap(); + let next_children = node.children.first().unwrap(); let next_children_node = self.nodes.swap_remove(next_children).unwrap(); // let next_children_node = self.nodes.get_mut(next_children).unwrap(); *node = next_children_node; @@ -293,25 +369,24 @@ impl AlgebraIndex { // decorrelate all descendant(recursively) with simple unnesting // returns true if all children were eliminated // TODO(impl me) - fn try_simple_unnest_descendent( + fn try_simple_decorrelate_descendent( &mut self, root_node: &mut Operator, child_node: &mut Operator, col_access: &ColumnAccess, result: &mut SimpleDecorrelationResult, - ) -> Result { + ) -> Result<()> { // unnest children first // println!("decorrelating {} from {}", child, root); if !self.is_linear_path(&root_node.id, &child_node.id) { // TODO: - return Ok(false); + return Ok(()); } // TODO: inplace update // let mut child_node = self.nodes.swap_remove(child).unwrap().clone(); // let mut root_node = self.nodes.swap_remove(root).unwrap(); - println!("child node is {}", child_node.plan); match &mut child_node.plan { LogicalPlan::Projection(proj) => { @@ -344,7 +419,7 @@ impl AlgebraIndex { // all expr of this node is pulled up, fully remove this node from the tree if proj.expr.len() == pulled_up_expr.len() { self.remove_node(root_node, child_node); - return Ok(true); + return Ok(()); } let new_proj = proj @@ -369,6 +444,7 @@ impl AlgebraIndex { .iter() .cloned() .partition(|e| e.contains_outer() && can_pull_up(e)); + // only remove the access tracker if non of the kept expr contains reference to the column // i.e some of the remaining expr still reference to the column and not pullable let removable = kept.iter().all(|e| { @@ -381,16 +457,12 @@ impl AlgebraIndex { .unwrap() }); if removable { - root_node.access_tracker.remove(col_access); - println!( - "remove {} access from node {:?}", - col_access.col, root_node.id - ); + root_node.access_tracker.swap_remove(col_access); } result.pulled_up_predicates.extend(pulled_up); if kept.is_empty() { self.remove_node(root_node, child_node); - return Ok(true); + return Ok(()); } filter.predicate = conjunction(kept).unwrap(); } @@ -412,16 +484,15 @@ impl AlgebraIndex { // ) } }; - // self.nodes.insert(*root, root_node); - // self.nodes.insert(*child, child_node); - Ok(false) + + Ok(()) } fn unnest( &mut self, node_id: usize, unnesting: &mut Unnesting, - outer_refs_from_parent: HashSet, + outer_refs_from_parent: IndexSet, ) -> Result { unimplemented!() // if unnesting.info.parent.is_some() { @@ -439,7 +510,7 @@ impl AlgebraIndex { fn right(&self, node: &Operator) -> &Operator { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first - let node_id = node.children.get(0).unwrap(); + let node_id = node.children.first().unwrap(); return self.nodes.get(node_id).unwrap(); } fn left(&self, node: &Operator) -> &Operator { @@ -473,7 +544,7 @@ impl AlgebraIndex { // replaces: IndexMap::new(), // }; - self.dependent_join_elimination(node.id, &unnesting_info, HashSet::new()) + self.dependent_join_elimination(node.id, &unnesting_info, IndexSet::new()) } fn column_accesses(&self, node_id: usize) -> Vec<&ColumnAccess> { @@ -515,32 +586,50 @@ impl AlgebraIndex { match dependent_join_node.plan { LogicalPlan::Filter(ref mut filter) => { let exprs = split_conjunction(&filter.predicate); - let mut kept_predicates: Vec = exprs - .into_iter() - .filter(|e| !expr_contains_sq(e, &subquery)) + let mut join_exprs = vec![]; + let mut kept_predicates = vec![]; + // maybe we also need to collect join columns here + let pulled_projection: Vec = ret + .pulled_up_projections + .iter() .cloned() + .map(strip_outer_reference) .collect(); + for expr in exprs.into_iter() { + // exist query may not have any transformed expr + // i.e where exists(suquery) => semi join + let (transformed, maybe_transformed_expr) = + transform_subquery_to_join_expr( + expr, + &subquery, + &pulled_projection, + )?; + if maybe_transformed_expr.is_some() { + join_exprs.push(maybe_transformed_expr.unwrap()); + } + if !transformed { + kept_predicates.push(expr.clone()) + } + } + let new_predicates = ret .pulled_up_predicates .iter() .map(|e| strip_outer_reference(e.clone())); + join_exprs.extend(new_predicates); // TODO: some predicate is join predicate, some is just filter // kept_predicates.extend(new_predicates); // filter.predicate = conjunction(kept_predicates).unwrap(); // left let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); - builder = - builder.join_on(subquery_children, JoinType::Left, new_predicates)?; - if !ret.pulled_up_projections.is_empty() { - // TODO: do we need to pull up projection? - // when most of the case they will be eliminated anyway - // builder = builder.project( - // ret.pulled_up_projections - // .iter() - // .map(|e| SelectExpr::Expression(e.clone())), - // )?; - } + builder = builder.join_on( + subquery_children, + // TODO: join type based on filter condition + JoinType::LeftSemi, + join_exprs, + )?; + if kept_predicates.len() > 0 { builder = builder.filter(conjunction(kept_predicates).unwrap())? } @@ -556,7 +645,7 @@ impl AlgebraIndex { &mut self, node: usize, unnesting: &UnnestingInfo, - outer_refs_from_parent: HashSet, + outer_refs_from_parent: IndexSet, ) -> Result { let parent = unnesting.parent.clone(); let operator = self.nodes.get(&node).unwrap(); @@ -587,7 +676,7 @@ impl AlgebraIndex { if parent.is_some() { // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) - let mut outer_ref_from_left = HashSet::new(); + let mut outer_ref_from_left = IndexSet::new(); let left = join.left.clone(); for col_from_parent in outer_refs_from_parent.iter() { if left @@ -619,7 +708,7 @@ impl AlgebraIndex { }, replaces: IndexMap::new(), }; - let mut accesses: HashSet = self + let mut accesses: IndexSet = self .column_accesses(node) .iter() .map(|a| a.col.clone()) @@ -656,40 +745,45 @@ impl AlgebraIndex { // }) // .expect("traversal is infallible"); } + fn get_node_uncheck(&self, node_id: &usize) -> Operator { + self.nodes.get(node_id).unwrap().clone() + } fn simple_decorrelation( &mut self, node_id: usize, ) -> Result { - let node = self.nodes.get(&node_id).unwrap().clone(); + let node = self.get_node_uncheck(&node_id); let mut all_eliminated = false; let mut result = SimpleDecorrelationResult { // new: None, pulled_up_projections: IndexSet::new(), pulled_up_predicates: vec![], - finished: false, }; - // only iter with direct child - // TODO: confirm if this needs to happen also with descendant - // most likely no, because if this is recursive, it is already non-linear anyway - // and simple decorrleation will stop - for col_access in node.clone().access_tracker.iter() { - let mut parent_node = self.nodes.get(&node_id).unwrap().clone(); - println!("{}", col_access.node_id); - let mut cloned_child_node = - self.nodes.get(&col_access.node_id).unwrap().clone(); - let branch_all_eliminated = self.try_simple_unnest_descendent( + + let accesses_bottom_up = node.access_tracker.clone().sorted_by(|a, b| { + if a.node_id < b.node_id { + Ordering::Greater + } else { + Ordering::Less + } + }); + + for col_access in accesses_bottom_up { + // create two copy because of + let mut parent_node = self.get_node_uncheck(&node_id); + let mut descendent = self.get_node_uncheck(&col_access.node_id); + self.try_simple_decorrelate_descendent( &mut parent_node, - &mut cloned_child_node, - col_access, + &mut descendent, + &col_access, &mut result, )?; + // TODO: find a nicer way to do in-place update self.nodes.insert(node_id, parent_node.clone()); - self.nodes.insert(col_access.node_id, cloned_child_node); - all_eliminated = all_eliminated || branch_all_eliminated; + self.nodes.insert(col_access.node_id, descendent); } - result.finished = all_eliminated; Ok(result) } fn build(&mut self, root: &LogicalPlan) -> Result<()> { @@ -698,7 +792,7 @@ impl AlgebraIndex { Ok(()) } } -impl fmt::Debug for AlgebraIndex { +impl fmt::Debug for DependentJoinTracker { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "GeneralDecorrelation Tree:")?; if let Some(root_op) = &self.root { @@ -710,7 +804,7 @@ impl fmt::Debug for AlgebraIndex { } } -impl AlgebraIndex { +impl DependentJoinTracker { fn fmt_operator( &self, f: &mut fmt::Formatter<'_>, @@ -770,12 +864,6 @@ impl AlgebraIndex { } } - let accessing_string = op - .potential_accesses - .iter() - .map(|c| c.debug()) - .collect::>() - .join(", "); let accessed_by_string = op .access_tracker .iter() @@ -783,11 +871,7 @@ impl AlgebraIndex { .collect::>() .join(", "); // Now print the Operator details - writeln!( - f, - "acccessing: {}, accessed_by: {}", - accessing_string, accessed_by_string, - )?; + writeln!(f, "accessed_by: {}", accessed_by_string,)?; let len = op.children.len(); // Recursively print children if Operator has children @@ -822,7 +906,7 @@ impl AlgebraIndex { // because the column providers are visited after column-accessor // function visit_with_subqueries always visit the subquery before visiting the other child // we can always infer the LCA inside this function, by getting the deepest common parent - fn conclude_lca_for_column(&mut self, child_id: usize, col: &Column) { + fn conclude_lowest_dependent_join_node(&mut self, child_id: usize, col: &Column) { if let Some(accesses) = self.accessed_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); @@ -830,6 +914,7 @@ impl AlgebraIndex { // this is a dependen join node let lca_node = Self::lca_from_stack(&cur_stack, &access.stack); let node = self.nodes.get_mut(&lca_node).unwrap(); + println!("inserting {}", access.node_id); node.access_tracker.insert(ColumnAccess { col: col.clone(), node_id: access.node_id, @@ -864,9 +949,9 @@ impl AlgebraIndex { } } -impl Default for AlgebraIndex { +impl Default for DependentJoinTracker { fn default() -> Self { - return AlgebraIndex { + return DependentJoinTracker { root: None, current_id: 0, nodes: IndexMap::new(), @@ -899,16 +984,6 @@ struct Operator { id: usize, plan: LogicalPlan, parent: Option, - // Note if the current node is a Subquery - // at the first time this node is visited, - // the set of accesses columns are not sufficient - // (i.e) some where deep down the ast another recursive subquery - // exists and also referencing some columns belongs to the outer part - // of the subquery - // Thus, on discovery of new subquery, we must - // add the accesses columns to the ancestor nodes which are Subquery - potential_accesses: HashSet, - provides: HashSet, // This field is only set if the node is dependent join node // it track which child still accessing which column of @@ -947,7 +1022,7 @@ fn print(a: &Expr) -> Result<()> { Ok(()) } -impl TreeNodeVisitor<'_> for AlgebraIndex { +impl TreeNodeVisitor<'_> for DependentJoinTracker { type Node = LogicalPlan; fn f_down(&mut self, node: &LogicalPlan) -> Result { self.current_id += 1; @@ -958,83 +1033,45 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { let mut is_dependent_join_node = false; // for each node, find which column it is accessing, which column it is providing // Set of columns current node access - let (accesses, provides): (HashSet, HashSet) = - match node { - LogicalPlan::Filter(f) => { - if contains_subquery(&f.predicate) { - is_dependent_join_node = true; - } - let outer_col_refs: HashSet = f - .predicate - .outer_column_refs() - .into_iter() - .map(|f| { - self.mark_column_access(self.current_id, f); - ColumnUsage::Outer(f.clone()) - }) - .collect(); - - (outer_col_refs, HashSet::new()) - } - LogicalPlan::TableScan(tbl_scan) => { - let provided_columns: HashSet = tbl_scan - .projected_schema - .columns() - .into_iter() - .map(|col| { - self.conclude_lca_for_column(self.current_id, &col); - ColumnUsage::Own(col) - }) - .collect(); - (HashSet::new(), provided_columns) + match node { + LogicalPlan::Filter(f) => { + if contains_subquery(&f.predicate) { + is_dependent_join_node = true; } - LogicalPlan::Aggregate(_) => (HashSet::new(), HashSet::new()), - LogicalPlan::EmptyRelation(_) => (HashSet::new(), HashSet::new()), - LogicalPlan::Limit(_) => (HashSet::new(), HashSet::new()), - // TODO - // 1.handle subquery inside projection - // 2.projection also provide some new columns - // 3.if within projection exists multiple subquery, how does this work - LogicalPlan::Projection(proj) => { - let mut outer_cols = HashSet::new(); - for expr in &proj.expr { - if contains_subquery(expr) { - is_dependent_join_node = true; - break; - } - expr.add_outer_column_refs(&mut outer_cols); + f.predicate.outer_column_refs().into_iter().for_each(|f| { + self.mark_column_access(self.current_id, f); + }); + } + LogicalPlan::TableScan(tbl_scan) => { + tbl_scan.projected_schema.columns().iter().for_each(|col| { + self.conclude_lowest_dependent_join_node(self.current_id, &col); + }); + } + // TODO + // 1.handle subquery inside projection + // 2.projection also provide some new columns + // 3.if within projection exists multiple subquery, how does this work + LogicalPlan::Projection(proj) => { + let mut outer_cols = HashSet::new(); + for expr in &proj.expr { + if contains_subquery(expr) { + is_dependent_join_node = true; + break; } - ( - outer_cols - .into_iter() - .map(|c| { - self.mark_column_access(self.current_id, c); - ColumnUsage::Outer(c.clone()) - }) - .collect(), - HashSet::new(), - ) - } - LogicalPlan::Subquery(subquery) => { - is_subquery_node = true; - // TODO: once we detect the subquery - let accessed = subquery - .outer_ref_columns - .iter() - .filter_map(|f| match f { - Expr::Column(col) => Some(ColumnUsage::Outer(col.clone())), - Expr::OuterReferenceColumn(_, col) => { - Some(ColumnUsage::Outer(col.clone())) - } - _ => None, - }) - .collect(); - (accessed, HashSet::new()) - } - _ => { - return internal_err!("impl scan for node type {:?}", node); + expr.add_outer_column_refs(&mut outer_cols); } - }; + outer_cols.into_iter().for_each(|c| { + self.mark_column_access(self.current_id, c); + }); + } + LogicalPlan::Subquery(subquery) => { + is_subquery_node = true; + // TODO: once we detect the subquery + } + _ => { + return internal_err!("impl scan for node type {:?}", node); + } + }; let parent = if self.stack.is_empty() { None @@ -1051,8 +1088,6 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { id: self.current_id, parent, plan: node.clone(), - potential_accesses: accesses, - provides, is_subquery_node, is_dependent_join_node, children: vec![], @@ -1071,7 +1106,7 @@ impl TreeNodeVisitor<'_> for AlgebraIndex { } } -impl OptimizerRule for AlgebraIndex { +impl OptimizerRule for DependentJoinTracker { fn supports_rewrite(&self) -> bool { true } @@ -1107,7 +1142,7 @@ mod tests { use crate::test::{test_table_scan, test_table_scan_with_name}; - use super::AlgebraIndex; + use super::DependentJoinTracker; use arrow::{ array::{Int32Array, StringArray}, datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, @@ -1123,9 +1158,19 @@ mod tests { LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), )? - .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b")])? + .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .alias("outer_b_alias")])? .build()?, ); @@ -1136,7 +1181,7 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - let mut index = AlgebraIndex::default(); + let mut index = DependentJoinTracker::default(); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; println!("{}", new_plan); @@ -1193,7 +1238,7 @@ mod tests { .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), )? .build()?; - let mut index = AlgebraIndex::default(); + let mut index = DependentJoinTracker::default(); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; println!("{}", new_plan); @@ -1246,7 +1291,7 @@ mod tests { .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), )? .build()?; - let mut index = AlgebraIndex::default(); + let mut index = DependentJoinTracker::default(); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; println!("{}", new_plan); From cc3e01cab5a19a7fea32a16d8666120c24846107 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 8 May 2025 12:48:56 +0200 Subject: [PATCH 10/70] chore: handle exist query --- .../optimizer/src/decorrelate_general.rs | 75 +++++++++++++------ 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index ba8ee189feff..9b687d747f7c 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -35,6 +35,7 @@ use datafusion_common::tree_node::{ TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, not_impl_err, Column, Result}; +use datafusion_expr::expr::Exists; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{conjunction, split_conjunction}; @@ -234,8 +235,11 @@ fn transform_subquery_to_join_expr( expr: &Expr, sq: &Subquery, replace_columns: &[Expr], -) -> Result<(bool, Option)> { - let mut transformed_expr = None; +) -> Result<(bool, Option, Option)> { + let mut post_join_predicate = None; + + // this is used for exist query + let mut join_predicate = None; if replace_columns.len() != 1 { for expr in replace_columns { println!("{}", expr) @@ -252,7 +256,7 @@ fn transform_subquery_to_join_expr( } if isq.subquery == *sq { if isq.negated { - transformed_expr = Some(binary_expr( + join_predicate = Some(binary_expr( *isq.expr.clone(), ExprOperator::NotEq, replace_columns[0].clone(), @@ -260,7 +264,7 @@ fn transform_subquery_to_join_expr( return Ok(true); } - transformed_expr = Some(binary_expr( + join_predicate = Some(binary_expr( *isq.expr.clone(), ExprOperator::NotEq, replace_columns[0].clone(), @@ -270,32 +274,53 @@ fn transform_subquery_to_join_expr( return Ok(false); } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (exist, transformed) = + let (exist, transformed, post_join_expr_from_left) = transform_subquery_to_join_expr(left.as_ref(), sq, replace_columns)?; if !exist { - let (right_exist, transformed_right) = + let (right_exist, transformed_right, post_join_expr_from_right) = transform_subquery_to_join_expr(right.as_ref(), sq, replace_columns)?; if !right_exist { return Ok(false); } - // TODO: exist query won't have any transformed expr, - // meaning this query is not supported `where bool_col = exists(subquery)` - transformed_expr = Some(binary_expr( - *left.clone(), - op.clone(), - transformed_right.unwrap(), - )); + if let Some(transformed_right) = transformed_right { + join_predicate = + Some(binary_expr(*left.clone(), op.clone(), transformed_right)); + } + if let Some(transformed_right) = post_join_expr_from_right { + post_join_predicate = + Some(binary_expr(*left.clone(), op.clone(), transformed_right)); + } + return Ok(true); } // TODO: exist query won't have any transformed expr, // meaning this query is not supported `where bool_col = exists(subquery)` - transformed_expr = Some(binary_expr( - transformed.unwrap(), - op.clone(), - *right.clone(), - )); + + if let Some(transformed) = transformed { + join_predicate = + Some(binary_expr(transformed, op.clone(), *right.clone())); + } + if let Some(transformed) = post_join_expr_from_left { + post_join_predicate = + Some(binary_expr(transformed, op.clone(), *right.clone())); + } return Ok(true); } + Expr::Exists(Exists { subquery, negated }) => { + if let LogicalPlan::Subquery(inner_sq) = subquery.subquery.as_ref() { + if inner_sq.clone() == *sq { + let op = if *negated { + ExprOperator::NotEq + } else { + ExprOperator::Eq + }; + join_predicate = + Some(binary_expr(col("mark"), op, replace_columns[0].clone())); + return Ok(true); + } + } + internal_err!("subquery field of Exists is not a subquery") + } Expr::ScalarSubquery(ssq) => { unimplemented!( "we need to store map between scalarsubquery and replaced_expr later on" @@ -309,7 +334,7 @@ fn transform_subquery_to_join_expr( } _ => Ok(false), })?; - return Ok((found_sq, transformed_expr)); + return Ok((found_sq, join_predicate, post_join_predicate)); } // impl Default for GeneralDecorrelation { @@ -598,14 +623,18 @@ impl DependentJoinTracker { for expr in exprs.into_iter() { // exist query may not have any transformed expr // i.e where exists(suquery) => semi join - let (transformed, maybe_transformed_expr) = + let (transformed, maybe_transformed_expr, maybe_post_join_expr) = transform_subquery_to_join_expr( expr, &subquery, &pulled_projection, )?; - if maybe_transformed_expr.is_some() { - join_exprs.push(maybe_transformed_expr.unwrap()); + + if let Some(transformed) = maybe_transformed_expr { + join_exprs.push(transformed) + } + if let Some(post_join_expr) = maybe_post_join_expr { + kept_predicates.push(post_join_expr) } if !transformed { kept_predicates.push(expr.clone()) @@ -626,7 +655,7 @@ impl DependentJoinTracker { builder = builder.join_on( subquery_children, // TODO: join type based on filter condition - JoinType::LeftSemi, + JoinType::LeftMark, join_exprs, )?; From 9b5daa2fe23f35700953db6ece8468613d302640 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 10 May 2025 10:15:21 +0200 Subject: [PATCH 11/70] test: in sq test --- .../optimizer/src/decorrelate_general.rs | 118 +++++++++++++----- 1 file changed, 85 insertions(+), 33 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 9b687d747f7c..a8d52df533e3 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -40,7 +40,7 @@ use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ - binary_expr, col, expr_fn, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, + binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_sql::unparser::Unparser; @@ -231,7 +231,7 @@ struct SimpleDecorrelationResult { pulled_up_predicates: Vec, } -fn transform_subquery_to_join_expr( +fn try_transform_subquery_to_join_expr( expr: &Expr, sq: &Subquery, replace_columns: &[Expr], @@ -259,15 +259,15 @@ fn transform_subquery_to_join_expr( join_predicate = Some(binary_expr( *isq.expr.clone(), ExprOperator::NotEq, - replace_columns[0].clone(), + strip_outer_reference(replace_columns[0].clone()), )); return Ok(true); } join_predicate = Some(binary_expr( *isq.expr.clone(), - ExprOperator::NotEq, - replace_columns[0].clone(), + ExprOperator::Eq, + strip_outer_reference(replace_columns[0].clone()), )); return Ok(true); } @@ -275,20 +275,24 @@ fn transform_subquery_to_join_expr( } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (exist, transformed, post_join_expr_from_left) = - transform_subquery_to_join_expr(left.as_ref(), sq, replace_columns)?; + try_transform_subquery_to_join_expr(left.as_ref(), sq, replace_columns)?; if !exist { let (right_exist, transformed_right, post_join_expr_from_right) = - transform_subquery_to_join_expr(right.as_ref(), sq, replace_columns)?; + try_transform_subquery_to_join_expr( + right.as_ref(), + sq, + replace_columns, + )?; if !right_exist { return Ok(false); } if let Some(transformed_right) = transformed_right { join_predicate = - Some(binary_expr(*left.clone(), op.clone(), transformed_right)); + Some(binary_expr(*left.clone(), *op, transformed_right)); } if let Some(transformed_right) = post_join_expr_from_right { post_join_predicate = - Some(binary_expr(*left.clone(), op.clone(), transformed_right)); + Some(binary_expr(*left.clone(), *op, transformed_right)); } return Ok(true); @@ -297,25 +301,22 @@ fn transform_subquery_to_join_expr( // meaning this query is not supported `where bool_col = exists(subquery)` if let Some(transformed) = transformed { - join_predicate = - Some(binary_expr(transformed, op.clone(), *right.clone())); + join_predicate = Some(binary_expr(transformed, *op, *right.clone())); } if let Some(transformed) = post_join_expr_from_left { - post_join_predicate = - Some(binary_expr(transformed, op.clone(), *right.clone())); + post_join_predicate = Some(binary_expr(transformed, *op, *right.clone())); } return Ok(true); } Expr::Exists(Exists { subquery, negated }) => { if let LogicalPlan::Subquery(inner_sq) = subquery.subquery.as_ref() { if inner_sq.clone() == *sq { - let op = if *negated { - ExprOperator::NotEq + let mark_predicate = if *negated { + expr_fn::not(col("mark")) } else { - ExprOperator::Eq + col("mark") }; - join_predicate = - Some(binary_expr(col("mark"), op, replace_columns[0].clone())); + join_predicate = Some(mark_predicate); return Ok(true); } } @@ -620,20 +621,42 @@ impl DependentJoinTracker { .cloned() .map(strip_outer_reference) .collect(); + let right_exprs: Vec = if ret.pulled_up_projections.is_empty() { + subquery_children.expressions() + } else { + ret.pulled_up_projections + .iter() + .cloned() + .map(strip_outer_reference) + .collect() + }; + let mut join_type = JoinType::LeftSemi; for expr in exprs.into_iter() { // exist query may not have any transformed expr // i.e where exists(suquery) => semi join let (transformed, maybe_transformed_expr, maybe_post_join_expr) = - transform_subquery_to_join_expr( + try_transform_subquery_to_join_expr( expr, &subquery, - &pulled_projection, + &right_exprs, )?; if let Some(transformed) = maybe_transformed_expr { join_exprs.push(transformed) } if let Some(post_join_expr) = maybe_post_join_expr { + if post_join_expr + .exists(|e| { + if let Expr::Column(col) = e { + return Ok(col.name == "mark"); + } + return Ok(false); + }) + .unwrap() + { + // only use mark join if required + join_type = JoinType::LeftMark + } kept_predicates.push(post_join_expr) } if !transformed { @@ -655,7 +678,7 @@ impl DependentJoinTracker { builder = builder.join_on( subquery_children, // TODO: join type based on filter condition - JoinType::LeftMark, + join_type, join_exprs, )?; @@ -1162,7 +1185,7 @@ mod tests { use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ - expr_fn::{self, col}, + expr_fn::{self, col, not}, in_subquery, lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -1176,9 +1199,41 @@ mod tests { array::{Int32Array, StringArray}, datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; + #[test] + fn simple_decorrelate_with_in_subquery_no_dependent_column() -> Result<()> { + // let mut framework = GeneralDecorrelation::default(); + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + let mut index = DependentJoinTracker::default(); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + let expected = "\ + Filter: outer_table.a > Int32(1)\ + \n LeftSemi Join: Filter: outer_table.c = inner_table_lv1.b\ + \n TableScan: outer_table\ + \n Projection: inner_table_lv1.b\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); + Ok(()) + } #[test] - fn play_unnest_simple_projection_pull_up() -> Result<()> { + fn simple_decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { // let mut framework = GeneralDecorrelation::default(); let outer_table = test_table_scan_with_name("outer_table")?; @@ -1213,16 +1268,13 @@ mod tests { let mut index = DependentJoinTracker::default(); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; - println!("{}", new_plan); - - // let input2 = LogicalPlanBuilder::from(input.clone()) - // .filter(col("int_col").gt(lit(1)))? - // .project(vec![col("string_col")])? - // .build()?; - - // let mut b = GeneralDecorrelation::default(); - // b.build_algebra_index(input2)?; - + let expected = "\ + Filter: outer_table.a > Int32(1)\ + \n LeftSemi Join: Filter: outer_table.c != outer_table.b AS outer_b_alias AND inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + \n TableScan: outer_table\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] From f26baf8fc723f24db6efa86be6341de1b7cd0b10 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 10 May 2025 11:01:41 +0200 Subject: [PATCH 12/70] test: exist with no dependent column --- .../optimizer/src/decorrelate_general.rs | 88 ++++++++++++------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index a8d52df533e3..331a23794705 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -240,18 +240,12 @@ fn try_transform_subquery_to_join_expr( // this is used for exist query let mut join_predicate = None; - if replace_columns.len() != 1 { - for expr in replace_columns { - println!("{}", expr) - } - return internal_err!("result of in subquery should only involve one column"); - } + let found_sq = expr.exists(|e| match e { Expr::InSubquery(isq) => { if replace_columns.len() != 1 { - println!("{:?}", replace_columns); return internal_err!( - "result of in subquery should only involve one column" + "result of IN subquery should only involve one column" ); } if isq.subquery == *sq { @@ -308,19 +302,21 @@ fn try_transform_subquery_to_join_expr( } return Ok(true); } - Expr::Exists(Exists { subquery, negated }) => { - if let LogicalPlan::Subquery(inner_sq) = subquery.subquery.as_ref() { - if inner_sq.clone() == *sq { - let mark_predicate = if *negated { - expr_fn::not(col("mark")) - } else { - col("mark") - }; - join_predicate = Some(mark_predicate); - return Ok(true); - } + Expr::Exists(Exists { + subquery: inner_sq, + negated, + .. + }) => { + if inner_sq.clone() == *sq { + let mark_predicate = if *negated { + expr_fn::not(col("mark")) + } else { + col("mark") + }; + post_join_predicate = Some(mark_predicate); + return Ok(true); } - internal_err!("subquery field of Exists is not a subquery") + return Ok(false); } Expr::ScalarSubquery(ssq) => { unimplemented!( @@ -675,12 +671,16 @@ impl DependentJoinTracker { // left let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); - builder = builder.join_on( - subquery_children, - // TODO: join type based on filter condition - join_type, - join_exprs, - )?; + builder = if join_exprs.is_empty() { + builder.join_on(subquery_children, join_type, vec![lit(true)])? + } else { + builder.join_on( + subquery_children, + // TODO: join type based on filter condition + join_type, + join_exprs, + )? + }; if kept_predicates.len() > 0 { builder = builder.filter(conjunction(kept_predicates).unwrap())? @@ -710,8 +710,6 @@ impl DependentJoinTracker { let simple_unnest_result = self.simple_decorrelation(node)?; let mut new_root = self.nodes.get(&node).unwrap().clone(); if new_root.access_tracker.len() == 0 { - println!("after rewriting================================"); - println!("{:?}", self); return self .build_join_from_simple_unnest(&mut new_root, simple_unnest_result); if parent.is_some() { @@ -723,8 +721,6 @@ impl DependentJoinTracker { unimplemented!() // return Ok(dependent_join); } - println!("after rewriting================================"); - println!("{:?}", self); if parent.is_some() { // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) @@ -1185,6 +1181,7 @@ mod tests { use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ + exists, expr_fn::{self, col, not}, in_subquery, lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, @@ -1200,9 +1197,34 @@ mod tests { datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; #[test] - fn simple_decorrelate_with_in_subquery_no_dependent_column() -> Result<()> { - // let mut framework = GeneralDecorrelation::default(); + fn simple_decorrelate_with_exist_subquery_no_dependent_column() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? + .build()?, + ); + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + let mut index = DependentJoinTracker::default(); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + let expected = "\ + Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ + \n LeftMark Join: Filter: Boolean(true)\ + \n TableScan: outer_table\ + \n Projection: inner_table_lv1.b, inner_table_lv1.a\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); + Ok(()) + } + #[test] + fn simple_decorrelate_with_in_subquery_no_dependent_column() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1234,8 +1256,6 @@ mod tests { } #[test] fn simple_decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { - // let mut framework = GeneralDecorrelation::default(); - let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( From 37852c1d556eb5893eebe007b2c6ba9c3118e03d Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 10 May 2025 11:46:30 +0200 Subject: [PATCH 13/70] test: exist with dependent columns --- .../optimizer/src/decorrelate_general.rs | 222 ++++-------------- 1 file changed, 49 insertions(+), 173 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 331a23794705..81137ca04de1 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -30,6 +30,7 @@ use crate::utils::has_all_column_refs; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; use arrow::compute::kernels::cmp::eq; +use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -60,6 +61,7 @@ pub struct DependentJoinTracker { stack: Vec, // track for each column, the nodes/logical plan that reference to its within the tree accessed_columns: IndexMap>, + alias_generator: Arc, } #[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] @@ -308,11 +310,7 @@ fn try_transform_subquery_to_join_expr( .. }) => { if inner_sq.clone() == *sq { - let mark_predicate = if *negated { - expr_fn::not(col("mark")) - } else { - col("mark") - }; + let mark_predicate = if *negated { !col("mark") } else { col("mark") }; post_join_predicate = Some(mark_predicate); return Ok(true); } @@ -997,10 +995,11 @@ impl DependentJoinTracker { } } -impl Default for DependentJoinTracker { - fn default() -> Self { +impl DependentJoinTracker { + fn new(alias_generator: Arc) -> Self { return DependentJoinTracker { root: None, + alias_generator, current_id: 0, nodes: IndexMap::new(), stack: vec![], @@ -1179,7 +1178,7 @@ impl OptimizerRule for DependentJoinTracker { mod tests { use std::sync::Arc; - use datafusion_common::{DFSchema, Result}; + use datafusion_common::{alias::AliasGenerator, DFSchema, Result}; use datafusion_expr::{ exists, expr_fn::{self, col, not}, @@ -1197,6 +1196,45 @@ mod tests { datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; #[test] + fn simple_decorrelate_with_exist_subquery_with_dependent_columns() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .alias("outer_b_alias")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + let expected = "\ + Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ + \n LeftMark Join: Filter: inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + \n TableScan: outer_table\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); + Ok(()) + } + #[test] fn simple_decorrelate_with_exist_subquery_no_dependent_column() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1210,7 +1248,7 @@ mod tests { let input1 = LogicalPlanBuilder::from(outer_table.clone()) .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; - let mut index = DependentJoinTracker::default(); + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ @@ -1241,7 +1279,7 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - let mut index = DependentJoinTracker::default(); + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ @@ -1285,7 +1323,7 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - let mut index = DependentJoinTracker::default(); + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); index.build(&input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ @@ -1297,166 +1335,4 @@ mod tests { assert_eq!(expected, format!("{new_plan}")); Ok(()) } - #[test] - fn play_unnest_simple_predicate_pull_up() -> Result<()> { - // let mut framework = GeneralDecorrelation::default(); - - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - // let sq_level2 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv2) - // .filter( - // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") - // .eq(col("inner_table_lv2.b")) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.c") - // .eq(col("inner_table_lv2.c")), - // ), - // )? - // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - // .build()?, - // ); - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") - .eq(lit(1)), - ), - )? - .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? - .project(vec![sum(col("inner_table_lv1.b"))])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), - )? - .build()?; - let mut index = DependentJoinTracker::default(); - index.build(&input1)?; - let new_plan = index.root_dependent_join_elimination()?; - println!("{}", new_plan); - - // let input2 = LogicalPlanBuilder::from(input.clone()) - // .filter(col("int_col").gt(lit(1)))? - // .project(vec![col("string_col")])? - // .build()?; - - // let mut b = GeneralDecorrelation::default(); - // b.build_algebra_index(input2)?; - - Ok(()) - } - #[test] - fn play_unnest() -> Result<()> { - // let mut framework = GeneralDecorrelation::default(); - - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - // let sq_level2 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv2) - // .filter( - // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") - // .eq(col("inner_table_lv2.b")) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.c") - // .eq(col("inner_table_lv2.c")), - // ), - // )? - // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - // .build()?, - // ); - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")), - )? - .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? - .project(vec![sum(col("inner_table_lv1.b"))])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), - )? - .build()?; - let mut index = DependentJoinTracker::default(); - index.build(&input1)?; - let new_plan = index.root_dependent_join_elimination()?; - println!("{}", new_plan); - - // let input2 = LogicalPlanBuilder::from(input.clone()) - // .filter(col("int_col").gt(lit(1)))? - // .project(vec![col("string_col")])? - // .build()?; - - // let mut b = GeneralDecorrelation::default(); - // b.build_algebra_index(input2)?; - - Ok(()) - } - - // #[test] - // fn todo() -> Result<()> { - // let mut framework = GeneralDecorrelation::default(); - - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - // let sq_level2 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv2) - // .filter( - // out_ref_col(ArrowDataType::UInt32, "inner_table_lv1.b") - // .eq(col("inner_table_lv2.b")) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.c") - // .eq(col("inner_table_lv2.c")), - // ), - // )? - // .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - // .build()?, - // ); - // let sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter( - // col("inner_table_lv1.a") - // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - // .and(scalar_subquery(sq_level2).gt(lit(5))), - // )? - // .aggregate(Vec::::new(), vec![sum(col("inner_table_lv1.b"))])? - // .project(vec![sum(col("inner_table_lv1.b"))])? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter( - // col("outer_table.a") - // .gt(lit(1)) - // .and(col("outer_table.b").gt(scalar_subquery(sq_level1))), - // )? - // .build()?; - // framework.build(&input1)?; - - // // let input2 = LogicalPlanBuilder::from(input.clone()) - // // .filter(col("int_col").gt(lit(1)))? - // // .project(vec![col("string_col")])? - // // .build()?; - - // // let mut b = GeneralDecorrelation::default(); - // // b.build_algebra_index(input2)?; - - // Ok(()) - // } } From e984a55b2f711e7b5974b6eaa13a953e3239a056 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 11 May 2025 13:23:57 +0200 Subject: [PATCH 14/70] chore: remove redundant clone --- .../optimizer/src/decorrelate_general.rs | 95 ++++++++++++++----- 1 file changed, 69 insertions(+), 26 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 81137ca04de1..845b19df5518 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -41,8 +41,8 @@ use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{conjunction, split_conjunction}; use datafusion_expr::{ - binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator as ExprOperator, Subquery, + binary_expr, col, expr_fn, lit, Aggregate, BinaryExpr, Cast, Expr, JoinType, + LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_sql::unparser::Unparser; use indexmap::map::Entry; @@ -350,8 +350,12 @@ impl DependentJoinTracker { _ => false, } } - fn is_linear_path(&self, parent: &usize, child: &usize) -> bool { - let mut current_node = *child; + fn is_linear_path(&self, parent: &Operator, child: &Operator) -> bool { + if !self.is_linear_operator(&child.plan) { + return false; + } + + let mut current_node = child.parent.unwrap(); loop { let child_node = self.nodes.get(¤t_node).unwrap(); @@ -361,16 +365,13 @@ impl DependentJoinTracker { unimplemented!("traversing from descedent to top does not meet expected root") } Some(new_parent) => { - if new_parent == *parent { + if new_parent == parent.id { return true; } return false; } } } - if current_node == *parent { - return true; - } match child_node.parent { None => return true, Some(new_parent) => { @@ -399,7 +400,7 @@ impl DependentJoinTracker { // unnest children first // println!("decorrelating {} from {}", child, root); - if !self.is_linear_path(&root_node.id, &child_node.id) { + if !self.is_linear_path(root_node, child_node) { // TODO: return Ok(()); } @@ -698,27 +699,27 @@ impl DependentJoinTracker { outer_refs_from_parent: IndexSet, ) -> Result { let parent = unnesting.parent.clone(); - let operator = self.nodes.get(&node).unwrap(); - let plan = &operator.plan; - let mut join = self.new_dependent_join(operator); + let mut root_node = self.nodes.swap_remove(&node).unwrap(); + // let plan = &root_node.plan; // we have to do the reversed iter, because we know the subquery (right side of // the dependent join) is always the first child of the node, and we want to visit // the left side first - let simple_unnest_result = self.simple_decorrelation(node)?; - let mut new_root = self.nodes.get(&node).unwrap().clone(); - if new_root.access_tracker.len() == 0 { - return self - .build_join_from_simple_unnest(&mut new_root, simple_unnest_result); + let simple_unnest_result = self.simple_decorrelation(&mut root_node)?; + if root_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation // replace them with the expr store inside parent.replaces - unimplemented!(""); + unimplemented!("simple dependent join not implemented for the case of recursive subquery"); return self.unnest(node, &mut parent.unwrap(), outer_refs_from_parent); } + return self + .build_join_from_simple_unnest(&mut root_node, simple_unnest_result); unimplemented!() // return Ok(dependent_join); } + + let mut join = self.new_dependent_join(&root_node); if parent.is_some() { // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) @@ -797,10 +798,8 @@ impl DependentJoinTracker { fn simple_decorrelation( &mut self, - node_id: usize, + node: &mut Operator, ) -> Result { - let node = self.get_node_uncheck(&node_id); - let mut all_eliminated = false; let mut result = SimpleDecorrelationResult { // new: None, pulled_up_projections: IndexSet::new(), @@ -817,16 +816,16 @@ impl DependentJoinTracker { for col_access in accesses_bottom_up { // create two copy because of - let mut parent_node = self.get_node_uncheck(&node_id); - let mut descendent = self.get_node_uncheck(&col_access.node_id); + // let mut descendent = self.get_node_uncheck(&col_access.node_id); + let mut descendent = self.nodes.swap_remove(&col_access.node_id).unwrap(); self.try_simple_decorrelate_descendent( - &mut parent_node, + node, &mut descendent, &col_access, &mut result, )?; // TODO: find a nicer way to do in-place update - self.nodes.insert(node_id, parent_node.clone()); + // self.nodes.insert(node_id, parent_node.clone()); self.nodes.insert(col_access.node_id, descendent); } @@ -1115,6 +1114,7 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { is_subquery_node = true; // TODO: once we detect the subquery } + LogicalPlan::Aggregate(_) => {} _ => { return internal_err!("impl scan for node type {:?}", node); } @@ -1196,6 +1196,49 @@ mod tests { datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; #[test] + fn complex_1_level_decorrelate_in_subquery_with_count() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + index.build(&input1)?; + let new_plan = index.root_dependent_join_elimination()?; + let expected = "\ + Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ + \n LeftMark Join: Filter: inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + \n TableScan: outer_table\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); + Ok(()) + } + #[test] fn simple_decorrelate_with_exist_subquery_with_dependent_columns() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1328,7 +1371,7 @@ mod tests { let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c != outer_table.b AS outer_b_alias AND inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + \n LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ \n TableScan: outer_table\ \n Filter: inner_table_lv1.b = Int32(1)\ \n TableScan: inner_table_lv1"; From 94aba08cbcfc255896d284811c30162cecd74d60 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 13 May 2025 20:28:21 +0200 Subject: [PATCH 15/70] feat: dummy implementation for aggregation --- .../optimizer/src/decorrelate_general.rs | 450 +++++++++++++----- 1 file changed, 324 insertions(+), 126 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 845b19df5518..f74e989cd7be 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -25,26 +25,34 @@ use std::ops::Deref; use std::rc::{Rc, Weak}; use std::sync::Arc; +use crate::decorrelate::UN_MATCHED_ROW_INDICATOR; use crate::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use crate::utils::has_all_column_refs; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; use arrow::compute::kernels::cmp::eq; +use arrow::datatypes::Schema; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; -use datafusion_common::{internal_err, not_impl_err, Column, Result}; +use datafusion_common::{ + internal_err, not_impl_err, Column, DFSchemaRef, HashMap, Result, +}; use datafusion_expr::expr::Exists; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; -use datafusion_expr::utils::{conjunction, split_conjunction}; +use datafusion_expr::utils::{ + conjunction, disjunction, split_conjunction, split_disjunction, +}; use datafusion_expr::{ - binary_expr, col, expr_fn, lit, Aggregate, BinaryExpr, Cast, Expr, JoinType, - LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, + binary_expr, col, expr_fn, lit, table_scan, Aggregate, BinaryExpr, Cast, Expr, + JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; +use datafusion_functions_aggregate::count; use datafusion_sql::unparser::Unparser; +use datafusion_sql::TableReference; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -144,53 +152,48 @@ struct DependentJoin { join_conditions: Vec, // join_type: } -impl DependentJoin { - fn replace_right( - &mut self, - plan: LogicalPlan, - unnesting: &UnnestingInfo, - replacements: &IndexMap, - ) { - self.right.plan = plan; - for col in unnesting.outer_refs.iter() { - let replacement = replacements.get(col).unwrap(); - self.join_conditions.push(binary_expr( - Expr::Column(col.clone()), - ExprOperator::IsNotDistinctFrom, - Expr::Column(replacement.clone()), - )); - } - } - fn replace_left( - &mut self, - plan: LogicalPlan, - column_replacements: &IndexMap, - ) { - self.left.plan = plan - // TODO: - // - update join condition - // - check if the relation with children should be removed - } -} +impl DependentJoin {} #[derive(Clone)] struct UnnestingInfo { - join: DependentJoin, - outer_refs: Vec, - domain: Vec, + // join: DependentJoin, + domain: LogicalPlan, parent: Option, } #[derive(Clone)] struct Unnesting { info: Arc, // cclasses: union find data structure of equivalent columns equivalences: UnionFind, - replaces: IndexMap, + + // for each outer exprs on the left, the set of exprs + // on the right required pulling up for the join condition to happen + // i.e select * from t1 where t1.col1 = ( + // select count(*) from t2 where t2.col1 > t1.col2 + t2.col2 or t1.col3 = t1.col2 or t1.col4=2 and t1.col3=1) + // we do this by split the complex expr into conjuctive sets + // for each of such set, if there exists any or binary operator + // we substitute the whole binary operator as true and add every expr appearing in the or condition + // to grouped_by + // and push every + pulled_up_columns: Vec, + //these predicates are disjunctive (combined by `Or` operator) + pulled_up_predicates: Vec, // mapping from outer ref column to new column, if any // i.e in some subquery ( // ... where outer.column_c=inner.column_a // ) // and through union find we have outer.column_c = some_other_expr // we can substitute the inner query with inner.column_a=some_other_expr + replaces: IndexMap, + + join_conditions: Vec, +} +impl Unnesting { + fn get_replaced_col(&self, col: &Column) -> Column { + match self.replaces.get(col) { + Some(col) => col.clone(), + None => col.clone(), + } + } } // TODO: looks like this function can be improved to allow more expr pull up @@ -339,6 +342,14 @@ fn try_transform_subquery_to_join_expr( // }; // } // } +struct GeneralDecorrelationResult { + // i.e for aggregation, dependent columns are added to the projection for joining + added_columns: Vec, + // the reason is, unnesting group by happen at lower nodes, + // but the filtering (if any) of such expr may happen higher node + // (because of known count_bug) + count_expr_map: HashSet, +} impl DependentJoinTracker { fn is_linear_operator(&self, plan: &LogicalPlan) -> bool { match plan { @@ -509,12 +520,159 @@ impl DependentJoinTracker { Ok(()) } - fn unnest( + fn general_decorrelate( &mut self, - node_id: usize, + node: &mut Operator, unnesting: &mut Unnesting, - outer_refs_from_parent: IndexSet, - ) -> Result { + outer_refs_from_parent: &mut IndexSet, + ) -> Result<()> { + if node.is_dependent_join_node { + unimplemented!("recursive unnest not implemented yet") + } + + match &mut node.plan { + LogicalPlan::Subquery(sq) => { + let next_node = node.children.first().unwrap(); + let mut only_child = self.nodes.swap_remove(next_node).unwrap(); + self.general_decorrelate( + &mut only_child, + unnesting, + outer_refs_from_parent, + )?; + *node = only_child; + return Ok(()); + } + LogicalPlan::Aggregate(agg) => { + let is_static = agg.group_expr.is_empty(); // TODO: grouping set also needs to check is_static + let next_node = node.children.first().unwrap(); + let mut only_child = self.nodes.swap_remove(next_node).unwrap(); + self.general_decorrelate( + &mut only_child, + unnesting, + outer_refs_from_parent, + )?; + agg.input = Arc::new(only_child.plan.clone()); + self.nodes.insert(*next_node, only_child); + + Self::rewrite_columns(agg.group_expr.iter_mut(), unnesting)?; + for col in unnesting.pulled_up_columns.iter() { + let replaced_col = unnesting.get_replaced_col(col); + agg.group_expr.push(Expr::Column(replaced_col.clone())); + } + + let need_handle_count_bug = true; + if need_handle_count_bug { + let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); + agg.group_expr.push(un_matched_row.clone()); + // unnesting.pulled_up_predicates.push(value); + } + + if is_static { + let join_condition = unnesting + .pulled_up_predicates + .iter() + .map(|e| strip_outer_reference(e.clone())); + // Building the Domain to join with the group by + // TODO: maybe the construction of domain can happen somewhere else + let new_plan = LogicalPlanBuilder::new(unnesting.info.domain.clone()) + .join_detailed( + node.plan.clone(), + JoinType::Left, + (Vec::::new(), Vec::::new()), + disjunction(join_condition), + true, + )? + .build()?; + println!("{}", new_plan); + node.plan = new_plan; + // self.remove_node(parent, node); + + // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) + // TODO: how domain projection work + // left = select distinct domain + // right = new group by + // if there exists count in the group by, the projection set should be something like + // + // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) + // 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int + } else { + unimplemented!("non static aggregation sq decorrelation not implemented, i.e exists sq with count") + } + } + LogicalPlan::Filter(filter) => { + let disjunctions: Vec = split_disjunction(&filter.predicate) + .into_iter() + .cloned() + .collect(); + let mut remained_expr = vec![]; + // TODO: the paper mention there are 2 approaches to remove these dependent predicate + // - substitute the outer ref columns and push them to the parent node (i.e add them to aggregation node) + // - perform a join with domain directly here + // for now we only implement with the approach substituting + + let mut pulled_up_columns = IndexSet::new(); + for expr in disjunctions.iter() { + if !expr.contains_outer() { + remained_expr.push(expr.clone()); + continue; + } + // extract all columns mentioned in this expr + // and push them up the dependent join + + unnesting.pulled_up_predicates.push(expr.clone()); + expr.clone().map_children(|e| { + if let Expr::Column(ref col) = e { + pulled_up_columns.insert(col.clone()); + } + Ok(Transformed::no(e)) + })?; + } + filter.predicate = match disjunction(remained_expr) { + Some(expr) => expr, + None => lit(true), + }; + unnesting.pulled_up_columns.extend(pulled_up_columns); + outer_refs_from_parent.retain(|ac| ac.node_id != node.id); + if !outer_refs_from_parent.is_empty() { + let next_node = node.children.first().unwrap(); + let mut only_child = self.nodes.swap_remove(next_node).unwrap(); + self.general_decorrelate( + &mut only_child, + unnesting, + outer_refs_from_parent, + )?; + self.nodes.insert(*next_node, only_child); + } + // TODO: add equivalences from select.predicate to info.cclasses + Self::rewrite_columns(vec![&mut filter.predicate].into_iter(), unnesting); + return Ok(()); + } + LogicalPlan::Projection(proj) => { + let next_node = node.children.first().unwrap(); + let mut only_child = self.nodes.swap_remove(next_node).unwrap(); + // TODO: if the children of this node was added with some extra column (i.e) + // aggregation + group by dependent_column + // the projection exprs must also include these new expr + self.general_decorrelate( + &mut only_child, + unnesting, + outer_refs_from_parent, + )?; + + self.nodes.insert(*next_node, only_child); + proj.expr.extend( + unnesting + .pulled_up_columns + .iter() + .map(|c| Expr::Column(c.clone())), + ); + Self::rewrite_columns(proj.expr.iter_mut(), unnesting); + return Ok(()); + } + _ => { + unimplemented!() + } + }; unimplemented!() // if unnesting.info.parent.is_some() { // not_impl_err!("impl me") @@ -528,17 +686,17 @@ impl DependentJoinTracker { // } // Ok(()) } - fn right(&self, node: &Operator) -> &Operator { + fn right_owned(&mut self, node: &Operator) -> Operator { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first let node_id = node.children.first().unwrap(); - return self.nodes.get(node_id).unwrap(); + return self.nodes.swap_remove(node_id).unwrap(); } - fn left(&self, node: &Operator) -> &Operator { + fn left_owned(&mut self, node: &Operator) -> Operator { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first let node_id = node.children.get(1).unwrap(); - return self.nodes.get(node_id).unwrap(); + return self.nodes.swap_remove(node_id).unwrap(); } fn root_dependent_join_elimination(&mut self) -> Result { let root = self.root.unwrap(); @@ -548,38 +706,34 @@ impl DependentJoinTracker { node.is_dependent_join_node, "need to handle the case root node is not dependent join node" ); + let unnesting_info = UnnestingInfo { parent: None, - join: DependentJoin { - original_expr: node.plan.clone(), - left: self.left(node).clone(), - right: self.right(node).clone(), - join_conditions: vec![], - }, - domain: vec![], - outer_refs: vec![], + domain: node.plan.clone(), // dummy }; + + let mut outer_refs = node.access_tracker.clone(); // let unnesting = Unnesting { // info: Arc::new(unnesting), // equivalences: UnionFind::new(), // replaces: IndexMap::new(), // }; - self.dependent_join_elimination(node.id, &unnesting_info, IndexSet::new()) + self.dependent_join_elimination(node.id, &unnesting_info, &mut IndexSet::new()) } fn column_accesses(&self, node_id: usize) -> Vec<&ColumnAccess> { let node = self.nodes.get(&node_id).unwrap(); node.access_tracker.iter().collect() } - fn new_dependent_join(&self, node: &Operator) -> DependentJoin { - DependentJoin { - original_expr: node.plan.clone(), - left: self.left(node).clone(), - right: self.left(node).clone(), - join_conditions: vec![], - } - } + // fn new_dependent_join(&self, node: &Operator) -> DependentJoin { + // DependentJoin { + // original_expr: node.plan.clone(), + // left: self.left(node).clone(), + // right: self.right(node).clone(), + // join_conditions: vec![], + // } + // } fn get_subquery_children( &self, parent: &Operator, @@ -692,26 +846,47 @@ impl DependentJoinTracker { } } + fn build_domain(&self, node: &Operator, left: &Operator) -> Result { + let unique_outer_refs: Vec = node + .access_tracker + .iter() + .map(|c| c.col.clone()) + .unique() + .collect(); + + // TODO: handle this correctly. + // the direct left child of root is not always the table scan node + // and there are many more table providing logical plan + let initial_domain = LogicalPlanBuilder::new(left.plan.clone()) + .project( + unique_outer_refs + .iter() + .map(|col| SelectExpr::Expression(Expr::Column(col.clone()))), + )? + .build()?; + return Ok(initial_domain); + } + fn dependent_join_elimination( &mut self, node: usize, unnesting: &UnnestingInfo, - outer_refs_from_parent: IndexSet, + outer_refs_from_parent: &mut IndexSet, ) -> Result { let parent = unnesting.parent.clone(); let mut root_node = self.nodes.swap_remove(&node).unwrap(); - // let plan = &root_node.plan; - // we have to do the reversed iter, because we know the subquery (right side of - // the dependent join) is always the first child of the node, and we want to visit - // the left side first - let simple_unnest_result = self.simple_decorrelation(&mut root_node)?; if root_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation // replace them with the expr store inside parent.replaces unimplemented!("simple dependent join not implemented for the case of recursive subquery"); - return self.unnest(node, &mut parent.unwrap(), outer_refs_from_parent); + self.general_decorrelate( + &mut root_node, + &mut parent.unwrap(), + outer_refs_from_parent, + )?; + return Ok(root_node.plan.clone()); } return self .build_join_from_simple_unnest(&mut root_node, simple_unnest_result); @@ -719,54 +894,69 @@ impl DependentJoinTracker { // return Ok(dependent_join); } - let mut join = self.new_dependent_join(&root_node); + // let mut join = self.new_dependent_join(&root_node); + let mut left = self.left_owned(&root_node); + let mut right = self.right_owned(&root_node); if parent.is_some() { + unimplemented!(""); // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) let mut outer_ref_from_left = IndexSet::new(); - let left = join.left.clone(); + // let left = join.left.clone(); for col_from_parent in outer_refs_from_parent.iter() { if left .plan .all_out_ref_exprs() - .contains(&Expr::Column(col_from_parent.clone())) + .contains(&Expr::Column(col_from_parent.col)) { outer_ref_from_left.insert(col_from_parent.clone()); } } let mut parent_unnesting = parent.clone().unwrap(); - let new_left = - self.unnest(left.id, &mut parent_unnesting, outer_ref_from_left)?; - join.replace_left(new_left, &parent_unnesting.replaces); + self.general_decorrelate( + &mut left, + &mut parent_unnesting, + &mut outer_ref_from_left, + )?; + // join.replace_left(new_left, &parent_unnesting.replaces); // TODO: after imple simple_decorrelation, rewrite the projection pushed up column as well } + let domain = match parent { + None => self.build_domain(&root_node, &left)?, + Some(info) => { + unimplemented!() + } + }; + let new_unnesting_info = UnnestingInfo { parent: parent.clone(), - join: join.clone(), - domain: vec![], // TODO: populate me - outer_refs: vec![], // TODO: populate me + domain, + // join: join.clone(), + // domain: vec![], // TODO: populate me }; let mut unnesting = Unnesting { info: Arc::new(new_unnesting_info.clone()), + join_conditions: vec![], equivalences: UnionFind { parent: IndexMap::new(), rank: IndexMap::new(), }, replaces: IndexMap::new(), + pulled_up_columns: vec![], + pulled_up_predicates: vec![], + // outer_col_ref_map: HashMap::new(), }; - let mut accesses: IndexSet = self - .column_accesses(node) - .iter() - .map(|a| a.col.clone()) - .collect(); + let mut accesses: IndexSet = root_node.access_tracker.clone(); + // .iter() + // .map(|a| a.col.clone()) + // .collect(); if parent.is_some() { - for col_access in outer_refs_from_parent { - if join - .right + for col_access in outer_refs_from_parent.iter() { + if right .plan .all_out_ref_exprs() - .contains(&Expr::Column(col_access.clone())) + .contains(&Expr::Column(col_access.col.clone())) { accesses.insert(col_access.clone()); } @@ -774,23 +964,46 @@ impl DependentJoinTracker { // add equivalences from join.condition to unnest.cclasses } - let new_right = self.unnest(join.right.id, &mut unnesting, accesses)?; - join.replace_right(new_right, &new_unnesting_info, &unnesting.replaces); + //TODO: add equivalences from join.condition to unnest.cclasses + self.general_decorrelate(&mut right, &mut unnesting, &mut accesses)?; + println!("temporary transformed result {:?}", self); + unimplemented!("implement relacing right node"); + // join.replace_right(new_right, &new_unnesting_info, &unnesting.replaces); // for acc in new_unnesting_info.outer_refs{ // join.join_conditions.append(other); // } - - unimplemented!() } - fn rewrite_columns(expr: Expr, unnesting: Unnesting) { - unimplemented!() - // expr.apply(|expr| { - // if let Expr::OuterReferenceColumn(_, col) = expr { - // set.insert(col); - // } - // Ok(TreeNodeRecursion::Continue) - // }) - // .expect("traversal is infallible"); + fn rewrite_columns<'a>( + exprs: impl Iterator, + unnesting: &Unnesting, + ) -> Result<()> { + for expr in exprs { + *expr = expr + .clone() + .transform(|e| { + match &e { + Expr::Column(col) => { + if let Some(replaced_by) = unnesting.replaces.get(col) { + return Ok(Transformed::yes(Expr::Column( + replaced_by.clone(), + ))); + } + } + Expr::OuterReferenceColumn(_, col) => { + if let Some(replaced_by) = unnesting.replaces.get(col) { + // TODO: no sure if we should use column or outer ref column here + return Ok(Transformed::yes(Expr::Column( + replaced_by.clone(), + ))); + } + } + _ => {} + }; + Ok(Transformed::no(e)) + })? + .data; + } + Ok(()) } fn get_node_uncheck(&self, node_id: &usize) -> Operator { self.nodes.get(node_id).unwrap().clone() @@ -806,6 +1019,8 @@ impl DependentJoinTracker { pulled_up_predicates: vec![], }; + // the iteration should happen with the order of bottom up, so any node push up won't + // affect its children (by accident) let accesses_bottom_up = node.access_tracker.clone().sorted_by(|a, b| { if a.node_id < b.node_id { Ordering::Greater @@ -831,11 +1046,6 @@ impl DependentJoinTracker { Ok(result) } - fn build(&mut self, root: &LogicalPlan) -> Result<()> { - self.build_algebra_index(root.clone())?; - println!("{:?}", self); - Ok(()) - } } impl fmt::Debug for DependentJoinTracker { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -949,17 +1159,16 @@ impl DependentJoinTracker { } // because the column providers are visited after column-accessor - // function visit_with_subqueries always visit the subquery before visiting the other child + // (function visit_with_subqueries always visit the subquery before visiting the other children) // we can always infer the LCA inside this function, by getting the deepest common parent fn conclude_lowest_dependent_join_node(&mut self, child_id: usize, col: &Column) { if let Some(accesses) = self.accessed_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); cur_stack.push(child_id); - // this is a dependen join node + // this is a dependent join node let lca_node = Self::lca_from_stack(&cur_stack, &access.stack); let node = self.nodes.get_mut(&lca_node).unwrap(); - println!("inserting {}", access.node_id); node.access_tracker.insert(ColumnAccess { col: col.clone(), node_id: access.node_id, @@ -983,7 +1192,7 @@ impl DependentJoinTracker { col: col.clone(), }); } - fn build_algebra_index(&mut self, plan: LogicalPlan) -> Result<()> { + fn build(&mut self, plan: LogicalPlan) -> Result<()> { // let mut index = AlgebraIndex::default(); plan.visit_with_subqueries(self)?; Ok(()) @@ -1007,19 +1216,6 @@ impl DependentJoinTracker { } } -#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] -enum ColumnUsage { - Own(Column), - Outer(Column), -} -impl ColumnUsage { - fn debug(&self) -> String { - match self { - ColumnUsage::Own(col) => format!("\x1b[34m{}\x1b[0m", col.flat_name()), - ColumnUsage::Outer(col) => format!("\x1b[31m{}\x1b[0m", col.flat_name()), - } - } -} impl ColumnAccess { fn debug(&self) -> String { format!("\x1b[31m{} ({})\x1b[0m", self.node_id, self.col) @@ -1195,6 +1391,7 @@ mod tests { array::{Int32Array, StringArray}, datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, }; + #[test] fn complex_1_level_decorrelate_in_subquery_with_count() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; @@ -1227,7 +1424,8 @@ mod tests { )? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(&input1)?; + index.build(input1)?; + println!("{:?}", index); let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ @@ -1266,7 +1464,7 @@ mod tests { .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(&input1)?; + index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ @@ -1292,7 +1490,7 @@ mod tests { .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(&input1)?; + index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ @@ -1323,7 +1521,7 @@ mod tests { )? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(&input1)?; + index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1)\ @@ -1367,7 +1565,7 @@ mod tests { )? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(&input1)?; + index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1)\ From 0f039fe7b85daddef026a15ec97b1c263b015d86 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 15 May 2025 06:58:45 +0200 Subject: [PATCH 16/70] feat: handle count bug --- datafusion/optimizer/Cargo.toml | 1 + .../optimizer/src/decorrelate_general.rs | 145 ++++++++++++------ 2 files changed, 101 insertions(+), 45 deletions(-) diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 60358d20e2a1..1f303088a294 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,6 +46,7 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-sql = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index f74e989cd7be..d495fbb178f4 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -17,42 +17,35 @@ //! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` -use std::cell::RefCell; use std::cmp::Ordering; -use std::collections::{BTreeSet, HashSet}; +use std::collections::HashSet; use std::fmt; use std::ops::Deref; -use std::rc::{Rc, Weak}; use std::sync::Arc; +use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::decorrelate::UN_MATCHED_ROW_INDICATOR; -use crate::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; -use crate::utils::has_all_column_refs; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; -use arrow::compute::kernels::cmp::eq; -use arrow::datatypes::Schema; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, - TreeNodeRewriter, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::{ - internal_err, not_impl_err, Column, DFSchemaRef, HashMap, Result, -}; -use datafusion_expr::expr::Exists; +use datafusion_common::{internal_err, Column, Result}; +use datafusion_expr::expr::{self, Exists}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{ conjunction, disjunction, split_conjunction, split_disjunction, }; use datafusion_expr::{ - binary_expr, col, expr_fn, lit, table_scan, Aggregate, BinaryExpr, Cast, Expr, - JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, + binary_expr, col, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; -use datafusion_functions_aggregate::count; +// use datafusion_sql::unparser::Unparser; + use datafusion_sql::unparser::Unparser; -use datafusion_sql::TableReference; +// use datafusion_sql::TableReference; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -164,6 +157,7 @@ struct UnnestingInfo { struct Unnesting { info: Arc, // cclasses: union find data structure of equivalent columns equivalences: UnionFind, + need_handle_count_bug: bool, // for each outer exprs on the left, the set of exprs // on the right required pulling up for the join condition to happen @@ -175,8 +169,9 @@ struct Unnesting { // to grouped_by // and push every pulled_up_columns: Vec, - //these predicates are disjunctive (combined by `Or` operator) + //these predicates are conjunctive pulled_up_predicates: Vec, + count_exprs_dectected: IndexSet, // mapping from outer ref column to new column, if any // i.e in some subquery ( // ... where outer.column_c=inner.column_a @@ -442,7 +437,6 @@ impl DependentJoinTracker { }) .cloned() .collect(); - println!("{:?}", pulled_up_expr); if !pulled_up_expr.is_empty() { for expr in pulled_up_expr.iter() { @@ -546,6 +540,10 @@ impl DependentJoinTracker { let is_static = agg.group_expr.is_empty(); // TODO: grouping set also needs to check is_static let next_node = node.children.first().unwrap(); let mut only_child = self.nodes.swap_remove(next_node).unwrap(); + // keep this for later projection + let mut original_expr = agg.aggr_expr.clone(); + original_expr.extend_from_slice(&agg.group_expr); + self.general_decorrelate( &mut only_child, unnesting, @@ -559,32 +557,73 @@ impl DependentJoinTracker { let replaced_col = unnesting.get_replaced_col(col); agg.group_expr.push(Expr::Column(replaced_col.clone())); } - - let need_handle_count_bug = true; - if need_handle_count_bug { - let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); - agg.group_expr.push(un_matched_row.clone()); - // unnesting.pulled_up_predicates.push(value); + for agg in agg.aggr_expr.iter() { + if contains_count_expr(agg) { + unnesting.count_exprs_dectected.insert(agg.clone()); + } } if is_static { - let join_condition = unnesting - .pulled_up_predicates - .iter() - .map(|e| strip_outer_reference(e.clone())); - // Building the Domain to join with the group by - // TODO: maybe the construction of domain can happen somewhere else - let new_plan = LogicalPlanBuilder::new(unnesting.info.domain.clone()) - .join_detailed( - node.plan.clone(), - JoinType::Left, - (Vec::::new(), Vec::::new()), - disjunction(join_condition), - true, - )? + if !unnesting.count_exprs_dectected.is_empty() + & unnesting.need_handle_count_bug + { + let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); + agg.group_expr.push(un_matched_row); + } + // let right = LogicalPlanBuilder::new(node.plan.clone()); + // the evaluation of + // let mut post_join_projection = vec![]; + + let join_condition = + unnesting.pulled_up_predicates.iter().filter_map(|e| { + let stripped_outer = strip_outer_reference(e.clone()); + if contains_count_expr(&stripped_outer) { + unimplemented!("handle having count(*) predicate pull up") + // post_join_predicates.push(stripped_outer); + // return None; + } + return Some(stripped_outer); + }); + + let right = LogicalPlanBuilder::new(agg.input.deref().clone()) + .aggregate(agg.group_expr.clone(), agg.aggr_expr.clone())? .build()?; - println!("{}", new_plan); - node.plan = new_plan; + let mut new_plan = + LogicalPlanBuilder::new(unnesting.info.domain.clone()) + .join_detailed( + right, + JoinType::Left, + (Vec::::new(), Vec::::new()), + conjunction(join_condition), + true, + )?; + for expr in original_expr.iter_mut() { + if contains_count_expr(expr) { + let new_expr = Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNull(Box::new(Expr::Column( + Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), + )))), + Box::new(lit(0)), + )], + else_expr: Some(Box::new(Expr::Column( + Column::new_unqualified( + expr.schema_name().to_string(), + ), + ))), + }); + let mut expr_rewrite = TypeCoercionRewriter { + schema: new_plan.schema(), + }; + *expr = new_expr.rewrite(&mut expr_rewrite)?.data; + } + } + new_plan = new_plan.project(original_expr)?; + + node.plan = new_plan.build()?; + + println!("{}", node.plan); // self.remove_node(parent, node); // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) @@ -600,7 +639,7 @@ impl DependentJoinTracker { } } LogicalPlan::Filter(filter) => { - let disjunctions: Vec = split_disjunction(&filter.predicate) + let conjuctives: Vec = split_conjunction(&filter.predicate) .into_iter() .cloned() .collect(); @@ -611,7 +650,7 @@ impl DependentJoinTracker { // for now we only implement with the approach substituting let mut pulled_up_columns = IndexSet::new(); - for expr in disjunctions.iter() { + for expr in conjuctives.iter() { if !expr.contains_outer() { remained_expr.push(expr.clone()); continue; @@ -627,7 +666,7 @@ impl DependentJoinTracker { Ok(Transformed::no(e)) })?; } - filter.predicate = match disjunction(remained_expr) { + filter.predicate = match conjunction(remained_expr) { Some(expr) => expr, None => lit(true), }; @@ -945,7 +984,8 @@ impl DependentJoinTracker { replaces: IndexMap::new(), pulled_up_columns: vec![], pulled_up_predicates: vec![], - // outer_col_ref_map: HashMap::new(), + count_exprs_dectected: IndexSet::new(), // outer_col_ref_map: HashMap::new(), + need_handle_count_bug: true, // TODO }; let mut accesses: IndexSet = root_node.access_tracker.clone(); // .iter() @@ -1047,6 +1087,21 @@ impl DependentJoinTracker { Ok(result) } } + +fn contains_count_expr( + expr: &Expr, + // schema: &DFSchemaRef, + // expr_result_map_for_count_bug: &mut HashMap, +) -> bool { + expr.exists(|e| match e { + Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + Ok(func.name() == "count") + } + _ => Ok(false), + }) + .unwrap() +} + impl fmt::Debug for DependentJoinTracker { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "GeneralDecorrelation Tree:")?; From 898bdc435563a89301d1f0d99b9dbb36928460e9 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Fri, 16 May 2025 17:12:48 +0200 Subject: [PATCH 17/70] feat: add sq alias step --- datafusion/expr/src/expr.rs | 13 + datafusion/expr/src/expr_rewriter/mod.rs | 22 + .../optimizer/src/decorrelate_general.rs | 458 ++++++++++++++---- 3 files changed, 400 insertions(+), 93 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 95a5c76fea46..4cc4e347659c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1734,6 +1734,19 @@ impl Expr { .expect("exists closure is infallible") } + /// Return true if the expression contains out reference(correlated) expressions. + pub fn contains_outer_from_relation(&self, outer_relation_name: &String) -> bool { + self.exists(|expr| { + if let Expr::OuterReferenceColumn(_, col) = expr { + if let Some(relation) = &col.relation { + return Ok(relation.table() == outer_relation_name); + } + } + Ok(false) + }) + .expect("exists closure is infallible") + } + /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 90dcbce46b01..b463dd43b228 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -130,6 +130,28 @@ pub fn normalize_sorts( .collect() } +/// Recursively rename the table of all [`Column`] expressions in a given expression tree with +/// a new name, ignoring the `skip_tables` +pub fn replace_col_base_table( + expr: Expr, + skip_tables: &[&str], + new_table: String, +) -> Result { + expr.transform(|expr| { + if let Expr::Column(c) = &expr { + if let Some(relation) = &c.relation { + if !skip_tables.contains(&relation.table()) { + return Ok(Transformed::yes(Expr::Column( + c.with_relation(TableReference::bare(new_table.clone())), + ))); + } + } + } + Ok(Transformed::no(expr)) + }) + .data() +} + /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index d495fbb178f4..db1965c7fc8b 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -32,19 +32,20 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{internal_err, Column, Result}; -use datafusion_expr::expr::{self, Exists}; -use datafusion_expr::expr_rewriter::strip_outer_reference; +use datafusion_expr::expr::{self, Exists, InSubquery}; +use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{ conjunction, disjunction, split_conjunction, split_disjunction, }; use datafusion_expr::{ - binary_expr, col, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, + binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; // use datafusion_sql::unparser::Unparser; use datafusion_sql::unparser::Unparser; +use datafusion_sql::TableReference; // use datafusion_sql::TableReference; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; @@ -155,6 +156,7 @@ struct UnnestingInfo { } #[derive(Clone)] struct Unnesting { + original_subquery: LogicalPlan, info: Arc, // cclasses: union find data structure of equivalent columns equivalences: UnionFind, need_handle_count_bug: bool, @@ -171,7 +173,10 @@ struct Unnesting { pulled_up_columns: Vec, //these predicates are conjunctive pulled_up_predicates: Vec, - count_exprs_dectected: IndexSet, + + subquery_alias_prefix: String, + // need this tracked to later on transform for which original subquery requires which join using which metadata + count_exprs_detected: IndexSet, // mapping from outer ref column to new column, if any // i.e in some subquery ( // ... where outer.column_c=inner.column_a @@ -189,6 +194,44 @@ impl Unnesting { None => col.clone(), } } + + fn rewrite_all_pulled_up_expr( + &mut self, + alias_name: &String, + outer_relations: &[&str], + ) -> Result<()> { + for expr in self.pulled_up_predicates.iter_mut() { + *expr = replace_col_base_table(expr.clone(), &outer_relations, alias_name)?; + } + // let rewritten_projections = self + // .pulled_up_columns + // .iter() + // .map(|e| replace_col_base_table(e.clone(), &outer_relations, alias_name)) + // .collect::>>()?; + // self.pulled_up_projections = rewritten_projections; + Ok(()) + } +} + +pub fn replace_col_base_table( + expr: Expr, + skip_tables: &[&str], + new_table: &String, +) -> Result { + Ok(expr + .transform(|expr| { + if let Expr::Column(c) = &expr { + if let Some(relation) = &c.relation { + if !skip_tables.contains(&relation.table()) { + return Ok(Transformed::yes(Expr::Column( + c.with_relation(TableReference::bare(new_table.clone())), + ))); + } + } + } + Ok(Transformed::no(expr)) + })? + .data) } // TODO: looks like this function can be improved to allow more expr pull up @@ -230,38 +273,63 @@ struct SimpleDecorrelationResult { pulled_up_projections: IndexSet, pulled_up_predicates: Vec, } +impl SimpleDecorrelationResult { + fn rewrite_all_pulled_up_expr( + &mut self, + alias_name: &String, + outer_relations: &[&str], + ) -> Result<()> { + for expr in self.pulled_up_predicates.iter_mut() { + *expr = replace_col_base_table(expr.clone(), &outer_relations, alias_name)?; + } + let rewritten_projections = self + .pulled_up_projections + .iter() + .map(|e| replace_col_base_table(e.clone(), &outer_relations, alias_name)) + .collect::>>()?; + self.pulled_up_projections = rewritten_projections; + Ok(()) + } +} -fn try_transform_subquery_to_join_expr( +fn extract_join_metadata_from_subquery( expr: &Expr, sq: &Subquery, - replace_columns: &[Expr], + subquery_projected_exprs: &[Expr], + alias: &String, + outer_relations: &[&str], ) -> Result<(bool, Option, Option)> { let mut post_join_predicate = None; - // this is used for exist query - let mut join_predicate = None; + // this can either be a projection expr or a predicate expr + let mut transformed_expr = None; let found_sq = expr.exists(|e| match e { Expr::InSubquery(isq) => { - if replace_columns.len() != 1 { + if subquery_projected_exprs.len() != 1 { return internal_err!( "result of IN subquery should only involve one column" ); } if isq.subquery == *sq { + let expr_with_alias = replace_col_base_table( + subquery_projected_exprs[0].clone(), + outer_relations, + alias, + )?; if isq.negated { - join_predicate = Some(binary_expr( + transformed_expr = Some(binary_expr( *isq.expr.clone(), ExprOperator::NotEq, - strip_outer_reference(replace_columns[0].clone()), + strip_outer_reference(expr_with_alias), )); return Ok(true); } - join_predicate = Some(binary_expr( + transformed_expr = Some(binary_expr( *isq.expr.clone(), ExprOperator::Eq, - strip_outer_reference(replace_columns[0].clone()), + strip_outer_reference(expr_with_alias), )); return Ok(true); } @@ -269,19 +337,27 @@ fn try_transform_subquery_to_join_expr( } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (exist, transformed, post_join_expr_from_left) = - try_transform_subquery_to_join_expr(left.as_ref(), sq, replace_columns)?; + extract_join_metadata_from_subquery( + left.as_ref(), + sq, + subquery_projected_exprs, + alias, + outer_relations, + )?; if !exist { let (right_exist, transformed_right, post_join_expr_from_right) = - try_transform_subquery_to_join_expr( + extract_join_metadata_from_subquery( right.as_ref(), sq, - replace_columns, + subquery_projected_exprs, + alias, + outer_relations, )?; if !right_exist { return Ok(false); } if let Some(transformed_right) = transformed_right { - join_predicate = + transformed_expr = Some(binary_expr(*left.clone(), *op, transformed_right)); } if let Some(transformed_right) = post_join_expr_from_right { @@ -291,11 +367,8 @@ fn try_transform_subquery_to_join_expr( return Ok(true); } - // TODO: exist query won't have any transformed expr, - // meaning this query is not supported `where bool_col = exists(subquery)` - if let Some(transformed) = transformed { - join_predicate = Some(binary_expr(transformed, *op, *right.clone())); + transformed_expr = Some(binary_expr(transformed, *op, *right.clone())); } if let Some(transformed) = post_join_expr_from_left { post_join_predicate = Some(binary_expr(transformed, *op, *right.clone())); @@ -315,11 +388,14 @@ fn try_transform_subquery_to_join_expr( return Ok(false); } Expr::ScalarSubquery(ssq) => { - unimplemented!( - "we need to store map between scalarsubquery and replaced_expr later on" - ); + if subquery_projected_exprs.len() != 1 { + return internal_err!( + "result of scalar subquery should only involve one column" + ); + } if let LogicalPlan::Subquery(inner_sq) = ssq.subquery.as_ref() { if inner_sq.clone() == *sq { + transformed_expr = Some(subquery_projected_exprs[0].clone()); return Ok(true); } } @@ -327,7 +403,7 @@ fn try_transform_subquery_to_join_expr( } _ => Ok(false), })?; - return Ok((found_sq, join_predicate, post_join_predicate)); + return Ok((found_sq, transformed_expr, post_join_predicate)); } // impl Default for GeneralDecorrelation { @@ -345,6 +421,7 @@ struct GeneralDecorrelationResult { // (because of known count_bug) count_expr_map: HashSet, } + impl DependentJoinTracker { fn is_linear_operator(&self, plan: &LogicalPlan) -> bool { match plan { @@ -426,7 +503,6 @@ impl DependentJoinTracker { .filter(|proj_expr| { proj_expr .exists(|expr| { - // TODO: what if parent has already rewritten outer_ref_col if let Expr::OuterReferenceColumn(_, col) = expr { root_node.access_tracker.remove(col_access); return Ok(*col == col_access.col); @@ -559,12 +635,12 @@ impl DependentJoinTracker { } for agg in agg.aggr_expr.iter() { if contains_count_expr(agg) { - unnesting.count_exprs_dectected.insert(agg.clone()); + unnesting.count_exprs_detected.insert(agg.clone()); } } if is_static { - if !unnesting.count_exprs_dectected.is_empty() + if !unnesting.count_exprs_detected.is_empty() & unnesting.need_handle_count_bug { let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); @@ -573,6 +649,8 @@ impl DependentJoinTracker { // let right = LogicalPlanBuilder::new(node.plan.clone()); // the evaluation of // let mut post_join_projection = vec![]; + let alias = + self.alias_generator.next(&unnesting.subquery_alias_prefix); let join_condition = unnesting.pulled_up_predicates.iter().filter_map(|e| { @@ -582,11 +660,18 @@ impl DependentJoinTracker { // post_join_predicates.push(stripped_outer); // return None; } - return Some(stripped_outer); + match &stripped_outer { + Expr::Column(col) => { + println!("{:?}", col); + } + _ => {} + } + Some(stripped_outer) }); let right = LogicalPlanBuilder::new(agg.input.deref().clone()) .aggregate(agg.group_expr.clone(), agg.aggr_expr.clone())? + .alias(alias.clone())? .build()?; let mut new_plan = LogicalPlanBuilder::new(unnesting.info.domain.clone()) @@ -618,12 +703,18 @@ impl DependentJoinTracker { }; *expr = new_expr.rewrite(&mut expr_rewrite)?.data; } + + // *expr = Expr::Column(create_col_from_scalar_expr( + // expr, + // alias.clone(), + // )?); } new_plan = new_plan.project(original_expr)?; node.plan = new_plan.build()?; println!("{}", node.plan); + return Ok(()); // self.remove_node(parent, node); // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) @@ -712,7 +803,6 @@ impl DependentJoinTracker { unimplemented!() } }; - unimplemented!() // if unnesting.info.parent.is_some() { // not_impl_err!("impl me") // // TODO @@ -734,7 +824,7 @@ impl DependentJoinTracker { fn left_owned(&mut self, node: &Operator) -> Operator { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first - let node_id = node.children.get(1).unwrap(); + let node_id = node.children.last().unwrap(); return self.nodes.swap_remove(node_id).unwrap(); } fn root_dependent_join_elimination(&mut self) -> Result { @@ -776,33 +866,43 @@ impl DependentJoinTracker { fn get_subquery_children( &self, parent: &Operator, - ) -> Result<(LogicalPlan, Subquery)> { - let subquery = parent.children.get(0).unwrap(); + // because one dependent join node can have multiple subquery at a time + sq_offset: usize, + ) -> Result<(LogicalPlan, Subquery, SubqueryType)> { + let subquery = parent.children.get(sq_offset).unwrap(); let sq_node = self.nodes.get(subquery).unwrap(); assert!(sq_node.is_subquery_node); let query = sq_node.children.get(0).unwrap(); let target_node = self.nodes.get(query).unwrap(); // let op = .clone(); if let LogicalPlan::Subquery(subquery) = sq_node.plan.clone() { - return Ok((target_node.plan.clone(), subquery)); + return Ok((target_node.plan.clone(), subquery, sq_node.subquery_type)); } else { internal_err!("") } } - fn build_join_from_simple_unnest( + fn build_join_from_simple_decorrelation_result( &self, dependent_join_node: &mut Operator, - ret: SimpleDecorrelationResult, + mut ret: SimpleDecorrelationResult, ) -> Result { - let (subquery_children, subquery) = - self.get_subquery_children(dependent_join_node)?; + let (subquery_children, subquery, sq_type) = + self.get_subquery_children(dependent_join_node, 0)?; + let outer_relations: Vec<&str> = dependent_join_node + .correlated_relations + .iter() + .map(String::as_str) + .collect(); + match dependent_join_node.plan { LogicalPlan::Filter(ref mut filter) => { - let exprs = split_conjunction(&filter.predicate); - let mut join_exprs = vec![]; - let mut kept_predicates = vec![]; + let predicate_expr = split_conjunction(&filter.predicate); + let mut join_predicates = vec![]; + let mut post_join_predicates = vec![]; // maybe we also need to collect join columns here + // TODO: we need to also pull up projectoin to support subqueries that appear + // in select expressions let pulled_projection: Vec = ret .pulled_up_projections .iter() @@ -818,21 +918,27 @@ impl DependentJoinTracker { .map(strip_outer_reference) .collect() }; - let mut join_type = JoinType::LeftSemi; - for expr in exprs.into_iter() { + let mut join_type = sq_type.default_join_type(); + let alias_name = self.alias_generator.next(&sq_type.prefix()).to_string(); + ret.rewrite_all_pulled_up_expr(&alias_name, &outer_relations)?; + + for expr in predicate_expr.into_iter() { // exist query may not have any transformed expr // i.e where exists(suquery) => semi join - let (transformed, maybe_transformed_expr, maybe_post_join_expr) = - try_transform_subquery_to_join_expr( + let (transformed, maybe_join_predicate, maybe_post_join_predicate) = + extract_join_metadata_from_subquery( expr, &subquery, &right_exprs, + &alias_name, + &outer_relations, )?; - if let Some(transformed) = maybe_transformed_expr { - join_exprs.push(transformed) + if let Some(transformed) = maybe_join_predicate { + println!("join predicate is {}", transformed.clone()); + join_predicates.push(transformed) } - if let Some(post_join_expr) = maybe_post_join_expr { + if let Some(post_join_expr) = maybe_post_join_predicate { if post_join_expr .exists(|e| { if let Expr::Column(col) = e { @@ -845,10 +951,10 @@ impl DependentJoinTracker { // only use mark join if required join_type = JoinType::LeftMark } - kept_predicates.push(post_join_expr) + post_join_predicates.push(post_join_expr) } if !transformed { - kept_predicates.push(expr.clone()) + post_join_predicates.push(expr.clone()) } } @@ -856,26 +962,32 @@ impl DependentJoinTracker { .pulled_up_predicates .iter() .map(|e| strip_outer_reference(e.clone())); - join_exprs.extend(new_predicates); + + join_predicates.extend(new_predicates); // TODO: some predicate is join predicate, some is just filter // kept_predicates.extend(new_predicates); // filter.predicate = conjunction(kept_predicates).unwrap(); // left + + let mut right = LogicalPlanBuilder::new(subquery_children) + .alias(&alias_name)? + .build()?; let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); - builder = if join_exprs.is_empty() { - builder.join_on(subquery_children, join_type, vec![lit(true)])? + builder = if join_predicates.is_empty() { + builder.join_on(right, join_type, vec![lit(true)])? } else { builder.join_on( - subquery_children, + right, // TODO: join type based on filter condition join_type, - join_exprs, + join_predicates, )? }; - if kept_predicates.len() > 0 { - builder = builder.filter(conjunction(kept_predicates).unwrap())? + if post_join_predicates.len() > 0 { + builder = + builder.filter(conjunction(post_join_predicates).unwrap())? } builder.build() } @@ -915,6 +1027,7 @@ impl DependentJoinTracker { let parent = unnesting.parent.clone(); let mut root_node = self.nodes.swap_remove(&node).unwrap(); let simple_unnest_result = self.simple_decorrelation(&mut root_node)?; + let (original_subquery, _, _) = self.get_subquery_children(&root_node, 0)?; if root_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation @@ -927,13 +1040,16 @@ impl DependentJoinTracker { )?; return Ok(root_node.plan.clone()); } - return self - .build_join_from_simple_unnest(&mut root_node, simple_unnest_result); + return self.build_join_from_simple_decorrelation_result( + &mut root_node, + simple_unnest_result, + ); unimplemented!() // return Ok(dependent_join); } // let mut join = self.new_dependent_join(&root_node); + // TODO: handle the case where one dependent join node contains multiple subqueries let mut left = self.left_owned(&root_node); let mut right = self.right_owned(&root_node); if parent.is_some() { @@ -975,6 +1091,7 @@ impl DependentJoinTracker { // domain: vec![], // TODO: populate me }; let mut unnesting = Unnesting { + original_subquery, info: Arc::new(new_unnesting_info.clone()), join_conditions: vec![], equivalences: UnionFind { @@ -984,8 +1101,9 @@ impl DependentJoinTracker { replaces: IndexMap::new(), pulled_up_columns: vec![], pulled_up_predicates: vec![], - count_exprs_dectected: IndexSet::new(), // outer_col_ref_map: HashMap::new(), - need_handle_count_bug: true, // TODO + count_exprs_detected: IndexSet::new(), // outer_col_ref_map: HashMap::new(), + need_handle_count_bug: true, // TODO + subquery_alias_prefix: "__scalar_sq".to_string(), // TODO }; let mut accesses: IndexSet = root_node.access_tracker.clone(); // .iter() @@ -1013,6 +1131,96 @@ impl DependentJoinTracker { // join.join_conditions.append(other); // } } + + fn build_join_from_general_unnesting_info( + &self, + dependent_join_node: &mut Operator, + decorrelated_right_node: &mut Operator, + unnesting: Unnesting, + ) -> Result { + let (subquery_children, subquery, subquery_type) = + self.get_subquery_children(dependent_join_node, 0)?; + let outer_relations: Vec<&str> = dependent_join_node + .correlated_relations + .iter() + .map(String::as_str) + .collect(); + match dependent_join_node.plan { + LogicalPlan::Filter(ref mut filter) => { + let exprs = split_conjunction(&filter.predicate); + let mut join_exprs = vec![]; + let mut kept_predicates = vec![]; + let right_expr: Vec<_> = decorrelated_right_node + .plan + .schema() + .columns() + .iter() + .map(|c| Expr::Column(c.clone())) + .collect(); + let mut join_type = subquery_type.default_join_type(); + let alias = self.alias_generator.next(&subquery_type.prefix()); + for expr in exprs.into_iter() { + // exist query may not have any transformed expr + // i.e where exists(suquery) => semi join + let (transformed, maybe_transformed_expr, maybe_post_join_expr) = + extract_join_metadata_from_subquery( + expr, + &subquery, + &right_expr, + &alias, + &outer_relations, + )?; + + if let Some(transformed) = maybe_transformed_expr { + join_exprs.push(transformed) + } + if let Some(post_join_expr) = maybe_post_join_expr { + if post_join_expr + .exists(|e| { + if let Expr::Column(col) = e { + return Ok(col.name == "mark"); + } + return Ok(false); + }) + .unwrap() + { + // only use mark join if required + join_type = JoinType::LeftMark + } + kept_predicates.push(post_join_expr) + } + if !transformed { + kept_predicates.push(expr.clone()) + } + } + + // TODO: some predicate is join predicate, some is just filter + // kept_predicates.extend(new_predicates); + // filter.predicate = conjunction(kept_predicates).unwrap(); + // left + let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); + + builder = if join_exprs.is_empty() { + builder.join_on(subquery_children, join_type, vec![lit(true)])? + } else { + builder.join_on( + subquery_children, + // TODO: join type based on filter condition + join_type, + join_exprs, + )? + }; + + if kept_predicates.len() > 0 { + builder = builder.filter(conjunction(kept_predicates).unwrap())? + } + builder.build() + } + _ => { + unimplemented!() + } + } + } fn rewrite_columns<'a>( exprs: impl Iterator, unnesting: &Unnesting, @@ -1216,7 +1424,12 @@ impl DependentJoinTracker { // because the column providers are visited after column-accessor // (function visit_with_subqueries always visit the subquery before visiting the other children) // we can always infer the LCA inside this function, by getting the deepest common parent - fn conclude_lowest_dependent_join_node(&mut self, child_id: usize, col: &Column) { + fn conclude_lowest_dependent_join_node( + &mut self, + child_id: usize, + col: &Column, + tbl_name: &str, + ) { if let Some(accesses) = self.accessed_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); @@ -1229,6 +1442,7 @@ impl DependentJoinTracker { node_id: access.node_id, stack: access.stack.clone(), }); + node.correlated_relations.insert(tbl_name.to_string()); } } } @@ -1283,23 +1497,50 @@ struct Operator { parent: Option, // This field is only set if the node is dependent join node - // it track which child still accessing which column of + // it track which descendent nodes still accessing the outer columns provided by its + // left child // the insertion order is top down access_tracker: IndexSet, is_dependent_join_node: bool, is_subquery_node: bool, + + // note that for dependent join nodes, there can be more than 1 + // subquery children at a time, but always 1 outer-column-providing-child + // which is at the last element children: Vec, + subquery_type: SubqueryType, + correlated_relations: IndexSet, } -impl Operator { - // fn to_dependent_join(&self) -> DependentJoin { - // DependentJoin { - // original_expr: self.plan.clone(), - // left: self.left(), - // right: self.right(), - // join_conditions: vec![], - // } - // } +#[derive(Debug, Clone, Copy)] +enum SubqueryType { + None, + In, + Exists, + Scalar, +} +impl SubqueryType { + fn default_join_type(&self) -> JoinType { + match self { + SubqueryType::None => { + panic!("not reached") + } + SubqueryType::In => JoinType::LeftSemi, + SubqueryType::Exists => JoinType::LeftSemi, + // TODO: in duckdb, they have JoinType::Single + // where there is only at most one join partner entry on the LEFT + SubqueryType::Scalar => JoinType::Left, + } + } + fn prefix(&self) -> String { + match self { + SubqueryType::None => "", + SubqueryType::In => "__in_sq", + SubqueryType::Exists => "__exists_sq", + SubqueryType::Scalar => "__scalar_sq", + } + .to_string() + } } fn contains_subquery(expr: &Expr) -> bool { @@ -1328,6 +1569,8 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { } let mut is_subquery_node = false; let mut is_dependent_join_node = false; + + let mut subquery_type = SubqueryType::None; // for each node, find which column it is accessing, which column it is providing // Set of columns current node access match node { @@ -1341,7 +1584,11 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { } LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node(self.current_id, &col); + self.conclude_lowest_dependent_join_node( + self.current_id, + &col, + tbl_scan.table_name.table(), + ); }); } // TODO @@ -1363,7 +1610,29 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { } LogicalPlan::Subquery(subquery) => { is_subquery_node = true; - // TODO: once we detect the subquery + let parent = self.stack.last().unwrap(); + let parent_node = self.get_node_uncheck(parent); + for expr in parent_node.plan.expressions() { + expr.exists(|e| { + let (found_sq, checking_type) = match e { + Expr::ScalarSubquery(sq) => { + (sq == subquery, SubqueryType::Scalar) + } + Expr::Exists(Exists { subquery: sq, .. }) => { + (sq == subquery, SubqueryType::Exists) + } + Expr::InSubquery(InSubquery { subquery: sq, .. }) => { + (sq == subquery, SubqueryType::In) + } + _ => (false, SubqueryType::None), + }; + if found_sq { + subquery_type = checking_type; + } + + Ok(found_sq) + })?; + } } LogicalPlan::Aggregate(_) => {} _ => { @@ -1390,6 +1659,8 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { is_dependent_join_node, children: vec![], access_tracker: IndexSet::new(), + subquery_type, + correlated_relations: IndexSet::new(), }, ); @@ -1442,10 +1713,7 @@ mod tests { use crate::test::{test_table_scan, test_table_scan_with_name}; use super::DependentJoinTracker; - use arrow::{ - array::{Int32Array, StringArray}, - datatypes::{DataType as ArrowDataType, Field, Fields, Schema}, - }; + use arrow::datatypes::DataType as ArrowDataType; #[test] fn complex_1_level_decorrelate_in_subquery_with_count() -> Result<()> { @@ -1522,11 +1790,12 @@ mod tests { index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ - Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ - \n LeftMark Join: Filter: inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ + \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ \n TableScan: outer_table\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __exists_sq_1\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } @@ -1548,12 +1817,13 @@ mod tests { index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ - Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ + Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ \n LeftMark Join: Filter: Boolean(true)\ \n TableScan: outer_table\ - \n Projection: inner_table_lv1.b, inner_table_lv1.a\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __exists_sq_1\ + \n Projection: inner_table_lv1.b, inner_table_lv1.a\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } @@ -1580,11 +1850,12 @@ mod tests { let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c = inner_table_lv1.b\ + \n LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ \n TableScan: outer_table\ - \n Projection: inner_table_lv1.b\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __in_sq_1\ + \n Projection: inner_table_lv1.b\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } @@ -1624,10 +1895,11 @@ mod tests { let new_plan = index.root_dependent_join_elimination()?; let expected = "\ Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + \n LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND __in_sq_1.a = outer_table.a AND outer_table.a > __in_sq_1.c AND outer_table.b = __in_sq_1.b\ \n TableScan: outer_table\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __in_sq_1\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } From 1a600b659437248afa0768bdaf547a4981823fe5 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Fri, 16 May 2025 18:01:03 +0200 Subject: [PATCH 18/70] test: simple count decorrelate --- .../optimizer/src/decorrelate_general.rs | 60 +++++++++++++------ 1 file changed, 43 insertions(+), 17 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index db1965c7fc8b..2d0093a9b785 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -174,7 +174,6 @@ struct Unnesting { //these predicates are conjunctive pulled_up_predicates: Vec, - subquery_alias_prefix: String, // need this tracked to later on transform for which original subquery requires which join using which metadata count_exprs_detected: IndexSet, // mapping from outer ref column to new column, if any @@ -186,6 +185,8 @@ struct Unnesting { replaces: IndexMap, join_conditions: Vec, + subquery_type: SubqueryType, + decorrelated_subquery: Option, } impl Unnesting { fn get_replaced_col(&self, col: &Column) -> Column { @@ -609,6 +610,7 @@ impl DependentJoinTracker { unnesting, outer_refs_from_parent, )?; + unnesting.decorrelated_subquery = Some(sq.clone()); *node = only_child; return Ok(()); } @@ -650,7 +652,7 @@ impl DependentJoinTracker { // the evaluation of // let mut post_join_projection = vec![]; let alias = - self.alias_generator.next(&unnesting.subquery_alias_prefix); + self.alias_generator.next(&unnesting.subquery_type.prefix()); let join_condition = unnesting.pulled_up_predicates.iter().filter_map(|e| { @@ -1009,10 +1011,11 @@ impl DependentJoinTracker { // the direct left child of root is not always the table scan node // and there are many more table providing logical plan let initial_domain = LogicalPlanBuilder::new(left.plan.clone()) - .project( + .aggregate( unique_outer_refs .iter() - .map(|col| SelectExpr::Expression(Expr::Column(col.clone()))), + .map(|col| Expr::Column(col.clone())), + Vec::::new(), )? .build()?; return Ok(initial_domain); @@ -1027,7 +1030,8 @@ impl DependentJoinTracker { let parent = unnesting.parent.clone(); let mut root_node = self.nodes.swap_remove(&node).unwrap(); let simple_unnest_result = self.simple_decorrelation(&mut root_node)?; - let (original_subquery, _, _) = self.get_subquery_children(&root_node, 0)?; + let (original_subquery, _, subquery_type) = + self.get_subquery_children(&root_node, 0)?; if root_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation @@ -1103,7 +1107,8 @@ impl DependentJoinTracker { pulled_up_predicates: vec![], count_exprs_detected: IndexSet::new(), // outer_col_ref_map: HashMap::new(), need_handle_count_bug: true, // TODO - subquery_alias_prefix: "__scalar_sq".to_string(), // TODO + subquery_type, + decorrelated_subquery: None, }; let mut accesses: IndexSet = root_node.access_tracker.clone(); // .iter() @@ -1124,7 +1129,18 @@ impl DependentJoinTracker { //TODO: add equivalences from join.condition to unnest.cclasses self.general_decorrelate(&mut right, &mut unnesting, &mut accesses)?; - println!("temporary transformed result {:?}", self); + let decorrelated_plan = self.build_join_from_general_unnesting_info( + &mut root_node, + &mut left, + &mut right, + unnesting, + )?; + return Ok(decorrelated_plan); + + // self.nodes.insert(left.id, left); + // self.nodes.insert(right.id, right); + // self.nodes.insert(node, root_node); + unimplemented!("implement relacing right node"); // join.replace_right(new_right, &new_unnesting_info, &unnesting.replaces); // for acc in new_unnesting_info.outer_refs{ @@ -1135,16 +1151,24 @@ impl DependentJoinTracker { fn build_join_from_general_unnesting_info( &self, dependent_join_node: &mut Operator, + left_node: &mut Operator, decorrelated_right_node: &mut Operator, - unnesting: Unnesting, + mut unnesting: Unnesting, ) -> Result { - let (subquery_children, subquery, subquery_type) = - self.get_subquery_children(dependent_join_node, 0)?; + let subquery = unnesting.decorrelated_subquery.take().unwrap(); + let decorrelated_right = decorrelated_right_node.plan.clone(); + let subquery_type = unnesting.subquery_type; + + let alias = self.alias_generator.next(&subquery_type.prefix()); let outer_relations: Vec<&str> = dependent_join_node .correlated_relations .iter() .map(String::as_str) .collect(); + + unnesting.rewrite_all_pulled_up_expr(&alias, &outer_relations)?; + // TODO: do this on left instead of dependent_join_node directly, because with recursive + // the left side can also be rewritten match dependent_join_node.plan { LogicalPlan::Filter(ref mut filter) => { let exprs = split_conjunction(&filter.predicate); @@ -1158,7 +1182,6 @@ impl DependentJoinTracker { .map(|c| Expr::Column(c.clone())) .collect(); let mut join_type = subquery_type.default_join_type(); - let alias = self.alias_generator.next(&subquery_type.prefix()); for expr in exprs.into_iter() { // exist query may not have any transformed expr // i.e where exists(suquery) => semi join @@ -1201,10 +1224,10 @@ impl DependentJoinTracker { let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); builder = if join_exprs.is_empty() { - builder.join_on(subquery_children, join_type, vec![lit(true)])? + builder.join_on(decorrelated_right, join_type, vec![lit(true)])? } else { builder.join_on( - subquery_children, + decorrelated_right, // TODO: join type based on filter condition join_type, join_exprs, @@ -1750,12 +1773,15 @@ mod tests { index.build(input1)?; println!("{:?}", index); let new_plan = index.root_dependent_join_elimination()?; + println!("{}", new_plan); let expected = "\ - Filter: outer_table.a > Int32(1) AND inner_table_lv1.mark\ - \n LeftMark Join: Filter: inner_table_lv1.a = outer_table.a AND outer_table.a > inner_table_lv1.c AND outer_table.b = inner_table_lv1.b\ + Filter: outer_table.a > Int32(1)\ + \n LeftSemi Join: Filter: outer_table.c = count_a\ \n TableScan: outer_table\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n Projection: count(inner_table_lv1.a) AS count_a, inner_table_lv1.a, inner_table_lv1.c, inner_table_lv1.b\ + \n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]]\ + \n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } From 6ce21b396f523e1e4a9372415084da111465276d Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 17 May 2025 15:26:18 +0200 Subject: [PATCH 19/70] chore: some work to support multiple subqueries per level --- .../optimizer/src/decorrelate_general.rs | 635 ++++++++++++------ .../sqllogictest/test_files/debug_count.slt | 116 ++++ 2 files changed, 534 insertions(+), 217 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/debug_count.slt diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 2d0093a9b785..510ce44fed9f 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -31,7 +31,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::{internal_err, Column, Result}; +use datafusion_common::{internal_err, Column, HashMap, Result}; use datafusion_expr::expr::{self, Exists, InSubquery}; use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; use datafusion_expr::select_expr::SelectExpr; @@ -39,8 +39,8 @@ use datafusion_expr::utils::{ conjunction, disjunction, split_conjunction, split_disjunction, }; use datafusion_expr::{ - binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, JoinType, LogicalPlan, - LogicalPlanBuilder, Operator as ExprOperator, Subquery, + binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, Filter, JoinType, + LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; // use datafusion_sql::unparser::Unparser; @@ -57,7 +57,7 @@ pub struct DependentJoinTracker { // each logical plan traversal will assign it a integer id current_id: usize, // each newly visted operator is inserted inside this map for tracking - nodes: IndexMap, + nodes: IndexMap, // all the node ids from root to the current node // this is used during traversal only stack: Vec, @@ -68,7 +68,9 @@ pub struct DependentJoinTracker { #[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] struct ColumnAccess { + // node ids from root to the node that is referencing the column stack: Vec, + // the node referencing the column node_id: usize, col: Column, } @@ -135,18 +137,6 @@ impl UnionFind { true } } -// TODO: impl me -#[derive(Clone)] -struct DependentJoin { - // - original_expr: LogicalPlan, - left: Operator, - right: Operator, - // TODO: combine into one Expr - join_conditions: Vec, - // join_type: -} -impl DependentJoin {} #[derive(Clone)] struct UnnestingInfo { @@ -184,7 +174,6 @@ struct Unnesting { // we can substitute the inner query with inner.column_a=some_other_expr replaces: IndexMap, - join_conditions: Vec, subquery_type: SubqueryType, decorrelated_subquery: Option, } @@ -199,10 +188,10 @@ impl Unnesting { fn rewrite_all_pulled_up_expr( &mut self, alias_name: &String, - outer_relations: &[&str], + outer_relations: &[String], ) -> Result<()> { for expr in self.pulled_up_predicates.iter_mut() { - *expr = replace_col_base_table(expr.clone(), &outer_relations, alias_name)?; + *expr = replace_col_base_table(expr.clone(), outer_relations, alias_name)?; } // let rewritten_projections = self // .pulled_up_columns @@ -216,14 +205,14 @@ impl Unnesting { pub fn replace_col_base_table( expr: Expr, - skip_tables: &[&str], + skip_tables: &[String], new_table: &String, ) -> Result { Ok(expr .transform(|expr| { if let Expr::Column(c) = &expr { if let Some(relation) = &c.relation { - if !skip_tables.contains(&relation.table()) { + if !skip_tables.contains(&relation.table().to_string()) { return Ok(Transformed::yes(Expr::Column( c.with_relation(TableReference::bare(new_table.clone())), ))); @@ -266,27 +255,70 @@ fn can_pull_up(expr: &Expr) -> bool { } } +#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] +struct PulledUpExpr { + expr: Expr, + // multiple expr can be pulled up at a time, and because multiple subquery exists + // at the same level, we need to track which subquery the pulling up is happening for + subquery_node_id: usize, +} + struct SimpleDecorrelationResult { - // new: Option, - // if projection pull up happened, each will be tracked, so that later on general decorrelation - // can rewrite them (a.k.a outer ref column maybe renamed/substituted some where in the parent already - // because the decorrelation is top-down) - pulled_up_projections: IndexSet, - pulled_up_predicates: Vec, + pulled_up_projections: IndexSet, + pulled_up_predicates: Vec, } impl SimpleDecorrelationResult { + // fn get_decorrelated_subquery_node_ids(&self) -> Vec { + // self.pulled_up_predicates + // .iter() + // .map(|e| e.subquery_node_id) + // .chain( + // self.pulled_up_projections + // .iter() + // .map(|e| e.subquery_node_id), + // ) + // .unique() + // .collect() + // // node_ids.extend( + // // self.pulled_up_projections + // // .iter() + // // .map(|e| e.subquery_node_id), + // // ); + // // node_ids.into_iter().unique().collect() + // } + // because we don't track which expr was pullled up for which relation to give alias for fn rewrite_all_pulled_up_expr( &mut self, - alias_name: &String, - outer_relations: &[&str], + subquery_node_alias_map: &IndexMap, + outer_relations: &[String], ) -> Result<()> { + let alias_by_subquery_node_id: IndexMap = subquery_node_alias_map + .iter() + .map(|(alias, node)| (node.id, alias)) + .collect(); for expr in self.pulled_up_predicates.iter_mut() { - *expr = replace_col_base_table(expr.clone(), &outer_relations, alias_name)?; + let alias = alias_by_subquery_node_id + .get(&expr.subquery_node_id) + .unwrap(); + expr.expr = + replace_col_base_table(expr.expr.clone(), &outer_relations, *alias)?; } let rewritten_projections = self .pulled_up_projections .iter() - .map(|e| replace_col_base_table(e.clone(), &outer_relations, alias_name)) + .map(|expr| { + let alias = alias_by_subquery_node_id + .get(&expr.subquery_node_id) + .unwrap(); + Ok(PulledUpExpr { + subquery_node_id: expr.subquery_node_id, + expr: replace_col_base_table( + expr.expr.clone(), + &outer_relations, + *alias, + )?, + }) + }) .collect::>>()?; self.pulled_up_projections = rewritten_projections; Ok(()) @@ -298,7 +330,7 @@ fn extract_join_metadata_from_subquery( sq: &Subquery, subquery_projected_exprs: &[Expr], alias: &String, - outer_relations: &[&str], + outer_relations: &[String], ) -> Result<(bool, Option, Option)> { let mut post_join_predicate = None; @@ -434,7 +466,7 @@ impl DependentJoinTracker { _ => false, } } - fn is_linear_path(&self, parent: &Operator, child: &Operator) -> bool { + fn is_linear_path(&self, parent: &Node, child: &Node) -> bool { if !self.is_linear_operator(&child.plan) { return false; } @@ -464,39 +496,40 @@ impl DependentJoinTracker { }; } } - fn remove_node(&mut self, parent: &mut Operator, node: &mut Operator) { + fn remove_node(&mut self, parent: &mut Node, node: &mut Node) { let next_children = node.children.first().unwrap(); let next_children_node = self.nodes.swap_remove(next_children).unwrap(); // let next_children_node = self.nodes.get_mut(next_children).unwrap(); *node = next_children_node; node.parent = Some(parent.id); } - // decorrelate all descendant(recursively) with simple unnesting - // returns true if all children were eliminated - // TODO(impl me) + + // decorrelate all descendant with simple unnesting + // this function will remove corresponding entry in root_node.access_tracker if applicable + // , so caller can rely on the length of this field to detect if simple decorrelation is enough + // and the decorrelation can stop using "simple method". + // It also does the in-place update to + // + // TODO: this is not yet recursive, but theoreically nested subqueries + // can be decorrelated using simple method as long as they are independent + // with each other fn try_simple_decorrelate_descendent( &mut self, - root_node: &mut Operator, - child_node: &mut Operator, + root_node: &mut Node, + child_node: &mut Node, col_access: &ColumnAccess, result: &mut SimpleDecorrelationResult, ) -> Result<()> { - // unnest children first - // println!("decorrelating {} from {}", child, root); - if !self.is_linear_path(root_node, child_node) { - // TODO: return Ok(()); } - - // TODO: inplace update - // let mut child_node = self.nodes.swap_remove(child).unwrap().clone(); - // let mut root_node = self.nodes.swap_remove(root).unwrap(); + // offest 0 (root) is dependent join node, will immediately followed by subquery node + let subquery_node_id = col_access.stack[1]; match &mut child_node.plan { LogicalPlan::Projection(proj) => { - // TODO: handle the case outer_ref_a + outer_ref_b??? - // if we only see outer_ref_a and decide to move the whole expr + // TODO: handle the case select binary_expr(outer_ref_a, outer_ref_b) ??? + // if we only see outer_ref_a and decide to pull up the whole expr here // outer_ref_b is accidentally pulled up let pulled_up_expr: IndexSet<_> = proj .expr @@ -505,7 +538,7 @@ impl DependentJoinTracker { proj_expr .exists(|expr| { if let Expr::OuterReferenceColumn(_, col) = expr { - root_node.access_tracker.remove(col_access); + root_node.access_tracker.swap_remove(col_access); return Ok(*col == col_access.col); } Ok(false) @@ -517,7 +550,10 @@ impl DependentJoinTracker { if !pulled_up_expr.is_empty() { for expr in pulled_up_expr.iter() { - result.pulled_up_projections.insert(expr.clone()); + result.pulled_up_projections.insert(PulledUpExpr { + expr: expr.clone(), + subquery_node_id, + }); } // all expr of this node is pulled up, fully remove this node from the tree if proj.expr.len() == pulled_up_expr.len() { @@ -546,10 +582,17 @@ impl DependentJoinTracker { let (pulled_up, kept): (Vec<_>, Vec<_>) = subquery_filter_exprs .iter() .cloned() + // NOTE: if later on we decide to support nested subquery inside this function + // (i.e multiple subqueries exist in the stack) + // the call to e.contains_outer must be aware of which subquery it is checking for:w .partition(|e| e.contains_outer() && can_pull_up(e)); - // only remove the access tracker if non of the kept expr contains reference to the column + // only remove the access tracker if none of the kept expr contains reference to the column // i.e some of the remaining expr still reference to the column and not pullable + // For example where outer.col_a=1 and outer.col_a=(some nested subqueries) + // in this case outer.col_a=1 is pull up, but the access tracker must remain + // so later on we can tell "simple approach" is not enough, and continue with + // the "general approach". let removable = kept.iter().all(|e| { !e.exists(|e| { if let Expr::Column(col) = e { @@ -562,7 +605,12 @@ impl DependentJoinTracker { if removable { root_node.access_tracker.swap_remove(col_access); } - result.pulled_up_predicates.extend(pulled_up); + result + .pulled_up_predicates + .extend(pulled_up.iter().map(|e| PulledUpExpr { + expr: e.clone(), + subquery_node_id, + })); if kept.is_empty() { self.remove_node(root_node, child_node); return Ok(()); @@ -570,6 +618,8 @@ impl DependentJoinTracker { filter.predicate = conjunction(kept).unwrap(); } + // TODO: nested subqueries can also be linear with each other + // i.e select expr, (subquery1) where expr = subquery2 // LogicalPlan::Subquery(sq) => { // let descendent_id = child_node.children.get(0).unwrap(); // let mut descendent_node = self.nodes.get(descendent_id).unwrap().clone(); @@ -580,11 +630,11 @@ impl DependentJoinTracker { // )?; // self.nodes.insert(*descendent_id, descendent_node); // } - _ => { - // unimplemented!( - // "simple unnest is missing for this operator {}", - // child_node.plan - // ) + unsupported => { + unimplemented!( + "simple unnest is missing for this operator {}", + unsupported + ) } }; @@ -593,7 +643,7 @@ impl DependentJoinTracker { fn general_decorrelate( &mut self, - node: &mut Operator, + node: &mut Node, unnesting: &mut Unnesting, outer_refs_from_parent: &mut IndexSet, ) -> Result<()> { @@ -817,13 +867,13 @@ impl DependentJoinTracker { // } // Ok(()) } - fn right_owned(&mut self, node: &Operator) -> Operator { + fn right_owned(&mut self, node: &Node) -> Node { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first let node_id = node.children.first().unwrap(); return self.nodes.swap_remove(node_id).unwrap(); } - fn left_owned(&mut self, node: &Operator) -> Operator { + fn left_owned(&mut self, node: &Node) -> Node { assert_eq!(2, node.children.len()); // during the building of the tree, the subquery (right node) is always traversed first let node_id = node.children.last().unwrap(); @@ -843,13 +893,6 @@ impl DependentJoinTracker { domain: node.plan.clone(), // dummy }; - let mut outer_refs = node.access_tracker.clone(); - // let unnesting = Unnesting { - // info: Arc::new(unnesting), - // equivalences: UnionFind::new(), - // replaces: IndexMap::new(), - // }; - self.dependent_join_elimination(node.id, &unnesting_info, &mut IndexSet::new()) } @@ -857,149 +900,224 @@ impl DependentJoinTracker { let node = self.nodes.get(&node_id).unwrap(); node.access_tracker.iter().collect() } - // fn new_dependent_join(&self, node: &Operator) -> DependentJoin { - // DependentJoin { - // original_expr: node.plan.clone(), - // left: self.left(node).clone(), - // right: self.right(node).clone(), - // join_conditions: vec![], - // } - // } - fn get_subquery_children( + fn get_children_subquery_ids(&self, node: &Node) -> Vec { + return node.children[..node.children.len() - 1].to_owned(); + } + + fn get_subquery_info( &self, - parent: &Operator, + parent: &Node, // because one dependent join node can have multiple subquery at a time sq_offset: usize, ) -> Result<(LogicalPlan, Subquery, SubqueryType)> { let subquery = parent.children.get(sq_offset).unwrap(); let sq_node = self.nodes.get(subquery).unwrap(); assert!(sq_node.is_subquery_node); - let query = sq_node.children.get(0).unwrap(); + let query = sq_node.children.first().unwrap(); let target_node = self.nodes.get(query).unwrap(); // let op = .clone(); if let LogicalPlan::Subquery(subquery) = sq_node.plan.clone() { - return Ok((target_node.plan.clone(), subquery, sq_node.subquery_type)); + Ok((target_node.plan.clone(), subquery, sq_node.subquery_type)) } else { - internal_err!("") + internal_err!( + "object construction error: subquery.plan is not with type Subquery" + ) } } - fn build_join_from_simple_decorrelation_result( + // this function is aware that multiple subqueries may exist inside the filter predicate + // and it tries it best to decorrelate all possible exprs, while leave the un-correlatable + // expr untouched + // + // Example of such expression + // `select * from outer_table where exists(select * from inner_table where ...) & col_b < complex_subquery` + // the relationship tree looks like this + // [1]dependent_join_node (filter exists(select * from inner_table where ...) & col_b < complex_subquery) + // | + // |- [2]simple_subquery + // |- [3]complex_subquery + // |- [4]outer_table scan + // After decorrelation, the relationship tree may be translated using 2 approaches + // Approach 1: Replace the left side of the join using the new input + // [1]dependent_join_node (filter col_b < complex_subquery) + // | + // |- [2]REMOVED + // |- [3]complex_subquery + // |- [4]markjoin <-------- This was modified + // |-outer_table scan + // |-inner_table scan + // + // Approach 2: Keep everything except for the decorrelated expressions, + // and add a new join above the original dependent join + // [NEW_NODE_ID] markjoin <----------------- This was added + // |-inner_table scan + // |-[1]dependent_join_node (filter col_b < complex_subquery) + // | + // |- [2]REMOVED + // |- [3]complex_subquery + // |- [4]outer_table scan + // The following uses approach 2 + // + // This function will returns a new Node object that is supposed to be the new root of the tree + fn build_join_from_simple_decorrelation_result_filter( &self, - dependent_join_node: &mut Operator, - mut ret: SimpleDecorrelationResult, - ) -> Result { - let (subquery_children, subquery, sq_type) = - self.get_subquery_children(dependent_join_node, 0)?; - let outer_relations: Vec<&str> = dependent_join_node - .correlated_relations + dependent_join_node: &mut Node, + outer_relations: &[String], + ret: &mut SimpleDecorrelationResult, + mut filter: Filter, + ) -> Result { + let subquery_node_ids = self.get_children_subquery_ids(dependent_join_node); + let subquery_node_alias_map: IndexMap = subquery_node_ids .iter() - .map(String::as_str) + .map(|id| { + let subquery_node = self.nodes.get(id).unwrap(); + let subquery_alias = self + .alias_generator + .next(&subquery_node.subquery_type.prefix()); + (subquery_alias, subquery_node) + }) .collect(); - match dependent_join_node.plan { - LogicalPlan::Filter(ref mut filter) => { - let predicate_expr = split_conjunction(&filter.predicate); - let mut join_predicates = vec![]; - let mut post_join_predicates = vec![]; - // maybe we also need to collect join columns here - // TODO: we need to also pull up projectoin to support subqueries that appear - // in select expressions - let pulled_projection: Vec = ret - .pulled_up_projections + ret.rewrite_all_pulled_up_expr(&subquery_node_alias_map, &outer_relations)?; + for (subquery_alias, subquery_node) in subquery_node_alias_map.iter() { + let input_plan = filter.input.as_ref().clone(); + let mut join_predicates = vec![]; + let mut post_join_predicates = vec![]; // this loop heavily assume that all subqueries belong to the same `dependent_join_node` + let sq_type = subquery_node.subquery_type; + let subquery = if let LogicalPlan::Subquery(subquery) = &subquery_node.plan { + Ok(subquery) + } else { + internal_err!( + "object construction error: subquery.plan is not with type Subquery" + ) + }?; + let subquery_children = self + .nodes + .get(subquery_node.children.first().unwrap()) + .unwrap() + .plan + .clone(); + + let predicate_expr = split_conjunction(&filter.predicate); + + // maybe we also need to collect join columns here + // TODO: we need to also pull up projectoin to support subqueries that appear + // in select expressions + let pulled_projection: Vec = ret + .pulled_up_projections + .iter() + .cloned() + .map(|pe| strip_outer_reference(pe.expr)) + .collect(); + let right_exprs: Vec = if ret.pulled_up_projections.is_empty() { + subquery_children.expressions() + } else { + ret.pulled_up_projections .iter() .cloned() - .map(strip_outer_reference) - .collect(); - let right_exprs: Vec = if ret.pulled_up_projections.is_empty() { - subquery_children.expressions() - } else { - ret.pulled_up_projections - .iter() - .cloned() - .map(strip_outer_reference) - .collect() - }; - let mut join_type = sq_type.default_join_type(); - let alias_name = self.alias_generator.next(&sq_type.prefix()).to_string(); - ret.rewrite_all_pulled_up_expr(&alias_name, &outer_relations)?; + .map(|pe| strip_outer_reference(pe.expr)) + .collect() + }; + let mut join_type = sq_type.default_join_type(); - for expr in predicate_expr.into_iter() { - // exist query may not have any transformed expr - // i.e where exists(suquery) => semi join - let (transformed, maybe_join_predicate, maybe_post_join_predicate) = - extract_join_metadata_from_subquery( - expr, - &subquery, - &right_exprs, - &alias_name, - &outer_relations, - )?; + for expr in predicate_expr.into_iter() { + // exist query may not have any transformed expr + // i.e where exists(suquery) => semi join + let (transformed, maybe_join_predicate, maybe_post_join_predicate) = + extract_join_metadata_from_subquery( + expr, + &subquery, + &right_exprs, + &subquery_alias, + &outer_relations, + )?; - if let Some(transformed) = maybe_join_predicate { - println!("join predicate is {}", transformed.clone()); - join_predicates.push(transformed) - } - if let Some(post_join_expr) = maybe_post_join_predicate { - if post_join_expr - .exists(|e| { - if let Expr::Column(col) = e { - return Ok(col.name == "mark"); - } - return Ok(false); - }) - .unwrap() - { - // only use mark join if required - join_type = JoinType::LeftMark - } - post_join_predicates.push(post_join_expr) - } - if !transformed { - post_join_predicates.push(expr.clone()) + if let Some(transformed) = maybe_join_predicate { + join_predicates.push(transformed) + } + if let Some(post_join_expr) = maybe_post_join_predicate { + if post_join_expr + .exists(|e| { + if let Expr::Column(col) = e { + return Ok(col.name == "mark"); + } + return Ok(false); + }) + .unwrap() + { + // only use mark join if required + join_type = JoinType::LeftMark } + post_join_predicates.push(post_join_expr) } + if !transformed { + post_join_predicates.push(expr.clone()) + } + } + let new_predicates = ret + .pulled_up_predicates + .iter() + .map(|e| strip_outer_reference(e.expr.clone())); - let new_predicates = ret - .pulled_up_predicates - .iter() - .map(|e| strip_outer_reference(e.clone())); - - join_predicates.extend(new_predicates); - // TODO: some predicate is join predicate, some is just filter - // kept_predicates.extend(new_predicates); - // filter.predicate = conjunction(kept_predicates).unwrap(); - // left + join_predicates.extend(new_predicates); - let mut right = LogicalPlanBuilder::new(subquery_children) - .alias(&alias_name)? - .build()?; - let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); + let mut right = LogicalPlanBuilder::new(subquery_children) + .alias(subquery_alias)? + .build()?; + let mut builder = LogicalPlanBuilder::new(*filter.input); - builder = if join_predicates.is_empty() { - builder.join_on(right, join_type, vec![lit(true)])? - } else { - builder.join_on( - right, - // TODO: join type based on filter condition - join_type, - join_predicates, - )? - }; + builder = if join_predicates.is_empty() { + builder.join_on(right, join_type, vec![lit(true)])? + } else { + builder.join_on( + right, + // TODO: join type based on filter condition + join_type, + join_predicates, + )? + }; - if post_join_predicates.len() > 0 { - builder = - builder.filter(conjunction(post_join_predicates).unwrap())? - } - builder.build() + if post_join_predicates.len() > 0 { + builder = builder.filter(conjunction(post_join_predicates).unwrap())? } + let temp_plan = builder.build()?; + filter.input = Arc::new(temp_plan); + // self.remove_node(parent, node); + // TODO: filter predicate is kept + // remove this subquery node from the map + // remove this subquery node from the current dependent join node + // update the dependent join node input + println!("temp plan\n{}", plan); + } + Ok(plan) + } + + fn build_join_from_simple_decorrelation_result( + &self, + dependent_join_node: &mut Node, + ret: &mut SimpleDecorrelationResult, + ) -> Result { + let outer_relations: Vec = dependent_join_node + .correlated_relations + .iter() + .cloned() + .collect(); + + match dependent_join_node.plan.clone() { + LogicalPlan::Filter(filter) => self + .build_join_from_simple_decorrelation_result_filter( + dependent_join_node, + &outer_relations, + ret, + filter, + ), _ => { unimplemented!() } } } - fn build_domain(&self, node: &Operator, left: &Operator) -> Result { + fn build_domain(&self, node: &Node, left: &Node) -> Result { let unique_outer_refs: Vec = node .access_tracker .iter() @@ -1023,39 +1141,55 @@ impl DependentJoinTracker { fn dependent_join_elimination( &mut self, - node: usize, + dependent_join_node_id: usize, unnesting: &UnnestingInfo, outer_refs_from_parent: &mut IndexSet, ) -> Result { let parent = unnesting.parent.clone(); - let mut root_node = self.nodes.swap_remove(&node).unwrap(); - let simple_unnest_result = self.simple_decorrelation(&mut root_node)?; - let (original_subquery, _, subquery_type) = - self.get_subquery_children(&root_node, 0)?; - if root_node.access_tracker.is_empty() { + let mut dependent_join_node = + self.nodes.swap_remove(&dependent_join_node_id).unwrap(); + + assert!(dependent_join_node.is_dependent_join_node); + + let mut simple_unnesting = SimpleDecorrelationResult { + pulled_up_predicates: vec![], + pulled_up_projections: IndexSet::new(), + }; + + self.simple_decorrelation(&mut dependent_join_node, &mut simple_unnesting)?; + if dependent_join_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation // replace them with the expr store inside parent.replaces unimplemented!("simple dependent join not implemented for the case of recursive subquery"); self.general_decorrelate( - &mut root_node, + &mut dependent_join_node, &mut parent.unwrap(), outer_refs_from_parent, )?; - return Ok(root_node.plan.clone()); + return Ok(dependent_join_node.plan.clone()); } return self.build_join_from_simple_decorrelation_result( - &mut root_node, - simple_unnest_result, + &mut dependent_join_node, + &mut simple_unnesting, ); - unimplemented!() - // return Ok(dependent_join); + } else { + // TODO: some of the expr was removed and expect to be pulled up in a best effort fashion + // (i.e partially decorrelate) + } + if self.get_children_subquery_ids(&dependent_join_node).len() > 1 { + unimplemented!( + "general decorrelation for multiple subqueries in the same node" + ) } + // for children_offset in self.get_children_subquery_ids(&dependent_join_node) { + let (original_subquery, _, subquery_type) = + self.get_subquery_info(&dependent_join_node, 0)?; // let mut join = self.new_dependent_join(&root_node); // TODO: handle the case where one dependent join node contains multiple subqueries - let mut left = self.left_owned(&root_node); - let mut right = self.right_owned(&root_node); + let mut left = self.left_owned(&dependent_join_node); + let mut right = self.right_owned(&dependent_join_node); if parent.is_some() { unimplemented!(""); // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) @@ -1082,7 +1216,7 @@ impl DependentJoinTracker { // TODO: after imple simple_decorrelation, rewrite the projection pushed up column as well } let domain = match parent { - None => self.build_domain(&root_node, &left)?, + None => self.build_domain(&dependent_join_node, &left)?, Some(info) => { unimplemented!() } @@ -1091,13 +1225,10 @@ impl DependentJoinTracker { let new_unnesting_info = UnnestingInfo { parent: parent.clone(), domain, - // join: join.clone(), - // domain: vec![], // TODO: populate me }; let mut unnesting = Unnesting { original_subquery, info: Arc::new(new_unnesting_info.clone()), - join_conditions: vec![], equivalences: UnionFind { parent: IndexMap::new(), rank: IndexMap::new(), @@ -1105,12 +1236,13 @@ impl DependentJoinTracker { replaces: IndexMap::new(), pulled_up_columns: vec![], pulled_up_predicates: vec![], - count_exprs_detected: IndexSet::new(), // outer_col_ref_map: HashMap::new(), - need_handle_count_bug: true, // TODO + count_exprs_detected: IndexSet::new(), + need_handle_count_bug: true, // TODO subquery_type, decorrelated_subquery: None, }; - let mut accesses: IndexSet = root_node.access_tracker.clone(); + let mut accesses: IndexSet = + dependent_join_node.access_tracker.clone(); // .iter() // .map(|a| a.col.clone()) // .collect(); @@ -1130,12 +1262,13 @@ impl DependentJoinTracker { //TODO: add equivalences from join.condition to unnest.cclasses self.general_decorrelate(&mut right, &mut unnesting, &mut accesses)?; let decorrelated_plan = self.build_join_from_general_unnesting_info( - &mut root_node, + &mut dependent_join_node, &mut left, &mut right, unnesting, )?; return Ok(decorrelated_plan); + // } // self.nodes.insert(left.id, left); // self.nodes.insert(right.id, right); @@ -1150,9 +1283,9 @@ impl DependentJoinTracker { fn build_join_from_general_unnesting_info( &self, - dependent_join_node: &mut Operator, - left_node: &mut Operator, - decorrelated_right_node: &mut Operator, + dependent_join_node: &mut Node, + left_node: &mut Node, + decorrelated_right_node: &mut Node, mut unnesting: Unnesting, ) -> Result { let subquery = unnesting.decorrelated_subquery.take().unwrap(); @@ -1160,10 +1293,10 @@ impl DependentJoinTracker { let subquery_type = unnesting.subquery_type; let alias = self.alias_generator.next(&subquery_type.prefix()); - let outer_relations: Vec<&str> = dependent_join_node + let outer_relations: Vec = dependent_join_node .correlated_relations .iter() - .map(String::as_str) + .cloned() .collect(); unnesting.rewrite_all_pulled_up_expr(&alias, &outer_relations)?; @@ -1276,21 +1409,16 @@ impl DependentJoinTracker { } Ok(()) } - fn get_node_uncheck(&self, node_id: &usize) -> Operator { + fn get_node_uncheck(&self, node_id: &usize) -> Node { self.nodes.get(node_id).unwrap().clone() } fn simple_decorrelation( &mut self, - node: &mut Operator, - ) -> Result { - let mut result = SimpleDecorrelationResult { - // new: None, - pulled_up_projections: IndexSet::new(), - pulled_up_predicates: vec![], - }; - - // the iteration should happen with the order of bottom up, so any node push up won't + node: &mut Node, + simple_unnesting: &mut SimpleDecorrelationResult, + ) -> Result<()> { + // the iteration should happen with the order of bottom up, so any node pull up won't // affect its children (by accident) let accesses_bottom_up = node.access_tracker.clone().sorted_by(|a, b| { if a.node_id < b.node_id { @@ -1308,14 +1436,14 @@ impl DependentJoinTracker { node, &mut descendent, &col_access, - &mut result, + simple_unnesting, )?; // TODO: find a nicer way to do in-place update // self.nodes.insert(node_id, parent_node.clone()); self.nodes.insert(col_access.node_id, descendent); } - Ok(result) + Ok(()) } } @@ -1514,7 +1642,7 @@ impl ColumnAccess { } } #[derive(Debug, Clone)] -struct Operator { +struct Node { id: usize, plan: LogicalPlan, parent: Option, @@ -1601,7 +1729,9 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { if contains_subquery(&f.predicate) { is_dependent_join_node = true; } + println!("debug predicate {}", f.predicate); f.predicate.outer_column_refs().into_iter().for_each(|f| { + println!("outer column ref {}", f); self.mark_column_access(self.current_id, f); }); } @@ -1674,7 +1804,7 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { self.stack.push(self.current_id); self.nodes.insert( self.current_id, - Operator { + Node { id: self.current_id, parent, plan: node.clone(), @@ -1737,6 +1867,59 @@ mod tests { use super::DependentJoinTracker; use arrow::datatypes::DataType as ArrowDataType; + #[test] + fn simple_1_level_subquery_in_from_expr() -> Result<()> { + unimplemented!() + } + #[test] + fn simple_1_level_subquery_in_selection_expr() -> Result<()> { + unimplemented!() + } + #[test] + fn complex_1_level_decorrelate_2_subqueries_at_the_same_level() -> Result<()> { + unimplemented!() + } + #[test] + fn simple_1_level_decorrelate_2_subqueries_at_the_same_level() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let in_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(2)))? + .project(vec![col("inner_table_lv1.a")])? + .build()?, + ); + let exist_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + )? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(exists(exist_sq_level1)) + .and(in_subquery(col("outer_table.b"), in_sq_level1)), + )? + .build()?; + let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + index.build(input1)?; + println!("{:?}", index); + let new_plan = index.root_dependent_join_elimination()?; + println!("{}", new_plan); + let expected = "\ + Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ + \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ + \n TableScan: outer_table\ + \n SubqueryAlias: __exists_sq_1\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; + assert_eq!(expected, format!("{new_plan}")); + Ok(()) + } #[test] fn complex_1_level_decorrelate_in_subquery_with_count() -> Result<()> { @@ -1930,3 +2113,21 @@ mod tests { Ok(()) } } + +// filter col < subquery1 & col < subquery 2 +// 1.subquery +// (table inner scan) +// ------------------ +// post joint +// join +// table scan +// inner table scan +// items todo: +// create a new plan, set this new plan = parent's input +// replace parent's last children with this plan + +// create new operator and replace parent's last children +// maybe invoke indexing for this new branch + +// 2.subquery2 +// 3.table scan diff --git a/datafusion/sqllogictest/test_files/debug_count.slt b/datafusion/sqllogictest/test_files/debug_count.slt new file mode 100644 index 000000000000..d52df0afba83 --- /dev/null +++ b/datafusion/sqllogictest/test_files/debug_count.slt @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# make sure to a batch size smaller than row number of the table. +statement ok +set datafusion.execution.batch_size = 2; + +############# +## Subquery Tests +############# + + +############# +## Setup test data table +############# +# there tables for subquery +statement ok +CREATE TABLE t0(t0_id INT, t0_name TEXT, t0_int INT) AS VALUES +(11, 'o', 6), +(22, 'p', 7), +(33, 'q', 8), +(44, 'r', 9); + +statement ok +CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES +(11, 'a', 1), +(22, 'b', 2), +(33, 'c', 3), +(44, 'd', 4); + +statement ok +CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES +(11, 'z', 3), +(22, 'y', 1), +(44, 'x', 3), +(55, 'w', 3); + +statement ok +CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES +(11, 'e', 3), +(22, 'f', 1), +(44, 'g', 3), +(55, 'h', 3); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS customer ( + c_custkey BIGINT, + c_name VARCHAR, + c_address VARCHAR, + c_nationkey BIGINT, + c_phone VARCHAR, + c_acctbal DECIMAL(15, 2), + c_mktsegment VARCHAR, + c_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/customer.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS orders ( + o_orderkey BIGINT, + o_custkey BIGINT, + o_orderstatus VARCHAR, + o_totalprice DECIMAL(15, 2), + o_orderdate DATE, + o_orderpriority VARCHAR, + o_clerk VARCHAR, + o_shippriority INTEGER, + o_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/orders.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( + l_orderkey BIGINT, + l_partkey BIGINT, + l_suppkey BIGINT, + l_linenumber INTEGER, + l_quantity DECIMAL(15, 2), + l_extendedprice DECIMAL(15, 2), + l_discount DECIMAL(15, 2), + l_tax DECIMAL(15, 2), + l_returnflag VARCHAR, + l_linestatus VARCHAR, + l_shipdate DATE, + l_commitdate DATE, + l_receiptdate DATE, + l_shipinstruct VARCHAR, + l_shipmode VARCHAR, + l_comment VARCHAR, +) STORED AS CSV LOCATION '../core/tests/tpch-csv/lineitem.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); + + +#correlated_scalar_subquery_count_agg +query TT +explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 +---- +logical_plan +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) +02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int +03)----TableScan: t1 projection=[t1_id, t1_int] +04)----SubqueryAlias: __scalar_sq_1 +05)------Projection: count(Int64(1)) AS count(*), t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] +07)----------TableScan: t2 projection=[t2_int] From 67923d4cb6f136b6d3ed76bd86a5e0585ec5b760 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 19 May 2025 06:41:02 +0200 Subject: [PATCH 20/70] feat: support multiple subqueries decorrelation untested --- .../optimizer/src/decorrelate_general.rs | 308 ++++++++++++------ 1 file changed, 209 insertions(+), 99 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 510ce44fed9f..98626f063bd1 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -289,7 +289,7 @@ impl SimpleDecorrelationResult { // because we don't track which expr was pullled up for which relation to give alias for fn rewrite_all_pulled_up_expr( &mut self, - subquery_node_alias_map: &IndexMap, + subquery_node_alias_map: &IndexMap, outer_relations: &[String], ) -> Result<()> { let alias_by_subquery_node_id: IndexMap = subquery_node_alias_map @@ -593,7 +593,7 @@ impl DependentJoinTracker { // in this case outer.col_a=1 is pull up, but the access tracker must remain // so later on we can tell "simple approach" is not enough, and continue with // the "general approach". - let removable = kept.iter().all(|e| { + let can_pull_up = kept.iter().all(|e| { !e.exists(|e| { if let Expr::Column(col) = e { return Ok(*col == col_access.col); @@ -602,9 +602,10 @@ impl DependentJoinTracker { }) .unwrap() }); - if removable { - root_node.access_tracker.swap_remove(col_access); + if !can_pull_up { + return Ok(()); } + root_node.access_tracker.swap_remove(col_access); result .pulled_up_predicates .extend(pulled_up.iter().map(|e| PulledUpExpr { @@ -925,6 +926,74 @@ impl DependentJoinTracker { } } + // Rewrite from + // TopNodeParent + // | + // TopNode + // |-SubqueryNode -----> This was decorelated + // | |- SubqueryInputNode + // |-SubqueryNode2 + // |-SomeTableScan + // + // Into + // TopNodeParent + // | + // NewTopNode <-------- This was added + // | + // |----TopNode + // | |-SubqueryNode2 + // | |-SomeTableScan + // | + // |----SubqueryInputNode + fn create_new_top_node<'a>( + &'a mut self, + new_plan: LogicalPlan, + current_top_node: &mut Node, + mut subquery_input_node: Node, + post_join_predicates: Option, + ) -> Result { + let mut new_node = self.new_empty_node(new_plan); + + if let Some(parent) = current_top_node.parent { + unimplemented!() + } + subquery_input_node.parent = Some(new_node.id); + new_node.children = vec![current_top_node.id, subquery_input_node.id]; + let mut node_id = new_node.id; + if let Some(expr) = post_join_predicates { + let new_plan = LogicalPlanBuilder::new(new_node.plan.clone()) + .filter(expr)? + .build()?; + let new_node = self.new_empty_node(new_plan); + new_node.parent = Some(node_id); + new_node.children = vec![node_id]; + node_id = new_node.id; + } + + self.root = Some(node_id); + self.nodes + .insert(subquery_input_node.id, subquery_input_node); + + Ok(self.nodes.swap_remove(&node_id).unwrap()) + } + fn new_empty_node<'a>(&'a mut self, plan: LogicalPlan) -> &'a mut Node { + self.current_id = self.current_id + 1; + let node_id = self.current_id; + let new_node = Node { + id: node_id, + plan, + parent: None, + is_subquery_node: false, + is_dependent_join_node: false, + children: vec![], + access_tracker: IndexSet::new(), + subquery_type: SubqueryType::None, + correlated_relations: IndexSet::new(), + }; + self.nodes.insert(node_id, new_node); + self.nodes.get_mut(&node_id).unwrap() + } + // this function is aware that multiple subqueries may exist inside the filter predicate // and it tries it best to decorrelate all possible exprs, while leave the un-correlatable // expr untouched @@ -932,14 +1001,18 @@ impl DependentJoinTracker { // Example of such expression // `select * from outer_table where exists(select * from inner_table where ...) & col_b < complex_subquery` // the relationship tree looks like this - // [1]dependent_join_node (filter exists(select * from inner_table where ...) & col_b < complex_subquery) + // [0] some parent node + // | + // -[1]dependent_join_node (filter exists(select * from inner_table where ...) & col_b < complex_subquery) // | // |- [2]simple_subquery // |- [3]complex_subquery // |- [4]outer_table scan // After decorrelation, the relationship tree may be translated using 2 approaches // Approach 1: Replace the left side of the join using the new input - // [1]dependent_join_node (filter col_b < complex_subquery) + // [0] some parent node + // | + // -[1]dependent_join_node (filter col_b < complex_subquery) // | // |- [2]REMOVED // |- [3]complex_subquery @@ -949,7 +1022,10 @@ impl DependentJoinTracker { // // Approach 2: Keep everything except for the decorrelated expressions, // and add a new join above the original dependent join - // [NEW_NODE_ID] markjoin <----------------- This was added + // [0] some parent node + // | + // -[NEW_NODE_ID] markjoin <----------------- This was added + // | // |-inner_table scan // |-[1]dependent_join_node (filter col_b < complex_subquery) // | @@ -958,19 +1034,27 @@ impl DependentJoinTracker { // |- [4]outer_table scan // The following uses approach 2 // - // This function will returns a new Node object that is supposed to be the new root of the tree + // If decorrelation happen, this function will returns a new Node object that is supposed to be the new root of the tree fn build_join_from_simple_decorrelation_result_filter( - &self, - dependent_join_node: &mut Node, + &mut self, + mut dependent_join_node: Node, outer_relations: &[String], ret: &mut SimpleDecorrelationResult, - mut filter: Filter, - ) -> Result { - let subquery_node_ids = self.get_children_subquery_ids(dependent_join_node); - let subquery_node_alias_map: IndexMap = subquery_node_ids + ) -> Result<()> { + let still_correlated_sq_ids: Vec = dependent_join_node + .access_tracker .iter() + .map(|ac| ac.stack[1]) + .unique() + .collect(); + + let decorrelated_sq_ids = self + .get_children_subquery_ids(&dependent_join_node) + .into_iter() + .filter(|n| still_correlated_sq_ids.contains(n)); + let subquery_node_alias_map: IndexMap = decorrelated_sq_ids .map(|id| { - let subquery_node = self.nodes.get(id).unwrap(); + let subquery_node = self.nodes.swap_remove(&id).unwrap(); let subquery_alias = self .alias_generator .next(&subquery_node.subquery_type.prefix()); @@ -979,10 +1063,43 @@ impl DependentJoinTracker { .collect(); ret.rewrite_all_pulled_up_expr(&subquery_node_alias_map, &outer_relations)?; - for (subquery_alias, subquery_node) in subquery_node_alias_map.iter() { - let input_plan = filter.input.as_ref().clone(); + let mut pullup_projection_by_sq_id: IndexMap> = ret + .pulled_up_projections + .iter() + .fold(IndexMap::>::new(), |mut acc, e| { + acc.entry(e.subquery_node_id) + .or_default() + .push(e.expr.clone()); + acc + }); + let mut pullup_predicate_by_sq_id: IndexMap> = ret + .pulled_up_predicates + .iter() + .fold(IndexMap::>::new(), |mut acc, e| { + acc.entry(e.subquery_node_id) + .or_default() + .push(e.expr.clone()); + acc + }); + let mut filter = + if let LogicalPlan::Filter(filter) = dependent_join_node.plan.clone() { + filter + } else { + return internal_err!("dependent join node is not a filter"); + }; + + let dependent_join_node_id = dependent_join_node.id; + let mut top_node = dependent_join_node; + + for (subquery_alias, subquery_node) in subquery_node_alias_map { + let subquery_input_node = self + .nodes + .swap_remove(subquery_node.children.first().unwrap()) + .unwrap(); + let subquery_input_plan = subquery_input_node.plan.clone(); let mut join_predicates = vec![]; let mut post_join_predicates = vec![]; // this loop heavily assume that all subqueries belong to the same `dependent_join_node` + let mut remained_predicates = vec![]; let sq_type = subquery_node.subquery_type; let subquery = if let LogicalPlan::Subquery(subquery) = &subquery_node.plan { Ok(subquery) @@ -991,34 +1108,16 @@ impl DependentJoinTracker { "object construction error: subquery.plan is not with type Subquery" ) }?; - let subquery_children = self - .nodes - .get(subquery_node.children.first().unwrap()) - .unwrap() - .plan - .clone(); + let mut join_type = sq_type.default_join_type(); let predicate_expr = split_conjunction(&filter.predicate); - // maybe we also need to collect join columns here - // TODO: we need to also pull up projectoin to support subqueries that appear - // in select expressions - let pulled_projection: Vec = ret - .pulled_up_projections - .iter() - .cloned() - .map(|pe| strip_outer_reference(pe.expr)) - .collect(); - let right_exprs: Vec = if ret.pulled_up_projections.is_empty() { - subquery_children.expressions() - } else { - ret.pulled_up_projections - .iter() - .cloned() - .map(|pe| strip_outer_reference(pe.expr)) - .collect() - }; - let mut join_type = sq_type.default_join_type(); + let pulled_up_projections = pullup_projection_by_sq_id + .swap_remove(&subquery_node.id) + .unwrap_or(vec![]); + let pulled_up_predicates = pullup_predicate_by_sq_id + .swap_remove(&subquery_node.id) + .unwrap_or(vec![]); for expr in predicate_expr.into_iter() { // exist query may not have any transformed expr @@ -1027,13 +1126,13 @@ impl DependentJoinTracker { extract_join_metadata_from_subquery( expr, &subquery, - &right_exprs, + &subquery_input_plan.expressions(), &subquery_alias, &outer_relations, )?; if let Some(transformed) = maybe_join_predicate { - join_predicates.push(transformed) + join_predicates.push(strip_outer_reference(transformed)); } if let Some(post_join_expr) = maybe_post_join_predicate { if post_join_expr @@ -1048,23 +1147,23 @@ impl DependentJoinTracker { // only use mark join if required join_type = JoinType::LeftMark } - post_join_predicates.push(post_join_expr) + post_join_predicates.push(strip_outer_reference(post_join_expr)) } if !transformed { - post_join_predicates.push(expr.clone()) + remained_predicates.push(expr.clone()); } } - let new_predicates = ret - .pulled_up_predicates - .iter() - .map(|e| strip_outer_reference(e.expr.clone())); - join_predicates.extend(new_predicates); + join_predicates + .extend(pulled_up_predicates.into_iter().map(strip_outer_reference)); + filter.predicate = conjunction(remained_predicates).unwrap(); - let mut right = LogicalPlanBuilder::new(subquery_children) + // building new join node + // let left = top_node.plan.clone(); + let mut right = LogicalPlanBuilder::new(subquery_input_plan) .alias(subquery_alias)? .build()?; - let mut builder = LogicalPlanBuilder::new(*filter.input); + let mut builder = LogicalPlanBuilder::empty(false); builder = if join_predicates.is_empty() { builder.join_on(right, join_type, vec![lit(true)])? @@ -1077,24 +1176,48 @@ impl DependentJoinTracker { )? }; - if post_join_predicates.len() > 0 { - builder = builder.filter(conjunction(post_join_predicates).unwrap())? - } - let temp_plan = builder.build()?; - filter.input = Arc::new(temp_plan); - // self.remove_node(parent, node); - // TODO: filter predicate is kept - // remove this subquery node from the map - // remove this subquery node from the current dependent join node - // update the dependent join node input - println!("temp plan\n{}", plan); + let new_plan = builder.build()?; + let new_top_node = self.create_new_top_node( + new_plan, + &mut top_node, + subquery_input_node, + conjunction(post_join_predicates), + // TODO: post join projection + )?; + self.nodes.insert(top_node.id, top_node); + top_node = new_top_node; + } + self.nodes.insert(top_node.id, top_node); + self.nodes.get_mut(&dependent_join_node_id).unwrap().plan = + LogicalPlan::Filter((filter)); + + Ok(()) + } + fn rewrite_node(&mut self, node_id: usize) -> Result { + let mut node = self.nodes.swap_remove(&node_id).unwrap(); + assert!( + !node.is_subquery_node, + "calling on rewrite_node while still exists subquery in the tree" + ); + if node.children.is_empty() { + return Ok(node.plan); } - Ok(plan) + let new_children = node + .children + .iter() + .map(|c| self.rewrite_node(*c)) + .collect::>>()?; + node.plan + .with_new_exprs(node.plan.expressions(), new_children) + } + + fn rewrite_from_root(&mut self) -> Result { + self.rewrite_node(self.root.unwrap()) } fn build_join_from_simple_decorrelation_result( - &self, - dependent_join_node: &mut Node, + &mut self, + mut dependent_join_node: Node, ret: &mut SimpleDecorrelationResult, ) -> Result { let outer_relations: Vec = dependent_join_node @@ -1104,13 +1227,14 @@ impl DependentJoinTracker { .collect(); match dependent_join_node.plan.clone() { - LogicalPlan::Filter(filter) => self - .build_join_from_simple_decorrelation_result_filter( + LogicalPlan::Filter(filter) => { + self.build_join_from_simple_decorrelation_result_filter( dependent_join_node, &outer_relations, ret, - filter, - ), + )?; + self.rewrite_from_root() + } _ => { unimplemented!() } @@ -1156,7 +1280,8 @@ impl DependentJoinTracker { pulled_up_projections: IndexSet::new(), }; - self.simple_decorrelation(&mut dependent_join_node, &mut simple_unnesting)?; + dependent_join_node = + self.simple_decorrelation(dependent_join_node, &mut simple_unnesting)?; if dependent_join_node.access_tracker.is_empty() { if parent.is_some() { // for each projection of outer column moved up by simple_decorrelation @@ -1169,10 +1294,7 @@ impl DependentJoinTracker { )?; return Ok(dependent_join_node.plan.clone()); } - return self.build_join_from_simple_decorrelation_result( - &mut dependent_join_node, - &mut simple_unnesting, - ); + return self.rewrite_from_root(); } else { // TODO: some of the expr was removed and expect to be pulled up in a best effort fashion // (i.e partially decorrelate) @@ -1413,11 +1535,17 @@ impl DependentJoinTracker { self.nodes.get(node_id).unwrap().clone() } + // Decorrelate the current node using `simple` approach. + // It will consume the node and returns a new node where the decorrelatoin should continue + // using `general` approach, should `simple` approach is not sufficient. + // Most of the time the same Node is returned, avoid using &mut Node because of borrow checker + // Beware that after calling this function, the root node may be changed (as new join node being added to the top) fn simple_decorrelation( &mut self, - node: &mut Node, + mut node: Node, simple_unnesting: &mut SimpleDecorrelationResult, - ) -> Result<()> { + ) -> Result { + let node_id = node.id; // the iteration should happen with the order of bottom up, so any node pull up won't // affect its children (by accident) let accesses_bottom_up = node.access_tracker.clone().sorted_by(|a, b| { @@ -1433,17 +1561,17 @@ impl DependentJoinTracker { // let mut descendent = self.get_node_uncheck(&col_access.node_id); let mut descendent = self.nodes.swap_remove(&col_access.node_id).unwrap(); self.try_simple_decorrelate_descendent( - node, + &mut node, &mut descendent, &col_access, simple_unnesting, )?; // TODO: find a nicer way to do in-place update - // self.nodes.insert(node_id, parent_node.clone()); self.nodes.insert(col_access.node_id, descendent); } + self.build_join_from_simple_decorrelation_result(node, simple_unnesting)?; - Ok(()) + Ok(self.nodes.swap_remove(&node_id).unwrap()) } } @@ -2113,21 +2241,3 @@ mod tests { Ok(()) } } - -// filter col < subquery1 & col < subquery 2 -// 1.subquery -// (table inner scan) -// ------------------ -// post joint -// join -// table scan -// inner table scan -// items todo: -// create a new plan, set this new plan = parent's input -// replace parent's last children with this plan - -// create new operator and replace parent's last children -// maybe invoke indexing for this new branch - -// 2.subquery2 -// 3.table scan From 64538cc92721a523eadbb9de733d94358aaab1eb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 19 May 2025 15:52:09 +0200 Subject: [PATCH 21/70] feat: correct node rewriting rule --- .../optimizer/src/decorrelate_general.rs | 75 +++++++++++-------- 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 98626f063bd1..a533181f5771 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -953,10 +953,9 @@ impl DependentJoinTracker { post_join_predicates: Option, ) -> Result { let mut new_node = self.new_empty_node(new_plan); + let parent = current_top_node.parent.clone(); + let previous_node_id = new_node.id; - if let Some(parent) = current_top_node.parent { - unimplemented!() - } subquery_input_node.parent = Some(new_node.id); new_node.children = vec![current_top_node.id, subquery_input_node.id]; let mut node_id = new_node.id; @@ -969,6 +968,15 @@ impl DependentJoinTracker { new_node.children = vec![node_id]; node_id = new_node.id; } + if let Some(parent) = parent { + let parent_node = self.nodes.get_mut(&parent).unwrap(); + for child_id in parent_node.children.iter_mut() { + if *child_id == previous_node_id { + *child_id = node_id; + } + } + current_top_node.parent = Some(parent); + } self.root = Some(node_id); self.nodes @@ -1040,7 +1048,8 @@ impl DependentJoinTracker { mut dependent_join_node: Node, outer_relations: &[String], ret: &mut SimpleDecorrelationResult, - ) -> Result<()> { + mut filter: Filter, + ) -> Result<(Node)> { let still_correlated_sq_ids: Vec = dependent_join_node .access_tracker .iter() @@ -1051,9 +1060,13 @@ impl DependentJoinTracker { let decorrelated_sq_ids = self .get_children_subquery_ids(&dependent_join_node) .into_iter() - .filter(|n| still_correlated_sq_ids.contains(n)); + .filter(|n| !still_correlated_sq_ids.contains(n)); let subquery_node_alias_map: IndexMap = decorrelated_sq_ids .map(|id| { + dependent_join_node + .children + .retain(|current_children| *current_children != id); + let subquery_node = self.nodes.swap_remove(&id).unwrap(); let subquery_alias = self .alias_generator @@ -1081,12 +1094,6 @@ impl DependentJoinTracker { .push(e.expr.clone()); acc }); - let mut filter = - if let LogicalPlan::Filter(filter) = dependent_join_node.plan.clone() { - filter - } else { - return internal_err!("dependent join node is not a filter"); - }; let dependent_join_node_id = dependent_join_node.id; let mut top_node = dependent_join_node; @@ -1188,13 +1195,18 @@ impl DependentJoinTracker { top_node = new_top_node; } self.nodes.insert(top_node.id, top_node); - self.nodes.get_mut(&dependent_join_node_id).unwrap().plan = - LogicalPlan::Filter((filter)); + let mut dependent_join_node = + self.nodes.swap_remove(&dependent_join_node_id).unwrap(); + dependent_join_node.plan = LogicalPlan::Filter((filter)); - Ok(()) + Ok(dependent_join_node) } + fn rewrite_node(&mut self, node_id: usize) -> Result { let mut node = self.nodes.swap_remove(&node_id).unwrap(); + if node.is_subquery_node { + println!("{} {}", node.id, node.plan); + } assert!( !node.is_subquery_node, "calling on rewrite_node while still exists subquery in the tree" @@ -1219,7 +1231,7 @@ impl DependentJoinTracker { &mut self, mut dependent_join_node: Node, ret: &mut SimpleDecorrelationResult, - ) -> Result { + ) -> Result { let outer_relations: Vec = dependent_join_node .correlated_relations .iter() @@ -1227,14 +1239,13 @@ impl DependentJoinTracker { .collect(); match dependent_join_node.plan.clone() { - LogicalPlan::Filter(filter) => { - self.build_join_from_simple_decorrelation_result_filter( + LogicalPlan::Filter(filter) => self + .build_join_from_simple_decorrelation_result_filter( dependent_join_node, &outer_relations, ret, - )?; - self.rewrite_from_root() - } + filter, + ), _ => { unimplemented!() } @@ -1294,6 +1305,8 @@ impl DependentJoinTracker { )?; return Ok(dependent_join_node.plan.clone()); } + self.nodes + .insert(dependent_join_node.id, dependent_join_node); return self.rewrite_from_root(); } else { // TODO: some of the expr was removed and expect to be pulled up in a best effort fashion @@ -1569,9 +1582,7 @@ impl DependentJoinTracker { // TODO: find a nicer way to do in-place update self.nodes.insert(col_access.node_id, descendent); } - self.build_join_from_simple_decorrelation_result(node, simple_unnesting)?; - - Ok(self.nodes.swap_remove(&node_id).unwrap()) + self.build_join_from_simple_decorrelation_result(node, simple_unnesting) } } @@ -1857,9 +1868,7 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { if contains_subquery(&f.predicate) { is_dependent_join_node = true; } - println!("debug predicate {}", f.predicate); f.predicate.outer_column_refs().into_iter().for_each(|f| { - println!("outer column ref {}", f); self.mark_column_access(self.current_id, f); }); } @@ -2039,12 +2048,16 @@ mod tests { let new_plan = index.root_dependent_join_elimination()?; println!("{}", new_plan); let expected = "\ - Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ - \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ - \n TableScan: outer_table\ - \n SubqueryAlias: __exists_sq_1\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + LeftSemi Join: Filter: outer_table.b = __in_sq_2.a\ + \n Filter: __exists_sq_1.mark\ + \n LeftMark Join: Filter: Boolean(true)\ + \n Filter: outer_table.a > Int32(1)\ + \n TableScan: outer_table\ + \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1\ + \n Projection: inner_table_lv1.a\ + \n Filter: inner_table_lv1.c = Int32(2)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } From 957403fe927439fa18fc94e1277020cc75f4e012 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 19 May 2025 17:03:00 +0200 Subject: [PATCH 22/70] fix: subquery alias --- .../optimizer/src/decorrelate_general.rs | 159 +++++++++++------- 1 file changed, 99 insertions(+), 60 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index a533181f5771..8e8c7f592f7b 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -929,10 +929,10 @@ impl DependentJoinTracker { // Rewrite from // TopNodeParent // | - // TopNode + // (current_top_node) // |-SubqueryNode -----> This was decorelated - // | |- SubqueryInputNode - // |-SubqueryNode2 + // | |- (subquery_input_node) + // |-SubqueryNode2 -----> This is not yet decorrelated // |-SomeTableScan // // Into @@ -940,60 +940,112 @@ impl DependentJoinTracker { // | // NewTopNode <-------- This was added // | - // |----TopNode + // |----(current_top_node) // | |-SubqueryNode2 // | |-SomeTableScan // | - // |----SubqueryInputNode - fn create_new_top_node<'a>( + // |----(subquery_input_node) + fn create_new_join_node_on_top<'a>( &'a mut self, - new_plan: LogicalPlan, + subquery_alias: String, + join_type: JoinType, current_top_node: &mut Node, - mut subquery_input_node: Node, + subquery_input_node: Node, + join_predicates: Vec, post_join_predicates: Option, ) -> Result { - let mut new_node = self.new_empty_node(new_plan); - let parent = current_top_node.parent.clone(); - let previous_node_id = new_node.id; + self.nodes + .insert(subquery_input_node.id, subquery_input_node.clone()); + // Build the join node + let mut right = LogicalPlanBuilder::new(subquery_input_node.plan.clone()) + .alias(subquery_alias)? + .build()?; + let alias_node = self.insert_node_and_links( + right.clone(), + 0, + None, + vec![subquery_input_node.id], + ); + let right_node_id = alias_node.id; + // the left input does not matter, because later on the rewritting will happen using the pointers + // from top node, following the children using Node.chilren field + let mut builder = LogicalPlanBuilder::empty(false); + + builder = if join_predicates.is_empty() { + builder.join_on(right, join_type, vec![lit(true)])? + } else { + builder.join_on( + right, + // TODO: join type based on filter condition + join_type, + join_predicates, + )? + }; + + let join_node = builder.build()?; + + let upper_most_parent = current_top_node.parent.clone(); + let mut new_node = self.insert_node_and_links( + join_node, + current_top_node.id, + upper_most_parent, + vec![current_top_node.id, right_node_id], + ); + current_top_node.parent = Some(new_node.id); - subquery_input_node.parent = Some(new_node.id); - new_node.children = vec![current_top_node.id, subquery_input_node.id]; - let mut node_id = new_node.id; + let mut new_node_id = new_node.id; if let Some(expr) = post_join_predicates { let new_plan = LogicalPlanBuilder::new(new_node.plan.clone()) .filter(expr)? .build()?; - let new_node = self.new_empty_node(new_plan); - new_node.parent = Some(node_id); - new_node.children = vec![node_id]; - node_id = new_node.id; - } - if let Some(parent) = parent { - let parent_node = self.nodes.get_mut(&parent).unwrap(); - for child_id in parent_node.children.iter_mut() { - if *child_id == previous_node_id { - *child_id = node_id; - } - } - current_top_node.parent = Some(parent); + let new_node = self.insert_node_and_links( + new_plan, + new_node_id, + upper_most_parent, + vec![new_node_id], + ); + new_node_id = new_node.id; } - self.root = Some(node_id); - self.nodes - .insert(subquery_input_node.id, subquery_input_node); + self.root = Some(new_node_id); - Ok(self.nodes.swap_remove(&node_id).unwrap()) + Ok(self.nodes.swap_remove(&new_node_id).unwrap()) } - fn new_empty_node<'a>(&'a mut self, plan: LogicalPlan) -> &'a mut Node { + + // insert a new node, if any link of parent, children is mentioned + // also update the relationship in these remote nodes + fn insert_node_and_links<'a>( + &'a mut self, + plan: LogicalPlan, + // which node id in the parent should be replaced by this new node + swapped_node_id: usize, + parent: Option, + children: Vec, + ) -> &'a mut Node { self.current_id = self.current_id + 1; let node_id = self.current_id; + + // update parent + if let Some(parent_id) = parent { + for child_id in self.nodes.get_mut(&parent_id).unwrap().children.iter_mut() { + if *child_id == swapped_node_id { + *child_id = node_id; + } + } + } + for child_id in children.iter() { + if let Some(node) = self.nodes.get_mut(child_id) { + node.parent = Some(node_id); + } + } + let new_node = Node { id: node_id, plan, - parent: None, + parent, is_subquery_node: false, is_dependent_join_node: false, - children: vec![], + children, access_tracker: IndexSet::new(), subquery_type: SubqueryType::None, correlated_relations: IndexSet::new(), @@ -1103,7 +1155,7 @@ impl DependentJoinTracker { .nodes .swap_remove(subquery_node.children.first().unwrap()) .unwrap(); - let subquery_input_plan = subquery_input_node.plan.clone(); + // let subquery_input_plan = subquery_input_node.plan.clone(); let mut join_predicates = vec![]; let mut post_join_predicates = vec![]; // this loop heavily assume that all subqueries belong to the same `dependent_join_node` let mut remained_predicates = vec![]; @@ -1133,7 +1185,7 @@ impl DependentJoinTracker { extract_join_metadata_from_subquery( expr, &subquery, - &subquery_input_plan.expressions(), + &subquery_input_node.plan.expressions(), &subquery_alias, &outer_relations, )?; @@ -1167,27 +1219,12 @@ impl DependentJoinTracker { // building new join node // let left = top_node.plan.clone(); - let mut right = LogicalPlanBuilder::new(subquery_input_plan) - .alias(subquery_alias)? - .build()?; - let mut builder = LogicalPlanBuilder::empty(false); - - builder = if join_predicates.is_empty() { - builder.join_on(right, join_type, vec![lit(true)])? - } else { - builder.join_on( - right, - // TODO: join type based on filter condition - join_type, - join_predicates, - )? - }; - - let new_plan = builder.build()?; - let new_top_node = self.create_new_top_node( - new_plan, + let new_top_node = self.create_new_join_node_on_top( + subquery_alias, + join_type, &mut top_node, subquery_input_node, + join_predicates, conjunction(post_join_predicates), // TODO: post join projection )?; @@ -2053,11 +2090,13 @@ mod tests { \n LeftMark Join: Filter: Boolean(true)\ \n Filter: outer_table.a > Int32(1)\ \n TableScan: outer_table\ - \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1\ - \n Projection: inner_table_lv1.a\ - \n Filter: inner_table_lv1.c = Int32(2)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __exists_sq_1\ + \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1\ + \n SubqueryAlias: __in_sq_2\ + \n Projection: inner_table_lv1.a\ + \n Filter: inner_table_lv1.c = Int32(2)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } From a46545967f8904852cf86e7bd0722cf55e86c985 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 19 May 2025 17:58:01 +0200 Subject: [PATCH 23/70] fix: adjust test case expectation --- .../optimizer/src/decorrelate_general.rs | 77 +++++++++---------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 8e8c7f592f7b..c95b6f1280d5 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -1181,31 +1181,29 @@ impl DependentJoinTracker { for expr in predicate_expr.into_iter() { // exist query may not have any transformed expr // i.e where exists(suquery) => semi join + + let projected_exprs: Vec = if pulled_up_projections.is_empty() { + subquery_input_node.plan.expressions() + } else { + pulled_up_projections + .iter() + .cloned() + .map(strip_outer_reference) + .collect() + }; let (transformed, maybe_join_predicate, maybe_post_join_predicate) = extract_join_metadata_from_subquery( expr, &subquery, - &subquery_input_node.plan.expressions(), + &projected_exprs, &subquery_alias, &outer_relations, )?; if let Some(transformed) = maybe_join_predicate { - join_predicates.push(strip_outer_reference(transformed)); + join_predicates.push(transformed); } if let Some(post_join_expr) = maybe_post_join_predicate { - if post_join_expr - .exists(|e| { - if let Expr::Column(col) = e { - return Ok(col.name == "mark"); - } - return Ok(false); - }) - .unwrap() - { - // only use mark join if required - join_type = JoinType::LeftMark - } post_join_predicates.push(strip_outer_reference(post_join_expr)) } if !transformed { @@ -1853,7 +1851,7 @@ impl SubqueryType { panic!("not reached") } SubqueryType::In => JoinType::LeftSemi, - SubqueryType::Exists => JoinType::LeftSemi, + SubqueryType::Exists => JoinType::LeftMark, // TODO: in duckdb, they have JoinType::Single // where there is only at most one join partner entry on the LEFT SubqueryType::Scalar => JoinType::Left, @@ -2042,19 +2040,19 @@ mod tests { use super::DependentJoinTracker; use arrow::datatypes::DataType as ArrowDataType; #[test] - fn simple_1_level_subquery_in_from_expr() -> Result<()> { + fn simple_in_subquery_inside_from_expr() -> Result<()> { unimplemented!() } #[test] - fn simple_1_level_subquery_in_selection_expr() -> Result<()> { + fn simple_in_subquery_inside_select_expr() -> Result<()> { unimplemented!() } #[test] - fn complex_1_level_decorrelate_2_subqueries_at_the_same_level() -> Result<()> { + fn one_simple_and_one_complex_subqueries_at_the_same_level() -> Result<()> { unimplemented!() } #[test] - fn simple_1_level_decorrelate_2_subqueries_at_the_same_level() -> Result<()> { + fn two_simple_subqueries_at_the_same_level() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let in_sq_level1 = Arc::new( @@ -2102,7 +2100,7 @@ mod tests { } #[test] - fn complex_1_level_decorrelate_in_subquery_with_count() -> Result<()> { + fn in_subquery_with_count_depth_1() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -2134,9 +2132,7 @@ mod tests { .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); index.build(input1)?; - println!("{:?}", index); let new_plan = index.root_dependent_join_elimination()?; - println!("{}", new_plan); let expected = "\ Filter: outer_table.a > Int32(1)\ \n LeftSemi Join: Filter: outer_table.c = count_a\ @@ -2149,7 +2145,7 @@ mod tests { Ok(()) } #[test] - fn simple_decorrelate_with_exist_subquery_with_dependent_columns() -> Result<()> { + fn simple_exist_subquery_with_dependent_columns() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -2179,9 +2175,10 @@ mod tests { index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ - Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ + Filter: __exists_sq_1.mark\ \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ - \n TableScan: outer_table\ + \n Filter: outer_table.a > Int32(1)\ + \n TableScan: outer_table\ \n SubqueryAlias: __exists_sq_1\ \n Filter: inner_table_lv1.b = Int32(1)\ \n TableScan: inner_table_lv1"; @@ -2189,7 +2186,7 @@ mod tests { Ok(()) } #[test] - fn simple_decorrelate_with_exist_subquery_no_dependent_column() -> Result<()> { + fn simple_exist_subquery_with_no_dependent_columns() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -2206,9 +2203,10 @@ mod tests { index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ - Filter: outer_table.a > Int32(1) AND __exists_sq_1.mark\ + Filter: __exists_sq_1.mark\ \n LeftMark Join: Filter: Boolean(true)\ - \n TableScan: outer_table\ + \n Filter: outer_table.a > Int32(1)\ + \n TableScan: outer_table\ \n SubqueryAlias: __exists_sq_1\ \n Projection: inner_table_lv1.b, inner_table_lv1.a\ \n Filter: inner_table_lv1.b = Int32(1)\ @@ -2238,13 +2236,13 @@ mod tests { index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; let expected = "\ - Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ + LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ + \n Filter: outer_table.a > Int32(1)\ \n TableScan: outer_table\ - \n SubqueryAlias: __in_sq_1\ - \n Projection: inner_table_lv1.b\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __in_sq_1\ + \n Projection: inner_table_lv1.b\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } @@ -2282,13 +2280,14 @@ mod tests { let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); index.build(input1)?; let new_plan = index.root_dependent_join_elimination()?; + println!("{new_plan}"); let expected = "\ - Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND __in_sq_1.a = outer_table.a AND outer_table.a > __in_sq_1.c AND outer_table.b = __in_sq_1.b\ + LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND __in_sq_1.a = outer_table.a AND outer_table.a > __in_sq_1.c AND outer_table.b = __in_sq_1.b\ + \n Filter: outer_table.a > Int32(1)\ \n TableScan: outer_table\ - \n SubqueryAlias: __in_sq_1\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; + \n SubqueryAlias: __in_sq_1\ + \n Filter: inner_table_lv1.b = Int32(1)\ + \n TableScan: inner_table_lv1"; assert_eq!(expected, format!("{new_plan}")); Ok(()) } From 479ae64fa514c8a61873e4c78b0eea0064a3d16a Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 13:36:28 +0200 Subject: [PATCH 24/70] feat: convert sq to dependent joins --- .../common/src/functional_dependencies.rs | 3 + datafusion/common/src/join_type.rs | 9 + datafusion/expr/src/logical_plan/builder.rs | 1 + .../expr/src/logical_plan/invariants.rs | 1 + datafusion/expr/src/logical_plan/plan.rs | 2 + .../optimizer/src/decorrelate_general.rs | 1832 ++++------------- .../physical-expr/src/equivalence/class.rs | 1 + datafusion/sql/src/unparser/plan.rs | 2 + 8 files changed, 396 insertions(+), 1455 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index c4f2805f8285..2de7db873af1 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -382,6 +382,9 @@ impl FunctionalDependencies { // All of the functional dependencies are lost in a FULL join: FunctionalDependencies::empty() } + JoinType::LeftDependent => { + unreachable!("LeftDependent should not be reached") + } } } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index ac81d977b729..7f962c065d7a 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,6 +67,10 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, + /// TODO: document me more + /// used to represent a virtual join in a complex expr containing subquery(ies), + /// The actual join type depends on the correlated expr + LeftDependent, } impl JoinType { @@ -90,6 +94,9 @@ impl JoinType { JoinType::LeftMark => { unreachable!("LeftMark join type does not support swapping") } + JoinType::LeftDependent => { + unreachable!("Dependent join type does not support swapping") + } } } @@ -121,6 +128,7 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", + JoinType::LeftDependent => "LeftDependent", }; write!(f, "{join_type}") } @@ -141,6 +149,7 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), + "LEFTDEPENDENT" => Ok(JoinType::LeftDependent), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d4d45226d354..2fe21387830b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1620,6 +1620,7 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } + JoinType::LeftDependent => todo!(), }; let func_dependencies = left.functional_dependencies().join( right.functional_dependencies(), diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 0c30c9785766..8e5b3156f53e 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -321,6 +321,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { })?; Ok(()) } + JoinType::LeftDependent => todo!(), }, LogicalPlan::Extension(_) => Ok(()), plan => check_no_outer_references(plan), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index edf5f1126be9..24d2dda4f5c5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -546,6 +546,7 @@ impl LogicalPlan { join_type, .. }) => match join_type { + JoinType::LeftDependent => todo!(), JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { if left.schema().fields().is_empty() { right.head_output_expr() @@ -1331,6 +1332,7 @@ impl LogicalPlan { join_type, .. }) => match join_type { + JoinType::LeftDependent => todo!(), JoinType::Inner => Some(left.max_rows()? * right.max_rows()?), JoinType::Left | JoinType::Right | JoinType::Full => { match (left.max_rows()?, right.max_rows()?, join_type) { diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index c95b6f1280d5..9f8931322742 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -17,6 +17,7 @@ //! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` +use std::any::Any; use std::cmp::Ordering; use std::collections::HashSet; use std::fmt; @@ -27,13 +28,15 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::decorrelate::UN_MATCHED_ROW_INDICATOR; use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use arrow::datatypes::DataType; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{internal_err, Column, HashMap, Result}; use datafusion_expr::expr::{self, Exists, InSubquery}; use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; +use datafusion_expr::out_ref_col; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{ conjunction, disjunction, split_conjunction, split_disjunction, @@ -62,7 +65,7 @@ pub struct DependentJoinTracker { // this is used during traversal only stack: Vec, // track for each column, the nodes/logical plan that reference to its within the tree - accessed_columns: IndexMap>, + all_outer_ref_columns: IndexMap>, alias_generator: Arc, } @@ -73,6 +76,7 @@ struct ColumnAccess { // the node referencing the column node_id: usize, col: Column, + data_type: DataType, } // pub struct GeneralDecorrelation { // index: AlgebraIndex, @@ -262,66 +266,12 @@ struct PulledUpExpr { // at the same level, we need to track which subquery the pulling up is happening for subquery_node_id: usize, } - -struct SimpleDecorrelationResult { - pulled_up_projections: IndexSet, - pulled_up_predicates: Vec, -} -impl SimpleDecorrelationResult { - // fn get_decorrelated_subquery_node_ids(&self) -> Vec { - // self.pulled_up_predicates - // .iter() - // .map(|e| e.subquery_node_id) - // .chain( - // self.pulled_up_projections - // .iter() - // .map(|e| e.subquery_node_id), - // ) - // .unique() - // .collect() - // // node_ids.extend( - // // self.pulled_up_projections - // // .iter() - // // .map(|e| e.subquery_node_id), - // // ); - // // node_ids.into_iter().unique().collect() - // } - // because we don't track which expr was pullled up for which relation to give alias for - fn rewrite_all_pulled_up_expr( - &mut self, - subquery_node_alias_map: &IndexMap, - outer_relations: &[String], - ) -> Result<()> { - let alias_by_subquery_node_id: IndexMap = subquery_node_alias_map - .iter() - .map(|(alias, node)| (node.id, alias)) - .collect(); - for expr in self.pulled_up_predicates.iter_mut() { - let alias = alias_by_subquery_node_id - .get(&expr.subquery_node_id) - .unwrap(); - expr.expr = - replace_col_base_table(expr.expr.clone(), &outer_relations, *alias)?; +fn unwrap_subquery(n: &Node) -> &Subquery { + match n.plan { + LogicalPlan::Subquery(ref sq) => sq, + _ => { + unreachable!() } - let rewritten_projections = self - .pulled_up_projections - .iter() - .map(|expr| { - let alias = alias_by_subquery_node_id - .get(&expr.subquery_node_id) - .unwrap(); - Ok(PulledUpExpr { - subquery_node_id: expr.subquery_node_id, - expr: replace_col_base_table( - expr.expr.clone(), - &outer_relations, - *alias, - )?, - }) - }) - .collect::>>()?; - self.pulled_up_projections = rewritten_projections; - Ok(()) } } @@ -455,1172 +405,6 @@ struct GeneralDecorrelationResult { count_expr_map: HashSet, } -impl DependentJoinTracker { - fn is_linear_operator(&self, plan: &LogicalPlan) -> bool { - match plan { - LogicalPlan::Limit(_) => true, - LogicalPlan::TableScan(_) => true, - LogicalPlan::Projection(_) => true, - LogicalPlan::Filter(_) => true, - LogicalPlan::Repartition(_) => true, - _ => false, - } - } - fn is_linear_path(&self, parent: &Node, child: &Node) -> bool { - if !self.is_linear_operator(&child.plan) { - return false; - } - - let mut current_node = child.parent.unwrap(); - - loop { - let child_node = self.nodes.get(¤t_node).unwrap(); - if !self.is_linear_operator(&child_node.plan) { - match child_node.parent { - None => { - unimplemented!("traversing from descedent to top does not meet expected root") - } - Some(new_parent) => { - if new_parent == parent.id { - return true; - } - return false; - } - } - } - match child_node.parent { - None => return true, - Some(new_parent) => { - current_node = new_parent; - } - }; - } - } - fn remove_node(&mut self, parent: &mut Node, node: &mut Node) { - let next_children = node.children.first().unwrap(); - let next_children_node = self.nodes.swap_remove(next_children).unwrap(); - // let next_children_node = self.nodes.get_mut(next_children).unwrap(); - *node = next_children_node; - node.parent = Some(parent.id); - } - - // decorrelate all descendant with simple unnesting - // this function will remove corresponding entry in root_node.access_tracker if applicable - // , so caller can rely on the length of this field to detect if simple decorrelation is enough - // and the decorrelation can stop using "simple method". - // It also does the in-place update to - // - // TODO: this is not yet recursive, but theoreically nested subqueries - // can be decorrelated using simple method as long as they are independent - // with each other - fn try_simple_decorrelate_descendent( - &mut self, - root_node: &mut Node, - child_node: &mut Node, - col_access: &ColumnAccess, - result: &mut SimpleDecorrelationResult, - ) -> Result<()> { - if !self.is_linear_path(root_node, child_node) { - return Ok(()); - } - // offest 0 (root) is dependent join node, will immediately followed by subquery node - let subquery_node_id = col_access.stack[1]; - - match &mut child_node.plan { - LogicalPlan::Projection(proj) => { - // TODO: handle the case select binary_expr(outer_ref_a, outer_ref_b) ??? - // if we only see outer_ref_a and decide to pull up the whole expr here - // outer_ref_b is accidentally pulled up - let pulled_up_expr: IndexSet<_> = proj - .expr - .iter() - .filter(|proj_expr| { - proj_expr - .exists(|expr| { - if let Expr::OuterReferenceColumn(_, col) = expr { - root_node.access_tracker.swap_remove(col_access); - return Ok(*col == col_access.col); - } - Ok(false) - }) - .unwrap() - }) - .cloned() - .collect(); - - if !pulled_up_expr.is_empty() { - for expr in pulled_up_expr.iter() { - result.pulled_up_projections.insert(PulledUpExpr { - expr: expr.clone(), - subquery_node_id, - }); - } - // all expr of this node is pulled up, fully remove this node from the tree - if proj.expr.len() == pulled_up_expr.len() { - self.remove_node(root_node, child_node); - return Ok(()); - } - - let new_proj = proj - .expr - .iter() - .filter(|expr| !pulled_up_expr.contains(*expr)) - .cloned() - .collect(); - proj.expr = new_proj; - } - // TODO: try_decorrelate for each of the child - } - LogicalPlan::Filter(filter) => { - // let accessed_from_child = &child_node.access_tracker; - let subquery_filter_exprs: Vec = - split_conjunction(&filter.predicate) - .into_iter() - .cloned() - .collect(); - - let (pulled_up, kept): (Vec<_>, Vec<_>) = subquery_filter_exprs - .iter() - .cloned() - // NOTE: if later on we decide to support nested subquery inside this function - // (i.e multiple subqueries exist in the stack) - // the call to e.contains_outer must be aware of which subquery it is checking for:w - .partition(|e| e.contains_outer() && can_pull_up(e)); - - // only remove the access tracker if none of the kept expr contains reference to the column - // i.e some of the remaining expr still reference to the column and not pullable - // For example where outer.col_a=1 and outer.col_a=(some nested subqueries) - // in this case outer.col_a=1 is pull up, but the access tracker must remain - // so later on we can tell "simple approach" is not enough, and continue with - // the "general approach". - let can_pull_up = kept.iter().all(|e| { - !e.exists(|e| { - if let Expr::Column(col) = e { - return Ok(*col == col_access.col); - } - Ok(false) - }) - .unwrap() - }); - if !can_pull_up { - return Ok(()); - } - root_node.access_tracker.swap_remove(col_access); - result - .pulled_up_predicates - .extend(pulled_up.iter().map(|e| PulledUpExpr { - expr: e.clone(), - subquery_node_id, - })); - if kept.is_empty() { - self.remove_node(root_node, child_node); - return Ok(()); - } - filter.predicate = conjunction(kept).unwrap(); - } - - // TODO: nested subqueries can also be linear with each other - // i.e select expr, (subquery1) where expr = subquery2 - // LogicalPlan::Subquery(sq) => { - // let descendent_id = child_node.children.get(0).unwrap(); - // let mut descendent_node = self.nodes.get(descendent_id).unwrap().clone(); - // self.try_simple_unnest_descendent( - // root_node, - // &mut descendent_node, - // result, - // )?; - // self.nodes.insert(*descendent_id, descendent_node); - // } - unsupported => { - unimplemented!( - "simple unnest is missing for this operator {}", - unsupported - ) - } - }; - - Ok(()) - } - - fn general_decorrelate( - &mut self, - node: &mut Node, - unnesting: &mut Unnesting, - outer_refs_from_parent: &mut IndexSet, - ) -> Result<()> { - if node.is_dependent_join_node { - unimplemented!("recursive unnest not implemented yet") - } - - match &mut node.plan { - LogicalPlan::Subquery(sq) => { - let next_node = node.children.first().unwrap(); - let mut only_child = self.nodes.swap_remove(next_node).unwrap(); - self.general_decorrelate( - &mut only_child, - unnesting, - outer_refs_from_parent, - )?; - unnesting.decorrelated_subquery = Some(sq.clone()); - *node = only_child; - return Ok(()); - } - LogicalPlan::Aggregate(agg) => { - let is_static = agg.group_expr.is_empty(); // TODO: grouping set also needs to check is_static - let next_node = node.children.first().unwrap(); - let mut only_child = self.nodes.swap_remove(next_node).unwrap(); - // keep this for later projection - let mut original_expr = agg.aggr_expr.clone(); - original_expr.extend_from_slice(&agg.group_expr); - - self.general_decorrelate( - &mut only_child, - unnesting, - outer_refs_from_parent, - )?; - agg.input = Arc::new(only_child.plan.clone()); - self.nodes.insert(*next_node, only_child); - - Self::rewrite_columns(agg.group_expr.iter_mut(), unnesting)?; - for col in unnesting.pulled_up_columns.iter() { - let replaced_col = unnesting.get_replaced_col(col); - agg.group_expr.push(Expr::Column(replaced_col.clone())); - } - for agg in agg.aggr_expr.iter() { - if contains_count_expr(agg) { - unnesting.count_exprs_detected.insert(agg.clone()); - } - } - - if is_static { - if !unnesting.count_exprs_detected.is_empty() - & unnesting.need_handle_count_bug - { - let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); - agg.group_expr.push(un_matched_row); - } - // let right = LogicalPlanBuilder::new(node.plan.clone()); - // the evaluation of - // let mut post_join_projection = vec![]; - let alias = - self.alias_generator.next(&unnesting.subquery_type.prefix()); - - let join_condition = - unnesting.pulled_up_predicates.iter().filter_map(|e| { - let stripped_outer = strip_outer_reference(e.clone()); - if contains_count_expr(&stripped_outer) { - unimplemented!("handle having count(*) predicate pull up") - // post_join_predicates.push(stripped_outer); - // return None; - } - match &stripped_outer { - Expr::Column(col) => { - println!("{:?}", col); - } - _ => {} - } - Some(stripped_outer) - }); - - let right = LogicalPlanBuilder::new(agg.input.deref().clone()) - .aggregate(agg.group_expr.clone(), agg.aggr_expr.clone())? - .alias(alias.clone())? - .build()?; - let mut new_plan = - LogicalPlanBuilder::new(unnesting.info.domain.clone()) - .join_detailed( - right, - JoinType::Left, - (Vec::::new(), Vec::::new()), - conjunction(join_condition), - true, - )?; - for expr in original_expr.iter_mut() { - if contains_count_expr(expr) { - let new_expr = Expr::Case(expr::Case { - expr: None, - when_then_expr: vec![( - Box::new(Expr::IsNull(Box::new(Expr::Column( - Column::new_unqualified(UN_MATCHED_ROW_INDICATOR), - )))), - Box::new(lit(0)), - )], - else_expr: Some(Box::new(Expr::Column( - Column::new_unqualified( - expr.schema_name().to_string(), - ), - ))), - }); - let mut expr_rewrite = TypeCoercionRewriter { - schema: new_plan.schema(), - }; - *expr = new_expr.rewrite(&mut expr_rewrite)?.data; - } - - // *expr = Expr::Column(create_col_from_scalar_expr( - // expr, - // alias.clone(), - // )?); - } - new_plan = new_plan.project(original_expr)?; - - node.plan = new_plan.build()?; - - println!("{}", node.plan); - return Ok(()); - // self.remove_node(parent, node); - - // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) - // TODO: how domain projection work - // left = select distinct domain - // right = new group by - // if there exists count in the group by, the projection set should be something like - // - // 01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) - // 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int - } else { - unimplemented!("non static aggregation sq decorrelation not implemented, i.e exists sq with count") - } - } - LogicalPlan::Filter(filter) => { - let conjuctives: Vec = split_conjunction(&filter.predicate) - .into_iter() - .cloned() - .collect(); - let mut remained_expr = vec![]; - // TODO: the paper mention there are 2 approaches to remove these dependent predicate - // - substitute the outer ref columns and push them to the parent node (i.e add them to aggregation node) - // - perform a join with domain directly here - // for now we only implement with the approach substituting - - let mut pulled_up_columns = IndexSet::new(); - for expr in conjuctives.iter() { - if !expr.contains_outer() { - remained_expr.push(expr.clone()); - continue; - } - // extract all columns mentioned in this expr - // and push them up the dependent join - - unnesting.pulled_up_predicates.push(expr.clone()); - expr.clone().map_children(|e| { - if let Expr::Column(ref col) = e { - pulled_up_columns.insert(col.clone()); - } - Ok(Transformed::no(e)) - })?; - } - filter.predicate = match conjunction(remained_expr) { - Some(expr) => expr, - None => lit(true), - }; - unnesting.pulled_up_columns.extend(pulled_up_columns); - outer_refs_from_parent.retain(|ac| ac.node_id != node.id); - if !outer_refs_from_parent.is_empty() { - let next_node = node.children.first().unwrap(); - let mut only_child = self.nodes.swap_remove(next_node).unwrap(); - self.general_decorrelate( - &mut only_child, - unnesting, - outer_refs_from_parent, - )?; - self.nodes.insert(*next_node, only_child); - } - // TODO: add equivalences from select.predicate to info.cclasses - Self::rewrite_columns(vec![&mut filter.predicate].into_iter(), unnesting); - return Ok(()); - } - LogicalPlan::Projection(proj) => { - let next_node = node.children.first().unwrap(); - let mut only_child = self.nodes.swap_remove(next_node).unwrap(); - // TODO: if the children of this node was added with some extra column (i.e) - // aggregation + group by dependent_column - // the projection exprs must also include these new expr - self.general_decorrelate( - &mut only_child, - unnesting, - outer_refs_from_parent, - )?; - - self.nodes.insert(*next_node, only_child); - proj.expr.extend( - unnesting - .pulled_up_columns - .iter() - .map(|c| Expr::Column(c.clone())), - ); - Self::rewrite_columns(proj.expr.iter_mut(), unnesting); - return Ok(()); - } - _ => { - unimplemented!() - } - }; - // if unnesting.info.parent.is_some() { - // not_impl_err!("impl me") - // // TODO - // } - // // info = Un - // let node = self.nodes.get(node_id).unwrap(); - // match node.plan { - // LogicalPlan::Aggregate(aggr) => {} - // _ => {} - // } - // Ok(()) - } - fn right_owned(&mut self, node: &Node) -> Node { - assert_eq!(2, node.children.len()); - // during the building of the tree, the subquery (right node) is always traversed first - let node_id = node.children.first().unwrap(); - return self.nodes.swap_remove(node_id).unwrap(); - } - fn left_owned(&mut self, node: &Node) -> Node { - assert_eq!(2, node.children.len()); - // during the building of the tree, the subquery (right node) is always traversed first - let node_id = node.children.last().unwrap(); - return self.nodes.swap_remove(node_id).unwrap(); - } - fn root_dependent_join_elimination(&mut self) -> Result { - let root = self.root.unwrap(); - let node = self.nodes.get(&root).unwrap(); - // TODO: need to store the first dependent join node - assert!( - node.is_dependent_join_node, - "need to handle the case root node is not dependent join node" - ); - - let unnesting_info = UnnestingInfo { - parent: None, - domain: node.plan.clone(), // dummy - }; - - self.dependent_join_elimination(node.id, &unnesting_info, &mut IndexSet::new()) - } - - fn column_accesses(&self, node_id: usize) -> Vec<&ColumnAccess> { - let node = self.nodes.get(&node_id).unwrap(); - node.access_tracker.iter().collect() - } - fn get_children_subquery_ids(&self, node: &Node) -> Vec { - return node.children[..node.children.len() - 1].to_owned(); - } - - fn get_subquery_info( - &self, - parent: &Node, - // because one dependent join node can have multiple subquery at a time - sq_offset: usize, - ) -> Result<(LogicalPlan, Subquery, SubqueryType)> { - let subquery = parent.children.get(sq_offset).unwrap(); - let sq_node = self.nodes.get(subquery).unwrap(); - assert!(sq_node.is_subquery_node); - let query = sq_node.children.first().unwrap(); - let target_node = self.nodes.get(query).unwrap(); - // let op = .clone(); - if let LogicalPlan::Subquery(subquery) = sq_node.plan.clone() { - Ok((target_node.plan.clone(), subquery, sq_node.subquery_type)) - } else { - internal_err!( - "object construction error: subquery.plan is not with type Subquery" - ) - } - } - - // Rewrite from - // TopNodeParent - // | - // (current_top_node) - // |-SubqueryNode -----> This was decorelated - // | |- (subquery_input_node) - // |-SubqueryNode2 -----> This is not yet decorrelated - // |-SomeTableScan - // - // Into - // TopNodeParent - // | - // NewTopNode <-------- This was added - // | - // |----(current_top_node) - // | |-SubqueryNode2 - // | |-SomeTableScan - // | - // |----(subquery_input_node) - fn create_new_join_node_on_top<'a>( - &'a mut self, - subquery_alias: String, - join_type: JoinType, - current_top_node: &mut Node, - subquery_input_node: Node, - join_predicates: Vec, - post_join_predicates: Option, - ) -> Result { - self.nodes - .insert(subquery_input_node.id, subquery_input_node.clone()); - // Build the join node - let mut right = LogicalPlanBuilder::new(subquery_input_node.plan.clone()) - .alias(subquery_alias)? - .build()?; - let alias_node = self.insert_node_and_links( - right.clone(), - 0, - None, - vec![subquery_input_node.id], - ); - let right_node_id = alias_node.id; - // the left input does not matter, because later on the rewritting will happen using the pointers - // from top node, following the children using Node.chilren field - let mut builder = LogicalPlanBuilder::empty(false); - - builder = if join_predicates.is_empty() { - builder.join_on(right, join_type, vec![lit(true)])? - } else { - builder.join_on( - right, - // TODO: join type based on filter condition - join_type, - join_predicates, - )? - }; - - let join_node = builder.build()?; - - let upper_most_parent = current_top_node.parent.clone(); - let mut new_node = self.insert_node_and_links( - join_node, - current_top_node.id, - upper_most_parent, - vec![current_top_node.id, right_node_id], - ); - current_top_node.parent = Some(new_node.id); - - let mut new_node_id = new_node.id; - if let Some(expr) = post_join_predicates { - let new_plan = LogicalPlanBuilder::new(new_node.plan.clone()) - .filter(expr)? - .build()?; - let new_node = self.insert_node_and_links( - new_plan, - new_node_id, - upper_most_parent, - vec![new_node_id], - ); - new_node_id = new_node.id; - } - - self.root = Some(new_node_id); - - Ok(self.nodes.swap_remove(&new_node_id).unwrap()) - } - - // insert a new node, if any link of parent, children is mentioned - // also update the relationship in these remote nodes - fn insert_node_and_links<'a>( - &'a mut self, - plan: LogicalPlan, - // which node id in the parent should be replaced by this new node - swapped_node_id: usize, - parent: Option, - children: Vec, - ) -> &'a mut Node { - self.current_id = self.current_id + 1; - let node_id = self.current_id; - - // update parent - if let Some(parent_id) = parent { - for child_id in self.nodes.get_mut(&parent_id).unwrap().children.iter_mut() { - if *child_id == swapped_node_id { - *child_id = node_id; - } - } - } - for child_id in children.iter() { - if let Some(node) = self.nodes.get_mut(child_id) { - node.parent = Some(node_id); - } - } - - let new_node = Node { - id: node_id, - plan, - parent, - is_subquery_node: false, - is_dependent_join_node: false, - children, - access_tracker: IndexSet::new(), - subquery_type: SubqueryType::None, - correlated_relations: IndexSet::new(), - }; - self.nodes.insert(node_id, new_node); - self.nodes.get_mut(&node_id).unwrap() - } - - // this function is aware that multiple subqueries may exist inside the filter predicate - // and it tries it best to decorrelate all possible exprs, while leave the un-correlatable - // expr untouched - // - // Example of such expression - // `select * from outer_table where exists(select * from inner_table where ...) & col_b < complex_subquery` - // the relationship tree looks like this - // [0] some parent node - // | - // -[1]dependent_join_node (filter exists(select * from inner_table where ...) & col_b < complex_subquery) - // | - // |- [2]simple_subquery - // |- [3]complex_subquery - // |- [4]outer_table scan - // After decorrelation, the relationship tree may be translated using 2 approaches - // Approach 1: Replace the left side of the join using the new input - // [0] some parent node - // | - // -[1]dependent_join_node (filter col_b < complex_subquery) - // | - // |- [2]REMOVED - // |- [3]complex_subquery - // |- [4]markjoin <-------- This was modified - // |-outer_table scan - // |-inner_table scan - // - // Approach 2: Keep everything except for the decorrelated expressions, - // and add a new join above the original dependent join - // [0] some parent node - // | - // -[NEW_NODE_ID] markjoin <----------------- This was added - // | - // |-inner_table scan - // |-[1]dependent_join_node (filter col_b < complex_subquery) - // | - // |- [2]REMOVED - // |- [3]complex_subquery - // |- [4]outer_table scan - // The following uses approach 2 - // - // If decorrelation happen, this function will returns a new Node object that is supposed to be the new root of the tree - fn build_join_from_simple_decorrelation_result_filter( - &mut self, - mut dependent_join_node: Node, - outer_relations: &[String], - ret: &mut SimpleDecorrelationResult, - mut filter: Filter, - ) -> Result<(Node)> { - let still_correlated_sq_ids: Vec = dependent_join_node - .access_tracker - .iter() - .map(|ac| ac.stack[1]) - .unique() - .collect(); - - let decorrelated_sq_ids = self - .get_children_subquery_ids(&dependent_join_node) - .into_iter() - .filter(|n| !still_correlated_sq_ids.contains(n)); - let subquery_node_alias_map: IndexMap = decorrelated_sq_ids - .map(|id| { - dependent_join_node - .children - .retain(|current_children| *current_children != id); - - let subquery_node = self.nodes.swap_remove(&id).unwrap(); - let subquery_alias = self - .alias_generator - .next(&subquery_node.subquery_type.prefix()); - (subquery_alias, subquery_node) - }) - .collect(); - - ret.rewrite_all_pulled_up_expr(&subquery_node_alias_map, &outer_relations)?; - let mut pullup_projection_by_sq_id: IndexMap> = ret - .pulled_up_projections - .iter() - .fold(IndexMap::>::new(), |mut acc, e| { - acc.entry(e.subquery_node_id) - .or_default() - .push(e.expr.clone()); - acc - }); - let mut pullup_predicate_by_sq_id: IndexMap> = ret - .pulled_up_predicates - .iter() - .fold(IndexMap::>::new(), |mut acc, e| { - acc.entry(e.subquery_node_id) - .or_default() - .push(e.expr.clone()); - acc - }); - - let dependent_join_node_id = dependent_join_node.id; - let mut top_node = dependent_join_node; - - for (subquery_alias, subquery_node) in subquery_node_alias_map { - let subquery_input_node = self - .nodes - .swap_remove(subquery_node.children.first().unwrap()) - .unwrap(); - // let subquery_input_plan = subquery_input_node.plan.clone(); - let mut join_predicates = vec![]; - let mut post_join_predicates = vec![]; // this loop heavily assume that all subqueries belong to the same `dependent_join_node` - let mut remained_predicates = vec![]; - let sq_type = subquery_node.subquery_type; - let subquery = if let LogicalPlan::Subquery(subquery) = &subquery_node.plan { - Ok(subquery) - } else { - internal_err!( - "object construction error: subquery.plan is not with type Subquery" - ) - }?; - let mut join_type = sq_type.default_join_type(); - - let predicate_expr = split_conjunction(&filter.predicate); - - let pulled_up_projections = pullup_projection_by_sq_id - .swap_remove(&subquery_node.id) - .unwrap_or(vec![]); - let pulled_up_predicates = pullup_predicate_by_sq_id - .swap_remove(&subquery_node.id) - .unwrap_or(vec![]); - - for expr in predicate_expr.into_iter() { - // exist query may not have any transformed expr - // i.e where exists(suquery) => semi join - - let projected_exprs: Vec = if pulled_up_projections.is_empty() { - subquery_input_node.plan.expressions() - } else { - pulled_up_projections - .iter() - .cloned() - .map(strip_outer_reference) - .collect() - }; - let (transformed, maybe_join_predicate, maybe_post_join_predicate) = - extract_join_metadata_from_subquery( - expr, - &subquery, - &projected_exprs, - &subquery_alias, - &outer_relations, - )?; - - if let Some(transformed) = maybe_join_predicate { - join_predicates.push(transformed); - } - if let Some(post_join_expr) = maybe_post_join_predicate { - post_join_predicates.push(strip_outer_reference(post_join_expr)) - } - if !transformed { - remained_predicates.push(expr.clone()); - } - } - - join_predicates - .extend(pulled_up_predicates.into_iter().map(strip_outer_reference)); - filter.predicate = conjunction(remained_predicates).unwrap(); - - // building new join node - // let left = top_node.plan.clone(); - let new_top_node = self.create_new_join_node_on_top( - subquery_alias, - join_type, - &mut top_node, - subquery_input_node, - join_predicates, - conjunction(post_join_predicates), - // TODO: post join projection - )?; - self.nodes.insert(top_node.id, top_node); - top_node = new_top_node; - } - self.nodes.insert(top_node.id, top_node); - let mut dependent_join_node = - self.nodes.swap_remove(&dependent_join_node_id).unwrap(); - dependent_join_node.plan = LogicalPlan::Filter((filter)); - - Ok(dependent_join_node) - } - - fn rewrite_node(&mut self, node_id: usize) -> Result { - let mut node = self.nodes.swap_remove(&node_id).unwrap(); - if node.is_subquery_node { - println!("{} {}", node.id, node.plan); - } - assert!( - !node.is_subquery_node, - "calling on rewrite_node while still exists subquery in the tree" - ); - if node.children.is_empty() { - return Ok(node.plan); - } - let new_children = node - .children - .iter() - .map(|c| self.rewrite_node(*c)) - .collect::>>()?; - node.plan - .with_new_exprs(node.plan.expressions(), new_children) - } - - fn rewrite_from_root(&mut self) -> Result { - self.rewrite_node(self.root.unwrap()) - } - - fn build_join_from_simple_decorrelation_result( - &mut self, - mut dependent_join_node: Node, - ret: &mut SimpleDecorrelationResult, - ) -> Result { - let outer_relations: Vec = dependent_join_node - .correlated_relations - .iter() - .cloned() - .collect(); - - match dependent_join_node.plan.clone() { - LogicalPlan::Filter(filter) => self - .build_join_from_simple_decorrelation_result_filter( - dependent_join_node, - &outer_relations, - ret, - filter, - ), - _ => { - unimplemented!() - } - } - } - - fn build_domain(&self, node: &Node, left: &Node) -> Result { - let unique_outer_refs: Vec = node - .access_tracker - .iter() - .map(|c| c.col.clone()) - .unique() - .collect(); - - // TODO: handle this correctly. - // the direct left child of root is not always the table scan node - // and there are many more table providing logical plan - let initial_domain = LogicalPlanBuilder::new(left.plan.clone()) - .aggregate( - unique_outer_refs - .iter() - .map(|col| Expr::Column(col.clone())), - Vec::::new(), - )? - .build()?; - return Ok(initial_domain); - } - - fn dependent_join_elimination( - &mut self, - dependent_join_node_id: usize, - unnesting: &UnnestingInfo, - outer_refs_from_parent: &mut IndexSet, - ) -> Result { - let parent = unnesting.parent.clone(); - let mut dependent_join_node = - self.nodes.swap_remove(&dependent_join_node_id).unwrap(); - - assert!(dependent_join_node.is_dependent_join_node); - - let mut simple_unnesting = SimpleDecorrelationResult { - pulled_up_predicates: vec![], - pulled_up_projections: IndexSet::new(), - }; - - dependent_join_node = - self.simple_decorrelation(dependent_join_node, &mut simple_unnesting)?; - if dependent_join_node.access_tracker.is_empty() { - if parent.is_some() { - // for each projection of outer column moved up by simple_decorrelation - // replace them with the expr store inside parent.replaces - unimplemented!("simple dependent join not implemented for the case of recursive subquery"); - self.general_decorrelate( - &mut dependent_join_node, - &mut parent.unwrap(), - outer_refs_from_parent, - )?; - return Ok(dependent_join_node.plan.clone()); - } - self.nodes - .insert(dependent_join_node.id, dependent_join_node); - return self.rewrite_from_root(); - } else { - // TODO: some of the expr was removed and expect to be pulled up in a best effort fashion - // (i.e partially decorrelate) - } - if self.get_children_subquery_ids(&dependent_join_node).len() > 1 { - unimplemented!( - "general decorrelation for multiple subqueries in the same node" - ) - } - - // for children_offset in self.get_children_subquery_ids(&dependent_join_node) { - let (original_subquery, _, subquery_type) = - self.get_subquery_info(&dependent_join_node, 0)?; - // let mut join = self.new_dependent_join(&root_node); - // TODO: handle the case where one dependent join node contains multiple subqueries - let mut left = self.left_owned(&dependent_join_node); - let mut right = self.right_owned(&dependent_join_node); - if parent.is_some() { - unimplemented!(""); - // i.e exists (where inner.col_a = outer_col.b and other_nested_subquery...) - - let mut outer_ref_from_left = IndexSet::new(); - // let left = join.left.clone(); - for col_from_parent in outer_refs_from_parent.iter() { - if left - .plan - .all_out_ref_exprs() - .contains(&Expr::Column(col_from_parent.col)) - { - outer_ref_from_left.insert(col_from_parent.clone()); - } - } - let mut parent_unnesting = parent.clone().unwrap(); - self.general_decorrelate( - &mut left, - &mut parent_unnesting, - &mut outer_ref_from_left, - )?; - // join.replace_left(new_left, &parent_unnesting.replaces); - - // TODO: after imple simple_decorrelation, rewrite the projection pushed up column as well - } - let domain = match parent { - None => self.build_domain(&dependent_join_node, &left)?, - Some(info) => { - unimplemented!() - } - }; - - let new_unnesting_info = UnnestingInfo { - parent: parent.clone(), - domain, - }; - let mut unnesting = Unnesting { - original_subquery, - info: Arc::new(new_unnesting_info.clone()), - equivalences: UnionFind { - parent: IndexMap::new(), - rank: IndexMap::new(), - }, - replaces: IndexMap::new(), - pulled_up_columns: vec![], - pulled_up_predicates: vec![], - count_exprs_detected: IndexSet::new(), - need_handle_count_bug: true, // TODO - subquery_type, - decorrelated_subquery: None, - }; - let mut accesses: IndexSet = - dependent_join_node.access_tracker.clone(); - // .iter() - // .map(|a| a.col.clone()) - // .collect(); - if parent.is_some() { - for col_access in outer_refs_from_parent.iter() { - if right - .plan - .all_out_ref_exprs() - .contains(&Expr::Column(col_access.col.clone())) - { - accesses.insert(col_access.clone()); - } - } - // add equivalences from join.condition to unnest.cclasses - } - - //TODO: add equivalences from join.condition to unnest.cclasses - self.general_decorrelate(&mut right, &mut unnesting, &mut accesses)?; - let decorrelated_plan = self.build_join_from_general_unnesting_info( - &mut dependent_join_node, - &mut left, - &mut right, - unnesting, - )?; - return Ok(decorrelated_plan); - // } - - // self.nodes.insert(left.id, left); - // self.nodes.insert(right.id, right); - // self.nodes.insert(node, root_node); - - unimplemented!("implement relacing right node"); - // join.replace_right(new_right, &new_unnesting_info, &unnesting.replaces); - // for acc in new_unnesting_info.outer_refs{ - // join.join_conditions.append(other); - // } - } - - fn build_join_from_general_unnesting_info( - &self, - dependent_join_node: &mut Node, - left_node: &mut Node, - decorrelated_right_node: &mut Node, - mut unnesting: Unnesting, - ) -> Result { - let subquery = unnesting.decorrelated_subquery.take().unwrap(); - let decorrelated_right = decorrelated_right_node.plan.clone(); - let subquery_type = unnesting.subquery_type; - - let alias = self.alias_generator.next(&subquery_type.prefix()); - let outer_relations: Vec = dependent_join_node - .correlated_relations - .iter() - .cloned() - .collect(); - - unnesting.rewrite_all_pulled_up_expr(&alias, &outer_relations)?; - // TODO: do this on left instead of dependent_join_node directly, because with recursive - // the left side can also be rewritten - match dependent_join_node.plan { - LogicalPlan::Filter(ref mut filter) => { - let exprs = split_conjunction(&filter.predicate); - let mut join_exprs = vec![]; - let mut kept_predicates = vec![]; - let right_expr: Vec<_> = decorrelated_right_node - .plan - .schema() - .columns() - .iter() - .map(|c| Expr::Column(c.clone())) - .collect(); - let mut join_type = subquery_type.default_join_type(); - for expr in exprs.into_iter() { - // exist query may not have any transformed expr - // i.e where exists(suquery) => semi join - let (transformed, maybe_transformed_expr, maybe_post_join_expr) = - extract_join_metadata_from_subquery( - expr, - &subquery, - &right_expr, - &alias, - &outer_relations, - )?; - - if let Some(transformed) = maybe_transformed_expr { - join_exprs.push(transformed) - } - if let Some(post_join_expr) = maybe_post_join_expr { - if post_join_expr - .exists(|e| { - if let Expr::Column(col) = e { - return Ok(col.name == "mark"); - } - return Ok(false); - }) - .unwrap() - { - // only use mark join if required - join_type = JoinType::LeftMark - } - kept_predicates.push(post_join_expr) - } - if !transformed { - kept_predicates.push(expr.clone()) - } - } - - // TODO: some predicate is join predicate, some is just filter - // kept_predicates.extend(new_predicates); - // filter.predicate = conjunction(kept_predicates).unwrap(); - // left - let mut builder = LogicalPlanBuilder::new(filter.input.deref().clone()); - - builder = if join_exprs.is_empty() { - builder.join_on(decorrelated_right, join_type, vec![lit(true)])? - } else { - builder.join_on( - decorrelated_right, - // TODO: join type based on filter condition - join_type, - join_exprs, - )? - }; - - if kept_predicates.len() > 0 { - builder = builder.filter(conjunction(kept_predicates).unwrap())? - } - builder.build() - } - _ => { - unimplemented!() - } - } - } - fn rewrite_columns<'a>( - exprs: impl Iterator, - unnesting: &Unnesting, - ) -> Result<()> { - for expr in exprs { - *expr = expr - .clone() - .transform(|e| { - match &e { - Expr::Column(col) => { - if let Some(replaced_by) = unnesting.replaces.get(col) { - return Ok(Transformed::yes(Expr::Column( - replaced_by.clone(), - ))); - } - } - Expr::OuterReferenceColumn(_, col) => { - if let Some(replaced_by) = unnesting.replaces.get(col) { - // TODO: no sure if we should use column or outer ref column here - return Ok(Transformed::yes(Expr::Column( - replaced_by.clone(), - ))); - } - } - _ => {} - }; - Ok(Transformed::no(e)) - })? - .data; - } - Ok(()) - } - fn get_node_uncheck(&self, node_id: &usize) -> Node { - self.nodes.get(node_id).unwrap().clone() - } - - // Decorrelate the current node using `simple` approach. - // It will consume the node and returns a new node where the decorrelatoin should continue - // using `general` approach, should `simple` approach is not sufficient. - // Most of the time the same Node is returned, avoid using &mut Node because of borrow checker - // Beware that after calling this function, the root node may be changed (as new join node being added to the top) - fn simple_decorrelation( - &mut self, - mut node: Node, - simple_unnesting: &mut SimpleDecorrelationResult, - ) -> Result { - let node_id = node.id; - // the iteration should happen with the order of bottom up, so any node pull up won't - // affect its children (by accident) - let accesses_bottom_up = node.access_tracker.clone().sorted_by(|a, b| { - if a.node_id < b.node_id { - Ordering::Greater - } else { - Ordering::Less - } - }); - - for col_access in accesses_bottom_up { - // create two copy because of - // let mut descendent = self.get_node_uncheck(&col_access.node_id); - let mut descendent = self.nodes.swap_remove(&col_access.node_id).unwrap(); - self.try_simple_decorrelate_descendent( - &mut node, - &mut descendent, - &col_access, - simple_unnesting, - )?; - // TODO: find a nicer way to do in-place update - self.nodes.insert(col_access.node_id, descendent); - } - self.build_join_from_simple_decorrelation_result(node, simple_unnesting) - } -} - fn contains_count_expr( expr: &Expr, // schema: &DFSchemaRef, @@ -1710,9 +494,11 @@ impl DependentJoinTracker { let accessed_by_string = op .access_tracker .iter() - .map(|c| c.debug()) + .map(|(_, ac)| ac.clone()) + .flatten() + .map(|ac| ac.debug()) .collect::>() - .join(", "); + .join(","); // Now print the Operator details writeln!(f, "accessed_by: {}", accessed_by_string,)?; let len = op.children.len(); @@ -1727,17 +513,37 @@ impl DependentJoinTracker { Ok(()) } - fn lca_from_stack(a: &[usize], b: &[usize]) -> usize { + // lowest common ancestor from stack + // given a tree of + // n1 + // | + // n2 filter where outer.column = exists(subquery) + // ---------------------- + // | \ + // | n5: subquery + // | | + // n3 scan table outer n6 filter outer.column=inner.column + // | + // n7 scan table inner + // this function is called with 2 args a:[1,2,3] and [1,2,5,6,7] + // it then returns the id of the dependent join node (2) + // and the id of the subquery node (5) + fn dependent_join_and_subquery_node_ids( + stack_with_table_provider: &[usize], + stack_with_subquery: &[usize], + ) -> (usize, usize) { let mut lca = None; - let min_len = a.len().min(b.len()); + let min_len = stack_with_table_provider + .len() + .min(stack_with_subquery.len()); for i in 0..min_len { - let ai = a[i]; - let bi = b[i]; + let ai = stack_with_subquery[i]; + let bi = stack_with_table_provider[i]; if ai == bi { - lca = Some(ai); + lca = Some((ai, stack_with_subquery[ai + 1])); } else { break; } @@ -1755,45 +561,51 @@ impl DependentJoinTracker { col: &Column, tbl_name: &str, ) { - if let Some(accesses) = self.accessed_columns.get(col) { + if let Some(accesses) = self.all_outer_ref_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); cur_stack.push(child_id); // this is a dependent join node - let lca_node = Self::lca_from_stack(&cur_stack, &access.stack); - let node = self.nodes.get_mut(&lca_node).unwrap(); - node.access_tracker.insert(ColumnAccess { + let (dependent_join_node_id, subquery_node_id) = + Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); + let node = self.nodes.get_mut(&dependent_join_node_id).unwrap(); + let accesses = node.access_tracker.entry(subquery_node_id).or_default(); + accesses.push(ColumnAccess { col: col.clone(), node_id: access.node_id, stack: access.stack.clone(), + data_type: access.data_type.clone(), }); node.correlated_relations.insert(tbl_name.to_string()); } } } - fn mark_column_access(&mut self, child_id: usize, col: &Column) { + fn mark_outer_column_access( + &mut self, + child_id: usize, + data_type: &DataType, + col: &Column, + ) { // iter from bottom to top, the goal is to mark the dependent node // the current child's access let mut stack = self.stack.clone(); stack.push(child_id); - self.accessed_columns + self.all_outer_ref_columns .entry(col.clone()) .or_default() .push(ColumnAccess { stack, node_id: child_id, col: col.clone(), + data_type: data_type.clone(), }); } - fn build(&mut self, plan: LogicalPlan) -> Result<()> { - // let mut index = AlgebraIndex::default(); - plan.visit_with_subqueries(self)?; - Ok(()) - } - fn create_child_relationship(&mut self, parent: usize, child: usize) { - let operator = self.nodes.get_mut(&parent).unwrap(); - operator.children.push(child); + fn rewrite_subqueries_into_dependent_joins( + &mut self, + plan: LogicalPlan, + ) -> Result> { + plan.rewrite_with_subqueries(self) } } @@ -1805,7 +617,7 @@ impl DependentJoinTracker { current_id: 0, nodes: IndexMap::new(), stack: vec![], - accessed_columns: IndexMap::new(), + all_outer_ref_columns: IndexMap::new(), }; } } @@ -1821,11 +633,11 @@ struct Node { plan: LogicalPlan, parent: Option, - // This field is only set if the node is dependent join node + // This field is only meaningful if the node is dependent join node // it track which descendent nodes still accessing the outer columns provided by its // left child // the insertion order is top down - access_tracker: IndexSet, + access_tracker: IndexMap>, is_dependent_join_node: bool, is_subquery_node: bool, @@ -1885,9 +697,9 @@ fn print(a: &Expr) -> Result<()> { Ok(()) } -impl TreeNodeVisitor<'_> for DependentJoinTracker { +impl TreeNodeRewriter for DependentJoinTracker { type Node = LogicalPlan; - fn f_down(&mut self, node: &LogicalPlan) -> Result { + fn f_down(&mut self, node: LogicalPlan) -> Result> { self.current_id += 1; if self.root.is_none() { self.root = Some(self.current_id); @@ -1898,14 +710,24 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { let mut subquery_type = SubqueryType::None; // for each node, find which column it is accessing, which column it is providing // Set of columns current node access - match node { + match &node { LogicalPlan::Filter(f) => { if contains_subquery(&f.predicate) { is_dependent_join_node = true; } - f.predicate.outer_column_refs().into_iter().for_each(|f| { - self.mark_column_access(self.current_id, f); - }); + + f.predicate + .apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access( + self.current_id, + data_type, + col, + ); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); } LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { @@ -1921,22 +743,27 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { // 2.projection also provide some new columns // 3.if within projection exists multiple subquery, how does this work LogicalPlan::Projection(proj) => { - let mut outer_cols = HashSet::new(); for expr in &proj.expr { if contains_subquery(expr) { is_dependent_join_node = true; break; } - expr.add_outer_column_refs(&mut outer_cols); + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access( + self.current_id, + data_type, + col, + ); + } + Ok(TreeNodeRecursion::Continue) + })?; } - outer_cols.into_iter().for_each(|c| { - self.mark_column_access(self.current_id, c); - }); } LogicalPlan::Subquery(subquery) => { is_subquery_node = true; let parent = self.stack.last().unwrap(); - let parent_node = self.get_node_uncheck(parent); + let parent_node = self.nodes.get(parent).unwrap(); for expr in parent_node.plan.expressions() { expr.exists(|e| { let (found_sq, checking_type) = match e { @@ -1961,7 +788,7 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { } LogicalPlan::Aggregate(_) => {} _ => { - return internal_err!("impl scan for node type {:?}", node); + return internal_err!("impl f_down for node type {:?}", node); } }; @@ -1969,7 +796,6 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { None } else { let previous_node = self.stack.last().unwrap().to_owned(); - self.create_child_relationship(previous_node, self.current_id); Some(self.stack.last().unwrap().to_owned()) }; @@ -1983,24 +809,112 @@ impl TreeNodeVisitor<'_> for DependentJoinTracker { is_subquery_node, is_dependent_join_node, children: vec![], - access_tracker: IndexSet::new(), + access_tracker: IndexMap::new(), subquery_type, correlated_relations: IndexSet::new(), }, ); - - Ok(TreeNodeRecursion::Continue) + Ok(Transformed::no(node)) } + fn f_up(&mut self, node: LogicalPlan) -> Result> { + // if the node in the f_up meet any node in the stack, it means that node itself + // is a dependent join node,transformation by + // build a join based on + let current_node_id = self.stack.pop().unwrap(); + let node_info = self.nodes.get(¤t_node_id).unwrap(); + if !node_info.is_dependent_join_node { + return Ok(Transformed::no(node)); + } + assert!( + 1 == node.inputs().len(), + "a dependent join node cannot have more than 1 child" + ); - /// Invoked while traversing up the tree after children are visited. Default - /// implementation continues the recursion. - fn f_up(&mut self, _node: &Self::Node) -> Result { - self.stack.pop(); - Ok(TreeNodeRecursion::Continue) + let cloned_input = (**node.inputs().first().unwrap()).clone(); + let mut current_plan = LogicalPlanBuilder::new(cloned_input); + let mut subquery_alias_map = HashMap::new(); + let mut subquery_alias_by_node_id = HashMap::new(); + for (subquery_id, column_accesses) in node_info.access_tracker.iter() { + let subquery_node = self.nodes.get(subquery_id).unwrap(); + let subquery_input = subquery_node.plan.inputs().first().unwrap(); + let alias = self + .alias_generator + .next(&subquery_node.subquery_type.prefix()); + subquery_alias_by_node_id.insert(subquery_id, alias.clone()); + subquery_alias_map.insert(unwrap_subquery(subquery_node), alias); + } + + match &node { + LogicalPlan::Filter(filter) => { + let new_predicate = filter + .predicate + .clone() + .transform(|e| { + // replace any subquery expr with subquery_alias.output + // column + match e { + Expr::InSubquery(isq) => { + let alias = + subquery_alias_map.get(&isq.subquery).unwrap(); + // TODO: this assume that after decorrelation + // the dependent join will provide an extra column with the structure + // of "subquery_alias.output" + Ok(Transformed::yes(col(format!("{}.output", alias)))) + } + Expr::Exists(esq) => { + let alias = + subquery_alias_map.get(&esq.subquery).unwrap(); + Ok(Transformed::yes(col(format!("{}.output", alias)))) + } + Expr::ScalarSubquery(sq) => { + let alias = subquery_alias_map.get(&sq).unwrap(); + Ok(Transformed::yes(col(format!("{}.output", alias)))) + } + _ => Ok(Transformed::no(e)), + } + })? + .data; + let post_join_projections: Vec = filter + .input + .schema() + .columns() + .iter() + .map(|c| col(c.clone())) + .collect(); + for (subquery_id, column_accesses) in node_info.access_tracker.iter() { + let alias = subquery_alias_by_node_id.get(subquery_id).unwrap(); + let subquery_node = self.nodes.get(subquery_id).unwrap(); + let subquery_input = + subquery_node.plan.inputs().first().unwrap().clone(); + let right = LogicalPlanBuilder::new(subquery_input.clone()) + .alias(alias.clone())? + .build()?; + let on_exprs = column_accesses + .iter() + .map(|ac| (ac.data_type.clone(), ac.col.clone())) + .unique() + .map(|(data_type, column)| { + out_ref_col(data_type.clone(), column.clone()).eq(col(column)) + }); + + current_plan = + current_plan.join_on(right, JoinType::LeftDependent, on_exprs)?; + } + current_plan = current_plan + .filter(new_predicate.clone())? + .project(post_join_projections)?; + } + _ => { + unimplemented!("implement more dependent join node creation") + } + } + Ok(Transformed::yes(current_plan.build()?)) } } +#[derive(Debug)] +struct Decorrelation {} -impl OptimizerRule for DependentJoinTracker { +impl OptimizerRule for Decorrelation { fn supports_rewrite(&self) -> bool { true } @@ -2009,23 +923,30 @@ impl OptimizerRule for DependentJoinTracker { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - internal_err!("todo") + let mut transformer = DependentJoinTracker::new(config.alias_generator().clone()); + let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; + if rewrite_result.transformed { + // At this point, we have a logical plan with DependentJoin similar to duckdb + unimplemented!("implement dependent join decorrelation") + } + Ok(rewrite_result) } fn name(&self) -> &str { "decorrelate_subquery" } - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } + // The rewriter handle recursion + // fn apply_order(&self) -> Option { + // None + // } } #[cfg(test)] mod tests { use std::sync::Arc; - use datafusion_common::{alias::AliasGenerator, DFSchema, Result}; + use datafusion_common::{alias::AliasGenerator, DFSchema, Result, ScalarValue}; use datafusion_expr::{ exists, expr_fn::{self, col, not}, @@ -2053,197 +974,197 @@ mod tests { } #[test] fn two_simple_subqueries_at_the_same_level() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let in_sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1.clone()) - .filter(col("inner_table_lv1.c").eq(lit(2)))? - .project(vec![col("inner_table_lv1.a")])? - .build()?, - ); - let exist_sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), - )? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(exists(exist_sq_level1)) - .and(in_subquery(col("outer_table.b"), in_sq_level1)), - )? - .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - println!("{:?}", index); - let new_plan = index.root_dependent_join_elimination()?; - println!("{}", new_plan); - let expected = "\ - LeftSemi Join: Filter: outer_table.b = __in_sq_2.a\ - \n Filter: __exists_sq_1.mark\ - \n LeftMark Join: Filter: Boolean(true)\ - \n Filter: outer_table.a > Int32(1)\ - \n TableScan: outer_table\ - \n SubqueryAlias: __exists_sq_1\ - \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1\ - \n SubqueryAlias: __in_sq_2\ - \n Projection: inner_table_lv1.a\ - \n Filter: inner_table_lv1.c = Int32(2)\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let in_sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1.clone()) + // .filter(col("inner_table_lv1.c").eq(lit(2)))? + // .project(vec![col("inner_table_lv1.a")])? + // .build()?, + // ); + // let exist_sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter( + // col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + // )? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter( + // col("outer_table.a") + // .gt(lit(1)) + // .and(exists(exist_sq_level1)) + // .and(in_subquery(col("outer_table.b"), in_sq_level1)), + // )? + // .build()?; + // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + // index.rewrite_subqueries_into_dependent_joins(input1)?; + // println!("{:?}", index); + // let new_plan = index.root_dependent_join_elimination()?; + // println!("{}", new_plan); + // let expected = "\ + // LeftSemi Join: Filter: outer_table.b = __in_sq_2.a\ + // \n Filter: __exists_sq_1.mark\ + // \n LeftMark Join: Filter: Boolean(true)\ + // \n Filter: outer_table.a > Int32(1)\ + // \n TableScan: outer_table\ + // \n SubqueryAlias: __exists_sq_1\ + // \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ + // \n TableScan: inner_table_lv1\ + // \n SubqueryAlias: __in_sq_2\ + // \n Projection: inner_table_lv1.a\ + // \n Filter: inner_table_lv1.c = Int32(2)\ + // \n TableScan: inner_table_lv1"; + // assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] fn in_subquery_with_count_depth_1() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") - .gt(col("inner_table_lv1.c")), - ) - .and(col("inner_table_lv1.b").eq(lit(1))) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .eq(col("inner_table_lv1.b")), - ), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? - .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(in_subquery(col("outer_table.c"), sq_level1)), - )? - .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - let new_plan = index.root_dependent_join_elimination()?; - let expected = "\ - Filter: outer_table.a > Int32(1)\ - \n LeftSemi Join: Filter: outer_table.c = count_a\ - \n TableScan: outer_table\ - \n Projection: count(inner_table_lv1.a) AS count_a, inner_table_lv1.a, inner_table_lv1.c, inner_table_lv1.b\ - \n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]]\ - \n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter( + // col("inner_table_lv1.a") + // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.a") + // .gt(col("inner_table_lv1.c")), + // ) + // .and(col("inner_table_lv1.b").eq(lit(1))) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.b") + // .eq(col("inner_table_lv1.b")), + // ), + // )? + // .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + // .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter( + // col("outer_table.a") + // .gt(lit(1)) + // .and(in_subquery(col("outer_table.c"), sq_level1)), + // )? + // .build()?; + // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + // index.rewrite_subqueries_into_dependent_joins(input1)?; + // let new_plan = index.root_dependent_join_elimination()?; + // let expected = "\ + // Filter: outer_table.a > Int32(1)\ + // \n LeftSemi Join: Filter: outer_table.c = count_a\ + // \n TableScan: outer_table\ + // \n Projection: count(inner_table_lv1.a) AS count_a, inner_table_lv1.a, inner_table_lv1.c, inner_table_lv1.b\ + // \n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]]\ + // \n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b\ + // \n TableScan: inner_table_lv1"; + // assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] fn simple_exist_subquery_with_dependent_columns() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") - .gt(col("inner_table_lv1.c")), - ) - .and(col("inner_table_lv1.b").eq(lit(1))) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .eq(col("inner_table_lv1.b")), - ), - )? - .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .alias("outer_b_alias")])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? - .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - let new_plan = index.root_dependent_join_elimination()?; - let expected = "\ - Filter: __exists_sq_1.mark\ - \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ - \n Filter: outer_table.a > Int32(1)\ - \n TableScan: outer_table\ - \n SubqueryAlias: __exists_sq_1\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter( + // col("inner_table_lv1.a") + // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.a") + // .gt(col("inner_table_lv1.c")), + // ) + // .and(col("inner_table_lv1.b").eq(lit(1))) + // .and( + // out_ref_col(ArrowDataType::UInt32, "outer_table.b") + // .eq(col("inner_table_lv1.b")), + // ), + // )? + // .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") + // .alias("outer_b_alias")])? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + // .build()?; + // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + // index.rewrite_subqueries_into_dependent_joins(input1)?; + // let new_plan = index.root_dependent_join_elimination()?; + // let expected = "\ + // Filter: __exists_sq_1.mark\ + // \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ + // \n Filter: outer_table.a > Int32(1)\ + // \n TableScan: outer_table\ + // \n SubqueryAlias: __exists_sq_1\ + // \n Filter: inner_table_lv1.b = Int32(1)\ + // \n TableScan: inner_table_lv1"; + // assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] fn simple_exist_subquery_with_no_dependent_columns() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter(col("inner_table_lv1.b").eq(lit(1)))? - .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? - .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - let new_plan = index.root_dependent_join_elimination()?; - let expected = "\ - Filter: __exists_sq_1.mark\ - \n LeftMark Join: Filter: Boolean(true)\ - \n Filter: outer_table.a > Int32(1)\ - \n TableScan: outer_table\ - \n SubqueryAlias: __exists_sq_1\ - \n Projection: inner_table_lv1.b, inner_table_lv1.a\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter(col("inner_table_lv1.b").eq(lit(1)))? + // .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + // .build()?; + // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + // index.rewrite_subqueries_into_dependent_joins(input1)?; + // let new_plan = index.root_dependent_join_elimination()?; + // let expected = "\ + // Filter: __exists_sq_1.mark\ + // \n LeftMark Join: Filter: Boolean(true)\ + // \n Filter: outer_table.a > Int32(1)\ + // \n TableScan: outer_table\ + // \n SubqueryAlias: __exists_sq_1\ + // \n Projection: inner_table_lv1.b, inner_table_lv1.a\ + // \n Filter: inner_table_lv1.b = Int32(1)\ + // \n TableScan: inner_table_lv1"; + // assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] fn simple_decorrelate_with_in_subquery_no_dependent_column() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter(col("inner_table_lv1.b").eq(lit(1)))? - .project(vec![col("inner_table_lv1.b")])? - .build()?, - ); - - let input1 = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(in_subquery(col("outer_table.c"), sq_level1)), - )? - .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - let new_plan = index.root_dependent_join_elimination()?; - let expected = "\ - LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ - \n Filter: outer_table.a > Int32(1)\ - \n TableScan: outer_table\ - \n SubqueryAlias: __in_sq_1\ - \n Projection: inner_table_lv1.b\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + // let outer_table = test_table_scan_with_name("outer_table")?; + // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + // let sq_level1 = Arc::new( + // LogicalPlanBuilder::from(inner_table_lv1) + // .filter(col("inner_table_lv1.b").eq(lit(1)))? + // .project(vec![col("inner_table_lv1.b")])? + // .build()?, + // ); + + // let input1 = LogicalPlanBuilder::from(outer_table.clone()) + // .filter( + // col("outer_table.a") + // .gt(lit(1)) + // .and(in_subquery(col("outer_table.c"), sq_level1)), + // )? + // .build()?; + // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + // index.rewrite_subqueries_into_dependent_joins(input1)?; + // let new_plan = index.root_dependent_join_elimination()?; + // let expected = "\ + // LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ + // \n Filter: outer_table.a > Int32(1)\ + // \n TableScan: outer_table\ + // \n SubqueryAlias: __in_sq_1\ + // \n Projection: inner_table_lv1.b\ + // \n Filter: inner_table_lv1.b = Int32(1)\ + // \n TableScan: inner_table_lv1"; + // assert_eq!(expected, format!("{new_plan}")); Ok(()) } #[test] @@ -2278,8 +1199,9 @@ mod tests { )? .build()?; let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - index.build(input1)?; - let new_plan = index.root_dependent_join_elimination()?; + let transformed = index.rewrite_subqueries_into_dependent_joins(input1)?; + assert!(transformed.transformed); + let new_plan = transformed.data; println!("{new_plan}"); let expected = "\ LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND __in_sq_1.a = outer_table.a AND outer_table.a > __in_sq_1.c AND outer_table.b = __in_sq_1.b\ diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 13a3c79a47a2..a7607416a9c8 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -671,6 +671,7 @@ impl EquivalenceGroup { } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), + JoinType::LeftDependent => todo!(), } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 1401a153b06d..487cd8614e3e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -710,6 +710,7 @@ impl Unparser<'_> { }; match join.join_type { + JoinType::LeftDependent => todo!(), JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark @@ -1237,6 +1238,7 @@ impl Unparser<'_> { ast::JoinOperator::CrossJoin } }, + JoinType::LeftDependent => todo!(), JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), JoinType::Full => ast::JoinOperator::FullOuter(constraint), From 2171e5293b3892b689e544cf3a41357397c7188d Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 14:35:28 +0200 Subject: [PATCH 25/70] feat: impl dependent join rewriter --- .../common/src/functional_dependencies.rs | 3 - datafusion/common/src/join_type.rs | 9 - datafusion/expr/src/logical_plan/builder.rs | 1 - .../expr/src/logical_plan/invariants.rs | 1 - datafusion/expr/src/logical_plan/plan.rs | 2 - .../optimizer/src/decorrelate_general.rs | 167 ++++-------------- .../physical-expr/src/equivalence/class.rs | 1 - datafusion/sql/src/unparser/plan.rs | 2 - 8 files changed, 32 insertions(+), 154 deletions(-) diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index 2de7db873af1..c4f2805f8285 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -382,9 +382,6 @@ impl FunctionalDependencies { // All of the functional dependencies are lost in a FULL join: FunctionalDependencies::empty() } - JoinType::LeftDependent => { - unreachable!("LeftDependent should not be reached") - } } } diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index 7f962c065d7a..ac81d977b729 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -67,10 +67,6 @@ pub enum JoinType { /// /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf LeftMark, - /// TODO: document me more - /// used to represent a virtual join in a complex expr containing subquery(ies), - /// The actual join type depends on the correlated expr - LeftDependent, } impl JoinType { @@ -94,9 +90,6 @@ impl JoinType { JoinType::LeftMark => { unreachable!("LeftMark join type does not support swapping") } - JoinType::LeftDependent => { - unreachable!("Dependent join type does not support swapping") - } } } @@ -128,7 +121,6 @@ impl Display for JoinType { JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", JoinType::LeftMark => "LeftMark", - JoinType::LeftDependent => "LeftDependent", }; write!(f, "{join_type}") } @@ -149,7 +141,6 @@ impl FromStr for JoinType { "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), "LEFTMARK" => Ok(JoinType::LeftMark), - "LEFTDEPENDENT" => Ok(JoinType::LeftDependent), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2fe21387830b..d4d45226d354 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1620,7 +1620,6 @@ pub fn build_join_schema( .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect() } - JoinType::LeftDependent => todo!(), }; let func_dependencies = left.functional_dependencies().join( right.functional_dependencies(), diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 8e5b3156f53e..0c30c9785766 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -321,7 +321,6 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { })?; Ok(()) } - JoinType::LeftDependent => todo!(), }, LogicalPlan::Extension(_) => Ok(()), plan => check_no_outer_references(plan), diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 24d2dda4f5c5..edf5f1126be9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -546,7 +546,6 @@ impl LogicalPlan { join_type, .. }) => match join_type { - JoinType::LeftDependent => todo!(), JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { if left.schema().fields().is_empty() { right.head_output_expr() @@ -1332,7 +1331,6 @@ impl LogicalPlan { join_type, .. }) => match join_type { - JoinType::LeftDependent => todo!(), JoinType::Inner => Some(left.max_rows()? * right.max_rows()?), JoinType::Left | JoinType::Right | JoinType::Full => { match (left.max_rows()?, right.max_rows()?, join_type) { diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 9f8931322742..bf6c4b460e64 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -36,7 +36,6 @@ use datafusion_common::tree_node::{ use datafusion_common::{internal_err, Column, HashMap, Result}; use datafusion_expr::expr::{self, Exists, InSubquery}; use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; -use datafusion_expr::out_ref_col; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{ conjunction, disjunction, split_conjunction, split_disjunction, @@ -45,6 +44,7 @@ use datafusion_expr::{ binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, Filter, JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; +use datafusion_expr::{in_list, out_ref_col}; // use datafusion_sql::unparser::Unparser; use datafusion_sql::unparser::Unparser; @@ -55,8 +55,7 @@ use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use log::Log; -pub struct DependentJoinTracker { - root: Option, +pub struct DependentJoinRewriter { // each logical plan traversal will assign it a integer id current_id: usize, // each newly visted operator is inserted inside this map for tracking @@ -419,100 +418,7 @@ fn contains_count_expr( .unwrap() } -impl fmt::Debug for DependentJoinTracker { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "GeneralDecorrelation Tree:")?; - if let Some(root_op) = &self.root { - self.fmt_operator(f, *root_op, 0, false)?; - } else { - writeln!(f, " ")?; - } - Ok(()) - } -} - -impl DependentJoinTracker { - fn fmt_operator( - &self, - f: &mut fmt::Formatter<'_>, - node_id: usize, - indent: usize, - is_last: bool, - ) -> fmt::Result { - // Find the LogicalPlan corresponding to this Operator - let op = self.nodes.get(&node_id).unwrap(); - let lp = &op.plan; - - for i in 0..indent { - if i + 1 == indent { - if is_last { - write!(f, " ")?; // if last child, no vertical line - } else { - write!(f, "| ")?; // vertical line continues - } - } else { - write!(f, "| ")?; - } - } - if indent > 0 { - write!(f, "|--- ")?; // branch - } - - let unparsed_sql = match Unparser::default().plan_to_sql(lp) { - Ok(str) => str.to_string(), - Err(_) => "".to_string(), - }; - let (node_color, display_str) = match lp { - LogicalPlan::Subquery(sq) => ( - "\x1b[32m", - format!("\x1b[1m{}{}", lp.display(), sq.subquery), - ), - _ => ("\x1b[33m", lp.display().to_string()), - }; - - writeln!(f, "{} [{}] {}\x1b[0m", node_color, node_id, display_str)?; - if !unparsed_sql.is_empty() { - for i in 0..=indent { - if i < indent { - write!(f, "| ")?; - } else if indent > 0 { - write!(f, "| ")?; // Align with LogicalPlan text - } - } - - writeln!(f, "{}", unparsed_sql)?; - } - - for i in 0..=indent { - if i < indent { - write!(f, "| ")?; - } else if indent > 0 { - write!(f, "| ")?; // Align with LogicalPlan text - } - } - - let accessed_by_string = op - .access_tracker - .iter() - .map(|(_, ac)| ac.clone()) - .flatten() - .map(|ac| ac.debug()) - .collect::>() - .join(","); - // Now print the Operator details - writeln!(f, "accessed_by: {}", accessed_by_string,)?; - let len = op.children.len(); - - // Recursively print children if Operator has children - for (i, child) in op.children.iter().enumerate() { - let last = i + 1 == len; - - self.fmt_operator(f, *child, indent + 1, last)?; - } - - Ok(()) - } - +impl DependentJoinRewriter { // lowest common ancestor from stack // given a tree of // n1 @@ -543,7 +449,7 @@ impl DependentJoinTracker { let bi = stack_with_table_provider[i]; if ai == bi { - lca = Some((ai, stack_with_subquery[ai + 1])); + lca = Some((ai, stack_with_subquery[ai])); } else { break; } @@ -576,7 +482,6 @@ impl DependentJoinTracker { stack: access.stack.clone(), data_type: access.data_type.clone(), }); - node.correlated_relations.insert(tbl_name.to_string()); } } } @@ -609,10 +514,9 @@ impl DependentJoinTracker { } } -impl DependentJoinTracker { +impl DependentJoinRewriter { fn new(alias_generator: Arc) -> Self { - return DependentJoinTracker { - root: None, + return DependentJoinRewriter { alias_generator, current_id: 0, nodes: IndexMap::new(), @@ -631,7 +535,6 @@ impl ColumnAccess { struct Node { id: usize, plan: LogicalPlan, - parent: Option, // This field is only meaningful if the node is dependent join node // it track which descendent nodes still accessing the outer columns provided by its @@ -645,9 +548,7 @@ struct Node { // note that for dependent join nodes, there can be more than 1 // subquery children at a time, but always 1 outer-column-providing-child // which is at the last element - children: Vec, subquery_type: SubqueryType, - correlated_relations: IndexSet, } #[derive(Debug, Clone, Copy)] enum SubqueryType { @@ -697,16 +598,12 @@ fn print(a: &Expr) -> Result<()> { Ok(()) } -impl TreeNodeRewriter for DependentJoinTracker { +impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; fn f_down(&mut self, node: LogicalPlan) -> Result> { self.current_id += 1; - if self.root.is_none() { - self.root = Some(self.current_id); - } let mut is_subquery_node = false; let mut is_dependent_join_node = false; - let mut subquery_type = SubqueryType::None; // for each node, find which column it is accessing, which column it is providing // Set of columns current node access @@ -792,28 +689,19 @@ impl TreeNodeRewriter for DependentJoinTracker { } }; - let parent = if self.stack.is_empty() { - None - } else { - let previous_node = self.stack.last().unwrap().to_owned(); - Some(self.stack.last().unwrap().to_owned()) - }; - self.stack.push(self.current_id); self.nodes.insert( self.current_id, Node { id: self.current_id, - parent, plan: node.clone(), is_subquery_node, is_dependent_join_node, - children: vec![], access_tracker: IndexMap::new(), subquery_type, - correlated_relations: IndexSet::new(), }, ); + Ok(Transformed::no(node)) } fn f_up(&mut self, node: LogicalPlan) -> Result> { @@ -897,8 +785,9 @@ impl TreeNodeRewriter for DependentJoinTracker { out_ref_col(data_type.clone(), column.clone()).eq(col(column)) }); + // TODO: create a new dependent join logical plan current_plan = - current_plan.join_on(right, JoinType::LeftDependent, on_exprs)?; + current_plan.join_on(right, JoinType::LeftMark, on_exprs)?; } current_plan = current_plan .filter(new_predicate.clone())? @@ -923,7 +812,8 @@ impl OptimizerRule for Decorrelation { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let mut transformer = DependentJoinTracker::new(config.alias_generator().clone()); + let mut transformer = + DependentJoinRewriter::new(config.alias_generator().clone()); let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; if rewrite_result.transformed { // At this point, we have a logical plan with DependentJoin similar to duckdb @@ -954,11 +844,15 @@ mod tests { EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, }; use datafusion_functions_aggregate::{count::count, sum::sum}; + use insta::assert_snapshot; use regex_syntax::ast::LiteralKind; - use crate::test::{test_table_scan, test_table_scan_with_name}; + use crate::{ + assert_optimized_plan_eq_display_indent_snapshot, + test::{test_table_scan, test_table_scan_with_name}, + }; - use super::DependentJoinTracker; + use super::DependentJoinRewriter; use arrow::datatypes::DataType as ArrowDataType; #[test] fn simple_in_subquery_inside_from_expr() -> Result<()> { @@ -1198,19 +1092,22 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); + let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); let transformed = index.rewrite_subqueries_into_dependent_joins(input1)?; assert!(transformed.transformed); - let new_plan = transformed.data; - println!("{new_plan}"); - let expected = "\ - LeftSemi Join: Filter: outer_table.c = outer_table.b AS outer_b_alias AND __in_sq_1.a = outer_table.a AND outer_table.a > __in_sq_1.c AND outer_table.b = __in_sq_1.b\ - \n Filter: outer_table.a > Int32(1)\ - \n TableScan: outer_table\ - \n SubqueryAlias: __in_sq_1\ - \n Filter: inner_table_lv1.b = Int32(1)\ - \n TableScan: inner_table_lv1"; - assert_eq!(expected, format!("{new_plan}")); + + let formatted_plan = transformed.data.display_indent_schema(); + assert_snapshot!(formatted_plan, + @r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, outer_b_alias:UInt32;N] + LeftMark Join: Filter: outer_ref(outer_table.a) = outer_table.a AND outer_ref(outer_table.b) = outer_table.b [a:UInt32, b:UInt32, c:UInt32, mark;Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: __in_sq_1 [outer_b_alias:UInt32;N] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } } diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index a7607416a9c8..13a3c79a47a2 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -671,7 +671,6 @@ impl EquivalenceGroup { } JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), - JoinType::LeftDependent => todo!(), } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 487cd8614e3e..1401a153b06d 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -710,7 +710,6 @@ impl Unparser<'_> { }; match join.join_type { - JoinType::LeftDependent => todo!(), JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark @@ -1238,7 +1237,6 @@ impl Unparser<'_> { ast::JoinOperator::CrossJoin } }, - JoinType::LeftDependent => todo!(), JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), JoinType::Full => ast::JoinOperator::FullOuter(constraint), From 9d264374ded1886aef462892fdf5a739438ad75e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 14:41:06 +0200 Subject: [PATCH 26/70] chore: clean up unused function --- datafusion/expr/src/expr.rs | 32 -- datafusion/expr/src/expr_rewriter/mod.rs | 22 -- datafusion/expr/src/utils.rs | 25 -- datafusion/optimizer/Cargo.toml | 1 - .../optimizer/src/decorrelate_general.rs | 333 ------------------ 5 files changed, 413 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4cc4e347659c..b8e4204a9c9e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1655,25 +1655,6 @@ impl Expr { using_columns } - pub fn outer_column_refs(&self) -> HashSet<&Column> { - let mut using_columns = HashSet::new(); - self.add_outer_column_refs(&mut using_columns); - using_columns - } - - /// Adds references to all outer columns in this expression to the set - /// - /// See [`Self::column_refs`] for details - pub fn add_outer_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply(|expr| { - if let Expr::OuterReferenceColumn(_, col) = expr { - set.insert(col); - } - Ok(TreeNodeRecursion::Continue) - }) - .expect("traversal is infallible"); - } - /// Adds references to all columns in this expression to the set /// /// See [`Self::column_refs`] for details @@ -1734,19 +1715,6 @@ impl Expr { .expect("exists closure is infallible") } - /// Return true if the expression contains out reference(correlated) expressions. - pub fn contains_outer_from_relation(&self, outer_relation_name: &String) -> bool { - self.exists(|expr| { - if let Expr::OuterReferenceColumn(_, col) = expr { - if let Some(relation) = &col.relation { - return Ok(relation.table() == outer_relation_name); - } - } - Ok(false) - }) - .expect("exists closure is infallible") - } - /// Returns true if the expression node is volatile, i.e. whether it can return /// different results when evaluated multiple times with the same input. /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index b463dd43b228..90dcbce46b01 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -130,28 +130,6 @@ pub fn normalize_sorts( .collect() } -/// Recursively rename the table of all [`Column`] expressions in a given expression tree with -/// a new name, ignoring the `skip_tables` -pub fn replace_col_base_table( - expr: Expr, - skip_tables: &[&str], - new_table: String, -) -> Result { - expr.transform(|expr| { - if let Expr::Column(c) = &expr { - if let Some(relation) = &c.relation { - if !skip_tables.contains(&relation.table()) { - return Ok(Transformed::yes(Expr::Column( - c.with_relation(TableReference::bare(new_table.clone())), - ))); - } - } - } - Ok(Transformed::no(expr)) - }) - .data() -} - /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 061782a5aa33..552ce1502d46 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -933,31 +933,6 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { split_conjunction_impl(expr, vec![]) } -/// Splits a conjunctive [`Expr`] such as `A OR B OR C` => `[A, B, C]` -/// -/// See [`split_disjunction`] for more details and an example. -pub fn split_disjunction(expr: &Expr) -> Vec<&Expr> { - split_disjunction_impl(expr, vec![]) -} - -fn split_disjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { - match expr { - Expr::BinaryExpr(BinaryExpr { - right, - op: Operator::Or, - left, - }) => { - let exprs = split_disjunction_impl(left, exprs); - split_disjunction_impl(right, exprs) - } - Expr::Alias(Alias { expr, .. }) => split_disjunction_impl(expr, exprs), - other => { - exprs.push(other); - exprs - } - } -} - fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 1f303088a294..60358d20e2a1 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -46,7 +46,6 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -datafusion-sql = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index bf6c4b460e64..d1b56e24ffc0 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -45,11 +45,9 @@ use datafusion_expr::{ LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_expr::{in_list, out_ref_col}; -// use datafusion_sql::unparser::Unparser; use datafusion_sql::unparser::Unparser; use datafusion_sql::TableReference; -// use datafusion_sql::TableReference; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -77,194 +75,7 @@ struct ColumnAccess { col: Column, data_type: DataType, } -// pub struct GeneralDecorrelation { -// index: AlgebraIndex, -// } -// data structure to store equivalent columns -// Expr is used to represent either own column or outer referencing columns -#[derive(Clone)] -pub struct UnionFind { - parent: IndexMap, - rank: IndexMap, -} - -impl UnionFind { - pub fn new() -> Self { - Self { - parent: IndexMap::new(), - rank: IndexMap::new(), - } - } - - pub fn find(&mut self, x: Expr) -> Expr { - let p = self.parent.get(&x).cloned(); - match p { - None => { - self.parent.insert(x.clone(), x.clone()); - self.rank.insert(x.clone(), 0); - x - } - Some(parent) => { - if parent == x { - x - } else { - let root = self.find(parent.clone()); - self.parent.insert(x, root.clone()); - root - } - } - } - } - - pub fn union(&mut self, x: Expr, y: Expr) -> bool { - let root_x = self.find(x.clone()); - let root_y = self.find(y.clone()); - if root_x == root_y { - return false; - } - - let rank_x = *self.rank.get(&root_x).unwrap_or(&0); - let rank_y = *self.rank.get(&root_y).unwrap_or(&0); - - if rank_x < rank_y { - self.parent.insert(root_x, root_y); - } else if rank_x > rank_y { - self.parent.insert(root_y, root_x); - } else { - // asign y as children of x - self.parent.insert(root_y.clone(), root_x.clone()); - *self.rank.entry(root_x).or_insert(0) += 1; - } - - true - } -} - -#[derive(Clone)] -struct UnnestingInfo { - // join: DependentJoin, - domain: LogicalPlan, - parent: Option, -} -#[derive(Clone)] -struct Unnesting { - original_subquery: LogicalPlan, - info: Arc, // cclasses: union find data structure of equivalent columns - equivalences: UnionFind, - need_handle_count_bug: bool, - - // for each outer exprs on the left, the set of exprs - // on the right required pulling up for the join condition to happen - // i.e select * from t1 where t1.col1 = ( - // select count(*) from t2 where t2.col1 > t1.col2 + t2.col2 or t1.col3 = t1.col2 or t1.col4=2 and t1.col3=1) - // we do this by split the complex expr into conjuctive sets - // for each of such set, if there exists any or binary operator - // we substitute the whole binary operator as true and add every expr appearing in the or condition - // to grouped_by - // and push every - pulled_up_columns: Vec, - //these predicates are conjunctive - pulled_up_predicates: Vec, - - // need this tracked to later on transform for which original subquery requires which join using which metadata - count_exprs_detected: IndexSet, - // mapping from outer ref column to new column, if any - // i.e in some subquery ( - // ... where outer.column_c=inner.column_a - // ) - // and through union find we have outer.column_c = some_other_expr - // we can substitute the inner query with inner.column_a=some_other_expr - replaces: IndexMap, - - subquery_type: SubqueryType, - decorrelated_subquery: Option, -} -impl Unnesting { - fn get_replaced_col(&self, col: &Column) -> Column { - match self.replaces.get(col) { - Some(col) => col.clone(), - None => col.clone(), - } - } - - fn rewrite_all_pulled_up_expr( - &mut self, - alias_name: &String, - outer_relations: &[String], - ) -> Result<()> { - for expr in self.pulled_up_predicates.iter_mut() { - *expr = replace_col_base_table(expr.clone(), outer_relations, alias_name)?; - } - // let rewritten_projections = self - // .pulled_up_columns - // .iter() - // .map(|e| replace_col_base_table(e.clone(), &outer_relations, alias_name)) - // .collect::>>()?; - // self.pulled_up_projections = rewritten_projections; - Ok(()) - } -} - -pub fn replace_col_base_table( - expr: Expr, - skip_tables: &[String], - new_table: &String, -) -> Result { - Ok(expr - .transform(|expr| { - if let Expr::Column(c) = &expr { - if let Some(relation) = &c.relation { - if !skip_tables.contains(&relation.table().to_string()) { - return Ok(Transformed::yes(Expr::Column( - c.with_relation(TableReference::bare(new_table.clone())), - ))); - } - } - } - Ok(Transformed::no(expr)) - })? - .data) -} - -// TODO: looks like this function can be improved to allow more expr pull up -fn can_pull_up(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { - match op { - ExprOperator::Eq - | ExprOperator::Gt - | ExprOperator::Lt - | ExprOperator::GtEq - | ExprOperator::LtEq => {} - _ => return false, - } - match (left.deref(), right.deref()) { - (Expr::Column(_), right) => !right.any_column_refs(), - (left, Expr::Column(_)) => !left.any_column_refs(), - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) => - { - !right.any_column_refs() - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) => - { - !left.any_column_refs() - } - (_, _) => false, - } - } else { - false - } -} - -#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] -struct PulledUpExpr { - expr: Expr, - // multiple expr can be pulled up at a time, and because multiple subquery exists - // at the same level, we need to track which subquery the pulling up is happening for - subquery_node_id: usize, -} fn unwrap_subquery(n: &Node) -> &Subquery { match n.plan { LogicalPlan::Subquery(ref sq) => sq, @@ -274,150 +85,6 @@ fn unwrap_subquery(n: &Node) -> &Subquery { } } -fn extract_join_metadata_from_subquery( - expr: &Expr, - sq: &Subquery, - subquery_projected_exprs: &[Expr], - alias: &String, - outer_relations: &[String], -) -> Result<(bool, Option, Option)> { - let mut post_join_predicate = None; - - // this can either be a projection expr or a predicate expr - let mut transformed_expr = None; - - let found_sq = expr.exists(|e| match e { - Expr::InSubquery(isq) => { - if subquery_projected_exprs.len() != 1 { - return internal_err!( - "result of IN subquery should only involve one column" - ); - } - if isq.subquery == *sq { - let expr_with_alias = replace_col_base_table( - subquery_projected_exprs[0].clone(), - outer_relations, - alias, - )?; - if isq.negated { - transformed_expr = Some(binary_expr( - *isq.expr.clone(), - ExprOperator::NotEq, - strip_outer_reference(expr_with_alias), - )); - return Ok(true); - } - - transformed_expr = Some(binary_expr( - *isq.expr.clone(), - ExprOperator::Eq, - strip_outer_reference(expr_with_alias), - )); - return Ok(true); - } - return Ok(false); - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (exist, transformed, post_join_expr_from_left) = - extract_join_metadata_from_subquery( - left.as_ref(), - sq, - subquery_projected_exprs, - alias, - outer_relations, - )?; - if !exist { - let (right_exist, transformed_right, post_join_expr_from_right) = - extract_join_metadata_from_subquery( - right.as_ref(), - sq, - subquery_projected_exprs, - alias, - outer_relations, - )?; - if !right_exist { - return Ok(false); - } - if let Some(transformed_right) = transformed_right { - transformed_expr = - Some(binary_expr(*left.clone(), *op, transformed_right)); - } - if let Some(transformed_right) = post_join_expr_from_right { - post_join_predicate = - Some(binary_expr(*left.clone(), *op, transformed_right)); - } - - return Ok(true); - } - if let Some(transformed) = transformed { - transformed_expr = Some(binary_expr(transformed, *op, *right.clone())); - } - if let Some(transformed) = post_join_expr_from_left { - post_join_predicate = Some(binary_expr(transformed, *op, *right.clone())); - } - return Ok(true); - } - Expr::Exists(Exists { - subquery: inner_sq, - negated, - .. - }) => { - if inner_sq.clone() == *sq { - let mark_predicate = if *negated { !col("mark") } else { col("mark") }; - post_join_predicate = Some(mark_predicate); - return Ok(true); - } - return Ok(false); - } - Expr::ScalarSubquery(ssq) => { - if subquery_projected_exprs.len() != 1 { - return internal_err!( - "result of scalar subquery should only involve one column" - ); - } - if let LogicalPlan::Subquery(inner_sq) = ssq.subquery.as_ref() { - if inner_sq.clone() == *sq { - transformed_expr = Some(subquery_projected_exprs[0].clone()); - return Ok(true); - } - } - return Ok(false); - } - _ => Ok(false), - })?; - return Ok((found_sq, transformed_expr, post_join_predicate)); -} - -// impl Default for GeneralDecorrelation { -// fn default() -> Self { -// return GeneralDecorrelation { -// index: AlgebraIndex::default(), -// }; -// } -// } -struct GeneralDecorrelationResult { - // i.e for aggregation, dependent columns are added to the projection for joining - added_columns: Vec, - // the reason is, unnesting group by happen at lower nodes, - // but the filtering (if any) of such expr may happen higher node - // (because of known count_bug) - count_expr_map: HashSet, -} - -fn contains_count_expr( - expr: &Expr, - // schema: &DFSchemaRef, - // expr_result_map_for_count_bug: &mut HashMap, -) -> bool { - expr.exists(|e| match e { - Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { - Ok(func.name() == "count") - } - _ => Ok(false), - }) - .unwrap() -} - impl DependentJoinRewriter { // lowest common ancestor from stack // given a tree of From 24d122377fdda36a3d50e4175c70d521b40f1eae Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 14:47:08 +0200 Subject: [PATCH 27/70] chore: clean up debug slt --- .../optimizer/src/decorrelate_general.rs | 15 +-- datafusion/sqllogictest/test_files/debug.slt | 61 --------- datafusion/sqllogictest/test_files/debug2.slt | 114 ----------------- .../sqllogictest/test_files/debug_count.slt | 116 ------------------ .../sqllogictest/test_files/subquery.slt | 7 -- .../sqllogictest/test_files/unsupported.slt | 76 ------------ 6 files changed, 2 insertions(+), 387 deletions(-) delete mode 100644 datafusion/sqllogictest/test_files/debug.slt delete mode 100644 datafusion/sqllogictest/test_files/debug2.slt delete mode 100644 datafusion/sqllogictest/test_files/debug_count.slt delete mode 100644 datafusion/sqllogictest/test_files/unsupported.slt diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index d1b56e24ffc0..ff1ac64527fe 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -37,17 +37,13 @@ use datafusion_common::{internal_err, Column, HashMap, Result}; use datafusion_expr::expr::{self, Exists, InSubquery}; use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; use datafusion_expr::select_expr::SelectExpr; -use datafusion_expr::utils::{ - conjunction, disjunction, split_conjunction, split_disjunction, -}; +use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use datafusion_expr::{ binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, Filter, JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_expr::{in_list, out_ref_col}; -use datafusion_sql::unparser::Unparser; -use datafusion_sql::TableReference; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -258,13 +254,6 @@ fn contains_subquery(expr: &Expr) -> bool { .expect("Inner is always Ok") } -fn print(a: &Expr) -> Result<()> { - let unparser = Unparser::default(); - let round_trip_sql = unparser.expr_to_sql(a)?.to_string(); - println!("{}", round_trip_sql); - Ok(()) -} - impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; fn f_down(&mut self, node: LogicalPlan) -> Result> { @@ -767,7 +756,7 @@ mod tests { assert_snapshot!(formatted_plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, outer_b_alias:UInt32;N] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] LeftMark Join: Filter: outer_ref(outer_table.a) = outer_table.a AND outer_ref(outer_table.b) = outer_table.b [a:UInt32, b:UInt32, c:UInt32, mark;Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __in_sq_1 [outer_b_alias:UInt32;N] diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt deleted file mode 100644 index d56f2a210d64..000000000000 --- a/datafusion/sqllogictest/test_files/debug.slt +++ /dev/null @@ -1,61 +0,0 @@ -statement ok -CREATE TABLE students( - id int, - name varchar, - major varchar, - year timestamp -) -AS VALUES - (1,'A','math','2014-01-01T00:00:00'::timestamp), - (2,'B','math','2015-01-01T00:00:00'::timestamp), - (3,'C','math','2016-01-01T00:00:00'::timestamp) -; - -statement ok -CREATE TABLE exams( - sid int, - curriculum varchar, - grade int, - date timestamp -) -AS VALUES - (1, 'math', 10, '2014-01-01T00:00:00'::timestamp), - (2, 'math', 9, '2015-01-01T00:00:00'::timestamp), - (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) -; - -## Multi-level correlated subquery -##query TT -##explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid -##and e2.curriculum=(select max(grade) from exams e3 group by curriculum)) -##---- - -# query TT -#explain select * from exams e1 where grade > (select avg(grade) from exams as e2 where e1.sid = e2.sid -# and e2.sid='some fixed value 1' -# or e2.sid='some fixed value 2' -#) -# ---- - - -## select * from exams e1, ( -## select avg(score) as avg_score, e2.sid, e2.year,e2.subject from exams e2 group by e2.sid,e2.year,e2.subject -## ) as pulled_up where e1.score > pulled_up.avg_score - -query TT -explain select s.name, ( - select count(e2.grade) as c from exams e2 - having c > 10 -) from students s ----- - -## query TT -## explain select s.name, e.curriculum from students s, exams e where s.id=e.sid -## and s.major='math' and 0 < ( -## select count(e2.grade) from exams e2 where s.id=e2.sid and e2.grade>0 -## having count(e2.grade) < 10 -## -- or (s.year1) from t1 ----- -logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.cnt_plus_2 AS cnt_plus_2 -02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int -03)----TableScan: t1 projection=[t1_id, t1_int] -04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: count(Int64(1)) AS count(*) + Int64(2) AS cnt_plus_2, t2.t2_int -06)--------Filter: count(Int64(1)) > Int64(1) -07)----------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] -08)------------TableScan: t2 projection=[t2_int] - - -query TT -explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 ----- -logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(Int64(1)) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 -02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int -03)----TableScan: t1 projection=[t1_id, t1_int] -04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: count(Int64(1)) + Int64(2) AS cnt_plus_2, t2.t2_int, count(Int64(1)), Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] -07)----------TableScan: t2 projection=[t2_int] - -query TT -explain select t1.t1_int from t1 where (select cnt from (select count(*) as cnt, sum(t2_int) from t2 where t1.t1_int = t2.t2_int)) = 0 ----- -logical_plan -01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.cnt END = Int64(0) -03)----Projection: t1.t1_int, __scalar_sq_1.cnt, __scalar_sq_1.__always_true -04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int -05)--------TableScan: t1 projection=[t1_int] -06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: count(Int64(1)) AS cnt, t2.t2_int, Boolean(true) AS __always_true -08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] -09)--------------TableScan: t2 projection=[t2_int] - diff --git a/datafusion/sqllogictest/test_files/debug_count.slt b/datafusion/sqllogictest/test_files/debug_count.slt deleted file mode 100644 index d52df0afba83..000000000000 --- a/datafusion/sqllogictest/test_files/debug_count.slt +++ /dev/null @@ -1,116 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# make sure to a batch size smaller than row number of the table. -statement ok -set datafusion.execution.batch_size = 2; - -############# -## Subquery Tests -############# - - -############# -## Setup test data table -############# -# there tables for subquery -statement ok -CREATE TABLE t0(t0_id INT, t0_name TEXT, t0_int INT) AS VALUES -(11, 'o', 6), -(22, 'p', 7), -(33, 'q', 8), -(44, 'r', 9); - -statement ok -CREATE TABLE t1(t1_id INT, t1_name TEXT, t1_int INT) AS VALUES -(11, 'a', 1), -(22, 'b', 2), -(33, 'c', 3), -(44, 'd', 4); - -statement ok -CREATE TABLE t2(t2_id INT, t2_name TEXT, t2_int INT) AS VALUES -(11, 'z', 3), -(22, 'y', 1), -(44, 'x', 3), -(55, 'w', 3); - -statement ok -CREATE TABLE t3(t3_id INT PRIMARY KEY, t3_name TEXT, t3_int INT) AS VALUES -(11, 'e', 3), -(22, 'f', 1), -(44, 'g', 3), -(55, 'h', 3); - -statement ok -CREATE EXTERNAL TABLE IF NOT EXISTS customer ( - c_custkey BIGINT, - c_name VARCHAR, - c_address VARCHAR, - c_nationkey BIGINT, - c_phone VARCHAR, - c_acctbal DECIMAL(15, 2), - c_mktsegment VARCHAR, - c_comment VARCHAR, -) STORED AS CSV LOCATION '../core/tests/tpch-csv/customer.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); - -statement ok -CREATE EXTERNAL TABLE IF NOT EXISTS orders ( - o_orderkey BIGINT, - o_custkey BIGINT, - o_orderstatus VARCHAR, - o_totalprice DECIMAL(15, 2), - o_orderdate DATE, - o_orderpriority VARCHAR, - o_clerk VARCHAR, - o_shippriority INTEGER, - o_comment VARCHAR, -) STORED AS CSV LOCATION '../core/tests/tpch-csv/orders.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); - -statement ok -CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( - l_orderkey BIGINT, - l_partkey BIGINT, - l_suppkey BIGINT, - l_linenumber INTEGER, - l_quantity DECIMAL(15, 2), - l_extendedprice DECIMAL(15, 2), - l_discount DECIMAL(15, 2), - l_tax DECIMAL(15, 2), - l_returnflag VARCHAR, - l_linestatus VARCHAR, - l_shipdate DATE, - l_commitdate DATE, - l_receiptdate DATE, - l_shipinstruct VARCHAR, - l_shipmode VARCHAR, - l_comment VARCHAR, -) STORED AS CSV LOCATION '../core/tests/tpch-csv/lineitem.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); - - -#correlated_scalar_subquery_count_agg -query TT -explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 ----- -logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) -02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int -03)----TableScan: t1 projection=[t1_id, t1_int] -04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: count(Int64(1)) AS count(*), t2.t2_int, Boolean(true) AS __always_true -06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1))]] -07)----------TableScan: t2 projection=[t2_int] diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 878ba5da7eba..a0ac15b740d7 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -873,13 +873,6 @@ SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) #correlated_scalar_subquery_count_agg_where_clause query TT explain select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = t2.t2_id) < t1.t1_int -select t1.t1_int from t1, -( - select count(*) as count_all from t2, ( - select distinct t1_id - ) as domain where t2.t2_id = domain.t1_id -) as pulled_up -where t1.t1_id=pulled_up.t1_id and pulled_up.count_all < t1.t1_int ---- logical_plan 01)Projection: t1.t1_int diff --git a/datafusion/sqllogictest/test_files/unsupported.slt b/datafusion/sqllogictest/test_files/unsupported.slt deleted file mode 100644 index b4c581d332e0..000000000000 --- a/datafusion/sqllogictest/test_files/unsupported.slt +++ /dev/null @@ -1,76 +0,0 @@ -statement ok -CREATE TABLE students( - id int, - name varchar, - major varchar, - year timestamp -) -AS VALUES - (1,'A','math','2014-01-01T00:00:00'::timestamp), - (2,'B','math','2015-01-01T00:00:00'::timestamp), - (3,'C','math','2016-01-01T00:00:00'::timestamp) -; - -statement ok -CREATE TABLE exams( - sid int, - curriculum varchar, - grade int, - date timestamp -) -AS VALUES - (1, 'math', 10, '2014-01-01T00:00:00'::timestamp), - (2, 'math', 9, '2015-01-01T00:00:00'::timestamp), - (3, 'math', 4, '2016-01-01T00:00:00'::timestamp) -; - --- explain select s.name, e.curriculum from students s, exams e where s.id=e.sid --- and (s.major='math') and e.grade < ( --- select avg(e2.grade) from exams e2 where s.id=e2.sid or ( --- s.year e2.date and d.major = e2.curriculum - ) group by id,year,major -) as pulled where -s.id=e.sid -and e.grade < pulled.m -and ( - pulled.id=s.id and pulled.year=s.year and pulled.major=s.major -- join with the domain columns -) ----- -manh math 9.5 -bao math 7.666666666667 - -query TT -explain select s.name, e.curriculum from students s, exams e where s.id=e.sid -and (s.major='math') and e.grade < ( - select avg(e2.grade) from exams e2 where s.id=e2.sid or ( - s.year) -10)----------Subquery: -11)------------Projection: avg(e2.grade) -12)--------------Aggregate: groupBy=[[]], aggr=[[avg(CAST(e2.grade AS Float64))]] -13)----------------SubqueryAlias: e2 -14)------------------Filter: outer_ref(s.id) = exams.sid OR outer_ref(s.year) < exams.date AND exams.curriculum = outer_ref(s.major) -15)--------------------TableScan: exams -16)----------TableScan: exams projection=[sid, curriculum, grade] -physical_plan_error This feature is not implemented: Physical plan does not support logical expression ScalarSubquery() \ No newline at end of file From 3533cd18949e9152153eaf48326b8f88067f687f Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 15:54:43 +0200 Subject: [PATCH 28/70] chore: simple logical plan type for dependent join --- datafusion/expr/src/logical_plan/display.rs | 1 + datafusion/expr/src/logical_plan/mod.rs | 11 +-- datafusion/expr/src/logical_plan/plan.rs | 77 +++++++++++++++++++ datafusion/expr/src/logical_plan/tree_node.rs | 6 ++ .../optimizer/src/common_subexpr_eliminate.rs | 3 +- .../optimizer/src/decorrelate_general.rs | 37 ++++++--- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/sql/src/unparser/plan.rs | 5 +- 8 files changed, 125 insertions(+), 19 deletions(-) diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 14758b61e859..dfcccbb087ff 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -485,6 +485,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } + LogicalPlan::DependentJoin(..) => todo!(), LogicalPlan::Join(Join { on: ref keys, filter, diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a55f4d97b212..8bd1417b6f06 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -37,11 +37,12 @@ pub use ddl::{ }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, - DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, - Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, ColumnUnnestList, DependentJoin, + DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainFormat, + Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, + Partitioning, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, + StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, + Unnest, Values, Window, }; pub use statement::{ Deallocate, Execute, Prepare, SetVariable, Statement, TransactionAccessMode, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index edf5f1126be9..e00ef51aee86 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -287,6 +287,63 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// A node type that only exist during subquery decorrelation + /// TODO: maybe we can avoid creating new type of LogicalPlan for this usecase + DependentJoin(DependentJoin), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct DependentJoin { + pub schema: DFSchemaRef, + // all the columns provided by the LHS being referenced + // in the RHS (and its children nested subqueries, if any) (note that not all outer_refs from the RHS are mentioned in this vectors + // because RHS may reference columns provided somewhere from the above join) + pub correlated_columns: Vec, + // the upper expr that containing the subquery expr + // i.e for predicates: where outer = scalar_sq + 1 + // correlated exprs are `scalar_sq + 1` + pub subquery_expr: Expr, + // subquery depth + // begins with depth = 1 + pub depth: usize, + pub left: Arc, + // dependent side accessing columns from left hand side (and maybe columns) + // belong to the parent dependent join node in case of recursion) + pub right: Arc, +} + +impl PartialOrd for DependentJoin { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableJoin<'a> { + correlated_columns: &'a Vec, + // the upper expr that containing the subquery expr + // i.e for predicates: where outer = scalar_sq + 1 + // correlated exprs are `scalar_sq + 1` + subquery_expr: &'a Expr, + + depth: &'a usize, + left: &'a Arc, + // dependent side accessing columns from left hand side (and maybe columns) + // belong to the parent dependent join node in case of recursion) + right: &'a Arc, + } + let comparable_self = ComparableJoin { + left: &self.left, + right: &self.right, + correlated_columns: &self.correlated_columns, + subquery_expr: &self.subquery_expr, + depth: &self.depth, + }; + let comparable_other = ComparableJoin { + left: &other.left, + right: &other.right, + correlated_columns: &other.correlated_columns, + subquery_expr: &other.subquery_expr, + depth: &other.depth, + }; + comparable_self.partial_cmp(&comparable_other) + } } impl Default for LogicalPlan { @@ -318,6 +375,7 @@ impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { match self { + LogicalPlan::DependentJoin(DependentJoin { schema, .. }) => schema, LogicalPlan::EmptyRelation(EmptyRelation { schema, .. }) => schema, LogicalPlan::Values(Values { schema, .. }) => schema, LogicalPlan::TableScan(TableScan { @@ -452,6 +510,9 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], + LogicalPlan::DependentJoin(DependentJoin { left, right, .. }) => { + vec![left, right] + } LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -540,6 +601,7 @@ impl LogicalPlan { | LogicalPlan::Limit(Limit { input, .. }) | LogicalPlan::Repartition(Repartition { input, .. }) | LogicalPlan::Window(Window { input, .. }) => input.head_output_expr(), + LogicalPlan::DependentJoin(..) => todo!(), LogicalPlan::Join(Join { left, right, @@ -650,6 +712,7 @@ impl LogicalPlan { }) => Aggregate::try_new(input, group_expr, aggr_expr) .map(LogicalPlan::Aggregate), LogicalPlan::Sort(_) => Ok(self), + LogicalPlan::DependentJoin(_) => todo!(), LogicalPlan::Join(Join { left, right, @@ -837,6 +900,7 @@ impl LogicalPlan { Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter) } + LogicalPlan::DependentJoin(DependentJoin { left, right, .. }) => todo!(), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -1293,6 +1357,7 @@ impl LogicalPlan { /// If `Some(n)` then the plan can return at most `n` rows but may return fewer. pub fn max_rows(self: &LogicalPlan) -> Option { match self { + LogicalPlan::DependentJoin(DependentJoin { left, .. }) => left.max_rows(), LogicalPlan::Projection(Projection { input, .. }) => input.max_rows(), LogicalPlan::Filter(filter) => { if filter.is_scalar() { @@ -1885,6 +1950,18 @@ impl LogicalPlan { Ok(()) } + + LogicalPlan::DependentJoin(DependentJoin{ + left,right, + subquery_expr, + correlated_columns, + .. + }) => { + let correlated_str = correlated_columns.iter().map(|c|{ + format!("{c}") + }).collect::>().join(", "); + write!(f,"DependentJoin on {} with expr {}",correlated_str,subquery_expr) + }, LogicalPlan::Join(Join { on: ref keys, filter, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f6e1e025387..c07fe828b907 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -53,6 +53,8 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, Result}; +use super::plan::DependentJoin; + impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, @@ -356,6 +358,7 @@ impl TreeNode for LogicalPlan { | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } | LogicalPlan::DescribeTable(_) => Transformed::no(self), + LogicalPlan::DependentJoin(..) => todo!(), }) } } @@ -408,6 +411,8 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { + // TODO: apply expr on the subquery + LogicalPlan::DependentJoin(..) => Ok(TreeNodeRecursion::Continue), LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), @@ -495,6 +500,7 @@ impl LogicalPlan { mut f: F, ) -> Result> { Ok(match self { + LogicalPlan::DependentJoin(DependentJoin { .. }) => todo!(), LogicalPlan::Projection(Projection { expr, input, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0..825dc804e1c1 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -564,7 +564,8 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::DependentJoin(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index ff1ac64527fe..a9cbfc514417 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -39,8 +39,8 @@ use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; use datafusion_expr::{ - binary_expr, col, expr_fn, lit, BinaryExpr, Cast, Expr, Filter, JoinType, - LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, + binary_expr, col, expr_fn, lit, BinaryExpr, Cast, DependentJoin, Expr, Filter, + JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, }; use datafusion_expr::{in_list, out_ref_col}; @@ -52,6 +52,7 @@ use log::Log; pub struct DependentJoinRewriter { // each logical plan traversal will assign it a integer id current_id: usize, + subquery_depth: usize, // each newly visted operator is inserted inside this map for tracking nodes: IndexMap, // all the node ids from root to the current node @@ -185,6 +186,7 @@ impl DependentJoinRewriter { nodes: IndexMap::new(), stack: vec![], all_outer_ref_columns: IndexMap::new(), + subquery_depth: 0, }; } } @@ -345,6 +347,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { } }; + if is_dependent_join_node { + self.subquery_depth += 1 + } self.stack.push(self.current_id); self.nodes.insert( self.current_id, @@ -369,6 +374,8 @@ impl TreeNodeRewriter for DependentJoinRewriter { if !node_info.is_dependent_join_node { return Ok(Transformed::no(node)); } + let current_subquery_depth = self.subquery_depth; + self.subquery_depth -= 1; assert!( 1 == node.inputs().len(), "a dependent join node cannot have more than 1 child" @@ -433,17 +440,25 @@ impl TreeNodeRewriter for DependentJoinRewriter { let right = LogicalPlanBuilder::new(subquery_input.clone()) .alias(alias.clone())? .build()?; - let on_exprs = column_accesses + let correlated_columns = column_accesses .iter() - .map(|ac| (ac.data_type.clone(), ac.col.clone())) + .map(|ac| (ac.col.clone())) .unique() - .map(|(data_type, column)| { - out_ref_col(data_type.clone(), column.clone()).eq(col(column)) - }); + .collect(); + let left = current_plan.build()?; // TODO: create a new dependent join logical plan - current_plan = - current_plan.join_on(right, JoinType::LeftMark, on_exprs)?; + let dependent_join = DependentJoin { + left: Arc::new(left.clone()), + right: Arc::new(right), + schema: left.schema().clone(), + correlated_columns, + depth: current_subquery_depth, + subquery_expr: lit(true), + }; + current_plan = LogicalPlanBuilder::new(LogicalPlan::DependentJoin( + dependent_join, + )); } current_plan = current_plan .filter(new_predicate.clone())? @@ -756,8 +771,8 @@ mod tests { assert_snapshot!(formatted_plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] - LeftMark Join: Filter: outer_ref(outer_table.a) = outer_table.a AND outer_ref(outer_table.b) = outer_table.b [a:UInt32, b:UInt32, c:UInt32, mark;Boolean] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32] + DependentJoin on outer_table.a, outer_table.b with expr Boolean(true) [a:UInt32, b:UInt32, c:UInt32] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] SubqueryAlias: __in_sq_1 [outer_b_alias:UInt32;N] Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index a443c4cc81ef..5a0b5c5ae8f3 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -347,7 +347,8 @@ fn optimize_projections( LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Values(_) - | LogicalPlan::DescribeTable(_) => { + | LogicalPlan::DescribeTable(_) + | LogicalPlan::DependentJoin(_) => { // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } @@ -382,6 +383,7 @@ fn optimize_projections( dependency_indices.clone(), )] } + LogicalPlan::DependentJoin(..) => todo!(), }; // Required indices are currently ordered (child0, child1, ...) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 1401a153b06d..80e9232987a1 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,7 +124,10 @@ impl Unparser<'_> { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Unnest(_) => not_impl_err!("Unsupported plan: {plan:?}"), + | LogicalPlan::Unnest(_) + | LogicalPlan::DependentJoin(_) => { + not_impl_err!("Unsupported plan: {plan:?}") + } } } From e1002f8f0ee832c2faf4b7013cb28c7193973180 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 21:04:15 +0200 Subject: [PATCH 29/70] fix: recursive dependent join rewrite --- datafusion/expr/src/logical_plan/builder.rs | 57 ++ datafusion/expr/src/logical_plan/plan.rs | 11 +- .../optimizer/src/decorrelate_general.rs | 626 ++++++++++-------- 3 files changed, 411 insertions(+), 283 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d4d45226d354..b4a71ef8d0a4 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -49,6 +49,7 @@ use crate::{ use super::dml::InsertOp; use super::plan::{ColumnUnnestList, ExplainFormat}; +use super::DependentJoin; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; @@ -880,6 +881,41 @@ impl LogicalPlanBuilder { )))) } + /// + pub fn dependent_join( + self, + right: LogicalPlan, + correlated_columns: Vec, + subquery_expr: Expr, + subquery_depth: usize, + subquery_name: String, + ) -> Result { + let left = self.build()?; + let mut schema = left.schema(); + let qualified_fields = schema + .iter() + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(subquery_output_field( + &subquery_name, + right.schema(), + &subquery_expr, + ))) + .collect(); + let func_dependencies = schema.functional_dependencies(); + let metadata = schema.metadata().clone(); + let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; + + Ok(Self::new(LogicalPlan::DependentJoin(DependentJoin { + schema: DFSchemaRef::new(dfschema), + left: Arc::new(left), + right: Arc::new(right), + correlated_columns, + subquery_expr, + subquery_name, + subquery_depth, + }))) + } + /// Apply a join to `right` using explicitly specified columns and an /// optional filter expression. /// @@ -1544,6 +1580,27 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { ) } +fn subquery_output_field( + subquery_alias: &String, + right_schema: &DFSchema, + subquery_expr: &Expr, +) -> (Option, Arc) { + // TODO: check nullability + let field = match subquery_expr { + Expr::InSubquery(_) => Arc::new(Field::new("output", DataType::Boolean, false)), + Expr::Exists(_) => Arc::new(Field::new("output", DataType::Boolean, false)), + Expr::ScalarSubquery(sq) => { + let data_type = sq.subquery.schema().field(0).data_type().clone(); + Arc::new(Field::new("output", data_type, false)) + } + _ => { + unreachable!() + } + }; + + (Some(TableReference::bare(subquery_alias.clone())), field) +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e00ef51aee86..627a548d50b1 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -303,13 +303,13 @@ pub struct DependentJoin { // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` pub subquery_expr: Expr, - // subquery depth // begins with depth = 1 - pub depth: usize, + pub subquery_depth: usize, pub left: Arc, // dependent side accessing columns from left hand side (and maybe columns) // belong to the parent dependent join node in case of recursion) pub right: Arc, + pub subquery_name: String, } impl PartialOrd for DependentJoin { @@ -333,14 +333,14 @@ impl PartialOrd for DependentJoin { right: &self.right, correlated_columns: &self.correlated_columns, subquery_expr: &self.subquery_expr, - depth: &self.depth, + depth: &self.subquery_depth, }; let comparable_other = ComparableJoin { left: &other.left, right: &other.right, correlated_columns: &other.correlated_columns, subquery_expr: &other.subquery_expr, - depth: &other.depth, + depth: &other.subquery_depth, }; comparable_self.partial_cmp(&comparable_other) } @@ -1955,12 +1955,13 @@ impl LogicalPlan { left,right, subquery_expr, correlated_columns, + subquery_depth, .. }) => { let correlated_str = correlated_columns.iter().map(|c|{ format!("{c}") }).collect::>().join(", "); - write!(f,"DependentJoin on {} with expr {}",correlated_str,subquery_expr) + write!(f,"DependentJoin on [{}] with expr {} depth {}",correlated_str,subquery_expr,subquery_depth) }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index a9cbfc514417..536190243b61 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -73,15 +73,6 @@ struct ColumnAccess { data_type: DataType, } -fn unwrap_subquery(n: &Node) -> &Subquery { - match n.plan { - LogicalPlan::Subquery(ref sq) => sq, - _ => { - unreachable!() - } - } -} - impl DependentJoinRewriter { // lowest common ancestor from stack // given a tree of @@ -125,16 +116,16 @@ impl DependentJoinRewriter { // because the column providers are visited after column-accessor // (function visit_with_subqueries always visit the subquery before visiting the other children) // we can always infer the LCA inside this function, by getting the deepest common parent - fn conclude_lowest_dependent_join_node( - &mut self, - child_id: usize, - col: &Column, - tbl_name: &str, - ) { + fn conclude_lowest_dependent_join_node(&mut self, child_id: usize, col: &Column) { if let Some(accesses) = self.all_outer_ref_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); + cur_stack.push(child_id); + if col.name() == "outer_table.a" || col.name == "a" { + println!("{:?}", access); + println!("{:?}", cur_stack); + } // this is a dependent join node let (dependent_join_node_id, subquery_node_id) = Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); @@ -204,7 +195,8 @@ struct Node { // This field is only meaningful if the node is dependent join node // it track which descendent nodes still accessing the outer columns provided by its // left child - // the insertion order is top down + // the key of this map is node_id of the children subquery + // and insertion order matters here, and thus we use IndexMap access_tracker: IndexMap>, is_dependent_join_node: bool, @@ -215,36 +207,55 @@ struct Node { // which is at the last element subquery_type: SubqueryType, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum SubqueryType { None, - In, - Exists, - Scalar, + In(Expr), + Exists(Expr), + Scalar(Expr), } + impl SubqueryType { + fn unwrap_expr(&self) -> Expr { + match self { + SubqueryType::None => { + panic!("not reached") + } + SubqueryType::In(e) | SubqueryType::Exists(e) | SubqueryType::Scalar(e) => { + e.clone() + } + } + } fn default_join_type(&self) -> JoinType { match self { SubqueryType::None => { panic!("not reached") } - SubqueryType::In => JoinType::LeftSemi, - SubqueryType::Exists => JoinType::LeftMark, + SubqueryType::In(_) => JoinType::LeftSemi, + SubqueryType::Exists(_) => JoinType::LeftMark, // TODO: in duckdb, they have JoinType::Single // where there is only at most one join partner entry on the LEFT - SubqueryType::Scalar => JoinType::Left, + SubqueryType::Scalar(_) => JoinType::Left, } } fn prefix(&self) -> String { match self { SubqueryType::None => "", - SubqueryType::In => "__in_sq", - SubqueryType::Exists => "__exists_sq", - SubqueryType::Scalar => "__scalar_sq", + SubqueryType::In(_) => "__in_sq", + SubqueryType::Exists(_) => "__exists_sq", + SubqueryType::Scalar(_) => "__scalar_sq", } .to_string() } } +fn unwrap_subquery_input_from_expr(expr: &Expr) -> Arc { + match expr { + Expr::ScalarSubquery(sq) => sq.subquery.clone(), + Expr::Exists(exists) => exists.subquery.subquery.clone(), + Expr::InSubquery(in_sq) => in_sq.subquery.subquery.clone(), + _ => unreachable!(), + } +} fn contains_subquery(expr: &Expr) -> bool { expr.exists(|expr| { @@ -286,11 +297,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node( - self.current_id, - &col, - tbl_scan.table_name.table(), - ); + self.conclude_lowest_dependent_join_node(self.current_id, &col); }); } // TODO @@ -318,18 +325,31 @@ impl TreeNodeRewriter for DependentJoinRewriter { LogicalPlan::Subquery(subquery) => { is_subquery_node = true; let parent = self.stack.last().unwrap(); - let parent_node = self.nodes.get(parent).unwrap(); + let parent_node = self.nodes.get_mut(parent).unwrap(); + parent_node.access_tracker.insert(self.current_id, vec![]); for expr in parent_node.plan.expressions() { expr.exists(|e| { let (found_sq, checking_type) = match e { Expr::ScalarSubquery(sq) => { - (sq == subquery, SubqueryType::Scalar) + if sq == subquery { + (true, SubqueryType::Scalar(e.clone())) + } else { + (false, SubqueryType::None) + } } - Expr::Exists(Exists { subquery: sq, .. }) => { - (sq == subquery, SubqueryType::Exists) + Expr::Exists(exist) => { + if &exist.subquery == subquery { + (true, SubqueryType::Exists(e.clone())) + } else { + (false, SubqueryType::None) + } } - Expr::InSubquery(InSubquery { subquery: sq, .. }) => { - (sq == subquery, SubqueryType::In) + Expr::InSubquery(in_sq) => { + if &in_sq.subquery == subquery { + (true, SubqueryType::In(e.clone())) + } else { + (false, SubqueryType::None) + } } _ => (false, SubqueryType::None), }; @@ -383,48 +403,59 @@ impl TreeNodeRewriter for DependentJoinRewriter { let cloned_input = (**node.inputs().first().unwrap()).clone(); let mut current_plan = LogicalPlanBuilder::new(cloned_input); - let mut subquery_alias_map = HashMap::new(); - let mut subquery_alias_by_node_id = HashMap::new(); - for (subquery_id, column_accesses) in node_info.access_tracker.iter() { + let mut subquery_alias_by_offset = HashMap::new(); + // let mut subquery_alias_by_node_id = HashMap::new(); + let mut subquery_expr_by_offset = HashMap::new(); + for (subquery_offset, (subquery_id, column_accesses)) in + node_info.access_tracker.iter().enumerate() + { let subquery_node = self.nodes.get(subquery_id).unwrap(); - let subquery_input = subquery_node.plan.inputs().first().unwrap(); + // let subquery_input = subquery_node.plan.inputs().first().unwrap(); let alias = self .alias_generator .next(&subquery_node.subquery_type.prefix()); - subquery_alias_by_node_id.insert(subquery_id, alias.clone()); - subquery_alias_map.insert(unwrap_subquery(subquery_node), alias); + subquery_alias_by_offset.insert(subquery_offset, alias); } match &node { LogicalPlan::Filter(filter) => { + // everytime we meet a subquery during traversal, we increment this by 1 + // we can use this offset to lookup the original subquery info + // in subquery_alias_by_offset + // the reason why we cannot create a hashmap keyed by Subquery object + // is that the subquery inside this filter expr may have been rewritten in + // the lower level + let mut offset = 0; + let offset_ref = &mut offset; let new_predicate = filter .predicate .clone() .transform(|e| { // replace any subquery expr with subquery_alias.output // column - match e { - Expr::InSubquery(isq) => { - let alias = - subquery_alias_map.get(&isq.subquery).unwrap(); - // TODO: this assume that after decorrelation - // the dependent join will provide an extra column with the structure - // of "subquery_alias.output" - Ok(Transformed::yes(col(format!("{}.output", alias)))) + let alias = match e { + Expr::InSubquery(_) | Expr::Exists(_) => { + subquery_alias_by_offset.get(offset_ref).unwrap() } - Expr::Exists(esq) => { - let alias = - subquery_alias_map.get(&esq.subquery).unwrap(); - Ok(Transformed::yes(col(format!("{}.output", alias)))) + Expr::ScalarSubquery(ref s) => { + println!("inserting new expr {}", s.subquery); + subquery_alias_by_offset.get(offset_ref).unwrap() } - Expr::ScalarSubquery(sq) => { - let alias = subquery_alias_map.get(&sq).unwrap(); - Ok(Transformed::yes(col(format!("{}.output", alias)))) - } - _ => Ok(Transformed::no(e)), - } + _ => return Ok(Transformed::no(e)), + }; + // we are aware that the original subquery can be rewritten + // update the latest expr to this map + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + // TODO: this assume that after decorrelation + // the dependent join will provide an extra column with the structure + // of "subquery_alias.output" + Ok(Transformed::yes(col(format!("{}.output", alias)))) })? .data; + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // has another projection to remove these redundant columns let post_join_projections: Vec = filter .input .schema() @@ -432,33 +463,28 @@ impl TreeNodeRewriter for DependentJoinRewriter { .iter() .map(|c| col(c.clone())) .collect(); - for (subquery_id, column_accesses) in node_info.access_tracker.iter() { - let alias = subquery_alias_by_node_id.get(subquery_id).unwrap(); - let subquery_node = self.nodes.get(subquery_id).unwrap(); - let subquery_input = - subquery_node.plan.inputs().first().unwrap().clone(); - let right = LogicalPlanBuilder::new(subquery_input.clone()) - .alias(alias.clone())? - .build()?; + for (subquery_offset, (_, column_accesses)) in + node_info.access_tracker.iter().enumerate() + { + let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); + let subquery_expr = + subquery_expr_by_offset.get(&subquery_offset).unwrap(); + + let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); + let correlated_columns = column_accesses .iter() .map(|ac| (ac.col.clone())) .unique() .collect(); - let left = current_plan.build()?; - // TODO: create a new dependent join logical plan - let dependent_join = DependentJoin { - left: Arc::new(left.clone()), - right: Arc::new(right), - schema: left.schema().clone(), + current_plan = current_plan.dependent_join( + subquery_input.deref().clone(), correlated_columns, - depth: current_subquery_depth, - subquery_expr: lit(true), - }; - current_plan = LogicalPlanBuilder::new(LogicalPlan::DependentJoin( - dependent_join, - )); + subquery_expr.clone(), + current_subquery_depth, + alias.clone(), + )?; } current_plan = current_plan .filter(new_predicate.clone())? @@ -525,215 +551,266 @@ mod tests { use super::DependentJoinRewriter; use arrow::datatypes::DataType as ArrowDataType; + + macro_rules! assert_dependent_join_rewrite { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); + let transformed = index.rewrite_subqueries_into_dependent_joins($plan)?; + assert!(transformed.transformed); + let display = transformed.data.display_indent_schema(); + assert_snapshot!( + display, + @ $expected, + ) + }}; + } #[test] fn simple_in_subquery_inside_from_expr() -> Result<()> { - unimplemented!() + Ok(()) } #[test] fn simple_in_subquery_inside_select_expr() -> Result<()> { - unimplemented!() + Ok(()) } #[test] - fn one_simple_and_one_complex_subqueries_at_the_same_level() -> Result<()> { - unimplemented!() + fn rewrite_dependent_join_two_nested_subqueries() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = + Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and(col("inner_table_lv2.b").eq(out_ref_col( + ArrowDataType::UInt32, + "inner_table_lv1.b", + ))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), + )? + .build()?; + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [outer_table.a, outer_table.c] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) } #[test] - fn two_simple_subqueries_at_the_same_level() -> Result<()> { - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let in_sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1.clone()) - // .filter(col("inner_table_lv1.c").eq(lit(2)))? - // .project(vec![col("inner_table_lv1.a")])? - // .build()?, - // ); - // let exist_sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter( - // col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), - // )? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter( - // col("outer_table.a") - // .gt(lit(1)) - // .and(exists(exist_sq_level1)) - // .and(in_subquery(col("outer_table.b"), in_sq_level1)), - // )? - // .build()?; - // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - // index.rewrite_subqueries_into_dependent_joins(input1)?; - // println!("{:?}", index); - // let new_plan = index.root_dependent_join_elimination()?; - // println!("{}", new_plan); - // let expected = "\ - // LeftSemi Join: Filter: outer_table.b = __in_sq_2.a\ - // \n Filter: __exists_sq_1.mark\ - // \n LeftMark Join: Filter: Boolean(true)\ - // \n Filter: outer_table.a > Int32(1)\ - // \n TableScan: outer_table\ - // \n SubqueryAlias: __exists_sq_1\ - // \n Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1)\ - // \n TableScan: inner_table_lv1\ - // \n SubqueryAlias: __in_sq_2\ - // \n Projection: inner_table_lv1.a\ - // \n Filter: inner_table_lv1.c = Int32(2)\ - // \n TableScan: inner_table_lv1"; - // assert_eq!(expected, format!("{new_plan}")); + fn rewrite_dependent_join_two_subqueries_at_the_same_level() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let in_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(2)))? + .project(vec![col("inner_table_lv1.a")])? + .build()?, + ); + let exist_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + )? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(exists(exist_sq_level1)) + .and(in_subquery(col("outer_table.b"), in_sq_level1)), + )? + .build()?; + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] + DependentJoin on [] with expr outer_table.b IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.a [a:UInt32] + Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } #[test] - fn in_subquery_with_count_depth_1() -> Result<()> { - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter( - // col("inner_table_lv1.a") - // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.a") - // .gt(col("inner_table_lv1.c")), - // ) - // .and(col("inner_table_lv1.b").eq(lit(1))) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.b") - // .eq(col("inner_table_lv1.b")), - // ), - // )? - // .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? - // .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter( - // col("outer_table.a") - // .gt(lit(1)) - // .and(in_subquery(col("outer_table.c"), sq_level1)), - // )? - // .build()?; - // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - // index.rewrite_subqueries_into_dependent_joins(input1)?; - // let new_plan = index.root_dependent_join_elimination()?; - // let expected = "\ - // Filter: outer_table.a > Int32(1)\ - // \n LeftSemi Join: Filter: outer_table.c = count_a\ - // \n TableScan: outer_table\ - // \n Projection: count(inner_table_lv1.a) AS count_a, inner_table_lv1.a, inner_table_lv1.c, inner_table_lv1.b\ - // \n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]]\ - // \n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b\ - // \n TableScan: inner_table_lv1"; - // assert_eq!(expected, format!("{new_plan}")); + fn rewrite_dependent_join_in_subquery_with_count_depth_1() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } #[test] - fn simple_exist_subquery_with_dependent_columns() -> Result<()> { - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter( - // col("inner_table_lv1.a") - // .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.a") - // .gt(col("inner_table_lv1.c")), - // ) - // .and(col("inner_table_lv1.b").eq(lit(1))) - // .and( - // out_ref_col(ArrowDataType::UInt32, "outer_table.b") - // .eq(col("inner_table_lv1.b")), - // ), - // )? - // .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") - // .alias("outer_b_alias")])? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? - // .build()?; - // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - // index.rewrite_subqueries_into_dependent_joins(input1)?; - // let new_plan = index.root_dependent_join_elimination()?; - // let expected = "\ - // Filter: __exists_sq_1.mark\ - // \n LeftMark Join: Filter: __exists_sq_1.a = outer_table.a AND outer_table.a > __exists_sq_1.c AND outer_table.b = __exists_sq_1.b\ - // \n Filter: outer_table.a > Int32(1)\ - // \n TableScan: outer_table\ - // \n SubqueryAlias: __exists_sq_1\ - // \n Filter: inner_table_lv1.b = Int32(1)\ - // \n TableScan: inner_table_lv1"; - // assert_eq!(expected, format!("{new_plan}")); + fn rewrite_dependent_join_exist_subquery_with_dependent_columns() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .alias("outer_b_alias")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a, outer_table.b] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } + #[test] - fn simple_exist_subquery_with_no_dependent_columns() -> Result<()> { - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter(col("inner_table_lv1.b").eq(lit(1)))? - // .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? - // .build()?; - // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - // index.rewrite_subqueries_into_dependent_joins(input1)?; - // let new_plan = index.root_dependent_join_elimination()?; - // let expected = "\ - // Filter: __exists_sq_1.mark\ - // \n LeftMark Join: Filter: Boolean(true)\ - // \n Filter: outer_table.a > Int32(1)\ - // \n TableScan: outer_table\ - // \n SubqueryAlias: __exists_sq_1\ - // \n Projection: inner_table_lv1.b, inner_table_lv1.a\ - // \n Filter: inner_table_lv1.b = Int32(1)\ - // \n TableScan: inner_table_lv1"; - // assert_eq!(expected, format!("{new_plan}")); + fn rewrite_exist_subquery_with_no_dependent_columns() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b"), col("inner_table_lv1.a")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? + .build()?; + + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b, inner_table_lv1.a [b:UInt32, a:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] +"); + Ok(()) } #[test] - fn simple_decorrelate_with_in_subquery_no_dependent_column() -> Result<()> { - // let outer_table = test_table_scan_with_name("outer_table")?; - // let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - // let sq_level1 = Arc::new( - // LogicalPlanBuilder::from(inner_table_lv1) - // .filter(col("inner_table_lv1.b").eq(lit(1)))? - // .project(vec![col("inner_table_lv1.b")])? - // .build()?, - // ); - - // let input1 = LogicalPlanBuilder::from(outer_table.clone()) - // .filter( - // col("outer_table.a") - // .gt(lit(1)) - // .and(in_subquery(col("outer_table.c"), sq_level1)), - // )? - // .build()?; - // let mut index = DependentJoinTracker::new(Arc::new(AliasGenerator::new())); - // index.rewrite_subqueries_into_dependent_joins(input1)?; - // let new_plan = index.root_dependent_join_elimination()?; - // let expected = "\ - // LeftSemi Join: Filter: outer_table.c = __in_sq_1.b\ - // \n Filter: outer_table.a > Int32(1)\ - // \n TableScan: outer_table\ - // \n SubqueryAlias: __in_sq_1\ - // \n Projection: inner_table_lv1.b\ - // \n Filter: inner_table_lv1.b = Int32(1)\ - // \n TableScan: inner_table_lv1"; - // assert_eq!(expected, format!("{new_plan}")); + fn rewrite_dependent_join_with_in_subquery_no_dependent_column() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.b").eq(lit(1)))? + .project(vec![col("inner_table_lv1.b")])? + .build()?, + ); + + let input1 = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + assert_dependent_join_rewrite!(input1,@r" +Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b [b:Uint32] + Filter: inner_table_lv1.a = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) } #[test] - fn simple_decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { + fn rewrite_dependent_join_with_in_subquery_has_dependent_column() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -763,21 +840,14 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); - let transformed = index.rewrite_subqueries_into_dependent_joins(input1)?; - assert!(transformed.transformed); - - let formatted_plan = transformed.data.display_indent_schema(); - assert_snapshot!(formatted_plan, - @r" + assert_dependent_join_rewrite!(input1,@r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32] - DependentJoin on outer_table.a, outer_table.b with expr Boolean(true) [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: __in_sq_1 [outer_b_alias:UInt32;N] - Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] - Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } From e3c77d65d8c3df1eb19c2629cd3ca480d8ce22e6 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 21:16:40 +0200 Subject: [PATCH 30/70] chore: some more note on further implementation --- datafusion/expr/src/logical_plan/builder.rs | 6 +++++- datafusion/optimizer/src/decorrelate_general.rs | 15 ++++++++------- .../optimizer/src/scalar_subquery_to_join.rs | 11 +---------- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c69ffb0f2c86..da27fa0644da 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -881,7 +881,11 @@ impl LogicalPlanBuilder { )))) } - /// + /// Build a dependent join provided a subquery plan + /// this function should only be used by the optimizor + /// a dependent join node will provides all columns belonging to the LHS + /// and one additional column as the result of evaluating the subquery on the RHS + /// under the name "subquery_name.output" pub fn dependent_join( self, right: LogicalPlan, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 536190243b61..579aad09e69b 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -122,10 +122,6 @@ impl DependentJoinRewriter { let mut cur_stack = self.stack.clone(); cur_stack.push(child_id); - if col.name() == "outer_table.a" || col.name == "a" { - println!("{:?}", access); - println!("{:?}", cur_stack); - } // this is a dependent join node let (dependent_join_node_id, subquery_node_id) = Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); @@ -326,6 +322,8 @@ impl TreeNodeRewriter for DependentJoinRewriter { is_subquery_node = true; let parent = self.stack.last().unwrap(); let parent_node = self.nodes.get_mut(parent).unwrap(); + // the inserting sequence matter here + // when a parent has multiple children subquery at the same time parent_node.access_tracker.insert(self.current_id, vec![]); for expr in parent_node.plan.expressions() { expr.exists(|e| { @@ -438,7 +436,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { subquery_alias_by_offset.get(offset_ref).unwrap() } Expr::ScalarSubquery(ref s) => { - println!("inserting new expr {}", s.subquery); subquery_alias_by_offset.get(offset_ref).unwrap() } _ => return Ok(Transformed::no(e)), @@ -568,11 +565,15 @@ mod tests { }}; } #[test] - fn simple_in_subquery_inside_from_expr() -> Result<()> { + fn rewrite_dependent_join_with_lateral_join() -> Result<()> { + Ok(()) + } + #[test] + fn rewrite_dependent_join_in_from_expr() -> Result<()> { Ok(()) } #[test] - fn simple_in_subquery_inside_select_expr() -> Result<()> { + fn rewrite_dependent_join_inside_select_expr() -> Result<()> { Ok(()) } #[test] diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 64a438997f5d..b3de703e8991 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -28,8 +28,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use crate::analyzer::type_coercion::TypeCoercionRewriter; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, - TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -87,8 +86,6 @@ impl OptimizerRule for ScalarSubqueryToJoin { return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - // reWriteExpr is all the filter in the subquery that is irrelevant to the subquery execution - // i.e where outer=some col, or outer + binary operator with some aggregated value let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), @@ -290,12 +287,8 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// /// # Arguments /// -/// * `subquery` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) /// * `subquery_alias` - Subquery aliases -/// # Returns -/// * an optimize subquery if any -/// * a map of original count expr to a transformed expr (a hacky way to handle count bug) fn build_join( subquery: &Subquery, filter_input: &LogicalPlan, @@ -326,8 +319,6 @@ fn build_join( conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; - // TODO: build domain from filter input - // select distinct columns from filter input // join our sub query into the main plan let new_plan = if join_filter_opt.is_none() { From 1ae09262e099a7ab1a0c4519e410e94c735ae25e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 24 May 2025 22:05:27 +0200 Subject: [PATCH 31/70] chore: lint --- datafusion/expr/src/logical_plan/builder.rs | 10 +---- datafusion/expr/src/logical_plan/plan.rs | 3 +- .../optimizer/src/decorrelate_general.rs | 37 ++++++++----------- 3 files changed, 18 insertions(+), 32 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da27fa0644da..104242c72237 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -895,17 +895,12 @@ impl LogicalPlanBuilder { subquery_name: String, ) -> Result { let left = self.build()?; - let mut schema = left.schema(); + let schema = left.schema(); let qualified_fields = schema .iter() .map(|(q, f)| (q.cloned(), Arc::clone(f))) - .chain(once(subquery_output_field( - &subquery_name, - right.schema(), - &subquery_expr, - ))) + .chain(once(subquery_output_field(&subquery_name, &subquery_expr))) .collect(); - let func_dependencies = schema.functional_dependencies(); let metadata = schema.metadata().clone(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; @@ -1586,7 +1581,6 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { fn subquery_output_field( subquery_alias: &String, - right_schema: &DFSchema, subquery_expr: &Expr, ) -> (Option, Arc) { // TODO: check nullability diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 76d58730a8fb..5c5174f6701e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -897,7 +897,6 @@ impl LogicalPlan { Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter) } - LogicalPlan::DependentJoin(DependentJoin { left, right, .. }) => todo!(), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -1202,6 +1201,7 @@ impl LogicalPlan { unnest_with_options(input, columns.clone(), options.clone())?; Ok(new_plan) } + LogicalPlan::DependentJoin(_) => todo!(), } } @@ -1949,7 +1949,6 @@ impl LogicalPlan { } LogicalPlan::DependentJoin(DependentJoin{ - left,right, subquery_expr, correlated_columns, subquery_depth, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 579aad09e69b..adeb451362fc 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -206,40 +206,30 @@ struct Node { #[derive(Debug, Clone)] enum SubqueryType { None, - In(Expr), - Exists(Expr), - Scalar(Expr), + In, + Exists, + Scalar, } impl SubqueryType { - fn unwrap_expr(&self) -> Expr { - match self { - SubqueryType::None => { - panic!("not reached") - } - SubqueryType::In(e) | SubqueryType::Exists(e) | SubqueryType::Scalar(e) => { - e.clone() - } - } - } fn default_join_type(&self) -> JoinType { match self { SubqueryType::None => { panic!("not reached") } - SubqueryType::In(_) => JoinType::LeftSemi, - SubqueryType::Exists(_) => JoinType::LeftMark, + SubqueryType::In => JoinType::LeftSemi, + SubqueryType::Exists => JoinType::LeftMark, // TODO: in duckdb, they have JoinType::Single // where there is only at most one join partner entry on the LEFT - SubqueryType::Scalar(_) => JoinType::Left, + SubqueryType::Scalar => JoinType::Left, } } fn prefix(&self) -> String { match self { SubqueryType::None => "", - SubqueryType::In(_) => "__in_sq", - SubqueryType::Exists(_) => "__exists_sq", - SubqueryType::Scalar(_) => "__scalar_sq", + SubqueryType::In => "__in_sq", + SubqueryType::Exists => "__exists_sq", + SubqueryType::Scalar => "__scalar_sq", } .to_string() } @@ -330,21 +320,21 @@ impl TreeNodeRewriter for DependentJoinRewriter { let (found_sq, checking_type) = match e { Expr::ScalarSubquery(sq) => { if sq == subquery { - (true, SubqueryType::Scalar(e.clone())) + (true, SubqueryType::Scalar) } else { (false, SubqueryType::None) } } Expr::Exists(exist) => { if &exist.subquery == subquery { - (true, SubqueryType::Exists(e.clone())) + (true, SubqueryType::Exists) } else { (false, SubqueryType::None) } } Expr::InSubquery(in_sq) => { if &in_sq.subquery == subquery { - (true, SubqueryType::In(e.clone())) + (true, SubqueryType::In) } else { (false, SubqueryType::None) } @@ -416,6 +406,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { } match &node { + LogicalPlan::Projection(_) => { + // TODO: implement me + } LogicalPlan::Filter(filter) => { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info From d15c2aa9905cab10b31dc80adf7b610f2d589d4c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 00:24:39 +0200 Subject: [PATCH 32/70] chore: clippy --- .../optimizer/src/decorrelate_general.rs | 87 +++++-------------- 1 file changed, 21 insertions(+), 66 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index adeb451362fc..13e2e12daf1a 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -15,39 +15,23 @@ // specific language governing permissions and limitations // under the License. -//! [`GeneralPullUpCorrelatedExpr`] converts correlated subqueries to `Joins` +//! [`DependentJoinRewriter`] converts correlated subqueries to `DependentJoin` -use std::any::Any; -use std::cmp::Ordering; -use std::collections::HashSet; -use std::fmt; use std::ops::Deref; use std::sync::Arc; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::decorrelate::UN_MATCHED_ROW_INDICATOR; -use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{internal_err, Column, HashMap, Result}; -use datafusion_expr::expr::{self, Exists, InSubquery}; -use datafusion_expr::expr_rewriter::{normalize_col, strip_outer_reference}; -use datafusion_expr::select_expr::SelectExpr; -use datafusion_expr::utils::{conjunction, disjunction, split_conjunction}; -use datafusion_expr::{ - binary_expr, col, expr_fn, lit, BinaryExpr, Cast, DependentJoin, Expr, Filter, - JoinType, LogicalPlan, LogicalPlanBuilder, Operator as ExprOperator, Subquery, -}; -use datafusion_expr::{in_list, out_ref_col}; +use datafusion_expr::{col, Expr, LogicalPlan, LogicalPlanBuilder}; -use indexmap::map::Entry; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexMap; use itertools::Itertools; -use log::Log; pub struct DependentJoinRewriter { // each logical plan traversal will assign it a integer id @@ -178,11 +162,6 @@ impl DependentJoinRewriter { } } -impl ColumnAccess { - fn debug(&self) -> String { - format!("\x1b[31m{} ({})\x1b[0m", self.node_id, self.col) - } -} #[derive(Debug, Clone)] struct Node { id: usize, @@ -196,7 +175,6 @@ struct Node { access_tracker: IndexMap>, is_dependent_join_node: bool, - is_subquery_node: bool, // note that for dependent join nodes, there can be more than 1 // subquery children at a time, but always 1 outer-column-providing-child @@ -212,18 +190,6 @@ enum SubqueryType { } impl SubqueryType { - fn default_join_type(&self) -> JoinType { - match self { - SubqueryType::None => { - panic!("not reached") - } - SubqueryType::In => JoinType::LeftSemi, - SubqueryType::Exists => JoinType::LeftMark, - // TODO: in duckdb, they have JoinType::Single - // where there is only at most one join partner entry on the LEFT - SubqueryType::Scalar => JoinType::Left, - } - } fn prefix(&self) -> String { match self { SubqueryType::None => "", @@ -236,9 +202,9 @@ impl SubqueryType { } fn unwrap_subquery_input_from_expr(expr: &Expr) -> Arc { match expr { - Expr::ScalarSubquery(sq) => sq.subquery.clone(), - Expr::Exists(exists) => exists.subquery.subquery.clone(), - Expr::InSubquery(in_sq) => in_sq.subquery.subquery.clone(), + Expr::ScalarSubquery(sq) => Arc::clone(&sq.subquery), + Expr::Exists(exists) => Arc::clone(&exists.subquery.subquery), + Expr::InSubquery(in_sq) => Arc::clone(&in_sq.subquery.subquery), _ => unreachable!(), } } @@ -257,7 +223,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; fn f_down(&mut self, node: LogicalPlan) -> Result> { self.current_id += 1; - let mut is_subquery_node = false; let mut is_dependent_join_node = false; let mut subquery_type = SubqueryType::None; // for each node, find which column it is accessing, which column it is providing @@ -283,7 +248,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node(self.current_id, &col); + self.conclude_lowest_dependent_join_node(self.current_id, col); }); } // TODO @@ -309,7 +274,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } LogicalPlan::Subquery(subquery) => { - is_subquery_node = true; let parent = self.stack.last().unwrap(); let parent_node = self.nodes.get_mut(parent).unwrap(); // the inserting sequence matter here @@ -364,7 +328,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { Node { id: self.current_id, plan: node.clone(), - is_subquery_node, is_dependent_join_node, access_tracker: IndexMap::new(), subquery_type, @@ -394,7 +357,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { let mut subquery_alias_by_offset = HashMap::new(); // let mut subquery_alias_by_node_id = HashMap::new(); let mut subquery_expr_by_offset = HashMap::new(); - for (subquery_offset, (subquery_id, column_accesses)) in + for (subquery_offset, (subquery_id, _)) in node_info.access_tracker.iter().enumerate() { let subquery_node = self.nodes.get(subquery_id).unwrap(); @@ -425,10 +388,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { // replace any subquery expr with subquery_alias.output // column let alias = match e { - Expr::InSubquery(_) | Expr::Exists(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - Expr::ScalarSubquery(ref s) => { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => { subquery_alias_by_offset.get(offset_ref).unwrap() } _ => return Ok(Transformed::no(e)), @@ -440,7 +402,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { // TODO: this assume that after decorrelation // the dependent join will provide an extra column with the structure // of "subquery_alias.output" - Ok(Transformed::yes(col(format!("{}.output", alias)))) + Ok(Transformed::yes(col(format!("{alias}.output")))) })? .data; // because dependent join may introduce extra columns @@ -500,7 +462,7 @@ impl OptimizerRule for Decorrelation { config: &dyn OptimizerConfig, ) -> Result> { let mut transformer = - DependentJoinRewriter::new(config.alias_generator().clone()); + DependentJoinRewriter::new(Arc::clone(config.alias_generator())); let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; if rewrite_result.transformed { // At this point, we have a logical plan with DependentJoin similar to duckdb @@ -521,23 +483,16 @@ impl OptimizerRule for Decorrelation { #[cfg(test)] mod tests { - use std::sync::Arc; - - use datafusion_common::{alias::AliasGenerator, DFSchema, Result, ScalarValue}; + use datafusion_common::{alias::AliasGenerator, Result}; use datafusion_expr::{ - exists, - expr_fn::{self, col, not}, - in_subquery, lit, out_ref_col, scalar_subquery, table_scan, CreateMemoryTable, - EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder, + exists, expr_fn::col, in_subquery, lit, out_ref_col, scalar_subquery, Expr, + LogicalPlanBuilder, }; - use datafusion_functions_aggregate::{count::count, sum::sum}; + use datafusion_functions_aggregate::count::count; use insta::assert_snapshot; - use regex_syntax::ast::LiteralKind; + use std::sync::Arc; - use crate::{ - assert_optimized_plan_eq_display_indent_snapshot, - test::{test_table_scan, test_table_scan_with_name}, - }; + use crate::test::test_table_scan_with_name; use super::DependentJoinRewriter; use arrow::datatypes::DataType as ArrowDataType; From e5baf2cac3cee8a90266f5b175050dbb03f2d61a Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 07:25:06 +0200 Subject: [PATCH 33/70] fix: test --- datafusion/core/src/physical_planner.rs | 5 ++++ .../optimizer/src/decorrelate_general.rs | 26 +++++++++---------- .../optimizer/src/optimize_projections/mod.rs | 5 ++-- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fbb4250fc4df..ddb9db235335 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1246,6 +1246,11 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Analyze must be root of the plan" ) } + LogicalPlan::DependentJoin(_) => { + return internal_err!( + "Optimizors have not completely remove dependent join" + ) + } }; Ok(exec_node) } diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 13e2e12daf1a..87daead85662 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -20,7 +20,7 @@ use std::ops::Deref; use std::sync::Arc; -use crate::{OptimizerConfig, OptimizerRule}; +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::alias::AliasGenerator; @@ -164,7 +164,6 @@ impl DependentJoinRewriter { #[derive(Debug, Clone)] struct Node { - id: usize, plan: LogicalPlan, // This field is only meaningful if the node is dependent join node @@ -221,6 +220,7 @@ fn contains_subquery(expr: &Expr) -> bool { impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; + // fn f_down(&mut self, node: LogicalPlan) -> Result> { self.current_id += 1; let mut is_dependent_join_node = false; @@ -326,7 +326,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { self.nodes.insert( self.current_id, Node { - id: self.current_id, plan: node.clone(), is_dependent_join_node, access_tracker: IndexMap::new(), @@ -449,6 +448,8 @@ impl TreeNodeRewriter for DependentJoinRewriter { Ok(Transformed::yes(current_plan.build()?)) } } + +#[allow(dead_code)] #[derive(Debug)] struct Decorrelation {} @@ -475,14 +476,16 @@ impl OptimizerRule for Decorrelation { "decorrelate_subquery" } - // The rewriter handle recursion - // fn apply_order(&self) -> Option { - // None - // } + fn apply_order(&self) -> Option { + None + } } #[cfg(test)] mod tests { + use super::DependentJoinRewriter; + use crate::test::test_table_scan_with_name; + use arrow::datatypes::DataType as ArrowDataType; use datafusion_common::{alias::AliasGenerator, Result}; use datafusion_expr::{ exists, expr_fn::col, in_subquery, lit, out_ref_col, scalar_subquery, Expr, @@ -492,11 +495,6 @@ mod tests { use insta::assert_snapshot; use std::sync::Arc; - use crate::test::test_table_scan_with_name; - - use super::DependentJoinRewriter; - use arrow::datatypes::DataType as ArrowDataType; - macro_rules! assert_dependent_join_rewrite { ( $plan:expr, @@ -751,8 +749,8 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: inner_table_lv1.b [b:Uint32] - Filter: inner_table_lv1.a = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b [b:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 086c19c7dcc2..9b41893dffaa 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -347,8 +347,7 @@ fn optimize_projections( LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Values(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::DependentJoin(_) => { + | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. return Ok(Transformed::no(plan)); } @@ -383,7 +382,7 @@ fn optimize_projections( dependency_indices.clone(), )] } - LogicalPlan::DependentJoin(..) => todo!(), + LogicalPlan::DependentJoin(..) => unreachable!(), }; // Required indices are currently ordered (child0, child1, ...) From 11dbb803cef8ed7d6b58533569d8422b1843ff19 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 08:39:26 +0200 Subject: [PATCH 34/70] doc: draw diagram --- .../optimizer/src/decorrelate_general.rs | 136 ++++++++++++------ 1 file changed, 94 insertions(+), 42 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 87daead85662..e296928e6eb5 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -37,10 +37,10 @@ pub struct DependentJoinRewriter { // each logical plan traversal will assign it a integer id current_id: usize, subquery_depth: usize, - // each newly visted operator is inserted inside this map for tracking + // each newly visted `LogicalPlan` is inserted inside this map for tracking nodes: IndexMap, // all the node ids from root to the current node - // this is used during traversal only + // this is mutated duri traversal stack: Vec, // track for each column, the nodes/logical plan that reference to its within the tree all_outer_ref_columns: IndexMap>, @@ -129,13 +129,11 @@ impl DependentJoinRewriter { ) { // iter from bottom to top, the goal is to mark the dependent node // the current child's access - let mut stack = self.stack.clone(); - stack.push(child_id); self.all_outer_ref_columns .entry(col.clone()) .or_default() .push(ColumnAccess { - stack, + stack: self.stack.clone(), node_id: child_id, col: col.clone(), data_type: data_type.clone(), @@ -208,6 +206,8 @@ fn unwrap_subquery_input_from_expr(expr: &Expr) -> Arc { } } +// if current expr contains any subquery expr +// this function must not be recursive fn contains_subquery(expr: &Expr) -> bool { expr.exists(|expr| { Ok(matches!( @@ -218,10 +218,52 @@ fn contains_subquery(expr: &Expr) -> bool { .expect("Inner is always Ok") } +/// The rewriting happens up-down, where the parent nodes are downward-visited +/// before its children (subqueries children are visited first). +/// This behavior allow the fact that, at any moment, if we observe a `LogicalPlan` +/// that provides the data for columns, we can assume that all subqueries that reference +/// its data were already visited, and we can conclude the information of the `DependentJoin` +/// needed for the decorrelation: +/// - The subquery expr +/// - The correlated columns on the LHS referenced from the RHS (and its recursing subqueries if any) +/// If in the original node there exists multiple subqueries at the same time +/// two nested `DependentJoin` plans are generated (with equal depth) +/// +/// For illustration, given this query +/// ```sql +/// SELECT ID FROM T1 WHERE EXISTS(SELECT * FROM T2 WHERE T2.ID=T1.ID) OR EXISTS(SELECT * FROM T2 WHERE T2.VALUE=T1.ID); +/// ``` +/// +/// The traversal happens in the following sequence +/// +/// ```text +/// ↓1 +/// ↑12 +/// ┌────────────┐ +/// │ FILTER │<--- DependentJoin rewrite +/// │ │ happens here +/// └────┬────┬──┘ +/// ↓2 ↓6 ↓10 +/// ↑5 ↑9 ↑11 <---Here we already have enough information +/// │ | | of which node is accessing which column +/// │ | | provided by "Table Scan t1" node +/// │ | | +/// ┌─────┘ │ └─────┐ +/// │ │ │ +/// ┌───▼───┐ ┌──▼───┐ ┌───▼───────┐ +/// │SUBQ1 │ │SUBQ2 │ │TABLE SCAN │ +/// └──┬────┘ └──┬───┘ │ t1 │ +/// ↓3 ↓7 └───────────┘ +/// ↑4 ↑8 +/// ┌──▼────┐ ┌──▼────┐ +/// │SCAN t2│ │SCAN t2│ +/// └───────┘ └───────┘ +/// ``` impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; - // + fn f_down(&mut self, node: LogicalPlan) -> Result> { + let new_id = self.current_id; self.current_id += 1; let mut is_dependent_join_node = false; let mut subquery_type = SubqueryType::None; @@ -236,25 +278,20 @@ impl TreeNodeRewriter for DependentJoinRewriter { f.predicate .apply(|expr| { if let Expr::OuterReferenceColumn(data_type, col) = expr { - self.mark_outer_column_access( - self.current_id, - data_type, - col, - ); + self.mark_outer_column_access(new_id, data_type, col); } Ok(TreeNodeRecursion::Continue) }) .expect("traversal is infallible"); } + // TODO: maybe there are more logical plan that provides columns + // aside from TableScan LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node(self.current_id, col); + self.conclude_lowest_dependent_join_node(new_id, col); }); } - // TODO - // 1.handle subquery inside projection - // 2.projection also provide some new columns - // 3.if within projection exists multiple subquery, how does this work + // TODO: this is untested LogicalPlan::Projection(proj) => { for expr in &proj.expr { if contains_subquery(expr) { @@ -263,11 +300,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } expr.apply(|expr| { if let Expr::OuterReferenceColumn(data_type, col) = expr { - self.mark_outer_column_access( - self.current_id, - data_type, - col, - ); + self.mark_outer_column_access(new_id, data_type, col); } Ok(TreeNodeRecursion::Continue) })?; @@ -278,7 +311,10 @@ impl TreeNodeRewriter for DependentJoinRewriter { let parent_node = self.nodes.get_mut(parent).unwrap(); // the inserting sequence matter here // when a parent has multiple children subquery at the same time - parent_node.access_tracker.insert(self.current_id, vec![]); + // we rely on the order in which subquery children are visited + // to later on find back the corresponding subquery (if some part of them + // were rewritten in the lower node) + parent_node.access_tracker.insert(new_id, vec![]); for expr in parent_node.plan.expressions() { expr.exists(|e| { let (found_sq, checking_type) = match e { @@ -322,9 +358,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { if is_dependent_join_node { self.subquery_depth += 1 } - self.stack.push(self.current_id); + self.stack.push(new_id); self.nodes.insert( - self.current_id, + new_id, Node { plan: node.clone(), is_dependent_join_node, @@ -335,6 +371,11 @@ impl TreeNodeRewriter for DependentJoinRewriter { Ok(Transformed::no(node)) } + + /// All rewrite happens inside upward traversal + /// and only happens if the node is a "dependent join node" + /// (i.e the node with at least one subquery expr) + /// When all dependency information are already collected fn f_up(&mut self, node: LogicalPlan) -> Result> { // if the node in the f_up meet any node in the stack, it means that node itself // is a dependent join node,transformation by @@ -354,13 +395,11 @@ impl TreeNodeRewriter for DependentJoinRewriter { let cloned_input = (**node.inputs().first().unwrap()).clone(); let mut current_plan = LogicalPlanBuilder::new(cloned_input); let mut subquery_alias_by_offset = HashMap::new(); - // let mut subquery_alias_by_node_id = HashMap::new(); let mut subquery_expr_by_offset = HashMap::new(); for (subquery_offset, (subquery_id, _)) in node_info.access_tracker.iter().enumerate() { let subquery_node = self.nodes.get(subquery_id).unwrap(); - // let subquery_input = subquery_node.plan.inputs().first().unwrap(); let alias = self .alias_generator .next(&subquery_node.subquery_type.prefix()); @@ -375,7 +414,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info // in subquery_alias_by_offset - // the reason why we cannot create a hashmap keyed by Subquery object + // the reason why we cannot create a hashmap keyed by Subquery object HashMap // is that the subquery inside this filter expr may have been rewritten in // the lower level let mut offset = 0; @@ -398,9 +437,16 @@ impl TreeNodeRewriter for DependentJoinRewriter { // update the latest expr to this map subquery_expr_by_offset.insert(*offset_ref, e); *offset_ref += 1; + // TODO: this assume that after decorrelation // the dependent join will provide an extra column with the structure // of "subquery_alias.output" + // On later step of decorrelation, it rely on this structure + // to again rename the expression after join + // for example if the real join type is LeftMark, the correct output + // column should be "mark" instead, else after the join + // one extra layer of projection is needed to alias "mark" into + // "alias.output" Ok(Transformed::yes(col(format!("{alias}.output")))) })? .data; @@ -457,6 +503,11 @@ impl OptimizerRule for Decorrelation { fn supports_rewrite(&self) -> bool { true } + + // There will be 2 rewrites going on + // - Convert all subqueries (maybe including lateral join in the future) to temporary + // LogicalPlan node called DependentJoin + // - Decorrelate DependentJoin following top-down approach recursively fn rewrite( &self, plan: LogicalPlan, @@ -553,14 +604,14 @@ mod tests { .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter( col("outer_table.a") .gt(lit(1)) .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), )? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] DependentJoin on [outer_table.a, outer_table.c] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] @@ -594,7 +645,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter( col("outer_table.a") .gt(lit(1)) @@ -602,7 +653,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .and(in_subquery(col("outer_table.b"), in_sq_level1)), )? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] DependentJoin on [] with expr outer_table.b IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] @@ -641,14 +692,14 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter( col("outer_table.a") .gt(lit(1)) .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] @@ -684,10 +735,10 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [outer_table.a, outer_table.b] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] @@ -700,7 +751,8 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U } #[test] - fn rewrite_exist_subquery_with_no_dependent_columns() -> Result<()> { + fn rewrite_dependent_join_with_exist_subquery_with_no_dependent_columns() -> Result<()> + { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -710,11 +762,11 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] @@ -737,14 +789,14 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter( col("outer_table.a") .gt(lit(1)) .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] @@ -780,14 +832,14 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .build()?, ); - let input1 = LogicalPlanBuilder::from(outer_table.clone()) + let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter( col("outer_table.a") .gt(lit(1)) .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; - assert_dependent_join_rewrite!(input1,@r" + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] From 58562134af7b07b79ee4beed3adb42ac8207a54d Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 08:56:43 +0200 Subject: [PATCH 35/70] fix: proto --- datafusion/proto/src/logical_plan/mod.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index d934b24dc341..e488687e7acb 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -71,8 +71,8 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, ColumnUnnestList, DependentJoin, DmlStatement, FetchType, + RecursiveQuery, SkipType, TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -1804,6 +1804,17 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::DependentJoin(DependentJoin { + schema, + left, + right, + subquery_depth, + correlated_columns, + subquery_expr, + subquery_name, + }) => Err(proto_error( + "LogicalPlan serde is not implemented for DependentJoin", + )), } } } From a3f11a8b2a4fec62ac213848ea5dc51338b29160 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 08:59:55 +0200 Subject: [PATCH 36/70] chore: revert unrelated change --- datafusion/optimizer/src/scalar_subquery_to_join.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index b3de703e8991..ece6f00cacc3 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -287,7 +287,9 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// /// # Arguments /// +/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders /// * `filter_input` - The non-subquery portion (from customers) +/// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `subquery_alias` - Subquery aliases fn build_join( subquery: &Subquery, @@ -297,7 +299,6 @@ fn build_join( let subquery_plan = subquery.subquery.as_ref(); let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; - if !pull_up.can_pull_up { return Ok(None); } From e2d9d14bfe2003fbfb71a4fc923b1a1be59e047e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 09:10:36 +0200 Subject: [PATCH 37/70] chore: lint --- datafusion/proto/src/logical_plan/mod.rs | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e488687e7acb..ce3600b03ccd 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -71,8 +71,8 @@ use datafusion_expr::{ Statement, WindowUDF, }; use datafusion_expr::{ - AggregateUDF, ColumnUnnestList, DependentJoin, DmlStatement, FetchType, - RecursiveQuery, SkipType, TableSource, Unnest, + AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, + TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -1804,15 +1804,7 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::DependentJoin(DependentJoin { - schema, - left, - right, - subquery_depth, - correlated_columns, - subquery_expr, - subquery_name, - }) => Err(proto_error( + LogicalPlan::DependentJoin(_) => Err(proto_error( "LogicalPlan serde is not implemented for DependentJoin", )), } From b29842617ba64a635b709bd5731279eadfa29d31 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 09:28:12 +0200 Subject: [PATCH 38/70] fix: subtrait --- datafusion/substrait/src/logical_plan/producer/rel/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index c3599a2635ff..3efaab642a66 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,5 +74,8 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::DescribeTable(join) => { + not_impl_err!("Unsupported plan type: {plan:?}")? + } } } From cb1a757823fdb2c4fc918522446933c08376bb53 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 09:46:00 +0200 Subject: [PATCH 39/70] fix: subtrait again --- datafusion/substrait/src/logical_plan/producer/rel/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs index 3efaab642a66..2204e9913ea0 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/mod.rs @@ -74,8 +74,8 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } - LogicalPlan::DescribeTable(join) => { - not_impl_err!("Unsupported plan type: {plan:?}")? + LogicalPlan::DependentJoin(join) => { + not_impl_err!("Unsupported plan type: {join:?}")? } } } From baef0662f662125f858b881b2f05e57aebb9db3c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 10:38:36 +0200 Subject: [PATCH 40/70] fix: fail test --- datafusion/expr/src/logical_plan/builder.rs | 4 ++-- datafusion/expr/src/logical_plan/plan.rs | 2 +- datafusion/optimizer/src/decorrelate_general.rs | 15 +++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 104242c72237..d58583876ee0 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1580,7 +1580,7 @@ fn mark_field(schema: &DFSchema) -> (Option, Arc) { } fn subquery_output_field( - subquery_alias: &String, + subquery_alias: &str, subquery_expr: &Expr, ) -> (Option, Arc) { // TODO: check nullability @@ -1596,7 +1596,7 @@ fn subquery_output_field( } }; - (Some(TableReference::bare(subquery_alias.clone())), field) + (Some(TableReference::bare(subquery_alias)), field) } /// Creates a schema for a join operation. diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 5c5174f6701e..07ea2439f2fb 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1957,7 +1957,7 @@ impl LogicalPlan { let correlated_str = correlated_columns.iter().map(|c|{ format!("{c}") }).collect::>().join(", "); - write!(f,"DependentJoin on [{}] with expr {} depth {}",correlated_str,subquery_expr,subquery_depth) + write!(f,"DependentJoin on [{correlated_str}] with expr {subquery_expr} depth {subquery_depth}") }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index e296928e6eb5..e845da78fd51 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -77,24 +77,27 @@ impl DependentJoinRewriter { stack_with_table_provider: &[usize], stack_with_subquery: &[usize], ) -> (usize, usize) { - let mut lca = None; + let mut lowest_common_ancestor = 0; + let mut subquery_node_id = 0; let min_len = stack_with_table_provider .len() .min(stack_with_subquery.len()); for i in 0..min_len { - let ai = stack_with_subquery[i]; - let bi = stack_with_table_provider[i]; + let right_id = stack_with_subquery[i]; + let left_id = stack_with_table_provider[i]; - if ai == bi { - lca = Some((ai, stack_with_subquery[ai])); + if right_id == left_id { + // common parent + lowest_common_ancestor = right_id; + subquery_node_id = stack_with_subquery[i + 1]; } else { break; } } - lca.unwrap() + (lowest_common_ancestor, subquery_node_id) } // because the column providers are visited after column-accessor From a07b3b0ac8d129e9e7c55cbd886bd651efc0c15a Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 25 May 2025 10:57:03 +0200 Subject: [PATCH 41/70] chore: clippy --- .../optimizer/src/decorrelate_general.rs | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index e845da78fd51..d6ec2125139b 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -109,11 +109,13 @@ impl DependentJoinRewriter { let mut cur_stack = self.stack.clone(); cur_stack.push(child_id); - // this is a dependent join node let (dependent_join_node_id, subquery_node_id) = Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); let node = self.nodes.get_mut(&dependent_join_node_id).unwrap(); - let accesses = node.access_tracker.entry(subquery_node_id).or_default(); + let accesses = node + .columns_accesses_by_subquery_id + .entry(subquery_node_id) + .or_default(); accesses.push(ColumnAccess { col: col.clone(), node_id: access.node_id, @@ -152,14 +154,14 @@ impl DependentJoinRewriter { impl DependentJoinRewriter { fn new(alias_generator: Arc) -> Self { - return DependentJoinRewriter { + DependentJoinRewriter { alias_generator, current_id: 0, nodes: IndexMap::new(), stack: vec![], all_outer_ref_columns: IndexMap::new(), subquery_depth: 0, - }; + } } } @@ -167,12 +169,12 @@ impl DependentJoinRewriter { struct Node { plan: LogicalPlan, - // This field is only meaningful if the node is dependent join node - // it track which descendent nodes still accessing the outer columns provided by its + // This field is only meaningful if the node is dependent join node. + // It tracks which descendent nodes still accessing the outer columns provided by its // left child - // the key of this map is node_id of the children subquery - // and insertion order matters here, and thus we use IndexMap - access_tracker: IndexMap>, + // The key of this map is node_id of the children subqueries. + // The insertion order matters here, and thus we use IndexMap + columns_accesses_by_subquery_id: IndexMap>, is_dependent_join_node: bool, @@ -229,8 +231,9 @@ fn contains_subquery(expr: &Expr) -> bool { /// needed for the decorrelation: /// - The subquery expr /// - The correlated columns on the LHS referenced from the RHS (and its recursing subqueries if any) +/// /// If in the original node there exists multiple subqueries at the same time -/// two nested `DependentJoin` plans are generated (with equal depth) +/// two nested `DependentJoin` plans are generated (with equal depth). /// /// For illustration, given this query /// ```sql @@ -317,7 +320,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { // we rely on the order in which subquery children are visited // to later on find back the corresponding subquery (if some part of them // were rewritten in the lower node) - parent_node.access_tracker.insert(new_id, vec![]); + parent_node + .columns_accesses_by_subquery_id + .insert(new_id, vec![]); for expr in parent_node.plan.expressions() { expr.exists(|e| { let (found_sq, checking_type) = match e { @@ -367,7 +372,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { Node { plan: node.clone(), is_dependent_join_node, - access_tracker: IndexMap::new(), + columns_accesses_by_subquery_id: IndexMap::new(), subquery_type, }, ); @@ -400,7 +405,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { let mut subquery_alias_by_offset = HashMap::new(); let mut subquery_expr_by_offset = HashMap::new(); for (subquery_offset, (subquery_id, _)) in - node_info.access_tracker.iter().enumerate() + node_info.columns_accesses_by_subquery_id.iter().enumerate() { let subquery_node = self.nodes.get(subquery_id).unwrap(); let alias = self @@ -464,7 +469,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { .map(|c| col(c.clone())) .collect(); for (subquery_offset, (_, column_accesses)) in - node_info.access_tracker.iter().enumerate() + node_info.columns_accesses_by_subquery_id.iter().enumerate() { let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); let subquery_expr = From 32db3a922016d16e6a3a20567020bd91ded31c51 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 26 May 2025 07:15:32 +0200 Subject: [PATCH 42/70] chore: add depth and data_type to correlated columns --- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 22 +++++++++----- .../src/.decorrelate_general.rs.pending-snap | 7 +++++ .../optimizer/src/decorrelate_general.rs | 29 +++++++++++++++---- 4 files changed, 46 insertions(+), 14 deletions(-) create mode 100644 datafusion/optimizer/src/.decorrelate_general.rs.pending-snap diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index d58583876ee0..5adb2bfb0bb8 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -889,7 +889,7 @@ impl LogicalPlanBuilder { pub fn dependent_join( self, right: LogicalPlan, - correlated_columns: Vec, + correlated_columns: Vec<(usize, Expr)>, subquery_expr: Expr, subquery_depth: usize, subquery_name: String, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 07ea2439f2fb..229baa429b30 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -295,10 +295,14 @@ pub enum LogicalPlan { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DependentJoin { pub schema: DFSchemaRef, - // all the columns provided by the LHS being referenced - // in the RHS (and its children nested subqueries, if any) (note that not all outer_refs from the RHS are mentioned in this vectors - // because RHS may reference columns provided somewhere from the above join) - pub correlated_columns: Vec, + // All combinatoins of (subquery,OuterReferencedExpr) on the RHS (and its descendant) + // which points to a column on the LHS. + // The Expr should always be Expr::OuterRefColumn. + // Note that not all outer_refs from the RHS are mentioned in this vectors + // because RHS may reference columns provided somewhere from the above join. + // Depths of each correlated_columns should always be gte current dependent join + // subquery_depth + pub correlated_columns: Vec<(usize, Expr)>, // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` @@ -316,7 +320,7 @@ impl PartialOrd for DependentJoin { fn partial_cmp(&self, other: &Self) -> Option { #[derive(PartialEq, PartialOrd)] struct ComparableJoin<'a> { - correlated_columns: &'a Vec, + correlated_columns: &'a Vec<(usize, Expr)>, // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` @@ -1954,8 +1958,12 @@ impl LogicalPlan { subquery_depth, .. }) => { - let correlated_str = correlated_columns.iter().map(|c|{ - format!("{c}") + let correlated_str = correlated_columns.iter() + .map(|(level,c)|{ + if let Expr::OuterReferenceColumn(_, ref col) = c{ + return format!("{col} lvl {level}"); + } + "".to_string() }).collect::>().join(", "); write!(f,"DependentJoin on [{correlated_str}] with expr {subquery_expr} depth {subquery_depth}") }, diff --git a/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap b/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap new file mode 100644 index 000000000000..08b734d7cd1b --- /dev/null +++ b/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap @@ -0,0 +1,7 @@ +{"run_id":"1748236183-681925417","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 1, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lv 2, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lv2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} +{"run_id":"1748236309-277591598","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 1, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lv2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} +{"run_id":"1748236350-215042560","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} +{"run_id":"1748236388-546462772","line":639,"new":null,"old":null} +{"run_id":"1748236393-584012208","line":727,"new":null,"old":null} +{"run_id":"1748236398-329271850","line":766,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_exist_subquery_with_dependent_columns","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":766,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N]\n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n DependentJoin on [outer_table.a lv1, outer_table.b lv1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N]\n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]"}} +{"run_id":"1748236438-51540154","line":766,"new":null,"old":null} diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index d6ec2125139b..eb53c582b8a6 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -55,6 +55,7 @@ struct ColumnAccess { node_id: usize, col: Column, data_type: DataType, + subquery_depth: usize, } impl DependentJoinRewriter { @@ -121,6 +122,7 @@ impl DependentJoinRewriter { node_id: access.node_id, stack: access.stack.clone(), data_type: access.data_type.clone(), + subquery_depth: access.subquery_depth, }); } } @@ -142,6 +144,7 @@ impl DependentJoinRewriter { node_id: child_id, col: col.clone(), data_type: data_type.clone(), + subquery_depth: self.subquery_depth, }); } fn rewrite_subqueries_into_dependent_joins( @@ -418,6 +421,12 @@ impl TreeNodeRewriter for DependentJoinRewriter { LogicalPlan::Projection(_) => { // TODO: implement me } + LogicalPlan::SubqueryAlias(_) => { + unimplemented!( + "handle the case when the LHS has alias\ + and the RHS's subquery reference the alias column name" + ) + } LogicalPlan::Filter(filter) => { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info @@ -479,7 +488,15 @@ impl TreeNodeRewriter for DependentJoinRewriter { let correlated_columns = column_accesses .iter() - .map(|ac| (ac.col.clone())) + .map(|ac| { + ( + ac.subquery_depth, + Expr::OuterReferenceColumn( + ac.data_type.clone(), + ac.col.clone(), + ), + ) + }) .unique() .collect(); @@ -622,12 +639,12 @@ mod tests { assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] - DependentJoin on [outer_table.a, outer_table.c] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] - DependentJoin on [inner_table_lv1.b] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] @@ -710,7 +727,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] @@ -749,7 +766,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a, outer_table.b] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] @@ -850,7 +867,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a, outer_table.b] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] From 50d26f3446c889938c1c952815f6d2e7104c8438 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 26 May 2025 09:13:41 +0200 Subject: [PATCH 43/70] chore: rm snapshot --- .../optimizer/src/.decorrelate_general.rs.pending-snap | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 datafusion/optimizer/src/.decorrelate_general.rs.pending-snap diff --git a/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap b/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap deleted file mode 100644 index 08b734d7cd1b..000000000000 --- a/datafusion/optimizer/src/.decorrelate_general.rs.pending-snap +++ /dev/null @@ -1,7 +0,0 @@ -{"run_id":"1748236183-681925417","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 1, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lv 2, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lv2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} -{"run_id":"1748236309-277591598","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 1, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lv2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} -{"run_id":"1748236350-215042560","line":639,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_two_nested_subqueries","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":639,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [outer_table.a lvl 2, outer_table.c lv1 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64]\n Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]\n Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64]\n Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]"}} -{"run_id":"1748236388-546462772","line":639,"new":null,"old":null} -{"run_id":"1748236393-584012208","line":727,"new":null,"old":null} -{"run_id":"1748236398-329271850","line":766,"new":{"module_name":"datafusion_optimizer__decorrelate_general__tests","snapshot_name":"rewrite_dependent_join_exist_subquery_with_dependent_columns","metadata":{"source":"datafusion/optimizer/src/decorrelate_general.rs","assertion_line":766,"expression":"display"},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N]\n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]"},"old":{"module_name":"datafusion_optimizer__decorrelate_general__tests","metadata":{},"snapshot":"Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32]\n Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n DependentJoin on [outer_table.a lv1, outer_table.b lv1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean]\n TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32]\n Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N]\n Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32]\n TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]"}} -{"run_id":"1748236438-51540154","line":766,"new":null,"old":null} From 28dc7a4180b7341bd5a7cb4f2b7cb9d8ba762274 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 26 May 2025 21:33:38 +0200 Subject: [PATCH 44/70] feat: support alias and join --- .../optimizer/src/decorrelate_general.rs | 123 +++++++++++++++--- 1 file changed, 102 insertions(+), 21 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index eb53c582b8a6..5c280bea8988 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -279,6 +279,11 @@ impl TreeNodeRewriter for DependentJoinRewriter { // for each node, find which column it is accessing, which column it is providing // Set of columns current node access match &node { + LogicalPlan::SubqueryAlias(alias) => { + alias.schema.columns().iter().for_each(|col| { + self.conclude_lowest_dependent_join_node(new_id, col); + }); + } LogicalPlan::Filter(f) => { if contains_subquery(&f.predicate) { is_dependent_join_node = true; @@ -360,7 +365,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { })?; } } - LogicalPlan::Aggregate(_) => {} + LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => {} _ => { return internal_err!("impl f_down for node type {:?}", node); } @@ -421,12 +426,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { LogicalPlan::Projection(_) => { // TODO: implement me } - LogicalPlan::SubqueryAlias(_) => { - unimplemented!( - "handle the case when the LHS has alias\ - and the RHS's subquery reference the alias column name" - ) - } LogicalPlan::Filter(filter) => { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info @@ -564,8 +563,8 @@ mod tests { use arrow::datatypes::DataType as ArrowDataType; use datafusion_common::{alias::AliasGenerator, Result}; use datafusion_expr::{ - exists, expr_fn::col, in_subquery, lit, out_ref_col, scalar_subquery, Expr, - LogicalPlanBuilder, + binary_expr, exists, expr_fn::col, in_subquery, lit, out_ref_col, + scalar_subquery, Expr, JoinType, LogicalPlanBuilder, }; use datafusion_functions_aggregate::count::count; use insta::assert_snapshot; @@ -590,6 +589,50 @@ mod tests { fn rewrite_dependent_join_with_lateral_join() -> Result<()> { Ok(()) } + + #[test] + fn rewrite_dependent_join_with_lhs_as_a_join() -> Result<()> { + let outer_left_table = test_table_scan_with_name("outer_right_table")?; + let outer_right_table = test_table_scan_with_name("outer_left_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter(col("inner_table_lv1.a").eq(binary_expr( + out_ref_col(ArrowDataType::UInt32, "outer_left_table.a"), + datafusion_expr::Operator::Plus, + out_ref_col(ArrowDataType::UInt32, "outer_right_table.a"), + )))? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_left_table.clone()) + .join_on( + outer_right_table, + JoinType::Left, + vec![col("outer_left_table.a").eq(col("outer_right_table.a"))], + )? + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_right_table.a, outer_right_table.b, outer_right_table.c, outer_left_table.a, outer_left_table.b, outer_left_table.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, output:Boolean] + DependentJoin on [outer_right_table.a lvl 1, outer_left_table.a lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, output:Boolean] + Left Join: Filter: outer_left_table.a = outer_right_table.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] + TableScan: outer_right_table [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_left_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_left_table.a) + outer_ref(outer_right_table.a) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } #[test] fn rewrite_dependent_join_in_from_expr() -> Result<()> { Ok(()) @@ -598,6 +641,7 @@ mod tests { fn rewrite_dependent_join_inside_select_expr() -> Result<()> { Ok(()) } + #[test] fn rewrite_dependent_join_two_nested_subqueries() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; @@ -637,18 +681,18 @@ mod tests { )? .build()?; assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] - DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] - Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] - Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] - DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] - Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } @@ -875,4 +919,41 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U "); Ok(()) } + + #[test] + fn rewrite_dependent_join_reference_outer_column_with_alias_name() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table_alias.a")), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .alias("outer_table_alias")? + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table_alias.a, outer_table_alias.b, outer_table_alias.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table_alias.a lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + SubqueryAlias: outer_table_alias [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_table_alias.a) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } } From cf830cbede88c3c5af85b597cd0fe41904e8c2cb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 26 May 2025 21:52:14 +0200 Subject: [PATCH 45/70] feat: add lateral join fields to dependent join --- datafusion/expr/src/logical_plan/builder.rs | 2 ++ datafusion/expr/src/logical_plan/plan.rs | 5 ++++ .../optimizer/src/decorrelate_general.rs | 23 ++++++++++++------- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 5adb2bfb0bb8..f78c5b3e12bb 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -893,6 +893,7 @@ impl LogicalPlanBuilder { subquery_expr: Expr, subquery_depth: usize, subquery_name: String, + lateral_join_condition: Option<(JoinType, Expr)>, ) -> Result { let left = self.build()?; let schema = left.schema(); @@ -912,6 +913,7 @@ impl LogicalPlanBuilder { subquery_expr, subquery_name, subquery_depth, + lateral_join_condition, }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 1f99f0244251..59f5a9d243a2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -314,6 +314,8 @@ pub struct DependentJoin { // belong to the parent dependent join node in case of recursion) pub right: Arc, pub subquery_name: String, + + pub lateral_join_condition: Option<(JoinType, Expr)>, } impl PartialOrd for DependentJoin { @@ -331,6 +333,7 @@ impl PartialOrd for DependentJoin { // dependent side accessing columns from left hand side (and maybe columns) // belong to the parent dependent join node in case of recursion) right: &'a Arc, + lateral_join_condition: &'a Option<(JoinType, Expr)>, } let comparable_self = ComparableJoin { left: &self.left, @@ -338,6 +341,7 @@ impl PartialOrd for DependentJoin { correlated_columns: &self.correlated_columns, subquery_expr: &self.subquery_expr, depth: &self.subquery_depth, + lateral_join_condition: &self.lateral_join_condition, }; let comparable_other = ComparableJoin { left: &other.left, @@ -345,6 +349,7 @@ impl PartialOrd for DependentJoin { correlated_columns: &other.correlated_columns, subquery_expr: &other.subquery_expr, depth: &other.subquery_depth, + lateral_join_condition: &other.lateral_join_condition, }; comparable_self.partial_cmp(&comparable_other) } diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 5c280bea8988..1b102335ed64 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -104,7 +104,11 @@ impl DependentJoinRewriter { // because the column providers are visited after column-accessor // (function visit_with_subqueries always visit the subquery before visiting the other children) // we can always infer the LCA inside this function, by getting the deepest common parent - fn conclude_lowest_dependent_join_node(&mut self, child_id: usize, col: &Column) { + fn conclude_lowest_dependent_join_node_if_any( + &mut self, + child_id: usize, + col: &Column, + ) { if let Some(accesses) = self.all_outer_ref_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); @@ -279,11 +283,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { // for each node, find which column it is accessing, which column it is providing // Set of columns current node access match &node { - LogicalPlan::SubqueryAlias(alias) => { - alias.schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node(new_id, col); - }); - } LogicalPlan::Filter(f) => { if contains_subquery(&f.predicate) { is_dependent_join_node = true; @@ -302,7 +301,14 @@ impl TreeNodeRewriter for DependentJoinRewriter { // aside from TableScan LogicalPlan::TableScan(tbl_scan) => { tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node(new_id, col); + self.conclude_lowest_dependent_join_node_if_any(new_id, col); + }); + } + // Similar to TableScan, this node may provide column names which + // is referenced inside some subqueries + LogicalPlan::SubqueryAlias(alias) => { + alias.schema.columns().iter().for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col); }); } // TODO: this is untested @@ -505,6 +511,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { subquery_expr.clone(), current_subquery_depth, alias.clone(), + None, // TODO: handle this when we support lateral join rewrite )?; } current_plan = current_plan @@ -586,7 +593,7 @@ mod tests { }}; } #[test] - fn rewrite_dependent_join_with_lateral_join() -> Result<()> { + fn rewrite_dependent_join_with_nested_lateral_join() -> Result<()> { Ok(()) } From 95994da667f42eece62cd6bf5f088719f4eaf8ac Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 27 May 2025 18:43:19 +0200 Subject: [PATCH 46/70] feat: rewrite lateral join --- datafusion/expr/src/logical_plan/builder.rs | 9 +- datafusion/expr/src/logical_plan/plan.rs | 20 +- .../optimizer/src/decorrelate_general.rs | 245 +++++++++++++++--- 3 files changed, 232 insertions(+), 42 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index f78c5b3e12bb..1a179613d072 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -890,17 +890,22 @@ impl LogicalPlanBuilder { self, right: LogicalPlan, correlated_columns: Vec<(usize, Expr)>, - subquery_expr: Expr, + subquery_expr: Option, subquery_depth: usize, subquery_name: String, lateral_join_condition: Option<(JoinType, Expr)>, ) -> Result { let left = self.build()?; let schema = left.schema(); + // TODO: for lateral join, output schema is similar to a normal join let qualified_fields = schema .iter() .map(|(q, f)| (q.cloned(), Arc::clone(f))) - .chain(once(subquery_output_field(&subquery_name, &subquery_expr))) + .chain( + subquery_expr + .iter() + .map(|expr| subquery_output_field(&subquery_name, expr)), + ) .collect(); let metadata = schema.metadata().clone(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 59f5a9d243a2..97f06f24b824 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -306,7 +306,7 @@ pub struct DependentJoin { // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` - pub subquery_expr: Expr, + pub subquery_expr: Option, // begins with depth = 1 pub subquery_depth: usize, pub left: Arc, @@ -326,7 +326,7 @@ impl PartialOrd for DependentJoin { // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` - subquery_expr: &'a Expr, + subquery_expr: &'a Option, depth: &'a usize, left: &'a Arc, @@ -1961,6 +1961,7 @@ impl LogicalPlan { subquery_expr, correlated_columns, subquery_depth, + lateral_join_condition, .. }) => { let correlated_str = correlated_columns.iter() @@ -1970,7 +1971,20 @@ impl LogicalPlan { } "".to_string() }).collect::>().join(", "); - write!(f,"DependentJoin on [{correlated_str}] with expr {subquery_expr} depth {subquery_depth}") + let lateral_join_info = if let Some((join_type,join_expr))= + lateral_join_condition { + format!(" lateral {join_type} join with {join_expr}") + }else{ + "".to_string() + }; + let subquery_expr_str = if let Some(expr) = + subquery_expr{ + format!(" with expr {expr}") + }else{ + "".to_string() + }; + write!(f,"DependentJoin on [{correlated_str}]{subquery_expr_str}\ + {lateral_join_info} depth {subquery_depth}") }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 1b102335ed64..89abb3cf79b1 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -28,7 +28,7 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{internal_err, Column, HashMap, Result}; -use datafusion_expr::{col, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; use indexmap::IndexMap; use itertools::Itertools; @@ -196,6 +196,7 @@ enum SubqueryType { In, Exists, Scalar, + LateralJoin, } impl SubqueryType { @@ -205,6 +206,7 @@ impl SubqueryType { SubqueryType::In => "__in_sq", SubqueryType::Exists => "__exists_sq", SubqueryType::Scalar => "__scalar_sq", + SubqueryType::LateralJoin => "__lateral_sq", } .to_string() } @@ -337,41 +339,108 @@ impl TreeNodeRewriter for DependentJoinRewriter { parent_node .columns_accesses_by_subquery_id .insert(new_id, vec![]); - for expr in parent_node.plan.expressions() { - expr.exists(|e| { - let (found_sq, checking_type) = match e { - Expr::ScalarSubquery(sq) => { - if sq == subquery { - (true, SubqueryType::Scalar) - } else { - (false, SubqueryType::None) + + if let LogicalPlan::Join(_) = parent_node.plan { + subquery_type = SubqueryType::LateralJoin; + } else { + for expr in parent_node.plan.expressions() { + expr.exists(|e| { + let (found_sq, checking_type) = match e { + Expr::ScalarSubquery(sq) => { + if sq == subquery { + (true, SubqueryType::Scalar) + } else { + (false, SubqueryType::None) + } } - } - Expr::Exists(exist) => { - if &exist.subquery == subquery { - (true, SubqueryType::Exists) - } else { - (false, SubqueryType::None) + Expr::Exists(exist) => { + if &exist.subquery == subquery { + (true, SubqueryType::Exists) + } else { + (false, SubqueryType::None) + } } - } - Expr::InSubquery(in_sq) => { - if &in_sq.subquery == subquery { - (true, SubqueryType::In) - } else { - (false, SubqueryType::None) + Expr::InSubquery(in_sq) => { + if &in_sq.subquery == subquery { + (true, SubqueryType::In) + } else { + (false, SubqueryType::None) + } } + _ => (false, SubqueryType::None), + }; + if found_sq { + subquery_type = checking_type; } - _ => (false, SubqueryType::None), - }; - if found_sq { - subquery_type = checking_type; - } - Ok(found_sq) - })?; + Ok(found_sq) + })?; + } + } + } + LogicalPlan::Aggregate(_) => {} + LogicalPlan::Join(join) => { + let mut sq_count = if let LogicalPlan::Subquery(_) = &join.left.as_ref() { + 1 + } else { + 0 + }; + sq_count += if let LogicalPlan::Subquery(_) = join.right.as_ref() { + 1 + } else { + 0 + }; + match sq_count { + 0 => {} + 1 => { + is_dependent_join_node = true; + } + _ => { + return internal_err!( + "plan error: join logical plan has both children with type \ + Subquery" + ); + } + }; + + if is_dependent_join_node { + self.subquery_depth += 1; + self.stack.push(new_id); + self.nodes.insert( + new_id, + Node { + plan: node.clone(), + is_dependent_join_node, + columns_accesses_by_subquery_id: IndexMap::new(), + subquery_type, + }, + ); + + // we assume that RHS is always a subquery for the join + // and because this function assume that subquery side is visited first + // during f_down, we have to visit it at this step, else + // the function visit_with_subqueries will call f_down for the LHS instead + let transformed_subquery = self + .rewrite_subqueries_into_dependent_joins( + join.right.deref().clone(), + )? + .data; + let transformed_left = self + .rewrite_subqueries_into_dependent_joins( + join.left.deref().clone(), + )? + .data; + let mut new_join_node = join.clone(); + new_join_node.right = Arc::new(transformed_subquery); + new_join_node.left = Arc::new(transformed_left); + return Ok(Transformed::new( + LogicalPlan::Join(new_join_node), + true, + // since we rewrite the children directly in this function, + TreeNodeRecursion::Jump, + )); } } - LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => {} _ => { return internal_err!("impl f_down for node type {:?}", node); } @@ -409,10 +478,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { } let current_subquery_depth = self.subquery_depth; self.subquery_depth -= 1; - assert!( - 1 == node.inputs().len(), - "a dependent join node cannot have more than 1 child" - ); let cloned_input = (**node.inputs().first().unwrap()).clone(); let mut current_plan = LogicalPlanBuilder::new(cloned_input); @@ -432,6 +497,50 @@ impl TreeNodeRewriter for DependentJoinRewriter { LogicalPlan::Projection(_) => { // TODO: implement me } + LogicalPlan::Join(join) => { + assert!(node_info.columns_accesses_by_subquery_id.len() == 1); + let (_, column_accesses) = + node_info.columns_accesses_by_subquery_id.first().unwrap(); + let alias = subquery_alias_by_offset.get(&0).unwrap(); + let correlated_columns = column_accesses + .iter() + .map(|ac| { + ( + ac.subquery_depth, + Expr::OuterReferenceColumn( + ac.data_type.clone(), + ac.col.clone(), + ), + ) + }) + .unique() + .collect(); + + let subquery_plan = &join.right; + let sq = if let LogicalPlan::Subquery(sq) = subquery_plan.as_ref() { + sq + } else { + return internal_err!( + "lateral join must have right join as a subquery" + ); + }; + let right = sq.subquery.deref().clone(); + // At the time of implementation lateral join condition is not fully clear yet + // So a TODO for future tracking + let lateral_join_condition = if let Some(ref filter) = join.filter { + filter.clone() + } else { + lit(true) + }; + current_plan = current_plan.dependent_join( + right, + correlated_columns, + None, + current_subquery_depth, + alias.to_string(), + Some((join.join_type, lateral_join_condition)), + )?; + } LogicalPlan::Filter(filter) => { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info @@ -508,7 +617,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { current_plan = current_plan.dependent_join( subquery_input.deref().clone(), correlated_columns, - subquery_expr.clone(), + Some(subquery_expr.clone()), current_subquery_depth, alias.clone(), None, // TODO: handle this when we support lateral join rewrite @@ -519,7 +628,10 @@ impl TreeNodeRewriter for DependentJoinRewriter { .project(post_join_projections)?; } _ => { - unimplemented!("implement more dependent join node creation") + unimplemented!( + "implement more dependent join node creation for node {}", + node + ) } } Ok(Transformed::yes(current_plan.build()?)) @@ -568,10 +680,10 @@ mod tests { use super::DependentJoinRewriter; use crate::test::test_table_scan_with_name; use arrow::datatypes::DataType as ArrowDataType; - use datafusion_common::{alias::AliasGenerator, Result}; + use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr_fn::col, in_subquery, lit, out_ref_col, - scalar_subquery, Expr, JoinType, LogicalPlanBuilder, + scalar_subquery, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Subquery, }; use datafusion_functions_aggregate::count::count; use insta::assert_snapshot; @@ -594,6 +706,65 @@ mod tests { } #[test] fn rewrite_dependent_join_with_nested_lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = + Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and(col("inner_table_lv2.b").eq(out_ref_col( + ArrowDataType::UInt32, + "inner_table_lv1.b", + ))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: sq_level1, + outer_ref_columns: vec![out_ref_col( + ArrowDataType::UInt32, + "outer_table.c", + // note that subquery lvl2 is referencing outer_table.a, and it is not being listed here + // this simulate the limitation of current subquery planning and assert + // that the rewriter can fill in this gap + )], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } From 9745a4f7b7fa4ac66e4eb0a009d34b8589c8edbb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Wed, 28 May 2025 20:52:13 +0200 Subject: [PATCH 47/70] feat: rewrite projection --- datafusion/expr/src/logical_plan/display.rs | 2 +- .../optimizer/src/decorrelate_general.rs | 232 ++++++++++++++++-- datafusion/sqllogictest/test_files/debug.slt | 25 ++ 3 files changed, 237 insertions(+), 22 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/debug.slt diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 6bb60327fdf1..b24a25463276 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -485,7 +485,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { object } - LogicalPlan::DependentJoin(..) => todo!(), + LogicalPlan::DependentJoin(..) => json!({}), LogicalPlan::Join(Join { on: ref keys, filter, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 89abb3cf79b1..e74814772908 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -28,8 +28,9 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{internal_err, Column, HashMap, Result}; -use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder, Projection}; +use indexmap::map::Entry; use indexmap::IndexMap; use itertools::Itertools; @@ -59,6 +60,94 @@ struct ColumnAccess { } impl DependentJoinRewriter { + fn rewrite_projection( + &mut self, + original_proj: &Projection, + dependent_join_node: &Node, + current_subquery_depth: usize, + mut current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // everytime we meet a subquery during traversal, we increment this by 1 + // we can use this offset to lookup the original subquery info + // in subquery_alias_by_offset + // the reason why we cannot create a hashmap keyed by Subquery object HashMap + // is that the subquery inside this filter expr may have been rewritten in + // the lower level + let mut offset = 0; + let offset_ref = &mut offset; + let mut subquery_expr_by_offset = HashMap::new(); + // for each projected expr, we convert the SubqueryExpr into a ColExpr + // with structure "{subquery_alias}.output" + let new_projections = original_proj + .expr + .iter() + .cloned() + .map(|e| { + Ok(e.transform(|e| { + // replace any subquery expr with subquery_alias.output + // column + let alias = match e { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => { + subquery_alias_by_offset.get(offset_ref).unwrap() + } + _ => return Ok(Transformed::no(e)), + }; + // we are aware that the original subquery can be rewritten + // update the latest expr to this map + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + // TODO: this assume that after decorrelation + // the dependent join will provide an extra column with the structure + // of "subquery_alias.output" + // On later step of decorrelation, it rely on this structure + // to again rename the expression after join + // for example if the real join type is LeftMark, the correct output + // column should be "mark" instead, else after the join + // one extra layer of projection is needed to alias "mark" into + // "alias.output" + Ok(Transformed::yes(col(format!("{alias}.output")))) + })? + .data) + }) + .collect::>>()?; + + for (subquery_offset, (_, column_accesses)) in dependent_join_node + .columns_accesses_by_subquery_id + .iter() + .enumerate() + { + let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); + let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); + + let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); + + let correlated_columns = column_accesses + .iter() + .map(|ac| { + ( + ac.subquery_depth, + Expr::OuterReferenceColumn(ac.data_type.clone(), ac.col.clone()), + ) + }) + .unique() + .collect(); + + current_plan = current_plan.dependent_join( + subquery_input.deref().clone(), + correlated_columns, + Some(subquery_expr.clone()), + current_subquery_depth, + alias.clone(), + None, // TODO: handle this when we support lateral join rewrite + )?; + } + current_plan = current_plan.project(new_projections)?; + Ok(current_plan) + } // lowest common ancestor from stack // given a tree of // n1 @@ -256,20 +345,32 @@ fn contains_subquery(expr: &Expr) -> bool { /// ↑12 /// ┌────────────┐ /// │ FILTER │<--- DependentJoin rewrite -/// │ │ happens here -/// └────┬────┬──┘ -/// ↓2 ↓6 ↓10 -/// ↑5 ↑9 ↑11 <---Here we already have enough information -/// │ | | of which node is accessing which column -/// │ | | provided by "Table Scan t1" node -/// │ | | -/// ┌─────┘ │ └─────┐ -/// │ │ │ +/// │ (1) │ happens here (step 12) +/// └─────┬────┬─┘ Here we already have enough information +/// | | | of which node is accessing which column +/// | | | provided by "Table Scan t1" node +/// │ | | (for example node (6) below ) +/// │ | | +/// │ | | +/// │ | | +/// ↓2────┘ ↓6 └────↓10 +/// ↑5 ↑11 ↑11 /// ┌───▼───┐ ┌──▼───┐ ┌───▼───────┐ /// │SUBQ1 │ │SUBQ2 │ │TABLE SCAN │ /// └──┬────┘ └──┬───┘ │ t1 │ -/// ↓3 ↓7 └───────────┘ -/// ↑4 ↑8 +/// | | └───────────┘ +/// | | +/// | | +/// | ↓7 +/// | ↑10 +/// | ┌──▼───────┐ +/// | │Filter │----> mark_outer_column_access(outer_ref) +/// | │outer_ref | +/// | │ (6) | +/// | └──┬───────┘ +/// | | +/// ↓3 ↓8 +/// ↑4 ↑9 /// ┌──▼────┐ ┌──▼────┐ /// │SCAN t2│ │SCAN t2│ /// └───────┘ └───────┘ @@ -318,7 +419,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { for expr in &proj.expr { if contains_subquery(expr) { is_dependent_join_node = true; - break; } expr.apply(|expr| { if let Expr::OuterReferenceColumn(data_type, col) = expr { @@ -472,20 +572,29 @@ impl TreeNodeRewriter for DependentJoinRewriter { // is a dependent join node,transformation by // build a join based on let current_node_id = self.stack.pop().unwrap(); - let node_info = self.nodes.get(¤t_node_id).unwrap(); - if !node_info.is_dependent_join_node { - return Ok(Transformed::no(node)); - } + let node_info = if let Entry::Occupied(e) = self.nodes.entry(current_node_id) { + let node_info = e.get(); + if !node_info.is_dependent_join_node { + return Ok(Transformed::no(node)); + } + e.swap_remove() + } else { + unreachable!() + }; + let current_subquery_depth = self.subquery_depth; self.subquery_depth -= 1; let cloned_input = (**node.inputs().first().unwrap()).clone(); let mut current_plan = LogicalPlanBuilder::new(cloned_input); let mut subquery_alias_by_offset = HashMap::new(); - let mut subquery_expr_by_offset = HashMap::new(); for (subquery_offset, (subquery_id, _)) in node_info.columns_accesses_by_subquery_id.iter().enumerate() { + if self.nodes.get(subquery_id).is_none() { + println!("{node} {subquery_offset}"); + } + let subquery_node = self.nodes.get(subquery_id).unwrap(); let alias = self .alias_generator @@ -494,8 +603,14 @@ impl TreeNodeRewriter for DependentJoinRewriter { } match &node { - LogicalPlan::Projection(_) => { - // TODO: implement me + LogicalPlan::Projection(projection) => { + current_plan = self.rewrite_projection( + projection, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; } LogicalPlan::Join(join) => { assert!(node_info.columns_accesses_by_subquery_id.len() == 1); @@ -550,6 +665,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { // the lower level let mut offset = 0; let offset_ref = &mut offset; + let mut subquery_expr_by_offset = HashMap::new(); let new_predicate = filter .predicate .clone() @@ -816,7 +932,81 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_inside_select_expr() -> Result<()> { + fn rewrite_dependent_join_inside_project_exprs() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = + Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and(col("inner_table_lv2.b").eq(out_ref_col( + ArrowDataType::UInt32, + "inner_table_lv1.b", + ))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1_a = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + // scalar_sq_level2 is intentionally shared between both + // scalar_sq_level1_a and scalar_sq_level1_b + // to check if the framework can uniquely identify the correlated columns + .and(scalar_subquery(Arc::clone(&scalar_sq_level2)).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + let scalar_sq_level1_b = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.b"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .project(vec![ + col("outer_table.a"), + binary_expr( + scalar_subquery(scalar_sq_level1_a), + datafusion_expr::Operator::Plus, + scalar_subquery(scalar_sq_level1_b), + ), + ])? + .build()?; + assert_dependent_join_rewrite!(plan, @r" + Projection: outer_table.a, __scalar_sq_3.output + __scalar_sq_4.output [a:UInt32, __scalar_sq_3.output + __scalar_sq_4.output:Int64] + DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2, outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.b)]] [count(inner_table_lv1.b):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND __scalar_sq_2.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, output:Int64] + DependentJoin on [inner_table_lv1.b lvl 2] with expr () depth 2 [a:UInt32, b:UInt32, c:UInt32, output:Int64] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] [count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + "); Ok(()) } diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt new file mode 100644 index 000000000000..b190aec6152e --- /dev/null +++ b/datafusion/sqllogictest/test_files/debug.slt @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table t1 as values(1); + +statement ok +create table t2 as values(2); + +query TT +explain select * from t1 join lateral (select * from t2 where t1.column1+t2.column1=1) on t1.column1 Date: Wed, 28 May 2025 21:08:38 +0200 Subject: [PATCH 48/70] refactor: split rewrite logic --- datafusion/expr/src/logical_plan/plan.rs | 65 +++--- .../optimizer/src/decorrelate_general.rs | 195 +++++++++--------- 2 files changed, 140 insertions(+), 120 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 97f06f24b824..e6d500a68238 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -318,6 +318,41 @@ pub struct DependentJoin { pub lateral_join_condition: Option<(JoinType, Expr)>, } +impl DependentJoin { + fn indent_string(&self) -> String {} +} +impl Display for DependentJoin { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let correlated_str = self + .correlated_columns + .iter() + .map(|(level, c)| { + if let Expr::OuterReferenceColumn(_, ref col) = c { + return format!("{col} lvl {level}"); + } + "".to_string() + }) + .collect::>() + .join(", "); + let lateral_join_info = + if let Some((join_type, join_expr)) = self.lateral_join_condition { + format!(" lateral {join_type} join with {join_expr}") + } else { + "".to_string() + }; + let subquery_expr_str = if let Some(expr) = self.subquery_expr { + format!(" with expr {expr}") + } else { + "".to_string() + }; + write!( + f, + "DependentJoin on [{correlated_str}]{subquery_expr_str}\ + {lateral_join_info} depth {subquery_depth}" + ) + } +} + impl PartialOrd for DependentJoin { fn partial_cmp(&self, other: &Self) -> Option { #[derive(PartialEq, PartialOrd)] @@ -1957,34 +1992,8 @@ impl LogicalPlan { Ok(()) } - LogicalPlan::DependentJoin(DependentJoin{ - subquery_expr, - correlated_columns, - subquery_depth, - lateral_join_condition, - .. - }) => { - let correlated_str = correlated_columns.iter() - .map(|(level,c)|{ - if let Expr::OuterReferenceColumn(_, ref col) = c{ - return format!("{col} lvl {level}"); - } - "".to_string() - }).collect::>().join(", "); - let lateral_join_info = if let Some((join_type,join_expr))= - lateral_join_condition { - format!(" lateral {join_type} join with {join_expr}") - }else{ - "".to_string() - }; - let subquery_expr_str = if let Some(expr) = - subquery_expr{ - format!(" with expr {expr}") - }else{ - "".to_string() - }; - write!(f,"DependentJoin on [{correlated_str}]{subquery_expr_str}\ - {lateral_join_info} depth {subquery_depth}") + LogicalPlan::DependentJoin(dependent_join) => { + dependent_join.fmt(f) }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index e74814772908..e90270fa2e1a 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -28,7 +28,9 @@ use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{internal_err, Column, HashMap, Result}; -use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder, Projection}; +use datafusion_expr::{ + col, lit, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, +}; use indexmap::map::Entry; use indexmap::IndexMap; @@ -60,6 +62,97 @@ struct ColumnAccess { } impl DependentJoinRewriter { + fn rewrite_filter( + &mut self, + filter: &Filter, + dependent_join_node: &Node, + current_subquery_depth: usize, + mut current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + // everytime we meet a subquery during traversal, we increment this by 1 + // we can use this offset to lookup the original subquery info + // in subquery_alias_by_offset + // the reason why we cannot create a hashmap keyed by Subquery object HashMap + // is that the subquery inside this filter expr may have been rewritten in + // the lower level + let mut offset = 0; + let offset_ref = &mut offset; + let mut subquery_expr_by_offset = HashMap::new(); + let new_predicate = filter + .predicate + .clone() + .transform(|e| { + // replace any subquery expr with subquery_alias.output + // column + let alias = match e { + Expr::InSubquery(_) | Expr::Exists(_) | Expr::ScalarSubquery(_) => { + subquery_alias_by_offset.get(offset_ref).unwrap() + } + _ => return Ok(Transformed::no(e)), + }; + // we are aware that the original subquery can be rewritten + // update the latest expr to this map + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + // TODO: this assume that after decorrelation + // the dependent join will provide an extra column with the structure + // of "subquery_alias.output" + // On later step of decorrelation, it rely on this structure + // to again rename the expression after join + // for example if the real join type is LeftMark, the correct output + // column should be "mark" instead, else after the join + // one extra layer of projection is needed to alias "mark" into + // "alias.output" + Ok(Transformed::yes(col(format!("{alias}.output")))) + })? + .data; + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // has another projection to remove these redundant columns + let post_join_projections: Vec = filter + .input + .schema() + .columns() + .iter() + .map(|c| col(c.clone())) + .collect(); + for (subquery_offset, (_, column_accesses)) in dependent_join_node + .columns_accesses_by_subquery_id + .iter() + .enumerate() + { + let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); + let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); + + let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); + + let correlated_columns = column_accesses + .iter() + .map(|ac| { + ( + ac.subquery_depth, + Expr::OuterReferenceColumn(ac.data_type.clone(), ac.col.clone()), + ) + }) + .unique() + .collect(); + + current_plan = current_plan.dependent_join( + subquery_input.deref().clone(), + correlated_columns, + Some(subquery_expr.clone()), + current_subquery_depth, + alias.clone(), + None, // TODO: handle this when we support lateral join rewrite + )?; + } + current_plan + .filter(new_predicate.clone())? + .project(post_join_projections) + } + fn rewrite_projection( &mut self, original_proj: &Projection, @@ -591,10 +684,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { for (subquery_offset, (subquery_id, _)) in node_info.columns_accesses_by_subquery_id.iter().enumerate() { - if self.nodes.get(subquery_id).is_none() { - println!("{node} {subquery_offset}"); - } - let subquery_node = self.nodes.get(subquery_id).unwrap(); let alias = self .alias_generator @@ -612,6 +701,15 @@ impl TreeNodeRewriter for DependentJoinRewriter { subquery_alias_by_offset, )?; } + LogicalPlan::Filter(filter) => { + current_plan = self.rewrite_filter( + filter, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + } LogicalPlan::Join(join) => { assert!(node_info.columns_accesses_by_subquery_id.len() == 1); let (_, column_accesses) = @@ -656,93 +754,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { Some((join.join_type, lateral_join_condition)), )?; } - LogicalPlan::Filter(filter) => { - // everytime we meet a subquery during traversal, we increment this by 1 - // we can use this offset to lookup the original subquery info - // in subquery_alias_by_offset - // the reason why we cannot create a hashmap keyed by Subquery object HashMap - // is that the subquery inside this filter expr may have been rewritten in - // the lower level - let mut offset = 0; - let offset_ref = &mut offset; - let mut subquery_expr_by_offset = HashMap::new(); - let new_predicate = filter - .predicate - .clone() - .transform(|e| { - // replace any subquery expr with subquery_alias.output - // column - let alias = match e { - Expr::InSubquery(_) - | Expr::Exists(_) - | Expr::ScalarSubquery(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - _ => return Ok(Transformed::no(e)), - }; - // we are aware that the original subquery can be rewritten - // update the latest expr to this map - subquery_expr_by_offset.insert(*offset_ref, e); - *offset_ref += 1; - - // TODO: this assume that after decorrelation - // the dependent join will provide an extra column with the structure - // of "subquery_alias.output" - // On later step of decorrelation, it rely on this structure - // to again rename the expression after join - // for example if the real join type is LeftMark, the correct output - // column should be "mark" instead, else after the join - // one extra layer of projection is needed to alias "mark" into - // "alias.output" - Ok(Transformed::yes(col(format!("{alias}.output")))) - })? - .data; - // because dependent join may introduce extra columns - // to evaluate the subquery, the final plan should - // has another projection to remove these redundant columns - let post_join_projections: Vec = filter - .input - .schema() - .columns() - .iter() - .map(|c| col(c.clone())) - .collect(); - for (subquery_offset, (_, column_accesses)) in - node_info.columns_accesses_by_subquery_id.iter().enumerate() - { - let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); - let subquery_expr = - subquery_expr_by_offset.get(&subquery_offset).unwrap(); - - let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); - - let correlated_columns = column_accesses - .iter() - .map(|ac| { - ( - ac.subquery_depth, - Expr::OuterReferenceColumn( - ac.data_type.clone(), - ac.col.clone(), - ), - ) - }) - .unique() - .collect(); - - current_plan = current_plan.dependent_join( - subquery_input.deref().clone(), - correlated_columns, - Some(subquery_expr.clone()), - current_subquery_depth, - alias.clone(), - None, // TODO: handle this when we support lateral join rewrite - )?; - } - current_plan = current_plan - .filter(new_predicate.clone())? - .project(post_join_projections)?; - } _ => { unimplemented!( "implement more dependent join node creation for node {}", From c083501e1c37ac617a6733acbb6e883883fe8888 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Wed, 28 May 2025 23:35:46 +0200 Subject: [PATCH 49/70] feat: impl other api of logical plan for dependent join --- datafusion/expr/src/logical_plan/plan.rs | 12 ++--- datafusion/expr/src/logical_plan/tree_node.rs | 46 +++++++++++++++++-- .../optimizer/src/decorrelate_general.rs | 2 + .../optimizer/src/scalar_subquery_to_join.rs | 2 +- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e6d500a68238..e3ad16e98a4c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -318,9 +318,6 @@ pub struct DependentJoin { pub lateral_join_condition: Option<(JoinType, Expr)>, } -impl DependentJoin { - fn indent_string(&self) -> String {} -} impl Display for DependentJoin { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let correlated_str = self @@ -335,12 +332,12 @@ impl Display for DependentJoin { .collect::>() .join(", "); let lateral_join_info = - if let Some((join_type, join_expr)) = self.lateral_join_condition { + if let Some((join_type, join_expr)) = &self.lateral_join_condition { format!(" lateral {join_type} join with {join_expr}") } else { "".to_string() }; - let subquery_expr_str = if let Some(expr) = self.subquery_expr { + let subquery_expr_str = if let Some(expr) = &self.subquery_expr { format!(" with expr {expr}") } else { "".to_string() @@ -348,7 +345,8 @@ impl Display for DependentJoin { write!( f, "DependentJoin on [{correlated_str}]{subquery_expr_str}\ - {lateral_join_info} depth {subquery_depth}" + {lateral_join_info} depth {0}", + self.subquery_depth, ) } } @@ -1993,7 +1991,7 @@ impl LogicalPlan { } LogicalPlan::DependentJoin(dependent_join) => { - dependent_join.fmt(f) + Display::fmt(dependent_join,f) }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 0b9f4a40fff9..caa6449573d1 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -350,7 +350,27 @@ impl TreeNode for LogicalPlan { | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } | LogicalPlan::DescribeTable(_) => Transformed::no(self), - LogicalPlan::DependentJoin(..) => todo!(), + LogicalPlan::DependentJoin(DependentJoin { + schema, + correlated_columns, + subquery_expr, + subquery_depth, + subquery_name, + lateral_join_condition, + left, + right, + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { + LogicalPlan::DependentJoin(DependentJoin { + schema, + correlated_columns, + subquery_expr, + subquery_depth, + subquery_name, + lateral_join_condition, + left, + right, + }) + }), }) } } @@ -403,8 +423,28 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - // TODO: apply expr on the subquery - LogicalPlan::DependentJoin(..) => Ok(TreeNodeRecursion::Continue), + LogicalPlan::DependentJoin(DependentJoin { + correlated_columns, + subquery_expr, + lateral_join_condition, + .. + }) => { + let correlated_column_exprs = correlated_columns + .iter() + .map(|(_, c)| c.clone()) + .collect::>(); + let subquery_expr_opt = subquery_expr.clone(); + let maybe_lateral_join_condition = match lateral_join_condition { + Some((_, condition)) => Some(condition.clone()), + None => None, + }; + ( + &correlated_column_exprs, + &subquery_expr_opt, + &maybe_lateral_join_condition, + ) + .apply_ref_elements(f) + } LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index e90270fa2e1a..d5b3d9d8c8f7 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -805,8 +805,10 @@ impl OptimizerRule for Decorrelation { #[cfg(test)] mod tests { use super::DependentJoinRewriter; + use crate::test::test_table_scan_with_name; use arrow::datatypes::DataType as ArrowDataType; + use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeRefContainer}; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr_fn::col, in_subquery, lit, out_ref_col, diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index ece6f00cacc3..b01a55d98fec 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -287,7 +287,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// /// # Arguments /// -/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders +/// * `query_info` - The subquery portion of the `where` (select avg(total) from orders) /// * `filter_input` - The non-subquery portion (from customers) /// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `subquery_alias` - Subquery aliases From 9512ccccb17ca92dfd012270b3379f8db7fdcf7f Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Wed, 28 May 2025 23:38:15 +0200 Subject: [PATCH 50/70] chore: rm debug file --- datafusion/sqllogictest/test_files/debug.slt | 25 -------------------- 1 file changed, 25 deletions(-) delete mode 100644 datafusion/sqllogictest/test_files/debug.slt diff --git a/datafusion/sqllogictest/test_files/debug.slt b/datafusion/sqllogictest/test_files/debug.slt deleted file mode 100644 index b190aec6152e..000000000000 --- a/datafusion/sqllogictest/test_files/debug.slt +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -statement ok -create table t1 as values(1); - -statement ok -create table t2 as values(2); - -query TT -explain select * from t1 join lateral (select * from t2 where t1.column1+t2.column1=1) on t1.column1 Date: Thu, 29 May 2025 13:10:14 +0200 Subject: [PATCH 51/70] fix: not expose subquery expr for dependentjoin --- datafusion/expr/src/logical_plan/tree_node.rs | 7 +------ datafusion/optimizer/src/decorrelate_general.rs | 1 - 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index caa6449573d1..936350188434 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -433,16 +433,11 @@ impl LogicalPlan { .iter() .map(|(_, c)| c.clone()) .collect::>(); - let subquery_expr_opt = subquery_expr.clone(); let maybe_lateral_join_condition = match lateral_join_condition { Some((_, condition)) => Some(condition.clone()), None => None, }; - ( - &correlated_column_exprs, - &subquery_expr_opt, - &maybe_lateral_join_condition, - ) + (&correlated_column_exprs, &maybe_lateral_join_condition) .apply_ref_elements(f) } LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index d5b3d9d8c8f7..39f266b1f311 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -808,7 +808,6 @@ mod tests { use crate::test::test_table_scan_with_name; use arrow::datatypes::DataType as ArrowDataType; - use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeRefContainer}; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr_fn::col, in_subquery, lit, out_ref_col, From 10f9aeb0bfe3a3c70aef13cc2b4936f762dcb13e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 07:55:22 +0200 Subject: [PATCH 52/70] chore: add data type to correlated column --- datafusion/expr/src/logical_plan/builder.rs | 2 +- .../expr/src/logical_plan/invariants.rs | 25 +++++++++++++------ datafusion/expr/src/logical_plan/plan.rs | 20 ++++++--------- datafusion/expr/src/logical_plan/tree_node.rs | 11 +++----- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1a179613d072..a7c36a0f87b0 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -889,7 +889,7 @@ impl LogicalPlanBuilder { pub fn dependent_join( self, right: LogicalPlan, - correlated_columns: Vec<(usize, Expr)>, + correlated_columns: Vec<(usize, Column, DataType)>, subquery_expr: Option, subquery_depth: usize, subquery_name: String, diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 0c30c9785766..ebd1699ea99b 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -201,20 +201,27 @@ pub fn check_subquery_expr( }?; match outer_plan { LogicalPlan::Projection(_) - | LogicalPlan::Filter(_) => Ok(()), - LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => { + | LogicalPlan::Filter(_) + | LogicalPlan::DependentJoin(_) => Ok(()), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { if group_expr.contains(expr) && !aggr_expr.contains(expr) { // TODO revisit this validation logic plan_err!( - "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions" + "Correlated scalar subquery in the GROUP BY clause must \ + also be in the aggregate expressions" ) } else { Ok(()) } } _ => plan_err!( - "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes" - ) + "Correlated scalar subquery can only be used in Projection, Filter, \ + Aggregate, DependentJoin plan nodes" + ), }?; } check_correlations_in_subquery(inner_plan) @@ -235,11 +242,12 @@ pub fn check_subquery_expr( | LogicalPlan::TableScan(_) | LogicalPlan::Window(_) | LogicalPlan::Aggregate(_) - | LogicalPlan::Join(_) => Ok(()), + | LogicalPlan::Join(_) + | LogicalPlan::DependentJoin(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ - but was used in [{}]", + Projection, Filter, TableScan, Window functions, Aggregate, Join and \ + Dependent Join plan nodes, but was used in [{}]", outer_plan.display() ), }?; @@ -323,6 +331,7 @@ fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> { } }, LogicalPlan::Extension(_) => Ok(()), + LogicalPlan::DependentJoin(_) => Ok(()), plan => check_no_outer_references(plan), } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e3ad16e98a4c..a658e2a41381 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -295,14 +295,13 @@ pub enum LogicalPlan { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DependentJoin { pub schema: DFSchemaRef, - // All combinatoins of (subquery,OuterReferencedExpr) on the RHS (and its descendant) - // which points to a column on the LHS. - // The Expr should always be Expr::OuterRefColumn. + // All combinations of (subquery depth,Column and its DataType) on the RHS (and its descendant) + // which points to a column on the LHS of this dependent join // Note that not all outer_refs from the RHS are mentioned in this vectors - // because RHS may reference columns provided somewhere from the above join. + // because RHS may reference columns provided somewhere from the above parent dependent join. // Depths of each correlated_columns should always be gte current dependent join // subquery_depth - pub correlated_columns: Vec<(usize, Expr)>, + pub correlated_columns: Vec<(usize, Column, DataType)>, // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` @@ -323,12 +322,7 @@ impl Display for DependentJoin { let correlated_str = self .correlated_columns .iter() - .map(|(level, c)| { - if let Expr::OuterReferenceColumn(_, ref col) = c { - return format!("{col} lvl {level}"); - } - "".to_string() - }) + .map(|(level, col, _)| format!("{col} lvl {level}")) .collect::>() .join(", "); let lateral_join_info = @@ -355,7 +349,7 @@ impl PartialOrd for DependentJoin { fn partial_cmp(&self, other: &Self) -> Option { #[derive(PartialEq, PartialOrd)] struct ComparableJoin<'a> { - correlated_columns: &'a Vec<(usize, Expr)>, + correlated_columns: &'a Vec<(usize, Column, DataType)>, // the upper expr that containing the subquery expr // i.e for predicates: where outer = scalar_sq + 1 // correlated exprs are `scalar_sq + 1` @@ -1991,7 +1985,7 @@ impl LogicalPlan { } LogicalPlan::DependentJoin(dependent_join) => { - Display::fmt(dependent_join,f) + Display::fmt(dependent_join, f) }, LogicalPlan::Join(Join { on: ref keys, diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 936350188434..7f9a40a49c06 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -39,9 +39,9 @@ use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + DependentJoin, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, + Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; @@ -53,8 +53,6 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, Result}; -use super::plan::DependentJoin; - impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, @@ -425,13 +423,12 @@ impl LogicalPlan { match self { LogicalPlan::DependentJoin(DependentJoin { correlated_columns, - subquery_expr, lateral_join_condition, .. }) => { let correlated_column_exprs = correlated_columns .iter() - .map(|(_, c)| c.clone()) + .map(|(_, c, _)| c.clone()) .collect::>(); let maybe_lateral_join_condition = match lateral_join_condition { Some((_, condition)) => Some(condition.clone()), From 92bb17506ebc7842364ac97bd5f1be88440d4b22 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Thu, 29 May 2025 13:10:14 +0200 Subject: [PATCH 53/70] fix: not expose subquery expr for dependentjoin support sort support agg dummy unnest update test --- .../optimizer/src/decorrelate_general.rs | 1088 ++++++++++++++--- datafusion/optimizer/src/test/mod.rs | 27 + 2 files changed, 972 insertions(+), 143 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/decorrelate_general.rs index 39f266b1f311..2dbd055829eb 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/decorrelate_general.rs @@ -241,6 +241,114 @@ impl DependentJoinRewriter { current_plan = current_plan.project(new_projections)?; Ok(current_plan) } + + fn rewrite_aggregate( + &mut self, + aggregate: &Aggregate, + dependent_join_node: &Node, + current_subquery_depth: usize, + mut current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + let mut offset = 0; + let offset_ref = &mut offset; + let mut subquery_expr_by_offset = HashMap::new(); + let new_group_expr = aggregate + .group_expr + .iter() + .cloned() + .map(|e| { + Ok(e.transform(|e| { + // replace any subquery expr with subquery_alias.output column + let alias = match e { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => { + subquery_alias_by_offset.get(offset_ref).unwrap() + } + _ => return Ok(Transformed::no(e)), + }; + + // We are aware that the original subquery can be rewritten update the + // latest expr to this map. + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + Ok(Transformed::yes(col(format!("{alias}.output")))) + })? + .data) + }) + .collect::>>()?; + + let new_agg_expr = aggregate + .aggr_expr + .clone() + .iter() + .cloned() + .map(|e| { + Ok(e.transform(|e| { + // replace any subquery expr with subquery_alias.output column + let alias = match e { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => { + subquery_alias_by_offset.get(offset_ref).unwrap() + } + _ => return Ok(Transformed::no(e)), + }; + + // We are aware that the original subquery can be rewritten update the + // latest expr to this map. + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + Ok(Transformed::yes(col(format!("{alias}.output")))) + })? + .data) + }) + .collect::>>()?; + + for (subquery_offset, (_, column_accesses)) in dependent_join_node + .columns_accesses_by_subquery_id + .iter() + .enumerate() + { + let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); + let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); + + let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); + + let correlated_columns = column_accesses + .iter() + .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) + .unique() + .collect(); + + current_plan = current_plan.dependent_join( + subquery_input.deref().clone(), + correlated_columns, + Some(subquery_expr.clone()), + current_subquery_depth, + alias.clone(), + None, // TODO: handle this when we support lateral join rewrite + )?; + } + + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // has another projection to remove these redundant columns + let post_join_projections: Vec = aggregate + .schema + .columns() + .iter() + .map(|c| col(c.clone())) + .collect(); + + current_plan + .aggregate(new_group_expr.clone(), new_agg_expr.clone())? + .project(post_join_projections) + } + // lowest common ancestor from stack // given a tree of // n1 @@ -333,6 +441,7 @@ impl DependentJoinRewriter { subquery_depth: self.subquery_depth, }); } + fn rewrite_subqueries_into_dependent_joins( &mut self, plan: LogicalPlan, @@ -434,39 +543,39 @@ fn contains_subquery(expr: &Expr) -> bool { /// The traversal happens in the following sequence /// /// ```text -/// ↓1 -/// ↑12 -/// ┌────────────┐ -/// │ FILTER │<--- DependentJoin rewrite -/// │ (1) │ happens here (step 12) -/// └─────┬────┬─┘ Here we already have enough information -/// | | | of which node is accessing which column -/// | | | provided by "Table Scan t1" node -/// │ | | (for example node (6) below ) -/// │ | | -/// │ | | -/// │ | | +/// ↓1 +/// ↑12 +/// ┌──────────────┐ +/// │ FILTER │<--- DependentJoin rewrite +/// │ (1) │ happens here (step 12) +/// └─┬─────┬────┬─┘ Here we already have enough information +/// │ │ │ of which node is accessing which column +/// │ │ │ provided by "Table Scan t1" node +/// │ │ │ (for example node (6) below ) +/// │ │ │ +/// │ │ │ +/// │ │ │ /// ↓2────┘ ↓6 └────↓10 /// ↑5 ↑11 ↑11 /// ┌───▼───┐ ┌──▼───┐ ┌───▼───────┐ /// │SUBQ1 │ │SUBQ2 │ │TABLE SCAN │ /// └──┬────┘ └──┬───┘ │ t1 │ -/// | | └───────────┘ -/// | | -/// | | -/// | ↓7 -/// | ↑10 -/// | ┌──▼───────┐ -/// | │Filter │----> mark_outer_column_access(outer_ref) -/// | │outer_ref | -/// | │ (6) | -/// | └──┬───────┘ -/// | | +/// │ │ └───────────┘ +/// │ │ +/// │ │ +/// │ ↓7 +/// │ ↑10 +/// │ ┌───▼──────┐ +/// │ │Filter │----> mark_outer_column_access(outer_ref) +/// │ │outer_ref │ +/// │ │ (6) │ +/// │ └──┬───────┘ +/// │ │ /// ↓3 ↓8 /// ↑4 ↑9 -/// ┌──▼────┐ ┌──▼────┐ -/// │SCAN t2│ │SCAN t2│ -/// └───────┘ └───────┘ +/// ┌──▼────┐ ┌──▼────┐ +/// │SCAN t2│ │SCAN t2│ +/// └───────┘ └───────┘ /// ``` impl TreeNodeRewriter for DependentJoinRewriter { type Node = LogicalPlan; @@ -507,6 +616,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { self.conclude_lowest_dependent_join_node_if_any(new_id, col); }); } + LogicalPlan::Unnest(_unnest) => {} // TODO: this is untested LogicalPlan::Projection(proj) => { for expr in &proj.expr { @@ -571,7 +681,33 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } } - LogicalPlan::Aggregate(_) => {} + LogicalPlan::Aggregate(aggregate) => { + for expr in &aggregate.group_expr { + if contains_subquery(expr) { + is_dependent_join_node = true; + } + + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + + for expr in &aggregate.aggr_expr { + if contains_subquery(expr) { + is_dependent_join_node = true; + } + + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } LogicalPlan::Join(join) => { let mut sq_count = if let LogicalPlan::Subquery(_) = &join.left.as_ref() { 1 @@ -634,9 +770,21 @@ impl TreeNodeRewriter for DependentJoinRewriter { )); } } - _ => { - return internal_err!("impl f_down for node type {:?}", node); + LogicalPlan::Sort(sort) => { + for expr in &sort.expr { + if contains_subquery(&expr.expr) { + is_dependent_join_node = true; + } + + expr.expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } } + _ => {} }; if is_dependent_join_node { @@ -754,6 +902,15 @@ impl TreeNodeRewriter for DependentJoinRewriter { Some((join.join_type, lateral_join_condition)), )?; } + LogicalPlan::Aggregate(aggregate) => { + current_plan = self.rewrite_aggregate( + aggregate, + &node_info, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + } _ => { unimplemented!( "implement more dependent join node creation for node {}", @@ -806,14 +963,25 @@ impl OptimizerRule for Decorrelation { mod tests { use super::DependentJoinRewriter; +<<<<<<< HEAD use crate::test::test_table_scan_with_name; +======= + use crate::test::{test_table_scan_with_name, test_table_with_columns}; + use crate::{ + assert_optimized_plan_eq_display_indent_snapshot, + decorrelate_general::Decorrelation, OptimizerConfig, OptimizerContext, + OptimizerRule, + }; +>>>>>>> 496703d58 (fix: not expose subquery expr for dependentjoin) use arrow::datatypes::DataType as ArrowDataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ - binary_expr, exists, expr_fn::col, in_subquery, lit, out_ref_col, - scalar_subquery, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Subquery, + binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, + out_ref_col, scalar_subquery, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, + Operator, SortExpr, Subquery, }; - use datafusion_functions_aggregate::count::count; + use datafusion_functions_aggregate::{count::count, sum::sum}; use insta::assert_snapshot; use std::sync::Arc; @@ -832,6 +1000,7 @@ mod tests { ) }}; } + #[test] fn rewrite_dependent_join_with_nested_lateral_join() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; @@ -839,25 +1008,24 @@ mod tests { let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - let scalar_sq_level2 = - Arc::new( - LogicalPlanBuilder::from(inner_table_lv2) - .filter( - col("inner_table_lv2.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and(col("inner_table_lv2.b").eq(out_ref_col( - ArrowDataType::UInt32, - "inner_table_lv1.b", - ))), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - .build()?, - ); + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); let sq_level1 = Arc::new( LogicalPlanBuilder::from(inner_table_lv1.clone()) .filter( col("inner_table_lv1.c") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), )? .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? @@ -869,7 +1037,7 @@ mod tests { LogicalPlan::Subquery(Subquery { subquery: sq_level1, outer_ref_columns: vec![out_ref_col( - ArrowDataType::UInt32, + DataType::UInt32, "outer_table.c", // note that subquery lvl2 is referencing outer_table.a, and it is not being listed here // this simulate the limitation of current subquery planning and assert @@ -881,6 +1049,18 @@ mod tests { vec![lit(true)], )? .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + assert_dependent_join_rewrite!(plan, @r" DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] @@ -904,9 +1084,9 @@ mod tests { let sq_level1 = Arc::new( LogicalPlanBuilder::from(inner_table_lv1) .filter(col("inner_table_lv1.a").eq(binary_expr( - out_ref_col(ArrowDataType::UInt32, "outer_left_table.a"), - datafusion_expr::Operator::Plus, - out_ref_col(ArrowDataType::UInt32, "outer_right_table.a"), + out_ref_col(DataType::UInt32, "outer_left_table.a"), + Operator::Plus, + out_ref_col(DataType::UInt32, "outer_right_table.a"), )))? .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? @@ -925,6 +1105,17 @@ mod tests { .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_left_table.a) + outer_ref(outer_right_table.a) + // TableScan: inner_table_lv1 + // Left Join: Filter: outer_left_table.a = outer_right_table.a + // TableScan: outer_right_table + // TableScan: outer_left_table + assert_dependent_join_rewrite!(plan, @r" Projection: outer_right_table.a, outer_right_table.b, outer_right_table.c, outer_left_table.a, outer_left_table.b, outer_left_table.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, output:Boolean] @@ -949,25 +1140,24 @@ mod tests { let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - let scalar_sq_level2 = - Arc::new( - LogicalPlanBuilder::from(inner_table_lv2) - .filter( - col("inner_table_lv2.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and(col("inner_table_lv2.b").eq(out_ref_col( - ArrowDataType::UInt32, - "inner_table_lv1.b", - ))), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - .build()?, - ); + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); let scalar_sq_level1_a = Arc::new( LogicalPlanBuilder::from(inner_table_lv1.clone()) .filter( col("inner_table_lv1.c") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) // scalar_sq_level2 is intentionally shared between both // scalar_sq_level1_a and scalar_sq_level1_b // to check if the framework can uniquely identify the correlated columns @@ -980,7 +1170,7 @@ mod tests { LogicalPlanBuilder::from(inner_table_lv1.clone()) .filter( col("inner_table_lv1.c") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), )? .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.b"))])? @@ -992,11 +1182,31 @@ mod tests { col("outer_table.a"), binary_expr( scalar_subquery(scalar_sq_level1_a), - datafusion_expr::Operator::Plus, + Operator::Plus, scalar_subquery(scalar_sq_level1_b), ), ])? .build()?; + + // Projection: outer_table.a, () + () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.b)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, __scalar_sq_3.output + __scalar_sq_4.output [a:UInt32, __scalar_sq_3.output + __scalar_sq_4.output:Int64] DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Int64, output:Int64] @@ -1028,25 +1238,24 @@ mod tests { let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - let scalar_sq_level2 = - Arc::new( - LogicalPlanBuilder::from(inner_table_lv2) - .filter( - col("inner_table_lv2.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and(col("inner_table_lv2.b").eq(out_ref_col( - ArrowDataType::UInt32, - "inner_table_lv1.b", - ))), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - .build()?, - ); + let scalar_sq_level2 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) + .and( + col("inner_table_lv2.b") + .eq(out_ref_col(DataType::UInt32, "inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); let scalar_sq_level1 = Arc::new( LogicalPlanBuilder::from(inner_table_lv1.clone()) .filter( col("inner_table_lv1.c") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .eq(out_ref_col(DataType::UInt32, "outer_table.c")) .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), )? .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? @@ -1060,6 +1269,19 @@ mod tests { .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND () = outer_table.a + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) AND () = Int32(1) + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv2.a)]] + // Filter: inner_table_lv2.a = outer_ref(outer_table.a) AND inner_table_lv2.b = outer_ref(inner_table_lv1 + // .b) + // TableScan: inner_table_lv2 + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, output:Int64] @@ -1102,17 +1324,28 @@ mod tests { .and(in_subquery(col("outer_table.b"), in_sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND EXISTS () AND outer_table.b IN () + // Subquery: + // Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // Subquery: + // Projection: inner_table_lv1.a + // Filter: inner_table_lv1.c = Int32(2) + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] - DependentJoin on [] with expr outer_table.b IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] - DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - Projection: inner_table_lv1.a [a:UInt32] - Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] + DependentJoin on [] with expr outer_table.b IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean, output:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.a [a:UInt32] + Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } @@ -1125,14 +1358,14 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") + out_ref_col(DataType::UInt32, "outer_table.a") .gt(col("inner_table_lv1.c")), ) .and(col("inner_table_lv1.b").eq(lit(1))) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") + out_ref_col(DataType::UInt32, "outer_table.b") .eq(col("inner_table_lv1.b")), ), )? @@ -1148,15 +1381,24 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] - Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] - Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: count(inner_table_lv1.a) AS count_a [count_a:Int64] + Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] [count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } @@ -1168,33 +1410,43 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") + out_ref_col(DataType::UInt32, "outer_table.a") .gt(col("inner_table_lv1.c")), ) .and(col("inner_table_lv1.b").eq(lit(1))) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") + out_ref_col(DataType::UInt32, "outer_table.b") .eq(col("inner_table_lv1.b")), ), )? - .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .alias("outer_b_alias")])? + .project(vec![ + out_ref_col(DataType::UInt32, "outer_table.b").alias("outer_b_alias") + ])? .build()?, ); let plan = LogicalPlanBuilder::from(outer_table.clone()) .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; + + // Filter: outer_table.a > Int32(1) AND EXISTS () + // Subquery: + // Projection: outer_ref(outer_table.b) AS outer_b_alias + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND in + // ner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] - Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } @@ -1215,14 +1467,21 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .filter(col("outer_table.a").gt(lit(1)).and(exists(sq_level1)))? .build()?; + // Filter: outer_table.a > Int32(1) AND EXISTS () + // Subquery: + // Projection: inner_table_lv1.b, inner_table_lv1.a + // Filter: inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: inner_table_lv1.b, inner_table_lv1.a [b:UInt32, a:UInt32] - Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [] with expr EXISTS () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b, inner_table_lv1.a [b:UInt32, a:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) @@ -1245,14 +1504,22 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: inner_table_lv1.b + // Filter: inner_table_lv1.b = Int32(1) + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: inner_table_lv1.b [b:UInt32] - Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: inner_table_lv1.b [b:UInt32] + Filter: inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) @@ -1265,19 +1532,20 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .eq(out_ref_col(DataType::UInt32, "outer_table.a")) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") + out_ref_col(DataType::UInt32, "outer_table.a") .gt(col("inner_table_lv1.c")), ) .and(col("inner_table_lv1.b").eq(lit(1))) .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") + out_ref_col(DataType::UInt32, "outer_table.b") .eq(col("inner_table_lv1.b")), ), )? - .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .alias("outer_b_alias")])? + .project(vec![ + out_ref_col(DataType::UInt32, "outer_table.b").alias("outer_b_alias") + ])? .build()?, ); @@ -1288,14 +1556,22 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: outer_ref(outer_table.b) AS outer_b_alias + // Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b + // TableScan: inner_table_lv1 + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" -Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] - Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + DependentJoin on [outer_table.a lvl 1, outer_table.b lvl 1] with expr outer_table.c IN () depth 1 [a:UInt32, b:UInt32, c:UInt32, output:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: outer_ref(outer_table.b) AS outer_b_alias [outer_b_alias:UInt32;N] + Filter: inner_table_lv1.a = outer_ref(outer_table.a) AND outer_ref(outer_table.a) > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND outer_ref(outer_table.b) = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) } @@ -1308,7 +1584,7 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U LogicalPlanBuilder::from(inner_table_lv1) .filter( col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table_alias.a")), + .eq(out_ref_col(DataType::UInt32, "outer_table_alias.a")), )? .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? @@ -1323,6 +1599,16 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U .and(in_subquery(col("outer_table.c"), sq_level1)), )? .build()?; + + // Filter: outer_table.a > Int32(1) AND outer_table.c IN () + // Subquery: + // Projection: count(inner_table_lv1.a) AS count_a + // Aggregate: groupBy=[[]], aggr=[[count(inner_table_lv1.a)]] + // Filter: inner_table_lv1.a = outer_ref(outer_table_alias.a) + // TableScan: inner_table_lv1 + // SubqueryAlias: outer_table_alias + // TableScan: outer_table + assert_dependent_join_rewrite!(plan, @r" Projection: outer_table_alias.a, outer_table_alias.b, outer_table_alias.c [a:UInt32, b:UInt32, c:UInt32] Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, output:Boolean] @@ -1336,4 +1622,520 @@ Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:U "); Ok(()) } +<<<<<<< HEAD +======= + #[test] + fn decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + let dec = Decorrelation::new(); + let ctx: Box = Box::new(OptimizerContext::new()); + let plan = dec.rewrite(plan, ctx.as_ref())?.data; + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, delim_scan_1.mark AS __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] + LeftMark Join: Filter: outer_table.c = outer_ref(outer_table.b) AND outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_1.b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: delim_scan_1.b, delim_scan_1.a, delim_scan_1.b [outer_ref(outer_table.b):UInt32;N, a:UInt32;N, b:UInt32;N] + Filter: inner_table_lv1.a = delim_scan_1.a AND delim_scan_1.a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND delim_scan_1.b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_1 [a:UInt32;N, b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + "); + + Ok(()) + } + + // from duckdb test: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/test_correlated_any_all.test + #[test] + fn test_correlated_any_all_1() -> Result<()> { + // CREATE TABLE integers(i INTEGER); + // SELECT i = ANY( + // SELECT i + // FROM integers + // WHERE i = i1.i + // ) + // FROM integers i1 + // ORDER BY i; + + // Create base table + let integers = test_table_with_columns("integers", &[("i", DataType::Int32)])?; + + // Build correlated subquery: + // SELECT i FROM integers WHERE i = i1.i + let subquery = Arc::new( + LogicalPlanBuilder::from(integers.clone()) + .filter(col("integers.i").eq(out_ref_col(DataType::Int32, "i1.i")))? + .project(vec![col("integers.i")])? + .build()?, + ); + + // Build main query with table alias i1 + let plan = LogicalPlanBuilder::from(integers) + .alias("i1")? // Alias the table as i1 + .filter( + // i = ANY(subquery) + Expr::InSubquery(InSubquery { + expr: Box::new(col("i1.i")), + subquery: Subquery { + subquery, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "i1.i")], + spans: Spans::new(), + }, + negated: false, + }), + )? + .sort(vec![SortExpr::new(col("i1.i"), false, false)])? // ORDER BY i + .build()?; + + // original plan: + // Sort: i1.i DESC NULLS LAST + // Filter: i1.i IN () + // Subquery: + // Projection: integers.i + // Filter: integers.i = outer_ref(i1.i) + // TableScan: integers + // SubqueryAlias: i1 + // TableScan: integers + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r#" + Sort: i1.i DESC NULLS LAST [i:Int32] + Projection: i1.i [i:Int32] + Filter: __in_sq_1.output [i:Int32, output:Boolean] + DependentJoin on [i1.i lvl 1] with expr i1.i IN () depth 1 [i:Int32, output:Boolean] + SubqueryAlias: i1 [i:Int32] + TableScan: integers [i:Int32] + Projection: integers.i [i:Int32] + Filter: integers.i = outer_ref(i1.i) [i:Int32] + TableScan: integers [i:Int32] + "# + ); + + Ok(()) + } + + // from duckdb: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/issue_2999.test + #[test] + fn test_any_subquery_with_derived_join() -> Result<()> { + // SQL equivalent: + // CREATE TABLE t0 (c0 INT); + // CREATE TABLE t1 (c0 INT); + // SELECT 1 = ANY( + // SELECT 1 + // FROM t1 + // JOIN ( + // SELECT count(*) + // GROUP BY t0.c0 + // ) AS x(x) ON TRUE + // ) + // FROM t0; + + // Create base tables + let t0 = test_table_with_columns("t0", &[("c0", DataType::Int32)])?; + let t1 = test_table_with_columns("t1", &[("c0", DataType::Int32)])?; + + // Build derived table subquery: + // SELECT count(*) GROUP BY t0.c0 + let derived_table = Arc::new( + LogicalPlanBuilder::from(t1.clone()) + .aggregate( + vec![out_ref_col(DataType::Int32, "t0.c0")], // GROUP BY t0.c0 + vec![count(lit(1))], // count(*) + )? + .build()?, + ); + + // Build the join subquery: + // SELECT 1 FROM t1 JOIN (derived_table) x(x) ON TRUE + let join_subquery = Arc::new( + LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: derived_table, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "t0.c0")], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], // ON TRUE + )? + .project(vec![lit(1)])? // SELECT 1 + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t0) + .filter( + // 1 = ANY(subquery) + Expr::InSubquery(InSubquery { + expr: Box::new(lit(1)), + subquery: Subquery { + subquery: join_subquery, + outer_ref_columns: vec![out_ref_col(DataType::Int32, "t0.c0")], + spans: Spans::new(), + }, + negated: false, + }), + )? + .build()?; + + // Filter: Int32(1) IN () + // Subquery: + // Projection: Int32(1) + // Inner Join: Filter: Boolean(true) + // TableScan: t1 + // Subquery: + // Aggregate: groupBy=[[outer_ref(t0.c0)]], aggr=[[count(Int32(1))]] + // TableScan: t1 + // TableScan: t0 + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r#" + Projection: t0.c0 [c0:Int32] + Filter: __in_sq_2.output [c0:Int32, output:Boolean] + DependentJoin on [t0.c0 lvl 2] with expr Int32(1) IN () depth 1 [c0:Int32, output:Boolean] + TableScan: t0 [c0:Int32] + Projection: Int32(1) [Int32(1):Int32] + DependentJoin on [] lateral Inner join with Boolean(true) depth 2 [c0:Int32] + TableScan: t1 [c0:Int32] + Aggregate: groupBy=[[outer_ref(t0.c0)]], aggr=[[count(Int32(1))]] [outer_ref(t0.c0):Int32;N, count(Int32(1)):Int64] + TableScan: t1 [c0:Int32] + "# + ); + + Ok(()) + } + + #[test] + fn decorrelate_two_subqueries_at_the_same_level() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let in_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(2)))? + .project(vec![col("inner_table_lv1.a")])? + .build()?, + ); + let exist_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), + )? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(exists(exist_sq_level1)) + .and(in_subquery(col("outer_table.b"), in_sq_level1)), + )? + .build()?; + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, __exists_sq_1.output, inner_table_lv1.mark AS __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] + LeftMark Join: Filter: outer_table.b = inner_table_lv1.a [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, mark:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, inner_table_lv1.mark AS __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean] + LeftMark Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_1 [] + DelimGet: [] + Projection: inner_table_lv1.a [a:UInt32] + Cross Join: [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_1 [] + DelimGet: [] + "); + Ok(()) + } + #[test] + fn decorrelate_join_in_subquery_with_count_depth_1() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + let sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1) + .filter( + col("inner_table_lv1.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.a") + .gt(col("inner_table_lv1.c")), + ) + .and(col("inner_table_lv1.b").eq(lit(1))) + .and( + out_ref_col(ArrowDataType::UInt32, "outer_table.b") + .eq(col("inner_table_lv1.b")), + ), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + // TODO: if uncomment this the test fail + // .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(in_subquery(col("outer_table.c"), sq_level1)), + )? + .build()?; + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] + Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] + LeftMark Join: Filter: outer_table.c = CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END AND outer_table.a IS NOT DISTINCT FROM delim_scan_2.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_2.b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.b, delim_scan_2.a, delim_scan_1.a, delim_scan_1.b [CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32, b:UInt32;N, a:UInt32;N, a:UInt32;N, b:UInt32;N] + Inner Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_1.b [count(inner_table_lv1.a):Int64, b:UInt32;N, a:UInt32;N, a:UInt32;N, b:UInt32;N] + Projection: count(inner_table_lv1.a), delim_scan_2.b, delim_scan_2.a [count(inner_table_lv1.a):Int64, b:UInt32;N, a:UInt32;N] + Aggregate: groupBy=[[delim_scan_2.a, delim_scan_2.b]], aggr=[[count(inner_table_lv1.a)]] [a:UInt32;N, b:UInt32;N, count(inner_table_lv1.a):Int64] + Filter: inner_table_lv1.a = delim_scan_2.a AND delim_scan_2.a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND delim_scan_2.b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_2 [a:UInt32;N, b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + SubqueryAlias: delim_scan_1 [a:UInt32;N, b:UInt32;N] + DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] + "); + Ok(()) + } + #[test] + fn decorrelated_two_nested_subqueries() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; + let scalar_sq_level2 = + Arc::new( + LogicalPlanBuilder::from(inner_table_lv2) + .filter( + col("inner_table_lv2.a") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) + .and(col("inner_table_lv2.b").eq(out_ref_col( + ArrowDataType::UInt32, + "inner_table_lv1.b", + ))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? + .build()?, + ); + let scalar_sq_level1 = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) + .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), + )? + .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .filter( + col("outer_table.a") + .gt(lit(1)) + .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), + )? + .build()?; + assert_decorrelate!(plan, @r" + Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] + Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_2.output:Int32;N] + Projection: outer_table.a, outer_table.b, outer_table.c, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.c, delim_scan_2.a, delim_scan_1.a, delim_scan_1.c, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END AS __scalar_sq_2.output [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_2.output:Int32;N] + Left Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_2.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_2.c [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.c, delim_scan_2.a, delim_scan_1.a, delim_scan_1.c [CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] + Inner Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_1.c [count(inner_table_lv1.a):Int64, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] + Projection: count(inner_table_lv1.a), delim_scan_2.c, delim_scan_2.a [count(inner_table_lv1.a):Int64, c:UInt32;N, a:UInt32;N] + Aggregate: groupBy=[[delim_scan_2.a, delim_scan_2.c]], aggr=[[count(inner_table_lv1.a)]] [a:UInt32;N, c:UInt32;N, count(inner_table_lv1.a):Int64] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.a, delim_scan_2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N] + Filter: inner_table_lv1.c = delim_scan_2.c AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_1.output:Int32;N] + Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.a, delim_scan_2.c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.c, delim_scan_4.a, delim_scan_4.b, delim_scan_3.b, delim_scan_3.a, delim_scan_3.c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END AS __scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_1.output:Int32;N] + Left Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_2 [a:UInt32;N, c:UInt32;N] + DelimGet: outer_table.a, outer_table.c [a:UInt32;N, c:UInt32;N] + Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.c, delim_scan_4.a, delim_scan_4.b, delim_scan_3.b, delim_scan_3.a, delim_scan_3.c [CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] + Inner Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_3.b AND outer_table.a IS NOT DISTINCT FROM delim_scan_3.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_3.c [count(inner_table_lv2.a):Int64, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] + Projection: count(inner_table_lv2.a), delim_scan_4.c, delim_scan_4.a, delim_scan_4.b [count(inner_table_lv2.a):Int64, c:UInt32;N, a:UInt32;N, b:UInt32;N] + Aggregate: groupBy=[[delim_scan_4.b, delim_scan_4.a, delim_scan_4.c]], aggr=[[count(inner_table_lv2.a)]] [b:UInt32;N, a:UInt32;N, c:UInt32;N, count(inner_table_lv2.a):Int64] + Filter: inner_table_lv2.a = delim_scan_4.a AND inner_table_lv2.b = delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, b:UInt32;N, a:UInt32;N, c:UInt32;N] + Cross Join: [a:UInt32, b:UInt32, c:UInt32, b:UInt32;N, a:UInt32;N, c:UInt32;N] + TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] + SubqueryAlias: delim_scan_4 [b:UInt32;N, a:UInt32;N, c:UInt32;N] + DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [b:UInt32;N, a:UInt32;N, c:UInt32;N] + SubqueryAlias: delim_scan_3 [b:UInt32;N, a:UInt32;N, c:UInt32;N] + DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [b:UInt32;N, a:UInt32;N, c:UInt32;N] + SubqueryAlias: delim_scan_1 [a:UInt32;N, c:UInt32;N] + DelimGet: outer_table.a, outer_table.c [a:UInt32;N, c:UInt32;N] + "); + Ok(()) + } + + fn test_simple_correlated_agg_subquery() -> Result<()> { + // CREATE TABLE t(a INT, b INT); + // SELECT a, + // (SELECT SUM(b) + // FROM t t2 + // WHERE t2.a = t1.a) as sum_b + // FROM t t1; + + // Create base table + let t = test_table_with_columns( + "t", + &[("a", DataType::Int32), ("b", DataType::Int32)], + )?; + + // Build scalar subquery: + // SELECT SUM(b) FROM t t2 WHERE t2.a = t1.a + let scalar_sub = Arc::new( + LogicalPlanBuilder::from(t.clone()) + .alias("t2")? + .filter(col("t2.a").eq(out_ref_col(DataType::Int32, "t1.a")))? + .aggregate( + vec![col("t2.b")], // No GROUP BY + vec![sum(col("t2.b"))], // SUM(b) + )? + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t) + .alias("t1")? + .project(vec![ + col("t1.a"), // a + scalar_subquery(scalar_sub), // (SELECT SUM(b) ...) + ])? + .build()?; + + // Projection: t1.a, () + // Subquery: + // Aggregate: groupBy=[[t2.b]], aggr=[[sum(t2.b)]] + // Filter: t2.a = outer_ref(t1.a) + // SubqueryAlias: t2 + // TableScan: t + // SubqueryAlias: t1 + // TableScan: t + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r#" + Projection: t1.a, __scalar_sq_1.output [a:Int32, output:Int32] + DependentJoin on [t1.a lvl 1] with expr () depth 1 [a:Int32, b:Int32, output:Int32] + SubqueryAlias: t1 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + Aggregate: groupBy=[[t2.b]], aggr=[[sum(t2.b)]] [b:Int32, sum(t2.b):Int64;N] + Filter: t2.a = outer_ref(t1.a) [a:Int32, b:Int32] + SubqueryAlias: t2 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + "# + ); + + Ok(()) + } + + #[test] + fn test_simple_subquery_in_agg() -> Result<()> { + // CREATE TABLE t(a INT, b INT); + // SELECT a, + // SUM( + // (SELECT b FROM t t2 WHERE t2.a = t1.a) + // ) as sum_scalar + // FROM t t1 + // GROUP BY a; + + // Create base table + let t = test_table_with_columns( + "t", + &[("a", DataType::Int32), ("b", DataType::Int32)], + )?; + + // Build inner scalar subquery: + // SELECT b FROM t t2 WHERE t2.a = t1.a + let scalar_sub = Arc::new( + LogicalPlanBuilder::from(t.clone()) + .alias("t2")? + .filter(col("t2.a").eq(out_ref_col(DataType::Int32, "t1.a")))? + .project(vec![col("t2.b")])? // SELECT b + .build()?, + ); + + // Build main query + let plan = LogicalPlanBuilder::from(t) + .alias("t1")? + .aggregate( + vec![col("t1.a")], // GROUP BY a + vec![sum(scalar_subquery(scalar_sub)) // SUM((SELECT b ...)) + .alias("sum_scalar")], + )? + .build()?; + + // Aggregate: groupBy=[[t1.a]], aggr=[[sum(()) AS sum_scalar]] + // Subquery: + // Projection: t2.b + // Filter: t2.a = outer_ref(t1.a) + // SubqueryAlias: t2 + // TableScan: t + // SubqueryAlias: t1 + // TableScan: t + + // Verify the rewrite result + assert_dependent_join_rewrite!( + plan, + @r#" + Projection: t1.a, sum_scalar [a:Int32, sum_scalar:Int64;N] + Aggregate: groupBy=[[t1.a]], aggr=[[sum(__scalar_sq_1.output) AS sum_scalar]] [a:Int32, sum_scalar:Int64;N] + DependentJoin on [t1.a lvl 1] with expr () depth 1 [a:Int32, b:Int32, output:Int32] + SubqueryAlias: t1 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + Projection: t2.b [b:Int32] + Filter: t2.a = outer_ref(t1.a) [a:Int32, b:Int32] + SubqueryAlias: t2 [a:Int32, b:Int32] + TableScan: t [a:Int32, b:Int32] + "# + ); + + Ok(()) + } +>>>>>>> 496703d58 (fix: not expose subquery expr for dependentjoin) } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 6e0b734bb928..b93fb3d4ff84 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -45,6 +45,33 @@ pub fn test_table_scan() -> Result { test_table_scan_with_name("test") } +/// Create a table with the given name and column definitions. +/// +/// # Arguments +/// * `name` - The name of the table to create +/// * `columns` - Column definitions as slice of tuples (name, data_type) +/// +/// # Example +/// ``` +/// let plan = test_table_with_columns("integers", &[("i", DataType::Int32)])?; +/// ``` +pub fn test_table_with_columns( + name: &str, + columns: &[(&str, DataType)], +) -> Result { + // Create fields with specified types for each column + let fields: Vec = columns + .iter() + .map(|&(col_name, ref data_type)| Field::new(col_name, data_type.clone(), false)) + .collect(); + + // Create schema from fields + let schema = Schema::new(fields); + + // Create table scan + table_scan(Some(name), &schema, None)?.build() +} + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, From 29eff4b4a0750bb68f1837a7634ce9333e8ca676 Mon Sep 17 00:00:00 2001 From: irenjj Date: Fri, 6 Jun 2025 12:06:20 +0800 Subject: [PATCH 54/70] spilt into rewrite_dependent_join & decorrelate_dependent_join --- datafusion/optimizer/src/lib.rs | 2 +- ...e_general.rs => rewrite_dependent_join.rs} | 49 ++++++++++++++----- 2 files changed, 39 insertions(+), 12 deletions(-) rename datafusion/optimizer/src/{decorrelate_general.rs => rewrite_dependent_join.rs} (98%) diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 0fad43f248a6..8efeb20f5516 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -40,7 +40,6 @@ pub mod analyzer; pub mod common_subexpr_eliminate; pub mod decorrelate; -pub mod decorrelate_general; pub mod decorrelate_lateral_join; pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; @@ -60,6 +59,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_dependent_join; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/decorrelate_general.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs similarity index 98% rename from datafusion/optimizer/src/decorrelate_general.rs rename to datafusion/optimizer/src/rewrite_dependent_join.rs index 2dbd055829eb..25c8425edba1 100644 --- a/datafusion/optimizer/src/decorrelate_general.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_err, Column, HashMap, Result}; use datafusion_expr::{ - col, lit, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + col, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, }; use indexmap::map::Entry; @@ -442,7 +442,7 @@ impl DependentJoinRewriter { }); } - fn rewrite_subqueries_into_dependent_joins( + pub fn rewrite_subqueries_into_dependent_joins( &mut self, plan: LogicalPlan, ) -> Result> { @@ -451,7 +451,7 @@ impl DependentJoinRewriter { } impl DependentJoinRewriter { - fn new(alias_generator: Arc) -> Self { + pub fn new(alias_generator: Arc) -> Self { DependentJoinRewriter { alias_generator, current_id: 0, @@ -527,10 +527,12 @@ fn contains_subquery(expr: &Expr) -> bool { /// before its children (subqueries children are visited first). /// This behavior allow the fact that, at any moment, if we observe a `LogicalPlan` /// that provides the data for columns, we can assume that all subqueries that reference -/// its data were already visited, and we can conclude the information of the `DependentJoin` +/// its data were already visited, and we can conclude the information of +/// the `DependentJoin` /// needed for the decorrelation: /// - The subquery expr -/// - The correlated columns on the LHS referenced from the RHS (and its recursing subqueries if any) +/// - The correlated columns on the LHS referenced from the RHS +/// (and its recursing subqueries if any) /// /// If in the original node there exists multiple subqueries at the same time /// two nested `DependentJoin` plans are generated (with equal depth). @@ -922,19 +924,30 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } +<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs #[allow(dead_code)] #[derive(Debug)] struct Decorrelation {} +======= +/// Optimizer rule for rewriting subqueries to dependent join. +#[allow(dead_code)] +#[derive(Debug)] +pub struct RewriteDependentJoin {} + +impl RewriteDependentJoin { + pub fn new() -> Self { + return RewriteDependentJoin {}; + } +} +>>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs -impl OptimizerRule for Decorrelation { +impl OptimizerRule for RewriteDependentJoin { fn supports_rewrite(&self) -> bool { true } - // There will be 2 rewrites going on - // - Convert all subqueries (maybe including lateral join in the future) to temporary - // LogicalPlan node called DependentJoin - // - Decorrelate DependentJoin following top-down approach recursively + // Convert all subqueries (maybe including lateral join in the future) to temporary + // LogicalPlan node called DependentJoin. fn rewrite( &self, plan: LogicalPlan, @@ -944,14 +957,18 @@ impl OptimizerRule for Decorrelation { DependentJoinRewriter::new(Arc::clone(config.alias_generator())); let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; if rewrite_result.transformed { +<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs // At this point, we have a logical plan with DependentJoin similar to duckdb unimplemented!("implement dependent join decorrelation") +======= + println!("dependent join plan {}", rewrite_result.data); +>>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs } Ok(rewrite_result) } fn name(&self) -> &str { - "decorrelate_subquery" + "rewrite_dependent_join" } fn apply_order(&self) -> Option { @@ -967,6 +984,7 @@ mod tests { use crate::test::test_table_scan_with_name; ======= use crate::test::{test_table_scan_with_name, test_table_with_columns}; +<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs use crate::{ assert_optimized_plan_eq_display_indent_snapshot, decorrelate_general::Decorrelation, OptimizerConfig, OptimizerContext, @@ -975,6 +993,9 @@ mod tests { >>>>>>> 496703d58 (fix: not expose subquery expr for dependentjoin) use arrow::datatypes::DataType as ArrowDataType; use arrow::datatypes::{DataType, Field}; +======= + use arrow::datatypes::DataType; +>>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, @@ -1622,6 +1643,7 @@ mod tests { "); Ok(()) } +<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs <<<<<<< HEAD ======= #[test] @@ -1673,6 +1695,8 @@ mod tests { Ok(()) } +======= +>>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs // from duckdb test: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/test_correlated_any_all.test #[test] @@ -1839,6 +1863,7 @@ mod tests { } #[test] +<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs fn decorrelate_two_subqueries_at_the_same_level() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -2011,6 +2036,8 @@ mod tests { Ok(()) } +======= +>>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs fn test_simple_correlated_agg_subquery() -> Result<()> { // CREATE TABLE t(a INT, b INT); // SELECT a, From f4e332e6dd8e40705b74e8648b99deacf5371986 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 08:13:45 +0200 Subject: [PATCH 55/70] fix: cherry-pick conflict --- .../optimizer/src/rewrite_dependent_join.rs | 202 +----------------- 1 file changed, 1 insertion(+), 201 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 25c8425edba1..528ef264ed06 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -924,11 +924,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } -<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs -#[allow(dead_code)] -#[derive(Debug)] -struct Decorrelation {} -======= /// Optimizer rule for rewriting subqueries to dependent join. #[allow(dead_code)] #[derive(Debug)] @@ -939,7 +934,6 @@ impl RewriteDependentJoin { return RewriteDependentJoin {}; } } ->>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs impl OptimizerRule for RewriteDependentJoin { fn supports_rewrite(&self) -> bool { @@ -957,12 +951,7 @@ impl OptimizerRule for RewriteDependentJoin { DependentJoinRewriter::new(Arc::clone(config.alias_generator())); let rewrite_result = transformer.rewrite_subqueries_into_dependent_joins(plan)?; if rewrite_result.transformed { -<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs - // At this point, we have a logical plan with DependentJoin similar to duckdb - unimplemented!("implement dependent join decorrelation") -======= println!("dependent join plan {}", rewrite_result.data); ->>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs } Ok(rewrite_result) } @@ -980,22 +969,13 @@ impl OptimizerRule for RewriteDependentJoin { mod tests { use super::DependentJoinRewriter; -<<<<<<< HEAD - use crate::test::test_table_scan_with_name; -======= use crate::test::{test_table_scan_with_name, test_table_with_columns}; -<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs use crate::{ assert_optimized_plan_eq_display_indent_snapshot, decorrelate_general::Decorrelation, OptimizerConfig, OptimizerContext, OptimizerRule, }; ->>>>>>> 496703d58 (fix: not expose subquery expr for dependentjoin) - use arrow::datatypes::DataType as ArrowDataType; use arrow::datatypes::{DataType, Field}; -======= - use arrow::datatypes::DataType; ->>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, @@ -1643,9 +1623,7 @@ mod tests { "); Ok(()) } -<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs -<<<<<<< HEAD -======= + #[test] fn decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; @@ -1695,8 +1673,6 @@ mod tests { Ok(()) } -======= ->>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs // from duckdb test: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/test_correlated_any_all.test #[test] @@ -1863,181 +1839,6 @@ mod tests { } #[test] -<<<<<<< HEAD:datafusion/optimizer/src/decorrelate_general.rs - fn decorrelate_two_subqueries_at_the_same_level() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let in_sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1.clone()) - .filter(col("inner_table_lv1.c").eq(lit(2)))? - .project(vec![col("inner_table_lv1.a")])? - .build()?, - ); - let exist_sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a").and(col("inner_table_lv1.b").eq(lit(1))), - )? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(exists(exist_sq_level1)) - .and(in_subquery(col("outer_table.b"), in_sq_level1)), - )? - .build()?; - assert_decorrelate!(plan, @r" - Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __exists_sq_1.output AND __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] - Projection: outer_table.a, outer_table.b, outer_table.c, __exists_sq_1.output, inner_table_lv1.mark AS __in_sq_2.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, __in_sq_2.output:Boolean] - LeftMark Join: Filter: outer_table.b = inner_table_lv1.a [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean, mark:Boolean] - Projection: outer_table.a, outer_table.b, outer_table.c, inner_table_lv1.mark AS __exists_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __exists_sq_1.output:Boolean] - LeftMark Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32] - Filter: inner_table_lv1.a AND inner_table_lv1.b = Int32(1) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_1 [] - DelimGet: [] - Projection: inner_table_lv1.a [a:UInt32] - Cross Join: [a:UInt32, b:UInt32, c:UInt32] - Filter: inner_table_lv1.c = Int32(2) [a:UInt32, b:UInt32, c:UInt32] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_1 [] - DelimGet: [] - "); - Ok(()) - } - #[test] - fn decorrelate_join_in_subquery_with_count_depth_1() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") - .gt(col("inner_table_lv1.c")), - ) - .and(col("inner_table_lv1.b").eq(lit(1))) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .eq(col("inner_table_lv1.b")), - ), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? - // TODO: if uncomment this the test fail - // .project(vec![count(col("inner_table_lv1.a")).alias("count_a")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(in_subquery(col("outer_table.c"), sq_level1)), - )? - .build()?; - assert_decorrelate!(plan, @r" - Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] - Projection: outer_table.a, outer_table.b, outer_table.c, mark AS __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] - LeftMark Join: Filter: outer_table.c = CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END AND outer_table.a IS NOT DISTINCT FROM delim_scan_2.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_2.b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.b, delim_scan_2.a, delim_scan_1.a, delim_scan_1.b [CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32, b:UInt32;N, a:UInt32;N, a:UInt32;N, b:UInt32;N] - Inner Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_1.b [count(inner_table_lv1.a):Int64, b:UInt32;N, a:UInt32;N, a:UInt32;N, b:UInt32;N] - Projection: count(inner_table_lv1.a), delim_scan_2.b, delim_scan_2.a [count(inner_table_lv1.a):Int64, b:UInt32;N, a:UInt32;N] - Aggregate: groupBy=[[delim_scan_2.a, delim_scan_2.b]], aggr=[[count(inner_table_lv1.a)]] [a:UInt32;N, b:UInt32;N, count(inner_table_lv1.a):Int64] - Filter: inner_table_lv1.a = delim_scan_2.a AND delim_scan_2.a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND delim_scan_2.b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_2 [a:UInt32;N, b:UInt32;N] - DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] - SubqueryAlias: delim_scan_1 [a:UInt32;N, b:UInt32;N] - DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] - "); - Ok(()) - } - #[test] - fn decorrelated_two_nested_subqueries() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - - let inner_table_lv2 = test_table_scan_with_name("inner_table_lv2")?; - let scalar_sq_level2 = - Arc::new( - LogicalPlanBuilder::from(inner_table_lv2) - .filter( - col("inner_table_lv2.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and(col("inner_table_lv2.b").eq(out_ref_col( - ArrowDataType::UInt32, - "inner_table_lv1.b", - ))), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv2.a"))])? - .build()?, - ); - let scalar_sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1.clone()) - .filter( - col("inner_table_lv1.c") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.c")) - .and(scalar_subquery(scalar_sq_level2).eq(lit(1))), - )? - .aggregate(Vec::::new(), vec![count(col("inner_table_lv1.a"))])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))), - )? - .build()?; - assert_decorrelate!(plan, @r" - Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_2.output:Int32;N] - Projection: outer_table.a, outer_table.b, outer_table.c, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.c, delim_scan_2.a, delim_scan_1.a, delim_scan_1.c, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END AS __scalar_sq_2.output [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_2.output:Int32;N] - Left Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_2.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_2.c [a:UInt32, b:UInt32, c:UInt32, CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32;N, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END, delim_scan_2.c, delim_scan_2.a, delim_scan_1.a, delim_scan_1.c [CASE WHEN count(inner_table_lv1.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv1.a) END:Int32, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] - Inner Join: Filter: outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_1.c [count(inner_table_lv1.a):Int64, c:UInt32;N, a:UInt32;N, a:UInt32;N, c:UInt32;N] - Projection: count(inner_table_lv1.a), delim_scan_2.c, delim_scan_2.a [count(inner_table_lv1.a):Int64, c:UInt32;N, a:UInt32;N] - Aggregate: groupBy=[[delim_scan_2.a, delim_scan_2.c]], aggr=[[count(inner_table_lv1.a)]] [a:UInt32;N, c:UInt32;N, count(inner_table_lv1.a):Int64] - Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.a, delim_scan_2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N] - Filter: inner_table_lv1.c = delim_scan_2.c AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_1.output:Int32;N] - Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.a, delim_scan_2.c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.c, delim_scan_4.a, delim_scan_4.b, delim_scan_3.b, delim_scan_3.a, delim_scan_3.c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END AS __scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N, __scalar_sq_1.output:Int32;N] - Left Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, c:UInt32;N] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_2 [a:UInt32;N, c:UInt32;N] - DelimGet: outer_table.a, outer_table.c [a:UInt32;N, c:UInt32;N] - Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.c, delim_scan_4.a, delim_scan_4.b, delim_scan_3.b, delim_scan_3.a, delim_scan_3.c [CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] - Inner Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_3.b AND outer_table.a IS NOT DISTINCT FROM delim_scan_3.a AND outer_table.c IS NOT DISTINCT FROM delim_scan_3.c [count(inner_table_lv2.a):Int64, c:UInt32;N, a:UInt32;N, b:UInt32;N, b:UInt32;N, a:UInt32;N, c:UInt32;N] - Projection: count(inner_table_lv2.a), delim_scan_4.c, delim_scan_4.a, delim_scan_4.b [count(inner_table_lv2.a):Int64, c:UInt32;N, a:UInt32;N, b:UInt32;N] - Aggregate: groupBy=[[delim_scan_4.b, delim_scan_4.a, delim_scan_4.c]], aggr=[[count(inner_table_lv2.a)]] [b:UInt32;N, a:UInt32;N, c:UInt32;N, count(inner_table_lv2.a):Int64] - Filter: inner_table_lv2.a = delim_scan_4.a AND inner_table_lv2.b = delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, b:UInt32;N, a:UInt32;N, c:UInt32;N] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, b:UInt32;N, a:UInt32;N, c:UInt32;N] - TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_4 [b:UInt32;N, a:UInt32;N, c:UInt32;N] - DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [b:UInt32;N, a:UInt32;N, c:UInt32;N] - SubqueryAlias: delim_scan_3 [b:UInt32;N, a:UInt32;N, c:UInt32;N] - DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [b:UInt32;N, a:UInt32;N, c:UInt32;N] - SubqueryAlias: delim_scan_1 [a:UInt32;N, c:UInt32;N] - DelimGet: outer_table.a, outer_table.c [a:UInt32;N, c:UInt32;N] - "); - Ok(()) - } - -======= ->>>>>>> 47ace22cf (spilt into rewrite_dependent_join & decorrelate_dependent_join):datafusion/optimizer/src/rewrite_dependent_join.rs fn test_simple_correlated_agg_subquery() -> Result<()> { // CREATE TABLE t(a INT, b INT); // SELECT a, @@ -2164,5 +1965,4 @@ mod tests { Ok(()) } ->>>>>>> 496703d58 (fix: not expose subquery expr for dependentjoin) } From 2a324bdbee75c17af2f1b0ee4e3ca39fd8d3660c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 08:21:41 +0200 Subject: [PATCH 56/70] chore: move left over commit from feature branch --- datafusion/expr/src/logical_plan/tree_node.rs | 2 +- .../optimizer/src/rewrite_dependent_join.rs | 81 +------------------ 2 files changed, 5 insertions(+), 78 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f9a40a49c06..11d775953d8b 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -428,7 +428,7 @@ impl LogicalPlan { }) => { let correlated_column_exprs = correlated_columns .iter() - .map(|(_, c, _)| c.clone()) + .map(|(_, c, _)| Expr::Column(c.clone())) .collect::>(); let maybe_lateral_join_condition = match lateral_join_condition { Some((_, condition)) => Some(condition.clone()), diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 528ef264ed06..c95ee2d45c2a 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -130,12 +130,7 @@ impl DependentJoinRewriter { let correlated_columns = column_accesses .iter() - .map(|ac| { - ( - ac.subquery_depth, - Expr::OuterReferenceColumn(ac.data_type.clone(), ac.col.clone()), - ) - }) + .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) .unique() .collect(); @@ -220,12 +215,7 @@ impl DependentJoinRewriter { let correlated_columns = column_accesses .iter() - .map(|ac| { - ( - ac.subquery_depth, - Expr::OuterReferenceColumn(ac.data_type.clone(), ac.col.clone()), - ) - }) + .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) .unique() .collect(); @@ -867,15 +857,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { let alias = subquery_alias_by_offset.get(&0).unwrap(); let correlated_columns = column_accesses .iter() - .map(|ac| { - ( - ac.subquery_depth, - Expr::OuterReferenceColumn( - ac.data_type.clone(), - ac.col.clone(), - ), - ) - }) + .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) .unique() .collect(); @@ -970,12 +952,7 @@ mod tests { use super::DependentJoinRewriter; use crate::test::{test_table_scan_with_name, test_table_with_columns}; - use crate::{ - assert_optimized_plan_eq_display_indent_snapshot, - decorrelate_general::Decorrelation, OptimizerConfig, OptimizerContext, - OptimizerRule, - }; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::DataType; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, @@ -1624,56 +1601,6 @@ mod tests { Ok(()) } - #[test] - fn decorrelate_with_in_subquery_has_dependent_column() -> Result<()> { - let outer_table = test_table_scan_with_name("outer_table")?; - let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; - let sq_level1 = Arc::new( - LogicalPlanBuilder::from(inner_table_lv1) - .filter( - col("inner_table_lv1.a") - .eq(out_ref_col(ArrowDataType::UInt32, "outer_table.a")) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.a") - .gt(col("inner_table_lv1.c")), - ) - .and(col("inner_table_lv1.b").eq(lit(1))) - .and( - out_ref_col(ArrowDataType::UInt32, "outer_table.b") - .eq(col("inner_table_lv1.b")), - ), - )? - .project(vec![out_ref_col(ArrowDataType::UInt32, "outer_table.b")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(outer_table.clone()) - .filter( - col("outer_table.a") - .gt(lit(1)) - .and(in_subquery(col("outer_table.c"), sq_level1)), - )? - .build()?; - let dec = Decorrelation::new(); - let ctx: Box = Box::new(OptimizerContext::new()); - let plan = dec.rewrite(plan, ctx.as_ref())?.data; - assert_decorrelate!(plan, @r" - Projection: outer_table.a, outer_table.b, outer_table.c [a:UInt32, b:UInt32, c:UInt32] - Filter: outer_table.a > Int32(1) AND __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] - Projection: outer_table.a, outer_table.b, outer_table.c, delim_scan_1.mark AS __in_sq_1.output [a:UInt32, b:UInt32, c:UInt32, __in_sq_1.output:Boolean] - LeftMark Join: Filter: outer_table.c = outer_ref(outer_table.b) AND outer_table.a IS NOT DISTINCT FROM delim_scan_1.a AND outer_table.b IS NOT DISTINCT FROM delim_scan_1.b [a:UInt32, b:UInt32, c:UInt32, mark:Boolean] - TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Projection: delim_scan_1.b, delim_scan_1.a, delim_scan_1.b [outer_ref(outer_table.b):UInt32;N, a:UInt32;N, b:UInt32;N] - Filter: inner_table_lv1.a = delim_scan_1.a AND delim_scan_1.a > inner_table_lv1.c AND inner_table_lv1.b = Int32(1) AND delim_scan_1.b = inner_table_lv1.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] - Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N] - TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] - SubqueryAlias: delim_scan_1 [a:UInt32;N, b:UInt32;N] - DelimGet: outer_table.a, outer_table.b [a:UInt32;N, b:UInt32;N] - "); - - Ok(()) - } - // from duckdb test: https://github.com/duckdb/duckdb/blob/main/test/sql/subquery/any_all/test_correlated_any_all.test #[test] fn test_correlated_any_all_1() -> Result<()> { From f0c9f0b66c5ff33000a0d5b73acc07612f6038cb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 08:22:59 +0200 Subject: [PATCH 57/70] chore: minor import format --- datafusion/expr/src/logical_plan/builder.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index a7c36a0f87b0..14f9fb122079 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -31,10 +31,10 @@ use crate::expr_rewriter::{ rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, - Window, + Aggregate, Analyze, DependentJoin, Distinct, DistinctOn, EmptyRelation, Explain, + Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, + Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, + Values, Window, }; use crate::select_expr::SelectExpr; use crate::utils::{ @@ -49,7 +49,6 @@ use crate::{ use super::dml::InsertOp; use super::plan::{ColumnUnnestList, ExplainFormat}; -use super::DependentJoin; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; From e964d6ec3d8c5d4377dccab47293493629156a88 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 08:56:11 +0200 Subject: [PATCH 58/70] chore: clippy --- datafusion/expr/src/logical_plan/tree_node.rs | 8 ++++---- datafusion/optimizer/src/rewrite_dependent_join.rs | 8 +++++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 11d775953d8b..aa2f4cc7646e 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -430,10 +430,10 @@ impl LogicalPlan { .iter() .map(|(_, c, _)| Expr::Column(c.clone())) .collect::>(); - let maybe_lateral_join_condition = match lateral_join_condition { - Some((_, condition)) => Some(condition.clone()), - None => None, - }; + let maybe_lateral_join_condition = lateral_join_condition + .as_ref() + .map(|(_, condition)| condition.clone()); + (&correlated_column_exprs, &maybe_lateral_join_condition) .apply_ref_elements(f) } diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index c95ee2d45c2a..3285589ed48e 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -911,9 +911,15 @@ impl TreeNodeRewriter for DependentJoinRewriter { #[derive(Debug)] pub struct RewriteDependentJoin {} +impl Default for RewriteDependentJoin { + fn default() -> Self { + Self::new() + } +} + impl RewriteDependentJoin { pub fn new() -> Self { - return RewriteDependentJoin {}; + RewriteDependentJoin {} } } From 2eb723eff59816d228317e651561961b64549147 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 09:16:11 +0200 Subject: [PATCH 59/70] fix: err msg --- datafusion/sqllogictest/test_files/subquery.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 796570633f67..df82ba1591d0 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -439,7 +439,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate, Join and Dependent Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery From b8a8de80d501fa7ae520c0a93c2de2c215306feb Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 16:25:14 +0200 Subject: [PATCH 60/70] test: some more test cases --- .../optimizer/src/rewrite_dependent_join.rs | 115 +++++++++++++++--- 1 file changed, 101 insertions(+), 14 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 3285589ed48e..e65cc6405d41 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -958,7 +958,7 @@ mod tests { use super::DependentJoinRewriter; use crate::test::{test_table_scan_with_name, test_table_with_columns}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, @@ -984,9 +984,52 @@ mod tests { ) }}; } + #[test] + fn lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let lateral_join_rhs = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter( + col("inner_table_lv1.c") + .eq(out_ref_col(DataType::UInt32, "outer_table.c")), + )? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: lateral_join_rhs, + outer_ref_columns: vec![out_ref_col( + DataType::UInt32, + "outer_table.c", + )], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) + // TableScan: inner_table_lv1 + + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } #[test] - fn rewrite_dependent_join_with_nested_lateral_join() -> Result<()> { + fn scalar_subquery_nested_inside_a_lateral_join() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1061,7 +1104,7 @@ mod tests { } #[test] - fn rewrite_dependent_join_with_lhs_as_a_join() -> Result<()> { + fn join_logical_plan_with_subquery_in_filter_expr() -> Result<()> { let outer_left_table = test_table_scan_with_name("outer_right_table")?; let outer_right_table = test_table_scan_with_name("outer_left_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1115,11 +1158,11 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_in_from_expr() -> Result<()> { + fn subquery_in_from_expr() -> Result<()> { Ok(()) } #[test] - fn rewrite_dependent_join_inside_project_exprs() -> Result<()> { + fn nested_subquery_in_projection_expr() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1217,7 +1260,7 @@ mod tests { } #[test] - fn rewrite_dependent_join_two_nested_subqueries() -> Result<()> { + fn nested_subquery_in_filter() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; @@ -1283,7 +1326,7 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_two_subqueries_at_the_same_level() -> Result<()> { + fn two_subqueries_in_the_same_filter_expr() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let in_sq_level1 = Arc::new( @@ -1335,7 +1378,7 @@ mod tests { } #[test] - fn rewrite_dependent_join_in_subquery_with_count_depth_1() -> Result<()> { + fn in_subquery_with_count_of_1_depth() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1387,7 +1430,7 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_exist_subquery_with_dependent_columns() -> Result<()> { + fn correlated_exist_subquery() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1436,8 +1479,7 @@ mod tests { } #[test] - fn rewrite_dependent_join_with_exist_subquery_with_no_dependent_columns() -> Result<()> - { + fn uncorrelated_exist_subquery() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1471,7 +1513,7 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_with_in_subquery_no_dependent_column() -> Result<()> { + fn uncorrelated_in_subquery() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1509,7 +1551,7 @@ mod tests { Ok(()) } #[test] - fn rewrite_dependent_join_with_in_subquery_has_dependent_column() -> Result<()> { + fn correlated_in_subquery() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1561,7 +1603,7 @@ mod tests { } #[test] - fn rewrite_dependent_join_reference_outer_column_with_alias_name() -> Result<()> { + fn correlated_subquery_with_alias() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; let sq_level1 = Arc::new( @@ -1898,4 +1940,49 @@ mod tests { Ok(()) } + + #[test] + // https://github.com/duckdb/duckdb/blob/4d7cb701cabd646d8232a9933dd058a089ea7348/test/sql/subquery/any_all/subquery_in.test + fn correlated_scalar_subquery_returning_more_than_1_row() -> Result<()> { + // SELECT (FALSE) IN (TRUE, (SELECT TIME '13:35:07' FROM t1) BETWEEN t0.c0 AND t0.c0) FROM t0; + let t0 = test_table_with_columns( + "t0", + &[ + ("c0", DataType::Time64(TimeUnit::Second)), + ("c1", DataType::Float64), + ], + )?; + let t1 = test_table_with_columns("t1", &[("c0", DataType::Int32)])?; + let t1_subquery = Arc::new( + LogicalPlanBuilder::from(t1) + .project(vec![lit("13:35:07")])? + .build()?, + ); + let plan = LogicalPlanBuilder::from(t0) + .project(vec![lit(false).in_list( + vec![ + lit(true), + scalar_subquery(t1_subquery).between(col("t0.c0"), col("t0.c0")), + ], + false, + )])? + .build()?; + // Projection: Boolean(false) IN ([Boolean(true), () BETWEEN t0.c0 AND t0.c0]) + // Subquery: + // Projection: Utf8("13:35:07") + // TableScan: t1 + // TableScan: t0 + assert_dependent_join_rewrite!( + plan, + @r#" + Projection: Boolean(false) IN ([Boolean(true), __scalar_sq_1.output BETWEEN t0.c0 AND t0.c0]) [Boolean(false) IN Boolean(true), __scalar_sq_1.output BETWEEN t0.c0 AND t0.c0:Boolean] + DependentJoin on [] with expr () depth 1 [c0:Time64(Second), c1:Float64, output:Utf8] + TableScan: t0 [c0:Time64(Second), c1:Float64] + Projection: Utf8("13:35:07") [Utf8("13:35:07"):Utf8] + TableScan: t1 [c0:Int32] + "# + ); + + Ok(()) + } } From a3d0b650cd0c2756cd84c3a4a9be520ad48ee313 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 17:43:15 +0200 Subject: [PATCH 61/70] refactor: shared rewrite function --- .../optimizer/src/rewrite_dependent_join.rs | 336 +++++++----------- 1 file changed, 127 insertions(+), 209 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index e65cc6405d41..4214817a9380 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -27,7 +27,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{internal_err, Column, HashMap, Result}; +use datafusion_common::{internal_datafusion_err, internal_err, Column, HashMap, Result}; use datafusion_expr::{ col, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, }; @@ -62,14 +62,15 @@ struct ColumnAccess { } impl DependentJoinRewriter { - fn rewrite_filter( - &mut self, - filter: &Filter, + // this function is to rewrite logical plan having arbitrary exprs that contain + // subquery expr into dependent join logical plan + fn rewrite_exprs_into_dependent_join_plan( + exprs: Vec>, dependent_join_node: &Node, current_subquery_depth: usize, mut current_plan: LogicalPlanBuilder, subquery_alias_by_offset: HashMap, - ) -> Result { + ) -> Result<(LogicalPlanBuilder, Vec>)> { // everytime we meet a subquery during traversal, we increment this by 1 // we can use this offset to lookup the original subquery info // in subquery_alias_by_offset @@ -79,52 +80,55 @@ impl DependentJoinRewriter { let mut offset = 0; let offset_ref = &mut offset; let mut subquery_expr_by_offset = HashMap::new(); - let new_predicate = filter - .predicate - .clone() - .transform(|e| { - // replace any subquery expr with subquery_alias.output - // column - let alias = match e { - Expr::InSubquery(_) | Expr::Exists(_) | Expr::ScalarSubquery(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - _ => return Ok(Transformed::no(e)), - }; - // we are aware that the original subquery can be rewritten - // update the latest expr to this map - subquery_expr_by_offset.insert(*offset_ref, e); - *offset_ref += 1; - - // TODO: this assume that after decorrelation - // the dependent join will provide an extra column with the structure - // of "subquery_alias.output" - // On later step of decorrelation, it rely on this structure - // to again rename the expression after join - // for example if the real join type is LeftMark, the correct output - // column should be "mark" instead, else after the join - // one extra layer of projection is needed to alias "mark" into - // "alias.output" - Ok(Transformed::yes(col(format!("{alias}.output")))) - })? - .data; - // because dependent join may introduce extra columns - // to evaluate the subquery, the final plan should - // has another projection to remove these redundant columns - let post_join_projections: Vec = filter - .input - .schema() - .columns() - .iter() - .map(|c| col(c.clone())) - .collect(); + let mut rewritten_exprs_groups = vec![]; + for expr_group in exprs { + let rewritten_exprs = expr_group + .iter() + .cloned() + .map(|e| { + Ok(e.clone() + .transform(|e| { + // replace any subquery expr with subquery_alias.output column + let alias = match e { + Expr::InSubquery(_) + | Expr::Exists(_) + | Expr::ScalarSubquery(_) => subquery_alias_by_offset + .get(offset_ref) + .ok_or(internal_datafusion_err!( + "subquery alias not found at offset {}", + *offset_ref + )), + _ => return Ok(Transformed::no(e)), + }?; + + // We are aware that the original subquery can be rewritten update the + // latest expr to this map. + subquery_expr_by_offset.insert(*offset_ref, e); + *offset_ref += 1; + + Ok(Transformed::yes(col(format!("{alias}.output")))) + })? + .data) + }) + .collect::>>()?; + rewritten_exprs_groups.push(rewritten_exprs); + } + for (subquery_offset, (_, column_accesses)) in dependent_join_node .columns_accesses_by_subquery_id .iter() .enumerate() { - let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); - let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); + let alias = subquery_alias_by_offset.get(&subquery_offset).ok_or( + internal_datafusion_err!( + "subquery alias not found at offset {subquery_offset}" + ), + )?; + let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).ok_or( + internal_datafusion_err!( + "subquery expr not found at offset {subquery_offset}" + ), + )?; let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); @@ -140,96 +144,75 @@ impl DependentJoinRewriter { Some(subquery_expr.clone()), current_subquery_depth, alias.clone(), - None, // TODO: handle this when we support lateral join rewrite + None, )?; } - current_plan - .filter(new_predicate.clone())? - .project(post_join_projections) + Ok((current_plan, rewritten_exprs_groups)) } - fn rewrite_projection( + fn rewrite_filter( &mut self, - original_proj: &Projection, + filter: &Filter, dependent_join_node: &Node, current_subquery_depth: usize, - mut current_plan: LogicalPlanBuilder, + current_plan: LogicalPlanBuilder, subquery_alias_by_offset: HashMap, ) -> Result { - // everytime we meet a subquery during traversal, we increment this by 1 - // we can use this offset to lookup the original subquery info - // in subquery_alias_by_offset - // the reason why we cannot create a hashmap keyed by Subquery object HashMap - // is that the subquery inside this filter expr may have been rewritten in - // the lower level - let mut offset = 0; - let offset_ref = &mut offset; - let mut subquery_expr_by_offset = HashMap::new(); - // for each projected expr, we convert the SubqueryExpr into a ColExpr - // with structure "{subquery_alias}.output" - let new_projections = original_proj - .expr - .iter() - .cloned() - .map(|e| { - Ok(e.transform(|e| { - // replace any subquery expr with subquery_alias.output - // column - let alias = match e { - Expr::InSubquery(_) - | Expr::Exists(_) - | Expr::ScalarSubquery(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - _ => return Ok(Transformed::no(e)), - }; - // we are aware that the original subquery can be rewritten - // update the latest expr to this map - subquery_expr_by_offset.insert(*offset_ref, e); - *offset_ref += 1; - - // TODO: this assume that after decorrelation - // the dependent join will provide an extra column with the structure - // of "subquery_alias.output" - // On later step of decorrelation, it rely on this structure - // to again rename the expression after join - // for example if the real join type is LeftMark, the correct output - // column should be "mark" instead, else after the join - // one extra layer of projection is needed to alias "mark" into - // "alias.output" - Ok(Transformed::yes(col(format!("{alias}.output")))) - })? - .data) - }) - .collect::>>()?; - - for (subquery_offset, (_, column_accesses)) in dependent_join_node - .columns_accesses_by_subquery_id + // because dependent join may introduce extra columns + // to evaluate the subquery, the final plan should + // has another projection to remove these redundant columns + let post_join_projections: Vec = filter + .input + .schema() + .columns() .iter() - .enumerate() - { - let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); - let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); - - let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); + .map(|c| col(c.clone())) + .collect(); + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![vec![&filter.predicate]], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; - let correlated_columns = column_accesses - .iter() - .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) - .unique() - .collect(); + let transformed_predicate = transformed_exprs + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))? + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))?; + + transformed_plan + .filter(transformed_predicate.clone())? + .project(post_join_projections) + } - current_plan = current_plan.dependent_join( - subquery_input.deref().clone(), - correlated_columns, - Some(subquery_expr.clone()), + fn rewrite_projection( + &mut self, + original_proj: &Projection, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![original_proj.expr.iter().collect::>()], + dependent_join_node, current_subquery_depth, - alias.clone(), - None, // TODO: handle this when we support lateral join rewrite + current_plan, + subquery_alias_by_offset, )?; - } - current_plan = current_plan.project(new_projections)?; - Ok(current_plan) + let transformed_proj_exprs = + transformed_exprs.first().ok_or(internal_datafusion_err!( + "transform projection expr does not return 1 element" + ))?; + transformed_plan.project(transformed_proj_exprs.clone()) } fn rewrite_aggregate( @@ -237,93 +220,9 @@ impl DependentJoinRewriter { aggregate: &Aggregate, dependent_join_node: &Node, current_subquery_depth: usize, - mut current_plan: LogicalPlanBuilder, + current_plan: LogicalPlanBuilder, subquery_alias_by_offset: HashMap, ) -> Result { - let mut offset = 0; - let offset_ref = &mut offset; - let mut subquery_expr_by_offset = HashMap::new(); - let new_group_expr = aggregate - .group_expr - .iter() - .cloned() - .map(|e| { - Ok(e.transform(|e| { - // replace any subquery expr with subquery_alias.output column - let alias = match e { - Expr::InSubquery(_) - | Expr::Exists(_) - | Expr::ScalarSubquery(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - _ => return Ok(Transformed::no(e)), - }; - - // We are aware that the original subquery can be rewritten update the - // latest expr to this map. - subquery_expr_by_offset.insert(*offset_ref, e); - *offset_ref += 1; - - Ok(Transformed::yes(col(format!("{alias}.output")))) - })? - .data) - }) - .collect::>>()?; - - let new_agg_expr = aggregate - .aggr_expr - .clone() - .iter() - .cloned() - .map(|e| { - Ok(e.transform(|e| { - // replace any subquery expr with subquery_alias.output column - let alias = match e { - Expr::InSubquery(_) - | Expr::Exists(_) - | Expr::ScalarSubquery(_) => { - subquery_alias_by_offset.get(offset_ref).unwrap() - } - _ => return Ok(Transformed::no(e)), - }; - - // We are aware that the original subquery can be rewritten update the - // latest expr to this map. - subquery_expr_by_offset.insert(*offset_ref, e); - *offset_ref += 1; - - Ok(Transformed::yes(col(format!("{alias}.output")))) - })? - .data) - }) - .collect::>>()?; - - for (subquery_offset, (_, column_accesses)) in dependent_join_node - .columns_accesses_by_subquery_id - .iter() - .enumerate() - { - let alias = subquery_alias_by_offset.get(&subquery_offset).unwrap(); - let subquery_expr = subquery_expr_by_offset.get(&subquery_offset).unwrap(); - - let subquery_input = unwrap_subquery_input_from_expr(subquery_expr); - - let correlated_columns = column_accesses - .iter() - .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) - .unique() - .collect(); - - current_plan = current_plan.dependent_join( - subquery_input.deref().clone(), - correlated_columns, - Some(subquery_expr.clone()), - current_subquery_depth, - alias.clone(), - None, // TODO: handle this when we support lateral join rewrite - )?; - } - // because dependent join may introduce extra columns // to evaluate the subquery, the final plan should // has another projection to remove these redundant columns @@ -334,8 +233,27 @@ impl DependentJoinRewriter { .map(|c| col(c.clone())) .collect(); - current_plan - .aggregate(new_group_expr.clone(), new_agg_expr.clone())? + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![ + aggregate.group_expr.iter().collect::>(), + aggregate.aggr_expr.iter().collect::>(), + ], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + let (new_group_exprs, new_aggr_exprs) = match transformed_exprs.as_slice() { + [first, second] => (first, second), + _ => { + return internal_err!( + "transform group and aggr exprs does not return vector of 2 Vec") + } + }; + + transformed_plan + .aggregate(new_group_exprs.clone(), new_aggr_exprs.clone())? .project(post_join_projections) } From 8e858b4292b3418032fb40ecf2b0faf0423df76c Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 20:07:07 +0200 Subject: [PATCH 62/70] refactor: remove all unwrap --- .../optimizer/src/rewrite_dependent_join.rs | 107 +++++++++++++++--- 1 file changed, 89 insertions(+), 18 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 4214817a9380..ebe2965165fd 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -306,7 +306,7 @@ impl DependentJoinRewriter { &mut self, child_id: usize, col: &Column, - ) { + ) -> Result<()> { if let Some(accesses) = self.all_outer_ref_columns.get(col) { for access in accesses.iter() { let mut cur_stack = self.stack.clone(); @@ -314,7 +314,11 @@ impl DependentJoinRewriter { cur_stack.push(child_id); let (dependent_join_node_id, subquery_node_id) = Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); - let node = self.nodes.get_mut(&dependent_join_node_id).unwrap(); + let node = self.nodes.get_mut(&dependent_join_node_id).ok_or( + internal_datafusion_err!( + "dependent join node with id {dependent_join_node_id} not found" + ), + )?; let accesses = node .columns_accesses_by_subquery_id .entry(subquery_node_id) @@ -328,6 +332,7 @@ impl DependentJoinRewriter { }); } } + Ok(()) } fn mark_outer_column_access( @@ -515,19 +520,22 @@ impl TreeNodeRewriter for DependentJoinRewriter { // TODO: maybe there are more logical plan that provides columns // aside from TableScan LogicalPlan::TableScan(tbl_scan) => { - tbl_scan.projected_schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node_if_any(new_id, col); - }); + tbl_scan + .projected_schema + .columns() + .iter() + .try_for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col) + })?; } // Similar to TableScan, this node may provide column names which // is referenced inside some subqueries LogicalPlan::SubqueryAlias(alias) => { - alias.schema.columns().iter().for_each(|col| { - self.conclude_lowest_dependent_join_node_if_any(new_id, col); - }); + alias.schema.columns().iter().try_for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col) + })?; } LogicalPlan::Unnest(_unnest) => {} - // TODO: this is untested LogicalPlan::Projection(proj) => { for expr in &proj.expr { if contains_subquery(expr) { @@ -542,8 +550,14 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } LogicalPlan::Subquery(subquery) => { - let parent = self.stack.last().unwrap(); - let parent_node = self.nodes.get_mut(parent).unwrap(); + let parent = self.stack.last().ok_or(internal_datafusion_err!( + "subquery node cannot be at the beginning of the query plan" + ))?; + + let parent_node = self + .nodes + .get_mut(parent) + .ok_or(internal_datafusion_err!("node {parent} not found"))?; // the inserting sequence matter here // when a parent has multiple children subquery at the same time // we rely on the order in which subquery children are visited @@ -722,7 +736,9 @@ impl TreeNodeRewriter for DependentJoinRewriter { // if the node in the f_up meet any node in the stack, it means that node itself // is a dependent join node,transformation by // build a join based on - let current_node_id = self.stack.pop().unwrap(); + let current_node_id = self.stack.pop().ok_or(internal_datafusion_err!( + "stack cannot be empty during upward traversal" + ))?; let node_info = if let Entry::Occupied(e) = self.nodes.entry(current_node_id) { let node_info = e.get(); if !node_info.is_dependent_join_node { @@ -736,13 +752,20 @@ impl TreeNodeRewriter for DependentJoinRewriter { let current_subquery_depth = self.subquery_depth; self.subquery_depth -= 1; - let cloned_input = (**node.inputs().first().unwrap()).clone(); + let cloned_input = (**node.inputs().first().ok_or(internal_datafusion_err!( + "logical plan {} does not have any input", + node + ))?) + .clone(); let mut current_plan = LogicalPlanBuilder::new(cloned_input); let mut subquery_alias_by_offset = HashMap::new(); for (subquery_offset, (subquery_id, _)) in node_info.columns_accesses_by_subquery_id.iter().enumerate() { - let subquery_node = self.nodes.get(subquery_id).unwrap(); + let subquery_node = self + .nodes + .get(subquery_id) + .ok_or(internal_datafusion_err!("node {subquery_id} not found"))?; let alias = self .alias_generator .next(&subquery_node.subquery_type.prefix()); @@ -769,10 +792,20 @@ impl TreeNodeRewriter for DependentJoinRewriter { )?; } LogicalPlan::Join(join) => { + // this is lateral join assert!(node_info.columns_accesses_by_subquery_id.len() == 1); - let (_, column_accesses) = - node_info.columns_accesses_by_subquery_id.first().unwrap(); - let alias = subquery_alias_by_offset.get(&0).unwrap(); + let (_, column_accesses) = node_info + .columns_accesses_by_subquery_id + .first() + .ok_or(internal_datafusion_err!( + "a lateral join should always have one child subquery" + ))?; + let alias = + subquery_alias_by_offset + .get(&0) + .ok_or(internal_datafusion_err!( + "cannot find subquery alias for only-child of lateral join" + ))?; let correlated_columns = column_accesses .iter() .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) @@ -902,8 +935,46 @@ mod tests { ) }}; } + + #[test] + fn uncorrelated_lateral_join() -> Result<()> { + let outer_table = test_table_scan_with_name("outer_table")?; + let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; + + let lateral_join_rhs = Arc::new( + LogicalPlanBuilder::from(inner_table_lv1.clone()) + .filter(col("inner_table_lv1.c").eq(lit(1)))? + .build()?, + ); + + let plan = LogicalPlanBuilder::from(outer_table.clone()) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: lateral_join_rhs, + outer_ref_columns: vec![], + spans: Spans::new(), + }), + JoinType::Inner, + vec![lit(true)], + )? + .build()?; + + // Inner Join: Filter: Boolean(true) + // TableScan: outer_table + // Subquery: + // Filter: inner_table_lv1.c = outer_ref(outer_table.c) + // TableScan: inner_table_lv1 + + assert_dependent_join_rewrite!(plan, @r" + DependentJoin on [outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = outer_ref(outer_table.c) [a:UInt32, b:UInt32, c:UInt32] + TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] + "); + Ok(()) + } #[test] - fn lateral_join() -> Result<()> { + fn correlated_lateral_join() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; let inner_table_lv1 = test_table_scan_with_name("inner_table_lv1")?; From 30300d1440f9cf4fbaf28db77bc14ed2ba9e710e Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sat, 7 Jun 2025 20:20:26 +0200 Subject: [PATCH 63/70] fix: test expectation --- datafusion/optimizer/src/rewrite_dependent_join.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index ebe2965165fd..66029c9ff158 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -966,9 +966,9 @@ mod tests { // TableScan: inner_table_lv1 assert_dependent_join_rewrite!(plan, @r" - DependentJoin on [outer_table.c lvl 1] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] + DependentJoin on [] lateral Inner join with Boolean(true) depth 1 [a:UInt32, b:UInt32, c:UInt32] TableScan: outer_table [a:UInt32, b:UInt32, c:UInt32] - Filter: inner_table_lv1.c = outer_ref(outer_table.c) [a:UInt32, b:UInt32, c:UInt32] + Filter: inner_table_lv1.c = Int32(1) [a:UInt32, b:UInt32, c:UInt32] TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32] "); Ok(()) From a93f9010195a79e33212f4abdf5e37987218d79e Mon Sep 17 00:00:00 2001 From: irenjj Date: Sun, 8 Jun 2025 14:13:11 +0800 Subject: [PATCH 64/70] fix subquery in join filter --- .../optimizer/src/rewrite_dependent_join.rs | 181 +++++++++++++++--- 1 file changed, 155 insertions(+), 26 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 66029c9ff158..6052e403d3f2 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{internal_datafusion_err, internal_err, Column, HashMap, Result}; use datafusion_expr::{ - col, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + col, lit, Aggregate, Expr, Filter, Join, LogicalPlan, LogicalPlanBuilder, Projection, }; use indexmap::map::Entry; @@ -257,6 +257,42 @@ impl DependentJoinRewriter { .project(post_join_projections) } + fn rewrite_join( + &mut self, + join: &Join, + dependent_join_node: &Node, + current_subquery_depth: usize, + current_plan: LogicalPlanBuilder, + subquery_alias_by_offset: HashMap, + ) -> Result { + let filter = if let Some(filter) = &join.filter { + filter.clone() + } else { + return internal_err!("Join filter should not be empty"); + }; + + let (transformed_plan, transformed_exprs) = + Self::rewrite_exprs_into_dependent_join_plan( + vec![vec![&filter]], + dependent_join_node, + current_subquery_depth, + current_plan, + subquery_alias_by_offset, + )?; + + let transformed_predicate = transformed_exprs + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))? + .first() + .ok_or(internal_datafusion_err!( + "transform predicate does not return 1 element" + ))?; + + transformed_plan.filter(transformed_predicate.clone()) + } + // lowest common ancestor from stack // given a tree of // n1 @@ -633,6 +669,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } LogicalPlan::Join(join) => { + let mut is_has_correlated_subquery = false; let mut sq_count = if let LogicalPlan::Subquery(_) = &join.left.as_ref() { 1 } else { @@ -646,7 +683,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { match sq_count { 0 => {} 1 => { - is_dependent_join_node = true; + is_has_correlated_subquery = true; } _ => { return internal_err!( @@ -656,14 +693,14 @@ impl TreeNodeRewriter for DependentJoinRewriter { } }; - if is_dependent_join_node { + if is_has_correlated_subquery { self.subquery_depth += 1; self.stack.push(new_id); self.nodes.insert( new_id, Node { plan: node.clone(), - is_dependent_join_node, + is_dependent_join_node: is_has_correlated_subquery, columns_accesses_by_subquery_id: IndexMap::new(), subquery_type, }, @@ -693,6 +730,20 @@ impl TreeNodeRewriter for DependentJoinRewriter { TreeNodeRecursion::Jump, )); } + + // If expr has correlated subquery. + if let Some(filter) = &join.filter { + if contains_subquery(filter) { + is_dependent_join_node = true; + } + + filter.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } } LogicalPlan::Sort(sort) => { for expr in &sort.expr { @@ -813,29 +864,35 @@ impl TreeNodeRewriter for DependentJoinRewriter { .collect(); let subquery_plan = &join.right; - let sq = if let LogicalPlan::Subquery(sq) = subquery_plan.as_ref() { - sq + if let LogicalPlan::Subquery(sq) = subquery_plan.as_ref() { + let right = sq.subquery.deref().clone(); + // At the time of implementation lateral join condition is not fully clear yet + // So a TODO for future tracking + let lateral_join_condition = if let Some(ref filter) = join.filter { + filter.clone() + } else { + lit(true) + }; + current_plan = current_plan.dependent_join( + right, + correlated_columns, + None, + current_subquery_depth, + alias.to_string(), + Some((join.join_type, lateral_join_condition)), + )?; } else { - return internal_err!( - "lateral join must have right join as a subquery" - ); + // Correlated subquery in join filter. + let mut cross_join = join.clone(); + cross_join.filter = None; + current_plan = self.rewrite_join( + join, + &node_info, + current_subquery_depth, + LogicalPlanBuilder::new(LogicalPlan::Join(cross_join)), + subquery_alias_by_offset, + )?; }; - let right = sq.subquery.deref().clone(); - // At the time of implementation lateral join condition is not fully clear yet - // So a TODO for future tracking - let lateral_join_condition = if let Some(ref filter) = join.filter { - filter.clone() - } else { - lit(true) - }; - current_plan = current_plan.dependent_join( - right, - correlated_columns, - None, - current_subquery_depth, - alias.to_string(), - Some((join.join_type, lateral_join_condition)), - )?; } LogicalPlan::Aggregate(aggregate) => { current_plan = self.rewrite_aggregate( @@ -912,7 +969,7 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{alias::AliasGenerator, Result, Spans}; use datafusion_expr::{ - binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, + and, binary_expr, exists, expr::InSubquery, expr_fn::col, in_subquery, lit, out_ref_col, scalar_subquery, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, SortExpr, Subquery, }; @@ -973,6 +1030,7 @@ mod tests { "); Ok(()) } + #[test] fn correlated_lateral_join() -> Result<()> { let outer_table = test_table_scan_with_name("outer_table")?; @@ -1974,4 +2032,75 @@ mod tests { Ok(()) } + + #[test] + fn test_correlated_subquery_in_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + + // Build join condition: t2.key = t1.key AND t2.val > scalar_sq AND EXISTS(exists_sq) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ); + + let plan = LogicalPlanBuilder::from(t1) + .join_on(t2, JoinType::Inner, vec![join_condition])? + .build()?; + + // println!("{}", &plan.display_indent()); + // Inner Join: Filter: t2.key = t1.key AND t2.val > () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + + assert_dependent_join_rewrite!( + plan, + @r#" + Filter: t2.key = t1.key AND t2.val > __lateral_sq_1.output [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] + DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] + Cross Join: [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] + TableScan: t1 [key:Int32, id:Int32, val:Int32] + TableScan: t2 [key:Int32, val:Int32] + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + "# + ); + + Ok(()) + } } From 4aed14f10413a727699094d3299b5e2dbef79a2b Mon Sep 17 00:00:00 2001 From: irenjj Date: Sun, 8 Jun 2025 14:19:36 +0800 Subject: [PATCH 65/70] rename --- datafusion/optimizer/src/rewrite_dependent_join.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 6052e403d3f2..29bfd3e179b3 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -669,7 +669,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } LogicalPlan::Join(join) => { - let mut is_has_correlated_subquery = false; + let mut is_child_subquery = false; let mut sq_count = if let LogicalPlan::Subquery(_) = &join.left.as_ref() { 1 } else { @@ -683,7 +683,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { match sq_count { 0 => {} 1 => { - is_has_correlated_subquery = true; + is_child_subquery = true; } _ => { return internal_err!( @@ -693,14 +693,14 @@ impl TreeNodeRewriter for DependentJoinRewriter { } }; - if is_has_correlated_subquery { + if is_child_subquery { self.subquery_depth += 1; self.stack.push(new_id); self.nodes.insert( new_id, Node { plan: node.clone(), - is_dependent_join_node: is_has_correlated_subquery, + is_dependent_join_node: is_child_subquery, columns_accesses_by_subquery_id: IndexMap::new(), subquery_type, }, From 6f2ce78d0c14dad3dc4aad5196ae5679b70bc281 Mon Sep 17 00:00:00 2001 From: irenjj Date: Sun, 8 Jun 2025 14:20:59 +0800 Subject: [PATCH 66/70] add todo --- datafusion/optimizer/src/rewrite_dependent_join.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 29bfd3e179b3..9a8b0d748d3a 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -732,6 +732,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { } // If expr has correlated subquery. + // TODO: what if both child and expr has subquery? if let Some(filter) = &join.filter { if contains_subquery(filter) { is_dependent_join_node = true; From 5be430a33f414b86654d60f69dcf1c291cd9d007 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Sun, 8 Jun 2025 20:44:48 +0200 Subject: [PATCH 67/70] chore: more constraint on correlated subquery in join filter --- .../optimizer/src/rewrite_dependent_join.rs | 377 +++++++++++++----- 1 file changed, 285 insertions(+), 92 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 9a8b0d748d3a..408eb72858cd 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -27,7 +27,9 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{internal_datafusion_err, internal_err, Column, HashMap, Result}; +use datafusion_common::{ + internal_datafusion_err, internal_err, not_impl_err, Column, HashMap, Result, +}; use datafusion_expr::{ col, lit, Aggregate, Expr, Filter, Join, LogicalPlan, LogicalPlanBuilder, Projection, }; @@ -160,7 +162,7 @@ impl DependentJoinRewriter { ) -> Result { // because dependent join may introduce extra columns // to evaluate the subquery, the final plan should - // has another projection to remove these redundant columns + // have another projection to remove these redundant columns let post_join_projections: Vec = filter .input .schema() @@ -225,7 +227,7 @@ impl DependentJoinRewriter { ) -> Result { // because dependent join may introduce extra columns // to evaluate the subquery, the final plan should - // has another projection to remove these redundant columns + // have another projection to remove these redundant columns let post_join_projections: Vec = aggregate .schema .columns() @@ -257,7 +259,7 @@ impl DependentJoinRewriter { .project(post_join_projections) } - fn rewrite_join( + fn rewrite_lateral_join( &mut self, join: &Join, dependent_join_node: &Node, @@ -265,18 +267,78 @@ impl DependentJoinRewriter { current_plan: LogicalPlanBuilder, subquery_alias_by_offset: HashMap, ) -> Result { - let filter = if let Some(filter) = &join.filter { + // this is lateral join + assert!(dependent_join_node.columns_accesses_by_subquery_id.len() == 1); + let (_, column_accesses) = dependent_join_node + .columns_accesses_by_subquery_id + .first() + .ok_or(internal_datafusion_err!( + "a lateral join should always have one child subquery" + ))?; + let alias = subquery_alias_by_offset + .get(&0) + .ok_or(internal_datafusion_err!( + "cannot find subquery alias for only-child of lateral join" + ))?; + let correlated_columns = column_accesses + .iter() + .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) + .unique() + .collect(); + + let sq = if let LogicalPlan::Subquery(sq) = join.right.as_ref() { + sq + } else { + return internal_err!("right side of a lateral join is not a subquery"); + }; + let right = sq.subquery.deref().clone(); + // At the time of implementation lateral join condition is not fully clear yet + // So a TODO for future tracking + let lateral_join_condition = if let Some(ref filter) = join.filter { filter.clone() } else { - return internal_err!("Join filter should not be empty"); + lit(true) + }; + current_plan.dependent_join( + right, + correlated_columns, + None, + current_subquery_depth, + alias.to_string(), + Some((join.join_type, lateral_join_condition)), + ) + } + + // TODO: it is sub-optimal that we completely remove all + // the filters (including the ones that have no subquery attached) + // from the original join + // We have to check if after decorrelation, the other optimizers + // that follows are capable of merging these filters back to the + // join node or not + fn rewrite_join( + &mut self, + join: &Join, + dependent_join_node: &Node, + current_subquery_depth: usize, + subquery_alias_by_offset: HashMap, + ) -> Result { + let mut new_join = join.clone(); + let filter = if let Some(ref filter) = join.filter { + filter + } else { + return internal_err!( + "rewriting a correlated join node without any filter condition" + ); }; + new_join.filter = None; + let (transformed_plan, transformed_exprs) = Self::rewrite_exprs_into_dependent_join_plan( - vec![vec![&filter]], + vec![vec![filter]], dependent_join_node, current_subquery_depth, - current_plan, + LogicalPlanBuilder::new(LogicalPlan::Join(new_join)), subquery_alias_by_offset, )?; @@ -424,6 +486,12 @@ struct Node { columns_accesses_by_subquery_id: IndexMap>, is_dependent_join_node: bool, + // a dependent join node with LogicalPlan::Join variation can have subquery children + // in two scenarios: + // - it is a lateral join + // - it is a normal join, but the join conditions contain subquery + // These two scenarios are mutually exclusive and we need to maintain a flag for this + is_lateral_join: bool, // note that for dependent join nodes, there can be more than 1 // subquery children at a time, but always 1 outer-column-providing-child @@ -603,7 +671,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { .columns_accesses_by_subquery_id .insert(new_id, vec![]); - if let LogicalPlan::Join(_) = parent_node.plan { + if parent_node.is_lateral_join { subquery_type = SubqueryType::LateralJoin; } else { for expr in parent_node.plan.expressions() { @@ -669,47 +737,37 @@ impl TreeNodeRewriter for DependentJoinRewriter { } } LogicalPlan::Join(join) => { - let mut is_child_subquery = false; - let mut sq_count = if let LogicalPlan::Subquery(_) = &join.left.as_ref() { - 1 - } else { - 0 - }; - sq_count += if let LogicalPlan::Subquery(_) = join.right.as_ref() { - 1 - } else { - 0 - }; - match sq_count { - 0 => {} - 1 => { - is_child_subquery = true; - } - _ => { - return internal_err!( - "plan error: join logical plan has both children with type \ - Subquery" - ); - } - }; + if let LogicalPlan::Subquery(_) = &join.left.as_ref() { + return internal_err!("left side of a join cannot be a subquery"); + } - if is_child_subquery { + // Handle the case lateral join + if let LogicalPlan::Subquery(_) = join.right.as_ref() { + if let Some(ref filter) = join.filter { + if contains_subquery(filter) { + return not_impl_err!( + "subquery inside lateral join condition is not supported" + ); + } + } self.subquery_depth += 1; self.stack.push(new_id); self.nodes.insert( new_id, Node { plan: node.clone(), - is_dependent_join_node: is_child_subquery, + is_dependent_join_node: true, columns_accesses_by_subquery_id: IndexMap::new(), subquery_type, + is_lateral_join: true, }, ); - // we assume that RHS is always a subquery for the join - // and because this function assume that subquery side is visited first - // during f_down, we have to visit it at this step, else - // the function visit_with_subqueries will call f_down for the LHS instead + // we assume that RHS is always a subquery for the lateral join + // and because this function assume that subquery side is always + // visited first during f_down, we have to explicitly swap the rewrite + // order at this step, else the function visit_with_subqueries will + // call f_down for the LHS instead let transformed_subquery = self .rewrite_subqueries_into_dependent_joins( join.right.deref().clone(), @@ -731,8 +789,6 @@ impl TreeNodeRewriter for DependentJoinRewriter { )); } - // If expr has correlated subquery. - // TODO: what if both child and expr has subquery? if let Some(filter) = &join.filter { if contains_subquery(filter) { is_dependent_join_node = true; @@ -774,6 +830,7 @@ impl TreeNodeRewriter for DependentJoinRewriter { is_dependent_join_node, columns_accesses_by_subquery_id: IndexMap::new(), subquery_type, + is_lateral_join: false, }, ); @@ -844,53 +901,20 @@ impl TreeNodeRewriter for DependentJoinRewriter { )?; } LogicalPlan::Join(join) => { - // this is lateral join - assert!(node_info.columns_accesses_by_subquery_id.len() == 1); - let (_, column_accesses) = node_info - .columns_accesses_by_subquery_id - .first() - .ok_or(internal_datafusion_err!( - "a lateral join should always have one child subquery" - ))?; - let alias = - subquery_alias_by_offset - .get(&0) - .ok_or(internal_datafusion_err!( - "cannot find subquery alias for only-child of lateral join" - ))?; - let correlated_columns = column_accesses - .iter() - .map(|ac| (ac.subquery_depth, ac.col.clone(), ac.data_type.clone())) - .unique() - .collect(); - - let subquery_plan = &join.right; - if let LogicalPlan::Subquery(sq) = subquery_plan.as_ref() { - let right = sq.subquery.deref().clone(); - // At the time of implementation lateral join condition is not fully clear yet - // So a TODO for future tracking - let lateral_join_condition = if let Some(ref filter) = join.filter { - filter.clone() - } else { - lit(true) - }; - current_plan = current_plan.dependent_join( - right, - correlated_columns, - None, + if node_info.is_lateral_join { + current_plan = self.rewrite_lateral_join( + join, + &node_info, current_subquery_depth, - alias.to_string(), - Some((join.join_type, lateral_join_condition)), - )?; + current_plan, + subquery_alias_by_offset, + )? } else { // Correlated subquery in join filter. - let mut cross_join = join.clone(); - cross_join.filter = None; current_plan = self.rewrite_join( join, &node_info, current_subquery_depth, - LogicalPlanBuilder::new(LogicalPlan::Join(cross_join)), subquery_alias_by_offset, )?; }; @@ -978,6 +1002,25 @@ mod tests { use insta::assert_snapshot; use std::sync::Arc; + macro_rules! assert_dependent_join_rewrite_err { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); + let transformed = index.rewrite_subqueries_into_dependent_joins($plan.clone()); + if let Err(err) = transformed{ + assert_snapshot!( + err, + @ $expected, + ) + } else{ + panic!("rewriting {} was not returning error",$plan) + } + + }}; + } + macro_rules! assert_dependent_join_rewrite { ( $plan:expr, @@ -2074,12 +2117,10 @@ mod tests { col("t2.key").eq(col("t1.key")), col("t2.val").gt(scalar_subquery(scalar_sq)), ); - let plan = LogicalPlanBuilder::from(t1) .join_on(t2, JoinType::Inner, vec![join_condition])? .build()?; - // println!("{}", &plan.display_indent()); // Inner Join: Filter: t2.key = t1.key AND t2.val > () // Subquery: // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] @@ -2090,16 +2131,168 @@ mod tests { assert_dependent_join_rewrite!( plan, - @r#" - Filter: t2.key = t1.key AND t2.val > __lateral_sq_1.output [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] - DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] - Cross Join: [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] - TableScan: t1 [key:Int32, id:Int32, val:Int32] - TableScan: t2 [key:Int32, val:Int32] - Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] - Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] - TableScan: t3 [id:Int32, val:Int32] - "# + @r" + Filter: t2.key = t1.key AND t2.val > __scalar_sq_1.output [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] + DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] + Cross Join: [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] + TableScan: t1 [key:Int32, id:Int32, val:Int32] + TableScan: t2 [key:Int32, val:Int32] + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + " + ); + + Ok(()) + } + + #[test] + fn test_correlated_subquery_in_lateral_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + + // Build join condition: t2.key = t1.key AND t2.val > scalar_sq AND EXISTS(exists_sq) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ); + + let plan = LogicalPlanBuilder::from(t1) + .join_on( + LogicalPlan::Subquery(Subquery { + subquery: t2.into(), + outer_ref_columns: vec![], + spans: Spans::new(), + }), + JoinType::Inner, + vec![join_condition], + )? + .build()?; + + // Inner Join: Filter: t2.key = t1.key AND t2.val > () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + assert_dependent_join_rewrite_err!( + plan, + @"This feature is not implemented: subquery inside lateral join condition is not supported" + ); + + Ok(()) + } + + #[test] + fn test_multiple_correlated_subqueries_in_join_filter() -> Result<()> { + // Test demonstrates traversal order issue with subquery in JOIN condition + // Query pattern: + // SELECT * FROM t1 + // JOIN t2 ON (t2.key = t1.key + // AND t2.val > (SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id)) + // OR exits ( + // SELECT * FROM T3 WHERE T3.ID = T2.KEY + // ); + + let t1 = test_table_with_columns( + "t1", + &[ + ("key", DataType::Int32), + ("id", DataType::Int32), + ("val", DataType::Int32), + ], + )?; + + let t2 = test_table_with_columns( + "t2", + &[("key", DataType::Int32), ("val", DataType::Int32)], + )?; + + let t3 = test_table_with_columns( + "t3", + &[("id", DataType::Int32), ("val", DataType::Int32)], + )?; + + // Subquery in join condition: SELECT COUNT(*) FROM t3 WHERE t3.id = t1.id + let scalar_sq = Arc::new( + LogicalPlanBuilder::from(t3.clone()) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t1.id")))? + .aggregate(Vec::::new(), vec![count(lit(1))])? + .build()?, + ); + let exists_sq = Arc::new( + LogicalPlanBuilder::from(t3) + .filter(col("t3.id").eq(out_ref_col(DataType::Int32, "t2.key")))? + .build()?, + ); + + // Build join condition: (t2.key = t1.key AND t2.val > scalar_sq) OR (exists(exists_sq)) + let join_condition = and( + col("t2.key").eq(col("t1.key")), + col("t2.val").gt(scalar_subquery(scalar_sq)), + ) + .or(exists(exists_sq)); + + let plan = LogicalPlanBuilder::from(t1) + .join_on(t2, JoinType::Inner, vec![join_condition])? + .build()?; + // Inner Join: Filter: t2.key = t1.key AND t2.val > () OR EXISTS () + // Subquery: + // Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] + // Filter: t3.id = outer_ref(t1.id) + // TableScan: t3 + // Subquery: + // Filter: t3.id = outer_ref(t2.key) + // TableScan: t3 + // TableScan: t1 + // TableScan: t2 + + assert_dependent_join_rewrite!( + plan, + @r" + Filter: t2.key = t1.key AND t2.val > __scalar_sq_1.output OR __exists_sq_2.output [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64, output:Boolean] + DependentJoin on [t2.key lvl 1] with expr EXISTS () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64, output:Boolean] + DependentJoin on [t1.id lvl 1] with expr () depth 1 [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32, output:Int64] + Cross Join: [key:Int32, id:Int32, val:Int32, key:Int32, val:Int32] + TableScan: t1 [key:Int32, id:Int32, val:Int32] + TableScan: t2 [key:Int32, val:Int32] + Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]] [count(Int32(1)):Int64] + Filter: t3.id = outer_ref(t1.id) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + Filter: t3.id = outer_ref(t2.key) [id:Int32, val:Int32] + TableScan: t3 [id:Int32, val:Int32] + " ); Ok(()) From dc656f020186fc6704564c3877e32f4598360fb2 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Mon, 9 Jun 2025 09:01:16 +0200 Subject: [PATCH 68/70] chore: try fix snapshot --- datafusion/optimizer/src/rewrite_dependent_join.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 408eb72858cd..80958639e909 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -2208,7 +2208,8 @@ mod tests { // TableScan: t2 assert_dependent_join_rewrite_err!( plan, - @"This feature is not implemented: subquery inside lateral join condition is not supported" + @"This feature is not implemented: subquery inside lateral join condition is not supported + " ); Ok(()) From f4c4ec0cbc28db925a5d2b3cd1bcc437bf475797 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 10 Jun 2025 20:08:50 +0200 Subject: [PATCH 69/70] chore: use normal assert --- datafusion/optimizer/src/rewrite_dependent_join.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 80958639e909..491b00e69f39 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -1010,10 +1010,7 @@ mod tests { let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); let transformed = index.rewrite_subqueries_into_dependent_joins($plan.clone()); if let Err(err) = transformed{ - assert_snapshot!( - err, - @ $expected, - ) + assert_eq!(err, @ $expected); } else{ panic!("rewriting {} was not returning error",$plan) } @@ -2208,8 +2205,7 @@ mod tests { // TableScan: t2 assert_dependent_join_rewrite_err!( plan, - @"This feature is not implemented: subquery inside lateral join condition is not supported - " + @"This feature is not implemented: subquery inside lateral join condition is not supported" ); Ok(()) From 0f5278ff9ea2632a1438b909e6da6d89e7cf5ff9 Mon Sep 17 00:00:00 2001 From: Duong Cong Toai Date: Tue, 10 Jun 2025 21:11:24 +0200 Subject: [PATCH 70/70] fix: correct err assert --- .../optimizer/src/rewrite_dependent_join.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/optimizer/src/rewrite_dependent_join.rs b/datafusion/optimizer/src/rewrite_dependent_join.rs index 491b00e69f39..d8df04aae049 100644 --- a/datafusion/optimizer/src/rewrite_dependent_join.rs +++ b/datafusion/optimizer/src/rewrite_dependent_join.rs @@ -1005,16 +1005,16 @@ mod tests { macro_rules! assert_dependent_join_rewrite_err { ( $plan:expr, - @ $expected:literal $(,)? + $expected:literal $(,)? ) => {{ let mut index = DependentJoinRewriter::new(Arc::new(AliasGenerator::new())); - let transformed = index.rewrite_subqueries_into_dependent_joins($plan.clone()); - if let Err(err) = transformed{ - assert_eq!(err, @ $expected); - } else{ - panic!("rewriting {} was not returning error",$plan) + let transformed = + index.rewrite_subqueries_into_dependent_joins($plan.clone()); + if let Err(err) = transformed { + assert_eq!(format!("{err}"), $expected); + } else { + panic!("rewriting {} was not returning error", $plan) } - }}; } @@ -2205,7 +2205,7 @@ mod tests { // TableScan: t2 assert_dependent_join_rewrite_err!( plan, - @"This feature is not implemented: subquery inside lateral join condition is not supported" + "This feature is not implemented: subquery inside lateral join condition is not supported" ); Ok(())