diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index 82862eb873d7..5675eae1d955 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -33,10 +33,46 @@ pub enum Recursion { Stop(V), } -/// Encode the traversal of an expression tree. When passed to -/// `Expr::accept`, `ExpressionVisitor::visit` is invoked -/// recursively on all nodes of an expression tree. See the comments -/// on `Expr::accept` for details on its use +/// Implements the [visitor +/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`Expr`]s. +/// +/// [`ExpressionVisitor`] allows keeping the algorithms +/// separate from the code to traverse the structure of the `Expr` +/// tree and makes it easier to add new types of expressions and +/// algorithms by. +/// +/// When passed to[`Expr::accept`], [`ExpressionVisitor::pre_visit`] +/// and [`ExpressionVisitor::post_visit`] are invoked recursively +/// on all nodes of an expression tree. +/// +/// +/// For an expression tree such as +/// ```text +/// BinaryExpr (GT) +/// left: Column("foo") +/// right: Column("bar") +/// ``` +/// +/// The nodes are visited using the following order +/// ```text +/// pre_visit(BinaryExpr(GT)) +/// pre_visit(Column("foo")) +/// post_visit(Column("foo")) +/// pre_visit(Column("bar")) +/// post_visit(Column("bar")) +/// post_visit(BinaryExpr(GT)) +/// ``` +/// +/// If an [`Err`] result is returned, recursion is stopped +/// immediately. +/// +/// If [`Recursion::Stop`] is returned on a call to pre_visit, no +/// children of that expression are visited, nor is post_visit +/// called on that expression +/// +/// # See Also: +/// * [`Expr::accept`] to drive a visitor through an [`Expr`] +/// * [inspect_expr_pre]: For visiting [`Expr`]s using functions pub trait ExpressionVisitor: Sized { /// Invoked before any children of `expr` are visited. fn pre_visit(self, expr: &E) -> Result> @@ -58,37 +94,7 @@ pub trait ExprVisitable: Sized { impl ExprVisitable for Expr { /// Performs a depth first walk of an expression and - /// its children, calling [`ExpressionVisitor::pre_visit`] and - /// `visitor.post_visit`. - /// - /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to - /// separate expression algorithms from the structure of the - /// `Expr` tree and make it easier to add new types of expressions - /// and algorithms that walk the tree. - /// - /// For an expression tree such as - /// ```text - /// BinaryExpr (GT) - /// left: Column("foo") - /// right: Column("bar") - /// ``` - /// - /// The nodes are visited using the following order - /// ```text - /// pre_visit(BinaryExpr(GT)) - /// pre_visit(Column("foo")) - /// post_visit(Column("foo")) - /// pre_visit(Column("bar")) - /// post_visit(Column("bar")) - /// post_visit(BinaryExpr(GT)) - /// ``` - /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If `Recursion::Stop` is returned on a call to pre_visit, no - /// children of that expression are visited, nor is post_visit - /// called on that expression - /// + /// its children, see [`ExpressionVisitor`] for more details fn accept(&self, visitor: V) -> Result { let visitor = match visitor.pre_visit(self)? { Recursion::Continue(visitor) => visitor, @@ -223,6 +229,7 @@ impl ExprVisitable for Expr { struct VisitorAdapter { f: F, + // Store returned error as it my not be a DataFusionError err: std::result::Result<(), E>, } @@ -242,10 +249,12 @@ where } } -/// Conveniece function for using a mutable function as an expression visiitor +/// Recursively inspect an [`Expr`] and all its childen. /// -/// TODO make this match names in physical plan -pub fn walk_expr_down(expr: &Expr, f: F) -> std::result::Result<(), E> +/// Performs a pre-visit traversal by recursively calling `f(expr)` on +/// `expr`, and then on all its children. See [`ExpressionVisitor`] +/// for more details and more options to control the walk. +pub fn inspect_expr_pre(expr: &Expr, f: F) -> std::result::Result<(), E> where F: FnMut(&Expr) -> std::result::Result<(), E>, { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 23748470c171..a9e4943be599 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -16,7 +16,7 @@ // under the License. use crate::expr_rewriter::{ExprRewritable, ExprRewriter}; -use crate::expr_visitor::walk_expr_down; +use crate::expr_visitor::inspect_expr_pre; ///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; @@ -580,7 +580,7 @@ impl LogicalPlan { { self.inspect_expressions(|expr| { // recursively look for subqueries - walk_expr_down(expr, |expr| { + inspect_expr_pre(expr, |expr| { match expr { Expr::Exists { subquery, .. } | Expr::InSubquery { subquery, .. } @@ -1219,7 +1219,8 @@ pub struct DropView { pub schema: DFSchemaRef, } -/// Set a Variable's value -- value in [`ConfigOptions`] +/// Set a Variable's value -- value in +/// [`ConfigOptions`](datafusion_common::config::ConfigOptions) #[derive(Clone)] pub struct SetVariable { /// The variable name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8e22b5a23b3e..c48a1c5fb2eb 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,7 +19,9 @@ use crate::expr::{Sort, WindowFunction}; use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; +use crate::expr_visitor::{ + inspect_expr_pre, ExprVisitable, ExpressionVisitor, Recursion, +}; use crate::logical_plan::builder::build_join_schema; use crate::logical_plan::{ Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join, @@ -83,20 +85,16 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { } } -/// Recursively walk an expression tree, collecting the unique set of column names +/// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression -struct ColumnNameVisitor<'a> { - accum: &'a mut HashSet, -} - -impl ExpressionVisitor for ColumnNameVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { +pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { + inspect_expr_pre(expr, |expr| { match expr { Expr::Column(qc) => { - self.accum.insert(qc.clone()); + accum.insert(qc.clone()); } Expr::ScalarVariable(_, var_names) => { - self.accum.insert(Column::from_name(var_names.join("."))); + accum.insert(Column::from_name(var_names.join("."))); } Expr::Alias(_, _) | Expr::Literal(_) @@ -134,15 +132,8 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::GetIndexedField { .. } | Expr::Placeholder { .. } => {} } - Ok(Recursion::Continue(self)) - } -} - -/// Recursively walk an expression tree, collecting the unique set of columns -/// referenced in the expression -pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.accept(ColumnNameVisitor { accum })?; - Ok(()) + Ok(()) + }) } /// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s. @@ -861,27 +852,17 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { .collect() } -/// Recursively find all columns referenced by an expression -#[derive(Debug, Default)] -struct ColumnCollector { - exprs: Vec, -} - -impl ExpressionVisitor for ColumnCollector { - fn pre_visit(mut self, expr: &Expr) -> Result> { +pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { + let mut exprs = vec![]; + inspect_expr_pre(e, |expr| { if let Expr::Column(c) = expr { - self.exprs.push(c.clone()) + exprs.push(c.clone()) } - Ok(Recursion::Continue(self)) - } -} - -pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { + Ok(()) as Result<()> + }) // As the `ExpressionVisitor` impl above always returns Ok, this // "can't" error - let ColumnCollector { exprs } = e - .accept(ColumnCollector::default()) - .expect("Unexpected error"); + .expect("Unexpected error"); exprs } @@ -898,43 +879,26 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { /// Recursively walk an expression tree, collecting the column indexes /// referenced in the expression -struct ColumnIndexesCollector<'a> { - schema: &'a DFSchemaRef, - indexes: Vec, -} - -impl ExpressionVisitor for ColumnIndexesCollector<'_> { - fn pre_visit(mut self, expr: &Expr) -> Result> - where - Self: ExpressionVisitor, - { +pub(crate) fn find_column_indexes_referenced_by_expr( + e: &Expr, + schema: &DFSchemaRef, +) -> Vec { + let mut indexes = vec![]; + inspect_expr_pre(e, |expr| { match expr { Expr::Column(qc) => { - if let Ok(idx) = self.schema.index_of_column(qc) { - self.indexes.push(idx); + if let Ok(idx) = schema.index_of_column(qc) { + indexes.push(idx); } } Expr::Literal(_) => { - self.indexes.push(std::usize::MAX); + indexes.push(std::usize::MAX); } _ => {} } - Ok(Recursion::Continue(self)) - } -} - -pub(crate) fn find_column_indexes_referenced_by_expr( - e: &Expr, - schema: &DFSchemaRef, -) -> Vec { - // As the `ExpressionVisitor` impl above always returns Ok, this - // "can't" error - let ColumnIndexesCollector { indexes, .. } = e - .accept(ColumnIndexesCollector { - schema, - indexes: vec![], - }) - .expect("Unexpected error"); + Ok(()) as Result<()> + }) + .unwrap(); indexes } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index a5f06bc82441..4ac337ec21de 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -22,7 +22,7 @@ use datafusion_common::Result; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_expr::expr::{BinaryExpr, Sort}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter}; -use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion}; +use datafusion_expr::expr_visitor::inspect_expr_pre; use datafusion_expr::{ and, col, logical_plan::{Filter, LogicalPlan}, @@ -232,28 +232,21 @@ pub fn unalias(expr: Expr) -> Expr { /// /// A PlanError if a disjunction is found pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> { - struct DisjunctionVisitor {} - - impl ExpressionVisitor for DisjunctionVisitor { - fn pre_visit(self, expr: &Expr) -> Result> { - match expr { - Expr::BinaryExpr(BinaryExpr { - left: _, - op: Operator::Or, - right: _, - }) => { - plan_err!("Optimizing disjunctions not supported!") - } - _ => Ok(Recursion::Continue(self)), + // recursively check for unallowed predicates in expr + fn check(expr: &&Expr) -> Result<()> { + inspect_expr_pre(expr, |expr| match expr { + Expr::BinaryExpr(BinaryExpr { + left: _, + op: Operator::Or, + right: _, + }) => { + plan_err!("Optimizing disjunctions not supported!") } - } - } - - for predicate in predicates.iter() { - predicate.accept(DisjunctionVisitor {})?; + _ => Ok(()), + }) } - Ok(()) + predicates.iter().try_for_each(check) } /// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with