Skip to content

Improve documentation for ExprVisitor, port simple uses to new walking function #4916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 47 additions & 38 deletions datafusion/expr/src/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,46 @@ pub enum Recursion<V: ExpressionVisitor> {
Stop(V),
}

/// Encode the traversal of an expression tree. When passed to
/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
/// recursively on all nodes of an expression tree. See the comments
/// on `Expr::accept` for details on its use
/// Implements the [visitor
/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`Expr`]s.
///
/// [`ExpressionVisitor`] allows keeping the algorithms
/// separate from the code to traverse the structure of the `Expr`
/// tree and makes it easier to add new types of expressions and
/// algorithms by.
///
/// When passed to[`Expr::accept`], [`ExpressionVisitor::pre_visit`]
/// and [`ExpressionVisitor::post_visit`] are invoked recursively
/// on all nodes of an expression tree.
///
///
/// For an expression tree such as
/// ```text
/// BinaryExpr (GT)
/// left: Column("foo")
/// right: Column("bar")
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(BinaryExpr(GT))
/// pre_visit(Column("foo"))
/// post_visit(Column("foo"))
/// pre_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(BinaryExpr(GT))
/// ```
///
/// If an [`Err`] result is returned, recursion is stopped
/// immediately.
///
/// If [`Recursion::Stop`] is returned on a call to pre_visit, no
/// children of that expression are visited, nor is post_visit
/// called on that expression
///
/// # See Also:
/// * [`Expr::accept`] to drive a visitor through an [`Expr`]
/// * [inspect_expr_pre]: For visiting [`Expr`]s using functions
pub trait ExpressionVisitor<E: ExprVisitable = Expr>: Sized {
/// Invoked before any children of `expr` are visited.
fn pre_visit(self, expr: &E) -> Result<Recursion<Self>>
Expand All @@ -58,37 +94,7 @@ pub trait ExprVisitable: Sized {

impl ExprVisitable for Expr {
/// Performs a depth first walk of an expression and
/// its children, calling [`ExpressionVisitor::pre_visit`] and
/// `visitor.post_visit`.
///
/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
/// separate expression algorithms from the structure of the
/// `Expr` tree and make it easier to add new types of expressions
/// and algorithms that walk the tree.
///
/// For an expression tree such as
/// ```text
/// BinaryExpr (GT)
/// left: Column("foo")
/// right: Column("bar")
/// ```
///
/// The nodes are visited using the following order
/// ```text
/// pre_visit(BinaryExpr(GT))
/// pre_visit(Column("foo"))
/// post_visit(Column("foo"))
/// pre_visit(Column("bar"))
/// post_visit(Column("bar"))
/// post_visit(BinaryExpr(GT))
/// ```
///
/// If an Err result is returned, recursion is stopped immediately
///
/// If `Recursion::Stop` is returned on a call to pre_visit, no
/// children of that expression are visited, nor is post_visit
/// called on that expression
///
/// its children, see [`ExpressionVisitor`] for more details
fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
let visitor = match visitor.pre_visit(self)? {
Recursion::Continue(visitor) => visitor,
Expand Down Expand Up @@ -223,6 +229,7 @@ impl ExprVisitable for Expr {

struct VisitorAdapter<F, E> {
f: F,
// Store returned error as it my not be a DataFusionError
err: std::result::Result<(), E>,
}

Expand All @@ -242,10 +249,12 @@ where
}
}

/// Conveniece function for using a mutable function as an expression visiitor
/// Recursively inspect an [`Expr`] and all its childen.
///
/// TODO make this match names in physical plan
pub fn walk_expr_down<F, E>(expr: &Expr, f: F) -> std::result::Result<(), E>
/// Performs a pre-visit traversal by recursively calling `f(expr)` on
/// `expr`, and then on all its children. See [`ExpressionVisitor`]
/// for more details and more options to control the walk.
pub fn inspect_expr_pre<F, E>(expr: &Expr, f: F) -> std::result::Result<(), E>
where
F: FnMut(&Expr) -> std::result::Result<(), E>,
{
Expand Down
7 changes: 4 additions & 3 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
use crate::expr_visitor::walk_expr_down;
use crate::expr_visitor::inspect_expr_pre;
///! Logical plan types
use crate::logical_plan::builder::validate_unique_names;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
Expand Down Expand Up @@ -580,7 +580,7 @@ impl LogicalPlan {
{
self.inspect_expressions(|expr| {
// recursively look for subqueries
walk_expr_down(expr, |expr| {
inspect_expr_pre(expr, |expr| {
match expr {
Expr::Exists { subquery, .. }
| Expr::InSubquery { subquery, .. }
Expand Down Expand Up @@ -1219,7 +1219,8 @@ pub struct DropView {
pub schema: DFSchemaRef,
}

/// Set a Variable's value -- value in [`ConfigOptions`]
/// Set a Variable's value -- value in
/// [`ConfigOptions`](datafusion_common::config::ConfigOptions)
#[derive(Clone)]
pub struct SetVariable {
/// The variable name
Expand Down
94 changes: 29 additions & 65 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use crate::expr::{Sort, WindowFunction};
use crate::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use crate::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use crate::expr_visitor::{
inspect_expr_pre, ExprVisitable, ExpressionVisitor, Recursion,
};
use crate::logical_plan::builder::build_join_schema;
use crate::logical_plan::{
Aggregate, Analyze, CreateMemoryTable, CreateView, Distinct, Extension, Filter, Join,
Expand Down Expand Up @@ -83,20 +85,16 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
}
}

/// Recursively walk an expression tree, collecting the unique set of column names
/// Recursively walk an expression tree, collecting the unique set of columns
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is a pretty good example of reducing the ceremony required do to a walk of an expr by using inspect_expr_pre rather than having to define a visitor explicitly

/// referenced in the expression
struct ColumnNameVisitor<'a> {
accum: &'a mut HashSet<Column>,
}

impl ExpressionVisitor for ColumnNameVisitor<'_> {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
inspect_expr_pre(expr, |expr| {
Comment on lines +90 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

love this change❤️

match expr {
Expr::Column(qc) => {
self.accum.insert(qc.clone());
accum.insert(qc.clone());
}
Expr::ScalarVariable(_, var_names) => {
self.accum.insert(Column::from_name(var_names.join(".")));
accum.insert(Column::from_name(var_names.join(".")));
}
Expr::Alias(_, _)
| Expr::Literal(_)
Expand Down Expand Up @@ -134,15 +132,8 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
| Expr::GetIndexedField { .. }
| Expr::Placeholder { .. } => {}
}
Ok(Recursion::Continue(self))
}
}

/// Recursively walk an expression tree, collecting the unique set of columns
/// referenced in the expression
pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
expr.accept(ColumnNameVisitor { accum })?;
Ok(())
Ok(())
})
}

/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
Expand Down Expand Up @@ -861,27 +852,17 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
.collect()
}

/// Recursively find all columns referenced by an expression
#[derive(Debug, Default)]
struct ColumnCollector {
exprs: Vec<Column>,
}

impl ExpressionVisitor for ColumnCollector {
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>> {
pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
let mut exprs = vec![];
inspect_expr_pre(e, |expr| {
if let Expr::Column(c) = expr {
self.exprs.push(c.clone())
exprs.push(c.clone())
}
Ok(Recursion::Continue(self))
}
}

pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
Ok(()) as Result<()>
})
// As the `ExpressionVisitor` impl above always returns Ok, this
// "can't" error
let ColumnCollector { exprs } = e
.accept(ColumnCollector::default())
.expect("Unexpected error");
.expect("Unexpected error");
exprs
}

Expand All @@ -898,43 +879,26 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {

/// Recursively walk an expression tree, collecting the column indexes
/// referenced in the expression
struct ColumnIndexesCollector<'a> {
schema: &'a DFSchemaRef,
indexes: Vec<usize>,
}

impl ExpressionVisitor for ColumnIndexesCollector<'_> {
fn pre_visit(mut self, expr: &Expr) -> Result<Recursion<Self>>
where
Self: ExpressionVisitor,
{
pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
let mut indexes = vec![];
inspect_expr_pre(e, |expr| {
match expr {
Expr::Column(qc) => {
if let Ok(idx) = self.schema.index_of_column(qc) {
self.indexes.push(idx);
if let Ok(idx) = schema.index_of_column(qc) {
indexes.push(idx);
}
}
Expr::Literal(_) => {
self.indexes.push(std::usize::MAX);
indexes.push(std::usize::MAX);
}
_ => {}
}
Ok(Recursion::Continue(self))
}
}

pub(crate) fn find_column_indexes_referenced_by_expr(
e: &Expr,
schema: &DFSchemaRef,
) -> Vec<usize> {
// As the `ExpressionVisitor` impl above always returns Ok, this
// "can't" error
let ColumnIndexesCollector { indexes, .. } = e
.accept(ColumnIndexesCollector {
schema,
indexes: vec![],
})
.expect("Unexpected error");
Ok(()) as Result<()>
})
.unwrap();
indexes
}

Expand Down
33 changes: 13 additions & 20 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion_common::Result;
use datafusion_common::{plan_err, Column, DFSchemaRef};
use datafusion_expr::expr::{BinaryExpr, Sort};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::expr_visitor::{ExprVisitable, ExpressionVisitor, Recursion};
use datafusion_expr::expr_visitor::inspect_expr_pre;
use datafusion_expr::{
and, col,
logical_plan::{Filter, LogicalPlan},
Expand Down Expand Up @@ -232,28 +232,21 @@ pub fn unalias(expr: Expr) -> Expr {
///
/// A PlanError if a disjunction is found
pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> {
struct DisjunctionVisitor {}

impl ExpressionVisitor for DisjunctionVisitor {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Operator::Or,
right: _,
}) => {
plan_err!("Optimizing disjunctions not supported!")
}
_ => Ok(Recursion::Continue(self)),
// recursively check for unallowed predicates in expr
fn check(expr: &&Expr) -> Result<()> {
inspect_expr_pre(expr, |expr| match expr {
Expr::BinaryExpr(BinaryExpr {
left: _,
op: Operator::Or,
right: _,
}) => {
plan_err!("Optimizing disjunctions not supported!")
}
}
}

for predicate in predicates.iter() {
predicate.accept(DisjunctionVisitor {})?;
_ => Ok(()),
})
}

Ok(())
predicates.iter().try_for_each(check)
}

/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
Expand Down