diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 9fb61008b9f6..3dcca7aa264e 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -16,7 +16,6 @@ // under the License. use arrow_schema::DataType; -use std::sync::Arc; use datafusion::error::Result; use datafusion::prelude::*; @@ -44,7 +43,7 @@ async fn where_scalar_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? .filter( - scalar_subquery(Arc::new( + scalar_subquery(Box::new( ctx.table("t2") .await? .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? @@ -67,7 +66,7 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { .await? .filter(in_subquery( col("t1.c2"), - Arc::new( + Box::new( ctx.table("t2") .await? .filter(col("t2.c1").gt(lit(ScalarValue::UInt8(Some(0)))))? @@ -87,7 +86,7 @@ async fn where_in_subquery(ctx: &SessionContext) -> Result<()> { async fn where_exist_subquery(ctx: &SessionContext) -> Result<()> { ctx.table("t1") .await? - .filter(exists(Arc::new( + .filter(exists(Box::new( ctx.table("t2") .await? .filter(out_ref_col(DataType::Utf8, "t1.c1").eq(col("t2.c1")))? diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 541448ebf149..614654614323 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -59,7 +59,7 @@ pub fn main() -> Result<()> { // then run the optimizer with our custom rule let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?.data; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index f0edc7175948..ca12a7df23d1 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -35,6 +35,15 @@ pub struct Column { pub name: String, } +impl Default for Column { + fn default() -> Self { + Self { + relation: None, + name: "".to_string(), + } + } +} + impl Column { /// Create Column from optional qualifier and name. The optional qualifier, if present, /// will be parsed and normalized by default. diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 116e45c8c130..ac62f1f4df1e 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -529,8 +529,7 @@ impl SessionContext { column_defaults, } = cmd; - let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); - let input = self.state().optimize(&input)?; + let input = self.state().optimize(input.as_ref())?; let table = self.table(&name).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -1877,7 +1876,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1886,7 +1885,7 @@ impl SessionState { }, ); let (plan, logical_optimization_succeeded) = match optimized_plan { - Ok(plan) => (Arc::new(plan), true), + Ok(plan) => (Box::new(plan.data), true), Err(DataFusionError::Context(optimizer_name, err)) => { let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; stringified_plans @@ -1907,7 +1906,9 @@ impl SessionState { let analyzed_plan = self.analyzer .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + self.optimizer + .optimize(analyzed_plan, self, |_, _| {}) + .map(|t| t.data) } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ca708b05823e..09265cb534bd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -919,7 +919,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(filter.with_default_selectivity(selectivity)?)) } LogicalPlan::Union(Union { inputs, schema: _ }) => { - let physical_plans = self.create_initial_plan_multi(inputs.iter().map(|lp| lp.as_ref()), session_state).await?; + let physical_plans = self.create_initial_plan_multi(inputs.iter(), session_state).await?; Ok(Arc::new(UnionExec::new(physical_plans))) } @@ -1020,8 +1020,8 @@ impl DefaultPhysicalPlanner { let join_plan = LogicalPlan::Join(Join::try_new_with_project_input( logical_plan, - Arc::new(left), - Arc::new(right), + Box::new(left), + Box::new(right), column_on, )?); @@ -1037,7 +1037,7 @@ impl DefaultPhysicalPlanner { let projection = Projection::try_new( final_join_result, - Arc::new(join_plan), + Box::new(join_plan), )?; LogicalPlan::Projection(projection) } else { diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 305a7e69fdb2..0bb6f55fc7ab 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -102,7 +102,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .await? .filter(in_subquery( col("a"), - Arc::new( + Box::new( ctx.table("t2") .await? .aggregate(vec![], vec![count(wildcard())])? @@ -139,7 +139,7 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { let df_results = ctx .table("t1") .await? - .filter(exists(Arc::new( + .filter(exists(Box::new( ctx.table("t2") .await? .aggregate(vec![], vec![count(wildcard())])? @@ -251,7 +251,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .table("t1") .await? .filter( - scalar_subquery(Arc::new( + scalar_subquery(Box::new( ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 60010bdddfb8..20d55def22ca 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}).map(|p| p.data) } #[derive(Default)] diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index b5d10b1c5b9b..979e81911bf2 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -158,7 +158,7 @@ impl SimpleCsvTable { normalize_col(self.exprs[0].clone(), &plan)?, plan.schema(), )], - Arc::new(plan), + Box::new(plan), ) .map(LogicalPlan::Projection)?; let rbs = collect( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7ede4cd8ffc9..91295fb0f89d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -26,11 +26,11 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::{expr_to_columns, find_out_reference_exprs}; -use crate::window_frame; use crate::{ aggregate_function, built_in_function, built_in_window_function, udaf, BuiltinScalarFunction, ExprSchemable, Operator, Signature, }; +use crate::{lit, window_frame}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -184,6 +184,12 @@ pub enum Expr { Unnest(Unnest), } +impl Default for Expr { + fn default() -> Self { + lit(0) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Unnest { pub exprs: Vec, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e1ab11c5b778..b6376708f119 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -377,7 +377,7 @@ pub fn approx_percentile_cont_with_weight( } /// Create an EXISTS subquery expression -pub fn exists(subquery: Arc) -> Expr { +pub fn exists(subquery: Box) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); Expr::Exists(Exists { subquery: Subquery { @@ -389,7 +389,7 @@ pub fn exists(subquery: Arc) -> Expr { } /// Create a NOT EXISTS subquery expression -pub fn not_exists(subquery: Arc) -> Expr { +pub fn not_exists(subquery: Box) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); Expr::Exists(Exists { subquery: Subquery { @@ -401,7 +401,7 @@ pub fn not_exists(subquery: Arc) -> Expr { } /// Create an IN subquery expression -pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { +pub fn in_subquery(expr: Expr, subquery: Box) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); Expr::InSubquery(InSubquery::new( Box::new(expr), @@ -414,7 +414,7 @@ pub fn in_subquery(expr: Expr, subquery: Arc) -> Expr { } /// Create a NOT IN subquery expression -pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { +pub fn not_in_subquery(expr: Expr, subquery: Box) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); Expr::InSubquery(InSubquery::new( Box::new(expr), @@ -427,7 +427,7 @@ pub fn not_in_subquery(expr: Expr, subquery: Arc) -> Expr { } /// Create a scalar subquery expression -pub fn scalar_subquery(subquery: Arc) -> Expr { +pub fn scalar_subquery(subquery: Box) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); Expr::ScalarSubquery(Subquery { subquery, diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 7a227a91c455..5e1f86079850 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -19,7 +19,6 @@ use std::collections::HashMap; use std::collections::HashSet; -use std::sync::Arc; use crate::expr::{Alias, Unnest}; use crate::logical_plan::Projection; @@ -220,7 +219,7 @@ pub fn coerce_plan_expr_for_schema( let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err()); if add_project { - let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?; + let projection = Projection::try_new(new_exprs, Box::new(plan.clone()))?; Ok(LogicalPlan::Projection(projection)) } else { Ok(plan.clone()) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f1ac22d584ee..ec8d334a31f9 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -32,7 +32,6 @@ use datafusion_common::{ ExprSchema, Result, }; use std::collections::HashMap; -use std::sync::Arc; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { @@ -544,7 +543,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result, ) -> Result { Ok(Self::from(LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url, format_options, options, @@ -293,7 +293,7 @@ impl LogicalPlanBuilder { table_name: table_name.into(), table_schema, op, - input: Arc::new(input), + input: Box::new(input), }))) } @@ -367,7 +367,7 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new(expr, Arc::new(self.plan)) + Filter::try_new(expr, Box::new(self.plan)) .map(LogicalPlan::Filter) .map(Self::from) } @@ -377,7 +377,7 @@ impl LogicalPlanBuilder { Ok(Self::from(LogicalPlan::Prepare(Prepare { name, data_types, - input: Arc::new(self.plan), + input: Box::new(self.plan), }))) } @@ -391,7 +391,7 @@ impl LogicalPlanBuilder { Ok(Self::from(LogicalPlan::Limit(Limit { skip, fetch, - input: Arc::new(self.plan), + input: Box::new(self.plan), }))) } @@ -539,7 +539,7 @@ impl LogicalPlanBuilder { if missing_cols.is_empty() { return Ok(Self::from(LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &self.plan)?, - input: Arc::new(self.plan), + input: Box::new(self.plan), fetch: None, }))); } @@ -555,11 +555,11 @@ impl LogicalPlanBuilder { let plan = Self::add_missing_columns(self.plan, &missing_cols, is_distinct)?; let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_cols(exprs, &plan)?, - input: Arc::new(plan), + input: Box::new(plan), fetch: None, }); - Projection::try_new(new_expr, Arc::new(sort_plan)) + Projection::try_new(new_expr, Box::new(sort_plan)) .map(LogicalPlan::Projection) .map(Self::from) } @@ -574,14 +574,14 @@ impl LogicalPlanBuilder { let left_plan: LogicalPlan = self.plan; let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Box::new( union(left_plan, right_plan)?, ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Box::new( self.plan, ))))) } @@ -595,7 +595,7 @@ impl LogicalPlanBuilder { sort_expr: Option>, ) -> Result { Ok(Self::from(LogicalPlan::Distinct(Distinct::On( - DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + DistinctOn::try_new(on_expr, select_expr, sort_expr, Box::new(self.plan))?, )))) } @@ -811,8 +811,8 @@ impl LogicalPlanBuilder { build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), - right: Arc::new(right), + left: Box::new(self.plan), + right: Box::new(right), on, filter, join_type, @@ -875,8 +875,8 @@ impl LogicalPlanBuilder { })?) } else { Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), - right: Arc::new(right), + left: Box::new(self.plan), + right: Box::new(right), on: join_on, filter: filters, join_type, @@ -892,8 +892,8 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; Ok(Self::from(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(self.plan), - right: Arc::new(right), + left: Box::new(self.plan), + right: Box::new(right), schema: DFSchemaRef::new(join_schema), }))) } @@ -901,7 +901,7 @@ impl LogicalPlanBuilder { /// Repartition pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { Ok(Self::from(LogicalPlan::Repartition(Repartition { - input: Arc::new(self.plan), + input: Box::new(self.plan), partitioning_scheme, }))) } @@ -915,7 +915,7 @@ impl LogicalPlanBuilder { validate_unique_names("Windows", &window_expr)?; Ok(Self::from(LogicalPlan::Window(Window::try_new( window_expr, - Arc::new(self.plan), + Box::new(self.plan), )?))) } @@ -932,7 +932,7 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; - Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + Aggregate::try_new(Box::new(self.plan), group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::from) } @@ -950,7 +950,7 @@ impl LogicalPlanBuilder { if analyze { Ok(Self::from(LogicalPlan::Analyze(Analyze { verbose, - input: Arc::new(self.plan), + input: Box::new(self.plan), schema, }))) } else { @@ -959,7 +959,7 @@ impl LogicalPlanBuilder { Ok(Self::from(LogicalPlan::Explain(Explain { verbose, - plan: Arc::new(self.plan), + plan: Box::new(self.plan), stringified_plans, schema, logical_optimization_succeeded: false, @@ -1096,8 +1096,8 @@ impl LogicalPlanBuilder { build_join_schema(self.plan.schema(), right.schema(), &join_type)?; Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), - right: Arc::new(right), + left: Box::new(self.plan), + right: Box::new(right), on: join_key_pairs, filter, join_type, @@ -1278,7 +1278,7 @@ pub(crate) fn validate_unique_names<'a>( pub fn project_with_column_index( expr: Vec, - input: Arc, + input: Box, schema: DFSchemaRef, ) -> Result { let alias_expr = expr @@ -1347,13 +1347,13 @@ pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { - Ok(Arc::new(project_with_column_index( + Ok(project_with_column_index( expr, input, Arc::new(union_schema.clone()), - )?)) + )?) } - other_plan => Ok(Arc::new(other_plan)), + other_plan => Ok(other_plan), } }) .collect::>>()?; @@ -1400,7 +1400,7 @@ pub fn project( validate_unique_names("Projections", projected_expr.iter())?; - Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) + Projection::try_new(projected_expr, Box::new(plan)).map(LogicalPlan::Projection) } /// Create a SubqueryAlias to wrap a LogicalPlan. @@ -1408,7 +1408,7 @@ pub fn subquery_alias( plan: LogicalPlan, alias: impl Into, ) -> Result { - SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(Box::new(plan), alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -1578,7 +1578,7 @@ pub fn unnest_with_options( let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); Ok(LogicalPlan::Unnest(Unnest { - input: Arc::new(input), + input: Box::new(input), column: unnested_field.qualified_column(), schema, options, @@ -1809,7 +1809,7 @@ mod tests { let outer_query = LogicalPlanBuilder::from(bar) .project(vec![col("a")])? - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .build()?; let expected = "Filter: EXISTS ()\ @@ -1837,7 +1837,7 @@ mod tests { // SELECT a FROM bar WHERE a IN (SELECT a FROM foo WHERE a = bar.a) let outer_query = LogicalPlanBuilder::from(bar) .project(vec![col("a")])? - .filter(in_subquery(col("a"), Arc::new(subquery)))? + .filter(in_subquery(col("a"), Box::new(subquery)))? .build()?; let expected = "Filter: bar.a IN ()\ @@ -1864,7 +1864,7 @@ mod tests { // SELECT (SELECT a FROM foo WHERE a = bar.a) FROM bar let outer_query = LogicalPlanBuilder::from(bar) - .project(vec![scalar_subquery(Arc::new(subquery))])? + .project(vec![scalar_subquery(Box::new(subquery))])? .build()?; let expected = "Projection: ()\ diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 968c40c8bf62..2b7a192200e0 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -16,7 +16,6 @@ // under the License. use std::collections::HashMap; -use std::sync::Arc; use std::{ fmt::{self, Display}, hash::{Hash, Hasher}, @@ -243,7 +242,7 @@ pub struct CreateMemoryTable { /// The list of constraints in the schema, such as primary key, unique, etc. pub constraints: Constraints, /// The logical plan - pub input: Arc, + pub input: Box, /// Option to not error if table already exists pub if_not_exists: bool, /// Option to replace table content if table already exists @@ -258,7 +257,7 @@ pub struct CreateView { /// The table name pub name: OwnedTableReference, /// The logical plan - pub input: Arc, + pub input: Box, /// Option to not error if table already exists pub or_replace: bool, /// SQL used to create the view, if available diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 6ab06a57c1c2..b3dd1256916d 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -18,7 +18,6 @@ use std::collections::HashMap; use std::fmt::{self, Display}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; use datafusion_common::config::FormatOptions; use datafusion_common::{DFSchemaRef, OwnedTableReference}; @@ -29,7 +28,7 @@ use crate::LogicalPlan; #[derive(Clone)] pub struct CopyTo { /// The relation that determines the tuples to write to the output file(s) - pub input: Arc, + pub input: Box, /// The location to write the file(s) pub output_url: String, /// Determines which, if any, columns should be used for hive-style partitioned writes @@ -69,7 +68,7 @@ pub struct DmlStatement { /// The type of operation to perform pub op: WriteOp, /// The relation that determines the tuples to add/remove/modify the schema must match with table_schema - pub input: Arc, + pub input: Box, } impl DmlStatement { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 05d7ac539458..23eb42679b97 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -38,7 +38,7 @@ use crate::utils::{ split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, + build_join_schema, col, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, }; @@ -157,6 +157,15 @@ pub enum LogicalPlan { RecursiveQuery(RecursiveQuery), } +impl Default for LogicalPlan { + fn default() -> Self { + LogicalPlan::EmptyRelation(EmptyRelation { + schema: DFSchemaRef::new(DFSchema::empty()), + produce_one_row: false, + }) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { @@ -281,6 +290,141 @@ impl LogicalPlan { exprs } + pub fn rewrite_exprs(mut self, mut f: F) -> Result + where + F: FnMut(Expr) -> Result, + { + match &mut self { + LogicalPlan::Projection(Projection { expr, .. }) + | LogicalPlan::Window(Window { + window_expr: expr, .. + }) + | LogicalPlan::Sort(Sort { expr, .. }) + | LogicalPlan::TableScan(TableScan { filters: expr, .. }) => { + expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + } + LogicalPlan::Values(Values { values, .. }) => { + values.iter_mut().flatten().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + } + LogicalPlan::Filter(Filter { predicate, .. }) => { + let old_predicate = std::mem::take(predicate); + *predicate = f(old_predicate)?; + } + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { + expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + } + Partitioning::RoundRobinBatch(_) => {} + }, + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + .. + }) => { + group_expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + aggr_expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + } + // There are two part of expression for join, equijoin(on) and non-equijoin(filter). + // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. + // 2. the second part is non-equijoin(filter). + LogicalPlan::Join(Join { on, filter, .. }) => { + on.iter_mut().try_for_each(|(l, r)| { + let old_l = std::mem::take(l); + let old_r = std::mem::take(r); + *l = f(old_l)?; + *r = f(old_r)?; + Ok::<_, DataFusionError>(()) + })?; + + if let Some(filter) = filter.as_mut() { + let old_filter = std::mem::take(filter); + *filter = f(old_filter)?; + } + } + LogicalPlan::Extension(_extension) => { + todo!("implemet rewrite_exprs for extension node") + } + LogicalPlan::Unnest(Unnest { column, .. }) => { + let old_column = std::mem::take(column); + let new_col_expr = f(col(old_column))?; + match new_col_expr { + Expr::Column(col) => *column = col, + _ => { + return internal_err!( + "Simplified Unnest's column should be column" + ) + } + } + } + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + })) => { + on_expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + select_expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + if let Some(sort_expr) = sort_expr.as_mut() { + sort_expr.iter_mut().try_for_each(|e| { + let old_expr = std::mem::take(e); + *e = f(old_expr)?; + Ok::<_, DataFusionError>(()) + })?; + } + } + // plans without expressions + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Statement(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Union(_) + | LogicalPlan::Distinct(Distinct::All(_)) + | LogicalPlan::Dml(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Copy(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Prepare(_) => {} + } + + Ok(self) + } + /// Calls `f` on all expressions (non-recursively) in the current /// logical plan node. This does not include expressions in any /// children. @@ -385,9 +529,7 @@ impl LogicalPlan { LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], LogicalPlan::Extension(extension) => extension.node.inputs(), - LogicalPlan::Union(Union { inputs, .. }) => { - inputs.iter().map(|arc| arc.as_ref()).collect() - } + LogicalPlan::Union(Union { inputs, .. }) => inputs.iter().collect(), LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => vec![input], @@ -412,6 +554,71 @@ impl LogicalPlan { } } + pub fn rewrite_inputs(mut self, mut f: F) -> Result + where + F: FnMut(LogicalPlan) -> Result, + { + match &mut self { + LogicalPlan::Projection(Projection { input, .. }) + | LogicalPlan::Filter(Filter { input, .. }) + | LogicalPlan::Repartition(Repartition { input, .. }) + | LogicalPlan::Window(Window { input, .. }) + | LogicalPlan::Aggregate(Aggregate { input, .. }) + | LogicalPlan::Sort(Sort { input, .. }) + | LogicalPlan::Limit(Limit { input, .. }) + | LogicalPlan::Subquery(Subquery { + subquery: input, .. + }) + | LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) + | LogicalPlan::Prepare(Prepare { input, .. }) + | LogicalPlan::Unnest(Unnest { input, .. }) + | LogicalPlan::Distinct( + Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), + ) + | LogicalPlan::Explain(Explain { plan: input, .. }) + | LogicalPlan::Analyze(Analyze { input, .. }) + | LogicalPlan::Dml(DmlStatement { input, .. }) + | LogicalPlan::Copy(CopyTo { input, .. }) => { + let old_input = std::mem::take(input); + let plan = f(*old_input)?; + *input = Box::new(plan); + } + LogicalPlan::Join(Join { left, right, .. }) + | LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) + | LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term: left, + recursive_term: right, + .. + }) => { + let old_left = std::mem::take(left); + let old_right = std::mem::take(right); + let new_left = f(*old_left)?; + let new_right = f(*old_right)?; + *left = Box::new(new_left); + *right = Box::new(new_right); + } + LogicalPlan::Extension(_extension) => { + todo!("implemet rewrite_inputs for extension node") + } + LogicalPlan::Union(Union { inputs, .. }) => { + let _ = inputs.iter_mut().try_for_each(|input| { + let old_input = std::mem::take(input); + *input = f(old_input)?; + Ok::<_, DataFusionError>(()) + }); + } + LogicalPlan::Ddl(_ddl) => todo!("rewrite_inputs for DDL"), + // plans without inputs + LogicalPlan::TableScan { .. } + | LogicalPlan::Statement { .. } + | LogicalPlan::EmptyRelation { .. } + | LogicalPlan::Values { .. } + | LogicalPlan::DescribeTable(_) => {} + } + + Ok(self) + } + /// returns all `Using` join columns in a logical plan pub fn using_columns(&self) -> Result>, DataFusionError> { let mut using_columns: Vec> = vec![]; @@ -553,7 +760,7 @@ impl LogicalPlan { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. LogicalPlan::Projection(Projection { .. }) => { - Projection::try_new(expr, Arc::new(inputs.swap_remove(0))) + Projection::try_new(expr, Box::new(inputs.swap_remove(0))) .map(LogicalPlan::Projection) } LogicalPlan::Dml(DmlStatement { @@ -565,7 +772,7 @@ impl LogicalPlan { table_name: table_name.clone(), table_schema: table_schema.clone(), op: op.clone(), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })), LogicalPlan::Copy(CopyTo { input: _, @@ -574,7 +781,7 @@ impl LogicalPlan { options, partition_by, }) => Ok(LogicalPlan::Copy(CopyTo { - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), output_url: output_url.clone(), format_options: format_options.clone(), options: options.clone(), @@ -624,7 +831,7 @@ impl LogicalPlan { }) .data()?; - Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) + Filter::try_new(predicate, Box::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) } LogicalPlan::Repartition(Repartition { @@ -634,35 +841,35 @@ impl LogicalPlan { Partitioning::RoundRobinBatch(n) => { Ok(LogicalPlan::Repartition(Repartition { partitioning_scheme: Partitioning::RoundRobinBatch(*n), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })) } Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition { partitioning_scheme: Partitioning::Hash(expr, *n), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })), Partitioning::DistributeBy(_) => { Ok(LogicalPlan::Repartition(Repartition { partitioning_scheme: Partitioning::DistributeBy(expr), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })) } }, LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Window::try_new(expr, Arc::new(inputs.swap_remove(0))) + Window::try_new(expr, Box::new(inputs.swap_remove(0))) .map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); - Aggregate::try_new(Arc::new(inputs.swap_remove(0)), expr, agg_expr) + Aggregate::try_new(Box::new(inputs.swap_remove(0)), expr, agg_expr) .map(LogicalPlan::Aggregate) } LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { expr, - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), fetch: *fetch, })), LogicalPlan::Join(Join { @@ -702,8 +909,8 @@ impl LogicalPlan { }).collect::>>()?; Ok(LogicalPlan::Join(Join { - left: Arc::new(inputs.swap_remove(0)), - right: Arc::new(inputs.swap_remove(0)), + left: Box::new(inputs.swap_remove(0)), + right: Box::new(inputs.swap_remove(0)), join_type: *join_type, join_constraint: *join_constraint, on: new_on, @@ -722,19 +929,19 @@ impl LogicalPlan { }) => { let subquery = LogicalPlanBuilder::from(inputs.swap_remove(0)).build()?; Ok(LogicalPlan::Subquery(Subquery { - subquery: Arc::new(subquery), + subquery: Box::new(subquery), outer_ref_columns: outer_ref_columns.clone(), })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - SubqueryAlias::try_new(Arc::new(inputs.swap_remove(0)), alias.clone()) + SubqueryAlias::try_new(Box::new(inputs.swap_remove(0)), alias.clone()) .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { Ok(LogicalPlan::Limit(Limit { skip: *skip, fetch: *fetch, - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })) } LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { @@ -745,7 +952,7 @@ impl LogicalPlan { .. })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( CreateMemoryTable { - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), constraints: Constraints::empty(), name: name.clone(), if_not_exists: *if_not_exists, @@ -759,7 +966,7 @@ impl LogicalPlan { definition, .. })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), name: name.clone(), or_replace: *or_replace, definition: definition.clone(), @@ -776,13 +983,13 @@ impl LogicalPlan { input_schema.clone() }; Ok(LogicalPlan::Union(Union { - inputs: inputs.into_iter().map(Arc::new).collect(), + inputs: inputs.into_iter().collect(), schema, })) } LogicalPlan::Distinct(distinct) => { let distinct = match distinct { - Distinct::All(_) => Distinct::All(Arc::new(inputs.swap_remove(0))), + Distinct::All(_) => Distinct::All(Box::new(inputs.swap_remove(0))), Distinct::On(DistinctOn { on_expr, select_expr, @@ -798,7 +1005,7 @@ impl LogicalPlan { } else { None }, - Arc::new(inputs.swap_remove(0)), + Box::new(inputs.swap_remove(0)), )?) } }; @@ -808,8 +1015,8 @@ impl LogicalPlan { name, is_distinct, .. }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { name: name.clone(), - static_term: Arc::new(inputs.swap_remove(0)), - recursive_term: Arc::new(inputs.swap_remove(0)), + static_term: Box::new(inputs.swap_remove(0)), + recursive_term: Box::new(inputs.swap_remove(0)), is_distinct: *is_distinct, })), LogicalPlan::Analyze(a) => { @@ -818,7 +1025,7 @@ impl LogicalPlan { Ok(LogicalPlan::Analyze(Analyze { verbose: a.verbose, schema: a.schema.clone(), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })) } LogicalPlan::Explain(e) => { @@ -829,7 +1036,7 @@ impl LogicalPlan { assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); Ok(LogicalPlan::Explain(Explain { verbose: e.verbose, - plan: Arc::new(inputs.swap_remove(0)), + plan: Box::new(inputs.swap_remove(0)), stringified_plans: e.stringified_plans.clone(), schema: e.schema.clone(), logical_optimization_succeeded: e.logical_optimization_succeeded, @@ -840,7 +1047,7 @@ impl LogicalPlan { }) => Ok(LogicalPlan::Prepare(Prepare { name: name.clone(), data_types: data_types.clone(), - input: Arc::new(inputs.swap_remove(0)), + input: Box::new(inputs.swap_remove(0)), })), LogicalPlan::TableScan(ts) => { assert!(inputs.is_empty(), "{self:?} should have no inputs"); @@ -865,7 +1072,7 @@ impl LogicalPlan { .. }) => { // Update schema with unnested column type. - let input = Arc::new(inputs.swap_remove(0)); + let input = Box::new(inputs.swap_remove(0)); let nested_field = input.schema().field_from_column(column)?; let unnested_field = schema.field_from_column(column)?; let fields = input @@ -1200,7 +1407,7 @@ impl LogicalPlan { } Expr::ScalarSubquery(qry) => { let subquery = - Arc::new(qry.subquery.replace_params_with_values(param_values)?); + Box::new(qry.subquery.replace_params_with_values(param_values)?); Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), @@ -1723,10 +1930,10 @@ pub struct RecursiveQuery { /// Name of the query pub name: String, /// The static term (initial contents of the working table) - pub static_term: Arc, + pub static_term: Box, /// The recursive term (evaluated on the contents of the working table until /// it returns an empty set) - pub recursive_term: Arc, + pub recursive_term: Box, /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, @@ -1752,22 +1959,22 @@ pub struct Projection { /// The list of expressions pub expr: Vec, /// The incoming logical plan - pub input: Arc, + pub input: Box, /// The schema description of the output pub schema: DFSchemaRef, } impl Projection { /// Create a new Projection - pub fn try_new(expr: Vec, input: Arc) -> Result { - let projection_schema = projection_schema(&input, &expr)?; + pub fn try_new(expr: Vec, input: Box) -> Result { + let projection_schema = projection_schema(input.as_ref(), &expr)?; Self::try_new_with_schema(expr, input, projection_schema) } /// Create a new Projection using the specified output schema pub fn try_new_with_schema( expr: Vec, - input: Arc, + input: Box, schema: DFSchemaRef, ) -> Result { if expr.len() != schema.fields().len() { @@ -1781,7 +1988,7 @@ impl Projection { } /// Create a new Projection using the specified output schema - pub fn new_from_schema(input: Arc, schema: DFSchemaRef) -> Self { + pub fn new_from_schema(input: Box, schema: DFSchemaRef) -> Self { let expr: Vec = schema .fields() .iter() @@ -1826,7 +2033,7 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result, + pub input: Box, /// The alias for the input relation pub alias: OwnedTableReference, /// The schema with qualified field names @@ -1835,7 +2042,7 @@ pub struct SubqueryAlias { impl SubqueryAlias { pub fn try_new( - plan: Arc, + plan: Box, alias: impl Into, ) -> Result { let alias = alias.into(); @@ -1874,12 +2081,12 @@ pub struct Filter { /// The predicate expression, which must have Boolean type. pub predicate: Expr, /// The incoming logical plan - pub input: Arc, + pub input: Box, } impl Filter { /// Create a new filter operator. - pub fn try_new(predicate: Expr, input: Arc) -> Result { + pub fn try_new(predicate: Expr, input: Box) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and @@ -1976,7 +2183,7 @@ impl Filter { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Window { /// The incoming logical plan - pub input: Arc, + pub input: Box, /// The window function expression pub window_expr: Vec, /// The schema description of the window output @@ -1985,7 +2192,7 @@ pub struct Window { impl Window { /// Create a new window operator. - pub fn try_new(window_expr: Vec, input: Arc) -> Result { + pub fn try_new(window_expr: Vec, input: Box) -> Result { let fields = input.schema().fields(); let input_len = fields.len(); let mut window_fields = fields.clone(); @@ -2148,9 +2355,9 @@ impl TableScan { #[derive(Clone, PartialEq, Eq, Hash)] pub struct CrossJoin { /// Left input - pub left: Arc, + pub left: Box, /// Right input - pub right: Arc, + pub right: Box, /// The output schema, containing fields from the left and right inputs pub schema: DFSchemaRef, } @@ -2159,7 +2366,7 @@ pub struct CrossJoin { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Repartition { /// The incoming logical plan - pub input: Arc, + pub input: Box, /// The partitioning scheme pub partitioning_scheme: Partitioning, } @@ -2168,7 +2375,7 @@ pub struct Repartition { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Union { /// Inputs to merge - pub inputs: Vec>, + pub inputs: Vec, /// Union schema. Should be the same for all inputs. pub schema: DFSchemaRef, } @@ -2182,7 +2389,7 @@ pub struct Prepare { /// Data types of the parameters ([`Expr::Placeholder`]) pub data_types: Vec, /// The logical plan of the statements - pub input: Arc, + pub input: Box, } /// Describe the schema of table @@ -2222,7 +2429,7 @@ pub struct Explain { /// Should extra (detailed, intermediate plans) be included? pub verbose: bool, /// The logical plan that is being EXPLAIN'd - pub plan: Arc, + pub plan: Box, /// Represent the various stages plans have gone through pub stringified_plans: Vec, /// The output schema of the explain (2 columns of text) @@ -2238,7 +2445,7 @@ pub struct Analyze { /// Should extra detail be included? pub verbose: bool, /// The logical plan that is being EXPLAIN ANALYZE'd - pub input: Arc, + pub input: Box, /// The output schema of the explain (2 columns of text) pub schema: DFSchemaRef, } @@ -2272,14 +2479,14 @@ pub struct Limit { /// None means fetching all rows pub fetch: Option, /// The logical plan - pub input: Arc, + pub input: Box, } /// Removes duplicate rows from the input #[derive(Clone, PartialEq, Eq, Hash)] pub enum Distinct { /// Plain `DISTINCT` referencing all selection expressions - All(Arc), + All(Box), /// The `Postgres` addition, allowing separate control over DISTINCT'd and selected columns On(DistinctOn), } @@ -2296,7 +2503,7 @@ pub struct DistinctOn { /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd - pub input: Arc, + pub input: Box, /// The schema description of the DISTINCT ON output pub schema: DFSchemaRef, } @@ -2307,7 +2514,7 @@ impl DistinctOn { on_expr: Vec, select_expr: Vec, sort_expr: Option>, - input: Arc, + input: Box, ) -> Result { if on_expr.is_empty() { return plan_err!("No `ON` expressions provided"); @@ -2373,7 +2580,7 @@ impl DistinctOn { #[non_exhaustive] pub struct Aggregate { /// The incoming logical plan - pub input: Arc, + pub input: Box, /// Grouping expressions pub group_expr: Vec, /// Aggregate expressions @@ -2385,7 +2592,7 @@ pub struct Aggregate { impl Aggregate { /// Create a new aggregate operator. pub fn try_new( - input: Arc, + input: Box, group_expr: Vec, aggr_expr: Vec, ) -> Result { @@ -2419,7 +2626,7 @@ impl Aggregate { /// This method should only be called when you are absolutely sure that the schema being /// provided is correct for the aggregate. If in doubt, call [try_new](Self::try_new) instead. pub fn try_new_with_schema( - input: Arc, + input: Box, group_expr: Vec, aggr_expr: Vec, schema: DFSchemaRef, @@ -2531,7 +2738,7 @@ pub struct Sort { /// The sort expressions pub expr: Vec, /// The incoming logical plan - pub input: Arc, + pub input: Box, /// Optional fetch limit pub fetch: Option, } @@ -2540,9 +2747,9 @@ pub struct Sort { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Join { /// Left input - pub left: Arc, + pub left: Box, /// Right input - pub right: Arc, + pub right: Box, /// Equijoin clause expressed as pairs of (left, right) join expressions pub on: Vec<(Expr, Expr)>, /// Filters applied during join (non-equi conditions) @@ -2561,8 +2768,8 @@ impl Join { /// Create Join with input which wrapped with projection, this method is used to help create physical join. pub fn try_new_with_project_input( original: &LogicalPlan, - left: Arc, - right: Arc, + left: Box, + right: Box, column_on: (Vec, Vec), ) -> Result { let original_join = match original { @@ -2596,7 +2803,7 @@ impl Join { #[derive(Clone, PartialEq, Eq, Hash)] pub struct Subquery { /// The subquery - pub subquery: Arc, + pub subquery: Box, /// The outer references used in the subquery pub outer_ref_columns: Vec, } @@ -2610,7 +2817,7 @@ impl Subquery { } } - pub fn with_plan(&self, plan: Arc) -> Subquery { + pub fn with_plan(&self, plan: Box) -> Subquery { Subquery { subquery: plan, outer_ref_columns: self.outer_ref_columns.clone(), @@ -2641,7 +2848,7 @@ pub enum Partitioning { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Unnest { /// The incoming logical plan - pub input: Arc, + pub input: Box, /// The column to unnest pub column: Column, /// The output schema, containing the unnested field column. @@ -2681,7 +2888,7 @@ mod tests { .build()?; table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))? - .filter(in_subquery(col("state"), Arc::new(plan1)))? + .filter(in_subquery(col("state"), Box::new(plan1)))? .project(vec![col("id")])? .build() } @@ -2718,7 +2925,7 @@ mod tests { fn test_display_subquery_alias() -> Result<()> { let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))? .build()?; - let plan1 = Arc::new(plan1); + let plan1 = Box::new(plan1); let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))? @@ -3005,7 +3212,7 @@ digraph { let empty_schema = Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new())?); let p = Projection::try_new_with_schema( vec![col("a")], - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + Box::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: empty_schema.clone(), })), @@ -3168,7 +3375,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Box::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), source: source.clone(), projection: None, @@ -3198,7 +3405,7 @@ digraph { ) .unwrap(), ); - let scan = Arc::new(LogicalPlan::TableScan(TableScan { + let scan = Box::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), source, projection: None, @@ -3240,7 +3447,7 @@ digraph { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), - Arc::new(LogicalPlan::TableScan(table)), + Box::new(LogicalPlan::TableScan(table)), ) .unwrap(); Ok(Transformed::yes(LogicalPlan::Filter(filter))) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c7907d0db16a..1cf85b2e2e6c 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,7 +19,6 @@ use std::cmp::Ordering; use std::collections::HashSet; -use std::sync::Arc; use crate::expr::{Alias, Sort, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; @@ -1165,7 +1164,7 @@ pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result::new(), vec![count(wildcard())])? .project(vec![count(wildcard())])? @@ -292,7 +292,7 @@ mod tests { let table_scan_t2 = test_table_scan_with_name("t2")?; let plan = LogicalPlanBuilder::from(table_scan_t1) - .filter(exists(Arc::new( + .filter(exists(Box::new( LogicalPlanBuilder::from(table_scan_t2) .aggregate(Vec::::new(), vec![count(wildcard())])? .project(vec![count(wildcard())])? @@ -316,7 +316,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan_t1) .filter( - scalar_subquery(Arc::new( + scalar_subquery(Box::new( LogicalPlanBuilder::from(table_scan_t2) .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? .aggregate( diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index b21ec851dfcd..68547537c0e4 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -17,7 +17,6 @@ //! Analyzed rule to replace TableScan references //! such as DataFrames and Views and inlines the LogicalPlan. -use std::sync::Arc; use crate::analyzer::AnalyzerRule; @@ -87,7 +86,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); let new_plan = plan.transform_up(&analyze_internal).data()?; - let subquery = subquery.with_plan(Arc::new(new_plan)); + let subquery = subquery.with_plan(Box::new(new_plan)); Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } Expr::InSubquery(InSubquery { @@ -97,7 +96,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { }) => { let plan = subquery.subquery.as_ref().clone(); let new_plan = plan.transform_up(&analyze_internal).data()?; - let subquery = subquery.with_plan(Arc::new(new_plan)); + let subquery = subquery.with_plan(Box::new(new_plan)); Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, )))) @@ -105,7 +104,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); let new_plan = plan.transform_up(&analyze_internal).data()?; - let subquery = subquery.with_plan(Arc::new(new_plan)); + let subquery = subquery.with_plan(Box::new(new_plan)); Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } _ => Ok(Transformed::no(expr)), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c76c1c8a7bd0..516a3ce536be 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -131,7 +131,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { }) => { let new_plan = analyze_internal(&self.schema, &subquery)?; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), + subquery: Box::new(new_plan), outer_ref_columns, }))) } @@ -139,7 +139,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { - subquery: Arc::new(new_plan), + subquery: Box::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }, negated, @@ -158,7 +158,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ), )?; let new_subquery = Subquery { - subquery: Arc::new(new_plan), + subquery: Box::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( @@ -770,15 +770,15 @@ mod test { }; use datafusion_physical_expr::expressions::AvgAccumulator; - fn empty() -> Arc { - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + fn empty() -> Box { + Box::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), })) } - fn empty_with_type(data_type: DataType) -> Arc { - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + fn empty_with_type(data_type: DataType) -> Box { + Box::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new( DFSchema::new_with_metadata( @@ -1040,7 +1040,7 @@ mod test { // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); - let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + let empty = Box::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0c9064d0641f..13b75b011c93 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -224,7 +224,7 @@ impl CommonSubexprEliminate { new_window_expr.alias_if_changed(original_name) }) .collect::>>()?; - plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); + plan = LogicalPlan::Window(Window::try_new(new_window_expr, Box::new(plan))?); } Ok(plan) @@ -293,7 +293,7 @@ impl CommonSubexprEliminate { }) .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) + Aggregate::try_new(Box::new(new_input), new_group_expr, new_aggr_expr) .map(LogicalPlan::Aggregate) } else { let mut agg_exprs = vec![]; @@ -335,14 +335,14 @@ impl CommonSubexprEliminate { } let agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), + Box::new(new_input), new_group_expr, agg_exprs, )?); Ok(LogicalPlan::Projection(Projection::try_new( proj_exprs, - Arc::new(agg), + Box::new(agg), )?)) } } @@ -501,7 +501,7 @@ fn build_common_expr_project_plan( Ok(LogicalPlan::Projection(Projection::try_new( project_exprs, - Arc::new(input), + Box::new(input), )?)) } @@ -520,7 +520,7 @@ fn build_recover_project_plan( .collect(); Ok(LogicalPlan::Projection(Projection::try_new( col_exprs, - Arc::new(input), + Box::new(input), )?)) } diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index b94cf37c5c12..3c4ac6313f6b 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -74,7 +74,7 @@ impl DecorrelatePredicateSubquery { }) => { let subquery_plan = self .try_optimize(&subquery.subquery, config)? - .map(Arc::new) + .map(Box::new) .unwrap_or_else(|| subquery.subquery.clone()); let new_subquery = subquery.with_plan(subquery_plan); subqueries.push(SubqueryInfo::new_with_in_expr( @@ -86,7 +86,7 @@ impl DecorrelatePredicateSubquery { Expr::Exists(Exists { subquery, negated }) => { let subquery_plan = self .try_optimize(&subquery.subquery, config)? - .map(Arc::new) + .map(Box::new) .unwrap_or_else(|| subquery.subquery.clone()); let new_subquery = subquery.with_plan(subquery_plan); subqueries.push(SubqueryInfo::new(new_subquery, *negated)); @@ -151,7 +151,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { let expr = conjunction(other_exprs); if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + let new_filter = Filter::try_new(expr, Box::new(cur_input))?; cur_input = LogicalPlan::Filter(new_filter); } Ok(Some(cur_input)) @@ -337,7 +337,7 @@ mod tests { Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), plan, @@ -346,9 +346,9 @@ mod tests { Ok(()) } - fn test_subquery_with_name(name: &str) -> Result> { + fn test_subquery_with_name(name: &str) -> Result> { let table_scan = test_table_scan_with_name(name)?; - Ok(Arc::new( + Ok(Box::new( LogicalPlanBuilder::from(table_scan) .project(vec![col("c")])? .build()?, @@ -377,7 +377,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional AND filter @@ -403,7 +403,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional OR filter @@ -429,7 +429,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -457,7 +457,7 @@ mod tests { \n Projection: sq2.c [c:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for nested IN subqueries @@ -471,7 +471,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("b"), Arc::new(subquery)))? + .filter(in_subquery(col("b"), Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -486,7 +486,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for filter input modification in case filter not supported @@ -518,14 +518,14 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] fn multiple_subqueries() -> Result<()> { - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") @@ -556,7 +556,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -566,7 +566,7 @@ mod tests { /// See subqueries.rs where_in_recursive() #[test] fn recursive_subqueries() -> Result<()> { - let lineitem = Arc::new( + let lineitem = Box::new( LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") @@ -576,7 +576,7 @@ mod tests { .build()?, ); - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( in_subquery(col("orders.o_orderkey"), lineitem).and( @@ -606,7 +606,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -615,7 +615,7 @@ mod tests { /// Test for correlated IN subquery filter with additional subquery filters #[test] fn in_subquery_with_subquery_filters() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -641,7 +641,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -650,7 +650,7 @@ mod tests { /// Test for correlated IN subquery with no columns in schema #[test] fn in_subquery_no_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -674,7 +674,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -683,7 +683,7 @@ mod tests { /// Test for IN subquery with both columns in schema #[test] fn in_subquery_with_no_correlated_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? .project(vec![col("orders.o_custkey")])? @@ -705,7 +705,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -714,7 +714,7 @@ mod tests { /// Test for correlated IN subquery not equal #[test] fn in_subquery_where_not_eq() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -738,7 +738,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -747,7 +747,7 @@ mod tests { /// Test for correlated IN subquery less than #[test] fn in_subquery_where_less_than() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -771,7 +771,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -780,7 +780,7 @@ mod tests { /// Test for correlated IN subquery filter with subquery disjunction #[test] fn in_subquery_with_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -805,7 +805,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); @@ -815,7 +815,7 @@ mod tests { /// Test for correlated IN without projection #[test] fn in_subquery_no_projection() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? .build()?, @@ -838,7 +838,7 @@ mod tests { /// Test for correlated IN subquery join on expression #[test] fn in_subquery_join_expr() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -862,7 +862,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -871,7 +871,7 @@ mod tests { /// Test for correlated IN expressions #[test] fn in_subquery_project_expr() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -895,7 +895,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -904,7 +904,7 @@ mod tests { /// Test for correlated IN subquery multiple projected columns #[test] fn in_subquery_multi_col() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -933,7 +933,7 @@ mod tests { /// Test for correlated IN subquery filter with additional filters #[test] fn should_support_additional_filters() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -961,7 +961,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -970,7 +970,7 @@ mod tests { /// Test for correlated IN subquery filter with disjustions #[test] fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -999,7 +999,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1008,7 +1008,7 @@ mod tests { /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? @@ -1029,7 +1029,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1053,7 +1053,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1077,7 +1077,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1093,7 +1093,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .filter(in_subquery(col("c") + lit(1u32), Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1106,7 +1106,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1127,7 +1127,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .filter(in_subquery(col("c") + lit(1u32), Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1141,7 +1141,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1163,7 +1163,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .filter(in_subquery(col("c") + lit(1u32), Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1177,7 +1177,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1201,8 +1201,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter( - in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and( - in_subquery(col("c") * lit(2u32), Arc::new(subquery2)) + in_subquery(col("c") + lit(1u32), Box::new(subquery1)).and( + in_subquery(col("c") * lit(2u32), Box::new(subquery2)) .and(col("test.c").gt(lit(1u32))), ), )? @@ -1223,7 +1223,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1239,7 +1239,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(outer_scan) - .filter(in_subquery(col("test.a"), Arc::new(subquery)))? + .filter(in_subquery(col("test.a"), Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1254,7 +1254,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1263,7 +1263,7 @@ mod tests { /// Test for multiple exists subqueries in the same filter expression #[test] fn multiple_exists_subqueries() -> Result<()> { - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") @@ -1288,13 +1288,13 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test recursive correlated subqueries #[test] fn recursive_exists_subqueries() -> Result<()> { - let lineitem = Arc::new( + let lineitem = Box::new( LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") @@ -1304,7 +1304,7 @@ mod tests { .build()?, ); - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( exists(lineitem).and( @@ -1331,13 +1331,13 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters #[test] fn exists_subquery_with_subquery_filters() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1361,12 +1361,12 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] fn exists_subquery_no_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(out_ref_col(DataType::Int64, "customer.c_custkey").eq(lit(1u32)))? .project(vec![col("orders.o_custkey")])? @@ -1386,13 +1386,13 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for exists subquery with both columns in schema #[test] fn exists_subquery_with_no_correlated_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? .project(vec![col("orders.o_custkey")])? @@ -1404,13 +1404,13 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for correlated exists subquery not equal #[test] fn exists_subquery_where_not_eq() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1432,13 +1432,13 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery less than #[test] fn exists_subquery_where_less_than() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1460,13 +1460,13 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction #[test] fn exists_subquery_with_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1489,13 +1489,13 @@ mod tests { \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists without projection #[test] fn exists_subquery_no_projection() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1515,13 +1515,13 @@ mod tests { \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists expressions #[test] fn exists_subquery_project_expr() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1543,13 +1543,13 @@ mod tests { \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional filters #[test] fn exists_subquery_should_support_additional_filters() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1571,13 +1571,13 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with disjustions #[test] fn exists_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? .project(vec![col("orders.o_custkey")])? @@ -1598,13 +1598,13 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated EXISTS subquery filter #[test] fn exists_subquery_correlated() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? .project(vec![col("c")])? @@ -1623,7 +1623,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for single exists subquery filter @@ -1635,7 +1635,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for single NOT exists subquery filter @@ -1647,7 +1647,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } #[test] @@ -1668,8 +1668,8 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .filter( - exists(Arc::new(subquery1)) - .and(exists(Arc::new(subquery2)).and(col("test.c").gt(lit(1u32)))), + exists(Box::new(subquery1)) + .and(exists(Box::new(subquery2)).and(col("test.c").gt(lit(1u32)))), )? .project(vec![col("test.b")])? .build()?; @@ -1686,7 +1686,7 @@ mod tests { \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1701,7 +1701,7 @@ mod tests { .project(vec![lit(1u32)])? .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1712,7 +1712,7 @@ mod tests { \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1725,7 +1725,7 @@ mod tests { .build()?; let plan = LogicalPlanBuilder::from(outer_scan) - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1738,7 +1738,7 @@ mod tests { \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1754,7 +1754,7 @@ mod tests { .distinct()? .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1766,7 +1766,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1782,7 +1782,7 @@ mod tests { .distinct()? .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1794,7 +1794,7 @@ mod tests { \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1810,7 +1810,7 @@ mod tests { .distinct()? .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .project(vec![col("test.b")])? .build()?; @@ -1822,6 +1822,6 @@ mod tests { \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7f65690a4a7c..9f3ae5220483 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -116,7 +116,7 @@ impl OptimizerRule for EliminateCrossJoin { if plan.schema() != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(left), + Box::new(left), plan.schema().clone(), )); } @@ -127,12 +127,12 @@ impl OptimizerRule for EliminateCrossJoin { // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) + Filter::try_new(predicate.clone(), Box::new(left)) .map(|f| Some(LogicalPlan::Filter(f))) } else { // Remove join expressions from filter: match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) + Some(filter_expr) => Filter::try_new(filter_expr, Box::new(left)) .map(|f| Some(LogicalPlan::Filter(f))), _ => Ok(Some(left)), } @@ -225,8 +225,8 @@ fn find_inner_join( )?); return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), - right: Arc::new(right_input), + left: Box::new(left_input.clone()), + right: Box::new(right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, on: join_keys, @@ -244,8 +244,8 @@ fn find_inner_join( )?); Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), - right: Arc::new(right), + left: Box::new(left_input.clone()), + right: Box::new(right), schema: join_schema, })) } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index de05717a72e2..fae0eb5c8b1d 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -114,7 +114,7 @@ mod tests { use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(EliminateDuplicatedExpr::new()), plan, @@ -132,7 +132,7 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -151,6 +151,6 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index fea14342ca77..9287752a3f99 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -88,7 +88,7 @@ mod tests { use crate::test::*; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -104,7 +104,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -119,7 +119,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -141,7 +141,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -156,7 +156,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +179,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -202,6 +202,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 0dbebcc8a051..f4123c6503e8 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -82,7 +82,7 @@ mod tests { use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) } @@ -97,7 +97,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -114,6 +114,6 @@ mod tests { CrossJoin:\ \n EmptyRelation\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 4386253740aa..4c5089e85f00 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -93,7 +93,8 @@ mod tests { use crate::push_down_limit::PushDownLimit; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { + let plan_schema = plan.schema().clone(); let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); let optimized_plan = optimizer .optimize_recursively( @@ -101,18 +102,19 @@ mod tests { plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); + .data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(plan_schema.as_ref(), optimized_plan.schema().as_ref()); Ok(()) } fn assert_optimized_plan_eq_with_pushdown( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { + let plan_schema = plan.schema().clone(); fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = OptimizerContext::new().with_max_passes(1); let optimizer = Optimizer::with_rules(vec![ @@ -121,10 +123,11 @@ mod tests { ]); let optimized_plan = optimizer .optimize(plan, &config, observe) - .expect("failed to optimize plan"); + .expect("failed to optimize plan") + .data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(&plan_schema, optimized_plan.schema()); Ok(()) } @@ -137,7 +140,7 @@ mod tests { .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -157,7 +160,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -171,7 +174,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -191,7 +194,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -209,7 +212,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -227,7 +230,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -249,7 +252,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n TableScan: test\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -262,6 +265,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 5771ea2e19a2..516ba4701c4f 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -21,7 +21,6 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::{Distinct, LogicalPlan, Union}; -use std::sync::Arc; #[derive(Default)] /// An optimization rule that replaces nested unions with a single union. @@ -60,7 +59,7 @@ impl OptimizerRule for EliminateNestedUnion { .flat_map(extract_plans_from_union) .collect::>(); - Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Some(LogicalPlan::Distinct(Distinct::All(Box::new( LogicalPlan::Union(Union { inputs, schema: schema.clone(), @@ -82,25 +81,27 @@ impl OptimizerRule for EliminateNestedUnion { } } -fn extract_plans_from_union(plan: &Arc) -> Vec> { - match plan.as_ref() { +fn extract_plans_from_union(plan: &LogicalPlan) -> Vec { + match plan { LogicalPlan::Union(Union { inputs, schema }) => inputs .iter() - .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) + .map(|plan| coerce_plan_expr_for_schema(plan, schema).unwrap()) .collect::>(), _ => vec![plan.clone()], } } -fn extract_plan_from_distinct(plan: &Arc) -> &Arc { - match plan.as_ref() { - LogicalPlan::Distinct(Distinct::All(plan)) => plan, +fn extract_plan_from_distinct(plan: &LogicalPlan) -> &LogicalPlan { + match plan { + LogicalPlan::Distinct(Distinct::All(plan)) => plan.as_ref(), _ => plan, } } #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; @@ -114,7 +115,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) } @@ -131,7 +132,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -147,7 +148,7 @@ mod tests { \n Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -167,7 +168,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -188,7 +189,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -210,7 +211,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -230,7 +231,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // We don't need to use project_with_column_index in logical optimizer, @@ -261,7 +262,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -291,7 +292,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -337,7 +338,7 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -384,6 +385,6 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 70ee490346ff..f892cddc32b4 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -41,7 +41,7 @@ impl OptimizerRule for EliminateOneUnion { ) -> Result> { match plan { LogicalPlan::Union(Union { inputs, .. }) if inputs.len() == 1 => { - Ok(inputs.first().map(|input| input.as_ref().clone())) + Ok(inputs.first().cloned()) } _ => Ok(None), } @@ -76,7 +76,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, @@ -97,7 +97,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -108,11 +108,11 @@ mod tests { )?; let schema = table_plan.schema().clone(); let single_union_plan = LogicalPlan::Union(Union { - inputs: vec![Arc::new(table_plan)], + inputs: vec![table_plan], schema, }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&single_union_plan, expected) + assert_optimized_plan_equal(single_union_plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 56a4a76987f7..33ee871c2112 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -97,8 +97,8 @@ impl OptimizerRule for EliminateOuterJoin { join.join_type }; let new_join = LogicalPlan::Join(Join { - left: Arc::new((*join.left).clone()), - right: Arc::new((*join.right).clone()), + left: Box::new((*join.left).clone()), + right: Box::new((*join.right).clone()), join_type: new_join_type, join_constraint: join.join_constraint, on: join.on.clone(), @@ -306,7 +306,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -330,7 +330,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -353,7 +353,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -380,7 +380,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -407,7 +407,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -434,6 +434,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 24664d57c38d..efe92e2702b3 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -164,7 +164,7 @@ mod tests { col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(ExtractEquijoinPredicate {}), plan, @@ -186,7 +186,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -205,7 +205,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -228,7 +228,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -255,7 +255,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -281,7 +281,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -318,7 +318,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -375,6 +375,6 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 95cd8a9fd36c..0d967e69141f 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -26,7 +26,6 @@ use datafusion_common::Result; use datafusion_expr::{ and, logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan, }; -use std::sync::Arc; /// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions /// where the join key is nullable on one side and non-nullable on the other side @@ -71,14 +70,14 @@ impl OptimizerRule for FilterNullJoinKeys { if !left_filters.is_empty() { let predicate = create_not_null_predicate(left_filters); - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.left = Box::new(LogicalPlan::Filter(Filter::try_new( predicate, join.left.clone(), )?)); } if !right_filters.is_empty() { let predicate = create_not_null_predicate(right_filters); - join.right = Arc::new(LogicalPlan::Filter(Filter::try_new( + join.right = Box::new(LogicalPlan::Filter(Filter::try_new( predicate, join.right.clone(), )?)); @@ -112,6 +111,8 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::test::assert_optimized_plan_eq; use arrow::datatypes::{DataType, Field, Schema}; @@ -119,7 +120,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -131,7 +132,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -142,7 +143,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +180,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -200,7 +201,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -221,7 +222,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -244,7 +245,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 08ee38f64abd..809695c15c4e 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -257,7 +257,7 @@ fn optimize_projections( // Create a new aggregate plan with the updated input and only the // absolutely necessary fields: return Aggregate::try_new( - Arc::new(aggregate_input), + Box::new(aggregate_input), new_group_bys, new_aggr_expr, ) @@ -304,7 +304,7 @@ fn optimize_projections( get_required_exprs(window.input.schema(), &required_indices); let (window_child, _) = add_projection_on_top_if_helpful(window_child, required_exprs)?; - Window::try_new(new_window_expr, Arc::new(window_child)) + Window::try_new(new_window_expr, Box::new(window_child)) .map(|window| Some(LogicalPlan::Window(window))) }; } @@ -834,7 +834,7 @@ fn add_projection_on_top_if_helpful( if project_exprs.len() >= plan.schema().fields().len() { Ok((plan, false)) } else { - Projection::try_new(project_exprs, Arc::new(plan)) + Projection::try_new(project_exprs, Box::new(plan)) .map(|proj| (LogicalPlan::Projection(proj), true)) } } @@ -870,7 +870,7 @@ fn rewrite_projection_given_requirements( if is_projection_unnecessary(&input, &exprs_used)? { Ok(Some(input)) } else { - Projection::try_new(exprs_used, Arc::new(input)) + Projection::try_new(exprs_used, Box::new(input)) .map(|proj| Some(LogicalPlan::Projection(proj))) } } else if exprs_used.len() < proj.expr.len() { @@ -910,7 +910,7 @@ mod tests { table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -924,7 +924,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -938,7 +938,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -951,7 +951,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -964,7 +964,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -986,7 +986,7 @@ mod tests { \n Projection: \ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1009,7 +1009,7 @@ mod tests { .build()?; let expected = "Projection: (?table?.s)[x]\ \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1021,7 +1021,7 @@ mod tests { let expected = "Projection: (- test.a)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1033,7 +1033,7 @@ mod tests { let expected = "Projection: test.a IS NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1045,7 +1045,7 @@ mod tests { let expected = "Projection: test.a IS NOT NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1057,7 +1057,7 @@ mod tests { let expected = "Projection: test.a IS TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1069,7 +1069,7 @@ mod tests { let expected = "Projection: test.a IS NOT TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1081,7 +1081,7 @@ mod tests { let expected = "Projection: test.a IS FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1093,7 +1093,7 @@ mod tests { let expected = "Projection: test.a IS NOT FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1105,7 +1105,7 @@ mod tests { let expected = "Projection: test.a IS UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1117,7 +1117,7 @@ mod tests { let expected = "Projection: test.a IS NOT UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1129,7 +1129,7 @@ mod tests { let expected = "Projection: NOT test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1141,7 +1141,7 @@ mod tests { let expected = "Projection: TRY_CAST(test.a AS Float64)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1157,7 +1157,7 @@ mod tests { let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1169,7 +1169,7 @@ mod tests { let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Test outer projection isn't discarded despite the same schema as inner @@ -1190,6 +1190,6 @@ mod tests { let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ \n Projection: test.a, Int32(0) AS d\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index fe63766fc265..81f22edab32c 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -48,7 +48,8 @@ use crate::utils::log_plan; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{DFSchemaRef, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; @@ -76,6 +77,19 @@ pub trait OptimizerRule { config: &dyn OptimizerConfig, ) -> Result>; + fn try_optimize_owned( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let optimized_plan = self.try_optimize(&plan, config)?; + if let Some(optimized_plan) = optimized_plan { + Ok(Transformed::yes(optimized_plan)) + } else { + Ok(Transformed::no(plan)) + } + } + /// A human readable name for this optimizer rule fn name(&self) -> &str; @@ -85,6 +99,10 @@ pub trait OptimizerRule { fn apply_order(&self) -> Option { None } + + fn support_owned(&self) -> bool { + false + } } /// Options to control the DataFusion Optimizer. @@ -279,72 +297,88 @@ impl Optimizer { /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, - ) -> Result + ) -> Result> where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let options = config.options(); - let mut new_plan = plan.clone(); + let mut cur_plan = plan.clone(); let start_time = Instant::now(); + let mut is_transformed = false; let mut previous_plans = HashSet::with_capacity(16); - previous_plans.insert(LogicalPlanSignature::new(&new_plan)); + previous_plans.insert(LogicalPlanSignature::new(&cur_plan)); let mut i = 0; while i < options.optimizer.max_passes { - log_plan(&format!("Optimizer input (pass {i})"), &new_plan); + log_plan(&format!("Optimizer input (pass {i})"), &cur_plan); for rule in &self.rules { + let prev_plan = if options.optimizer.skip_failed_rules { + Some(cur_plan.clone()) + } else { + None + }; + + let prev_schema = cur_plan.schema().clone(); + let result = - self.optimize_recursively(rule, &new_plan, config) + self.optimize_recursively(rule, cur_plan, config) .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } + assert_schema_is_the_same( + rule.name(), + prev_schema, + &plan.data, + )?; Ok(plan) }); - match result { - Ok(Some(plan)) => { - new_plan = plan; - observer(&new_plan, rule.as_ref()); - log_plan(rule.name(), &new_plan); + + match (result, prev_plan) { + (Ok(t), _) if t.transformed => { + is_transformed = true; + cur_plan = t.data; + observer(&cur_plan, rule.as_ref()); + log_plan(rule.name(), &cur_plan); } - Ok(None) => { - observer(&new_plan, rule.as_ref()); + + (Ok(t), _) => { + cur_plan = t.data; + observer(&cur_plan, rule.as_ref()); debug!( "Plan unchanged by optimizer rule '{}' (pass {})", rule.name(), i ); } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + (Err(e), Some(prev_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + + cur_plan = prev_plan; + } + (Err(e), None) => { + return Err(DataFusionError::Context( + format!("Optimizer rule '{}' failed", rule.name(),), + Box::new(e), + )); } } } - log_plan(&format!("Optimized plan (pass {i})"), &new_plan); + log_plan(&format!("Optimized plan (pass {i})"), &cur_plan); // HashSet::insert returns, whether the value was newly inserted. let plan_is_fresh = - previous_plans.insert(LogicalPlanSignature::new(&new_plan)); + previous_plans.insert(LogicalPlanSignature::new(&cur_plan)); if !plan_is_fresh { // plan did not change, so no need to continue trying to optimize debug!("optimizer pass {} did not make changes", i); @@ -352,47 +386,56 @@ impl Optimizer { } i += 1; } - log_plan("Final optimized plan", &new_plan); + log_plan("Final optimized plan", &cur_plan); debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); - Ok(new_plan) + + Ok(if is_transformed { + Transformed::yes(cur_plan) + } else { + Transformed::no(cur_plan) + }) } fn optimize_node( &self, rule: &Arc, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) + if rule.support_owned() { + rule.try_optimize_owned(plan, config) + } else { + rule.try_optimize(&plan, config).map(|opt_plan| { + if let Some(opt_plan) = opt_plan { + Transformed::yes(opt_plan) + } else { + Transformed::no(plan) + } + }) + } } fn optimize_inputs( &self, rule: &Arc, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); + ) -> Result> { + let mut is_transformed = false; + let inputs = plan.rewrite_inputs(|child| { + let t = self.optimize_recursively(rule, child, config)?; + if t.transformed { + is_transformed = true; + } + Ok(t.data) + })?; - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) + if is_transformed { + Ok(Transformed::yes(inputs)) + } else { + Ok(Transformed::no(inputs)) + } } /// Use a rule to optimize the whole plan. @@ -400,33 +443,35 @@ impl Optimizer { pub fn optimize_recursively( &self, rule: &Arc, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match rule.apply_order() { Some(order) => match order { ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) + let optimizied_node = self.optimize_node(rule, plan, config)?; + let optimized_inputs = + self.optimize_inputs(rule, optimizied_node.data, config)?; + + if optimizied_node.transformed || optimized_inputs.transformed { + Ok(Transformed::yes(optimized_inputs.data)) + } else { + Ok(Transformed::no(optimized_inputs.data)) + } } ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) + let optimized_inputs = self.optimize_inputs(rule, plan, config)?; + let optimized_node = + self.optimize_node(rule, optimized_inputs.data, config)?; + + if optimized_node.transformed || optimized_inputs.transformed { + Ok(Transformed::yes(optimized_node.data)) + } else { + Ok(Transformed::no(optimized_node.data)) + } } }, - _ => rule.try_optimize(plan, config), + _ => self.optimize_node(rule, plan, config), } } } @@ -436,17 +481,17 @@ impl Optimizer { /// It ignores metadata and nullability. pub(crate) fn assert_schema_is_the_same( rule_name: &str, - prev_plan: &LogicalPlan, + prev_schema: DFSchemaRef, new_plan: &LogicalPlan, ) -> Result<()> { let equivalent = new_plan .schema() - .equivalent_names_and_types(prev_plan.schema()); + .equivalent_names_and_types(prev_schema.as_ref()); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + prev_schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -479,7 +524,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -490,7 +535,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", @@ -506,7 +551,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ Internal error: Failed due to a difference in schemas, \ @@ -529,7 +574,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -539,7 +584,7 @@ mod tests { let opt = Optimizer::with_rules(vec![Arc::new(GetTableScanRule {})]); let config = OptimizerContext::new().with_skip_failing_rules(false); - let input = Arc::new(test_table_scan()?); + let input = Box::new(test_table_scan()?); let input_schema = input.schema().clone(); let plan = LogicalPlan::Projection(Projection::try_new_with_schema( @@ -550,9 +595,10 @@ mod tests { // optimizing should be ok, but the schema will have changed (no metadata) assert_ne!(plan.schema().as_ref(), input_schema.as_ref()); - let optimized_plan = opt.optimize(&plan, &config, &observe)?; + let optimized_plan = opt.optimize(plan, &config, &observe)?; // metadata was removed - assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref()); + assert!(optimized_plan.transformed); + assert_eq!(optimized_plan.data.schema().as_ref(), input_schema.as_ref()); Ok(()) } @@ -571,13 +617,14 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan.clone(), &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 3 plans assert_eq!(3, plans.len()); // we got again the initial_plan with [1, 2, 3] - assert_eq!(initial_plan, final_plan); + assert!(final_plan.transformed); + assert_eq!(initial_plan, final_plan.data); Ok(()) } @@ -597,13 +644,14 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 4 plans assert_eq!(4, plans.len()); // we got again the plan with [3, 2, 1] - assert_eq!(plans[0], final_plan); + assert!(final_plan.transformed); + assert_eq!(plans[0], final_plan.data); Ok(()) } diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 4143d52a053e..7cf0b346c856 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -98,16 +98,16 @@ mod tests { let schema = Arc::new(DFSchema::empty()); let one_node_plan = - Arc::new(LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { + Box::new(LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { produce_one_row: false, schema: schema.clone(), })); assert_eq!(1, get_node_number(&one_node_plan).get()); - let two_node_plan = Arc::new(LogicalPlan::Projection( + let two_node_plan = LogicalPlan::Projection( datafusion_expr::Projection::try_new(vec![lit(1), lit(2)], one_node_plan)?, - )); + ); assert_eq!(2, get_node_number(&two_node_plan).get()); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index d1f9f87a32a3..e99fdd0ad51b 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -18,7 +18,6 @@ use datafusion_common::{plan_err, Result}; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; -use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -89,7 +88,7 @@ impl OptimizerRule for PropagateEmptyRelation { let new_inputs = union .inputs .iter() - .filter(|input| match &***input { + .filter(|input| match input { LogicalPlan::EmptyRelation(empty) => empty.produce_one_row, _ => true, }) @@ -104,13 +103,13 @@ impl OptimizerRule for PropagateEmptyRelation { schema: plan.schema().clone(), }))); } else if new_inputs.len() == 1 { - let child = (*new_inputs[0]).clone(); + let child = new_inputs[0].clone(); if child.schema().eq(plan.schema()) { return Ok(Some(child)); } else { return Ok(Some(LogicalPlan::Projection( Projection::new_from_schema( - Arc::new(child), + Box::new(child), plan.schema().clone(), ), ))); @@ -181,6 +180,8 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ @@ -197,12 +198,12 @@ mod tests { use super::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } fn assert_together_optimized_plan_eq( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { assert_optimized_plan_eq_with_rules( @@ -225,7 +226,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_eq(plan, expected) } #[test] @@ -248,7 +249,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -261,7 +262,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -286,7 +287,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -311,7 +312,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -338,7 +339,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -351,7 +352,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -366,7 +367,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -399,6 +400,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e93e171e0324..6290842c266e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -15,12 +15,10 @@ //! [`PushDownFilter`] Moves filters so they are applied as early as possible in //! the plan. -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; +use std::collections::{HashMap, HashSet}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, @@ -482,13 +480,13 @@ fn push_down_all_join( let left = match conjunction(left_push) { Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) + LogicalPlan::Filter(Filter::try_new(predicate, Box::new(left.clone()))?) } None => left.clone(), }; let right = match conjunction(right_push) { Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) + LogicalPlan::Filter(Filter::try_new(predicate, Box::new(right.clone()))?) } None => right.clone(), }; @@ -510,7 +508,7 @@ fn push_down_all_join( // wrap the join on the filter whose predicates must be kept match conjunction(keep_predicates) { Some(predicate) => { - Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter) + Filter::try_new(predicate, Box::new(plan)).map(LogicalPlan::Filter) } None => Ok(plan), } @@ -748,7 +746,7 @@ impl OptimizerRule for PushDownFilter { )?; LogicalPlan::Filter(Filter::try_new( keep_predicate, - Arc::new(child_plan), + Box::new(child_plan), )?) } } @@ -769,10 +767,10 @@ impl OptimizerRule for PushDownFilter { let push_predicate = replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( + inputs.push(LogicalPlan::Filter(Filter::try_new( push_predicate, - input.clone(), - )?))) + Box::new(input.clone()), + )?)) } LogicalPlan::Union(Union { inputs, @@ -825,7 +823,7 @@ impl OptimizerRule for PushDownFilter { match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, - Arc::new(new_agg), + Box::new(new_agg), )?), None => new_agg, } @@ -896,7 +894,7 @@ impl OptimizerRule for PushDownFilter { match conjunction(new_predicate) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, - Arc::new(new_scan), + Box::new(new_scan), )?), None => new_scan, } @@ -926,7 +924,7 @@ impl OptimizerRule for PushDownFilter { .map(|child| { Ok(LogicalPlan::Filter(Filter::try_new( predicate.clone(), - Arc::new(child.clone()), + Box::new(child.clone()), )?)) }) .collect::>>()?, @@ -939,7 +937,7 @@ impl OptimizerRule for PushDownFilter { match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, - Arc::new(new_extension), + Box::new(new_extension), )?), None => new_extension, } @@ -987,7 +985,7 @@ fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result } else if let LogicalPlan::Filter(filter) = &plan { let new_input = convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; - return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) + return Filter::try_new(filter.predicate.clone(), Box::new(new_input)) .map(LogicalPlan::Filter); } Ok(plan) @@ -1053,7 +1051,7 @@ mod tests { use async_trait::async_trait; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(PushDownFilter::new()), plan, @@ -1062,29 +1060,30 @@ mod tests { } fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(PushDownFilter::new()), ]); + let plan_schema = plan.schema().clone(); let mut optimized_plan = optimizer .optimize_recursively( optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); + .data; optimized_plan = optimizer .optimize_recursively( optimizer.rules.get(1).unwrap(), - &optimized_plan, + optimized_plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); + .data; let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(plan_schema.as_ref(), optimized_plan.schema().as_ref()); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1100,7 +1099,7 @@ mod tests { let expected = "\ Projection: test.a, test.b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1117,7 +1116,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1127,7 +1126,7 @@ mod tests { .filter(lit(0i64).eq(lit(1i64)))? .build()?; let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1143,7 +1142,7 @@ mod tests { Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1157,7 +1156,7 @@ mod tests { let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1170,7 +1169,7 @@ mod tests { let expected = "Filter: test.b > Int64(10)\ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1182,7 +1181,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1197,7 +1196,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1212,7 +1211,7 @@ mod tests { let expected = "\ Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1256,7 +1255,7 @@ mod tests { let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1288,7 +1287,7 @@ mod tests { Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1351,7 +1350,7 @@ mod tests { let expected = "\ NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1368,7 +1367,7 @@ mod tests { Filter: test.c = Int64(2)\ \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1385,7 +1384,7 @@ mod tests { NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1403,7 +1402,7 @@ mod tests { \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1436,7 +1435,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1470,7 +1469,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1492,7 +1491,7 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1507,7 +1506,7 @@ mod tests { let expected = "Union\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1530,7 +1529,7 @@ mod tests { \n SubqueryAlias: test2\ \n Projection: test.a AS b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1561,7 +1560,7 @@ mod tests { \n Projection: test1.d, test1.e, test1.f\ \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1587,7 +1586,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1621,7 +1620,7 @@ mod tests { \n Projection: test.a\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1651,7 +1650,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1677,7 +1676,7 @@ mod tests { TestUserDefined\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1715,7 +1714,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1752,7 +1751,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from both sides are converted to join filterss @@ -1794,7 +1793,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1836,7 +1835,7 @@ mod tests { \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1875,7 +1874,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the left side of a right join are not duplicated @@ -1913,7 +1912,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1951,7 +1950,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1989,7 +1988,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2032,7 +2031,7 @@ mod tests { \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// join filter should be completely removed after pushdown @@ -2074,7 +2073,7 @@ mod tests { \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2114,7 +2113,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -2157,7 +2156,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -2200,7 +2199,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -2238,7 +2237,7 @@ mod tests { ); let expected = &format!("{plan:?}"); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } struct PushDownProvider { @@ -2297,7 +2296,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2308,7 +2307,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2327,7 +2326,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected) + assert_optimized_plan_eq(optimised_plan, expected) } #[test] @@ -2338,7 +2337,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2367,7 +2366,7 @@ mod tests { \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2398,7 +2397,7 @@ Projection: a, b "# .trim(); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2426,7 +2425,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2458,7 +2457,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2483,7 +2482,7 @@ Projection: a, b Projection: test.a AS b, test.c AS d\ \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2523,7 +2522,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2552,7 +2551,7 @@ Projection: a, b Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2584,7 +2583,7 @@ Projection: a, b \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2593,7 +2592,7 @@ Projection: a, b // but we rename it as 'b', and use col 'b' in subquery filter let table_scan = test_table_scan()?; let table_scan_sq = test_table_scan_with_name("sq")?; - let subplan = Arc::new( + let subplan = Box::new( LogicalPlanBuilder::from(table_scan_sq) .project(vec![col("c")])? .build()?, @@ -2620,7 +2619,7 @@ Projection: a, b \n Subquery:\ \n Projection: sq.c\ \n TableScan: sq"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2653,7 +2652,7 @@ Projection: a, b \n Projection: Int64(0) AS a\ \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2681,14 +2680,14 @@ Projection: a, b \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); - assert_optimized_plan_eq(&optimized_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2729,7 +2728,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2770,7 +2769,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2816,7 +2815,7 @@ Projection: a, b \n TableScan: test1\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2861,7 +2860,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2894,7 +2893,7 @@ Projection: a, b \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2936,6 +2935,6 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 33d02d5c5628..bfcd471aedc7 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -18,8 +18,6 @@ //! Optimizer rule to push down LIMIT in the query plan //! It will push down through projection, limits (taking the smaller limit) -use std::sync::Arc; - use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; @@ -96,7 +94,7 @@ impl OptimizerRule for PushDownLimit { let plan = LogicalPlan::Limit(Limit { skip: child.skip + parent_skip, fetch: new_fetch, - input: Arc::new((*child.input).clone()), + input: Box::new((*child.input).clone()), }); return self .try_optimize(&plan, _config) @@ -132,13 +130,13 @@ impl OptimizerRule for PushDownLimit { .inputs .iter() .map(|x| { - Ok(Arc::new(LogicalPlan::Limit(Limit { + LogicalPlan::Limit(Limit { skip: 0, fetch: Some(fetch + skip), - input: x.clone(), - }))) + input: Box::new(x.clone()), + }) }) - .collect::>()?; + .collect::>(); let union = LogicalPlan::Union(Union { inputs: new_inputs, schema: union.schema.clone(), @@ -159,8 +157,8 @@ impl OptimizerRule for PushDownLimit { input: cross_join.right.clone(), }); let new_cross_join = LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(new_left), - right: Arc::new(new_right), + left: Box::new(new_left), + right: Box::new(new_right), schema: plan.schema().clone(), }); plan.with_new_exprs(plan.expressions(), vec![new_cross_join]) @@ -242,7 +240,7 @@ fn push_down_join(join: &Join, limit: usize) -> Option { (None, None) => None, _ => { let left = match left_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { + Some(limit) => Box::new(LogicalPlan::Limit(Limit { skip: 0, fetch: Some(limit), input: join.left.clone(), @@ -250,7 +248,7 @@ fn push_down_join(join: &Join, limit: usize) -> Option { None => join.left.clone(), }; let right = match right_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { + Some(limit) => Box::new(LogicalPlan::Limit(Limit { skip: 0, fetch: Some(limit), input: join.right.clone(), @@ -273,7 +271,7 @@ fn push_down_join(join: &Join, limit: usize) -> Option { #[cfg(test)] mod test { - use std::vec; + use std::{sync::Arc, vec}; use super::*; use crate::test::*; @@ -284,7 +282,7 @@ mod test { max, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } @@ -303,7 +301,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -321,7 +319,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -338,7 +336,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -358,7 +356,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -375,7 +373,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -392,7 +390,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -411,7 +409,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +424,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -444,7 +442,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +459,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -478,7 +476,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -494,7 +492,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -511,7 +509,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -531,7 +529,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +553,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -579,7 +577,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -594,7 +592,7 @@ mod test { let outer_query = LogicalPlanBuilder::from(table_scan_2) .project(vec![col("a")])? - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .limit(10, Some(100))? .build()?; @@ -608,7 +606,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -623,7 +621,7 @@ mod test { let outer_query = LogicalPlanBuilder::from(table_scan_2) .project(vec![col("a")])? - .filter(exists(Arc::new(subquery)))? + .filter(exists(Box::new(subquery)))? .limit(10, Some(100))? .build()?; @@ -637,7 +635,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -663,7 +661,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -682,7 +680,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -701,7 +699,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -719,7 +717,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -737,7 +735,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -755,7 +753,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -773,7 +771,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -798,7 +796,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -823,7 +821,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -848,7 +846,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -873,7 +871,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -893,7 +891,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -913,7 +911,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -928,7 +926,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -943,7 +941,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -960,6 +958,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index 28b3ff090fe6..e2c01fd45545 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -51,7 +51,7 @@ mod tests { let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -65,7 +65,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -81,7 +81,7 @@ mod tests { \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -98,7 +98,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -123,7 +123,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -137,7 +137,7 @@ mod tests { let expected = "Projection: test.a, test.c, test.b\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; let expected = "TableScan: test projection=[b, a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -160,7 +160,7 @@ mod tests { let expected = "Projection: test.a, test.b\ \n TableScan: test projection=[b, a]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -173,7 +173,7 @@ mod tests { let expected = "Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -195,7 +195,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -215,7 +215,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -258,7 +258,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -299,7 +299,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -334,7 +334,7 @@ mod tests { let expected = "Projection: CAST(test.c AS Float64)\ \n TableScan: test projection=[c]"; - assert_optimized_plan_eq(&projection, expected) + assert_optimized_plan_eq(projection, expected) } #[test] @@ -350,7 +350,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -365,13 +365,13 @@ mod tests { // relation is `None`). PlanBuilder resolves the expressions let expr = vec![col("test.a"), col("test.b")]; let plan = - LogicalPlan::Projection(Projection::try_new(expr, Arc::new(table_scan))?); + LogicalPlan::Projection(Projection::try_new(expr, Box::new(table_scan))?); assert_fields_eq(&plan, vec!["a", "b"]); let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -391,7 +391,7 @@ mod tests { \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -400,7 +400,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -411,7 +411,7 @@ mod tests { .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes unused columns in projections @@ -430,14 +430,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - let plan = optimize(&plan).expect("failed to optimize plan"); + let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes un-needed projections @@ -459,7 +459,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -488,7 +488,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that optimizing twice yields same plan @@ -501,12 +501,12 @@ mod tests { .project(vec![lit(1).alias("a")])? .build()?; - let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); - let optimized_plan2 = - optimize(&optimized_plan1).expect("failed to optimize plan"); - + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); let formatted_plan1 = format!("{optimized_plan1:?}"); + + let optimized_plan2 = optimize(optimized_plan1).expect("failed to optimize plan"); let formatted_plan2 = format!("{optimized_plan2:?}"); + assert_eq!(formatted_plan1, formatted_plan2); Ok(()) } @@ -532,7 +532,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -558,7 +558,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -575,7 +575,7 @@ mod tests { \n Distinct:\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -614,17 +614,17 @@ mod tests { \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } - fn optimize(plan: &LogicalPlan) -> Result { + fn optimize(plan: LogicalPlan) -> Result { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); let optimized_plan = optimizer .optimize_recursively( @@ -632,7 +632,7 @@ mod tests { plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); + .data; Ok(optimized_plan) } } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 0666c324d12c..c4b70e929831 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -174,7 +174,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } @@ -197,7 +197,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 8acc36e479ca..d7b27af9234b 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -392,7 +392,7 @@ mod tests { /// Test multiple correlated subqueries #[test] fn multiple_subqueries() -> Result<()> { - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") @@ -436,7 +436,7 @@ mod tests { /// Test recursive correlated subqueries #[test] fn recursive_subqueries() -> Result<()> { - let lineitem = Arc::new( + let lineitem = Box::new( LogicalPlanBuilder::from(scan_tpch_table("lineitem")) .filter( col("lineitem.l_orderkey") @@ -450,7 +450,7 @@ mod tests { .build()?, ); - let orders = Arc::new( + let orders = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( col("orders.o_custkey") @@ -492,7 +492,7 @@ mod tests { /// Test for correlated scalar subquery filter with additional subquery filters #[test] fn scalar_subquery_with_subquery_filters() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -530,7 +530,7 @@ mod tests { /// Test for correlated scalar subquery with no columns in schema #[test] fn scalar_subquery_no_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -566,7 +566,7 @@ mod tests { /// Test for scalar subquery with both columns in schema #[test] fn scalar_subquery_with_no_correlated_cols() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("orders.o_custkey").eq(col("orders.o_custkey")))? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? @@ -600,7 +600,7 @@ mod tests { /// Test for correlated scalar subquery not equal #[test] fn scalar_subquery_where_not_eq() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -627,7 +627,7 @@ mod tests { /// Test for correlated scalar subquery less than #[test] fn scalar_subquery_where_less_than() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -654,7 +654,7 @@ mod tests { /// Test for correlated scalar subquery filter with subquery disjunction #[test] fn scalar_subquery_with_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -682,7 +682,7 @@ mod tests { /// Test for correlated scalar without projection #[test] fn scalar_subquery_no_projection() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? .build()?, @@ -703,7 +703,7 @@ mod tests { /// Test for correlated scalar expressions #[test] fn scalar_subquery_project_expr() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -739,7 +739,7 @@ mod tests { /// Test for correlated scalar subquery multiple projected columns #[test] fn scalar_subquery_multi_col() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter(col("customer.c_custkey").eq(col("orders.o_custkey")))? .project(vec![col("orders.o_custkey"), col("orders.o_orderkey")])? @@ -765,7 +765,7 @@ mod tests { /// Test for correlated scalar subquery filter with additional filters #[test] fn scalar_subquery_additional_filters_with_non_equal_clause() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -804,7 +804,7 @@ mod tests { #[test] fn scalar_subquery_additional_filters_with_equal_clause() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -844,7 +844,7 @@ mod tests { /// Test for correlated scalar subquery filter with disjustions #[test] fn scalar_subquery_disjunction() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -884,7 +884,7 @@ mod tests { /// Test for correlated scalar subquery filter #[test] fn exists_subquery_correlated() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(test_table_scan_with_name("sq")?) .filter(out_ref_col(DataType::UInt32, "test.a").eq(col("sq.a")))? .aggregate(Vec::::new(), vec![min(col("c"))])? @@ -917,7 +917,7 @@ mod tests { /// Test for non-correlated scalar subquery with no filters #[test] fn scalar_subquery_non_correlated_no_filters_with_non_equal_clause() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? @@ -948,7 +948,7 @@ mod tests { #[test] fn scalar_subquery_non_correlated_no_filters_with_equal_clause() -> Result<()> { - let sq = Arc::new( + let sq = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? @@ -979,7 +979,7 @@ mod tests { #[test] fn correlated_scalar_subquery_in_between_clause() -> Result<()> { - let sq1 = Arc::new( + let sq1 = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -989,7 +989,7 @@ mod tests { .project(vec![min(col("orders.o_custkey"))])? .build()?, ); - let sq2 = Arc::new( + let sq2 = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .filter( out_ref_col(DataType::Int64, "customer.c_custkey") @@ -1036,13 +1036,13 @@ mod tests { #[test] fn uncorrelated_scalar_subquery_in_between_clause() -> Result<()> { - let sq1 = Arc::new( + let sq1 = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .aggregate(Vec::::new(), vec![min(col("orders.o_custkey"))])? .project(vec![min(col("orders.o_custkey"))])? .build()?, ); - let sq2 = Arc::new( + let sq2 = Box::new( LogicalPlanBuilder::from(scan_tpch_table("orders")) .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? .project(vec![max(col("orders.o_custkey"))])? diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 1cbe7decf15b..1f8fd4ad8789 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3415,7 +3415,7 @@ mod tests { col("c1").not_eq(lit(1)).and(col("c1").not_eq(lit(2))) ); - let subquery = Arc::new(test_table_scan_with_name("test").unwrap()); + let subquery = Box::new(test_table_scan_with_name("test").unwrap()); assert_eq!( simplify(in_list( col("c1"), @@ -3434,9 +3434,9 @@ mod tests { ); let subquery1 = - scalar_subquery(Arc::new(test_table_scan_with_name("test1").unwrap())); + scalar_subquery(Box::new(test_table_scan_with_name("test1").unwrap())); let subquery2 = - scalar_subquery(Arc::new(test_table_scan_with_name("test2").unwrap())); + scalar_subquery(Box::new(test_table_scan_with_name("test2").unwrap())); // c1 NOT IN (, ) -> c1 != AND c1 != assert_eq!( diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index 70b163acc208..c906510bde18 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -19,7 +19,9 @@ use std::sync::Arc; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::{ + internal_err, tree_node::Transformed, DFSchema, DFSchemaRef, Result, +}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; @@ -52,23 +54,35 @@ impl OptimizerRule for SimplifyExpressions { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("SimplifyExpressions implements `try_optimize_owned`") + } + + fn support_owned(&self) -> bool { + true + } + + fn try_optimize_owned( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { let mut execution_props = ExecutionProps::new(); execution_props.query_execution_start_time = config.query_execution_start_time(); - Ok(Some(Self::optimize_internal(plan, &execution_props)?)) + Self::optimize_internal(plan, &execution_props) } } impl SimplifyExpressions { fn optimize_internal( - plan: &LogicalPlan, + plan: LogicalPlan, execution_props: &ExecutionProps, - ) -> Result { + ) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(plan.inputs())) - } else if let LogicalPlan::TableScan(scan) = plan { + } else if let LogicalPlan::TableScan(scan) = &plan { // When predicates are pushed into a table scan, there is no input // schema to resolve predicates against, so it must be handled specially // @@ -88,11 +102,15 @@ impl SimplifyExpressions { }; let info = SimplifyContext::new(execution_props).with_schema(schema); - let new_inputs = plan - .inputs() - .iter() - .map(|input| Self::optimize_internal(input, execution_props)) - .collect::>>()?; + let mut is_transformed = false; + + let plan = plan.rewrite_inputs(|input| { + let t = Self::optimize_internal(input, execution_props)?; + if t.transformed { + is_transformed = true; + } + Ok(t.data) + })?; let simplifier = ExprSimplifier::new(info); @@ -109,18 +127,34 @@ impl SimplifyExpressions { simplifier }; - let exprs = plan - .expressions() - .into_iter() - .map(|e| { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - let new_e = simplifier.simplify(e)?; - new_e.alias_if_changed(original_name) + let has_no_alias = matches!(plan, LogicalPlan::Filter(_)); + + let plan = plan.rewrite_exprs(|e| { + if has_no_alias { + // no aliasing for filters + return simplifier.simplify(e); + } + + // TODO: unify with `rewrite_preserving_name` + let original_name = e.name_for_alias()?; + // TODO: Track if `simplify` transform the expression + let new_e = simplifier.simplify(e)?; + + // alias if the name has changed + let new_name = new_e.name_for_alias()?; + Ok(if new_name == original_name { + new_e + } else { + is_transformed = true; + new_e.alias(original_name) }) - .collect::>>()?; + })?; - plan.with_new_exprs(exprs, new_inputs) + if is_transformed { + Ok(Transformed::yes(plan)) + } else { + Ok(Transformed::no(plan)) + } } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 07a9d84f7d48..e7777dbf35ae 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -272,14 +272,14 @@ impl OptimizerRule for SingleDistinctToGroupBy { .collect(); let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(inner_agg), + Box::new(inner_agg), outer_group_exprs, outer_aggr_exprs, )?); Ok(Some(LogicalPlan::Projection(Projection::try_new( alias_expr, - Arc::new(outer_aggr), + Box::new(outer_aggr), )?))) } else { Ok(None) @@ -309,7 +309,7 @@ mod tests { min, sum, AggregateFunction, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), plan, @@ -331,7 +331,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -348,7 +348,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -369,7 +369,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -387,7 +387,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -406,7 +406,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -422,7 +422,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -439,7 +439,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -457,7 +457,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -486,7 +486,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -504,7 +504,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -521,7 +521,7 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -551,7 +551,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -570,7 +570,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -589,7 +589,7 @@ mod tests { \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -612,7 +612,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -635,7 +635,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -658,7 +658,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -681,7 +681,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -704,6 +704,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a5351..a1ed6e30f00b 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -152,9 +152,10 @@ pub fn assert_analyzer_check_err( } pub fn assert_optimized_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { + let prev_schema = plan.schema().clone(); let optimizer = Optimizer::with_rules(vec![rule.clone()]); let optimized_plan = optimizer .optimize_recursively( @@ -162,10 +163,10 @@ pub fn assert_optimized_plan_eq( plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); + .data; // Ensure schemas always match after an optimization - assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; + assert_schema_is_the_same(rule.name(), prev_schema, &optimized_plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -174,9 +175,11 @@ pub fn assert_optimized_plan_eq( pub fn assert_optimized_plan_eq_with_rules( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { + let plan_schema = plan.schema().clone(); + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) @@ -184,16 +187,17 @@ pub fn assert_optimized_plan_eq_with_rules( let optimizer = Optimizer::with_rules(rules); let optimized_plan = optimizer .optimize(plan, config, observe) - .expect("failed to optimize plan"); + .expect("failed to optimize plan") + .data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + assert_eq!(&plan_schema, optimized_plan.schema()); Ok(()) } pub fn assert_optimized_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); @@ -204,7 +208,7 @@ pub fn assert_optimized_plan_eq_display_indent( &OptimizerContext::new(), ) .expect("failed to optimize plan") - .unwrap_or_else(|| plan.clone()); + .data; let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } @@ -218,9 +222,9 @@ pub fn assert_multi_rules_optimized_plan_eq_display_indent( let mut optimized_plan = plan.clone(); for rule in &optimizer.rules { optimized_plan = optimizer - .optimize_recursively(rule, &optimized_plan, &OptimizerContext::new()) + .optimize_recursively(rule, optimized_plan, &OptimizerContext::new()) .expect("failed to optimize plan") - .unwrap_or_else(|| optimized_plan.clone()); + .data; } let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); @@ -228,7 +232,7 @@ pub fn assert_multi_rules_optimized_plan_eq_display_indent( pub fn assert_optimizer_err( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); @@ -238,7 +242,10 @@ pub fn assert_optimizer_err( &OptimizerContext::new(), ); match res { - Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), + Ok(plan) => { + assert!(plan.transformed); + assert_eq!(format!("{}", plan.data.display_indent()), "An error"); + } Err(ref e) => { let actual = format!("{e}"); if expected.is_empty() || !actual.contains(expected) { @@ -250,19 +257,20 @@ pub fn assert_optimizer_err( pub fn assert_optimization_skipped( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![rule]); + + let displayed_plan = format!("{}", plan.display_indent()); + let new_plan = optimizer .optimize_recursively( optimizer.rules.first().unwrap(), plan, &OptimizerContext::new(), )? - .unwrap_or_else(|| plan.clone()); - assert_eq!( - format!("{}", plan.display_indent()), - format!("{}", new_plan.display_indent()) - ); + .data; + + assert_eq!(displayed_plan, format!("{}", new_plan.display_indent())); Ok(()) } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index acafc0bafaf4..d5c9d2c43779 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -315,7 +315,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}).map(|p| p.data) } #[derive(Default)] diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 9b3b677e3c0a..aac09a141c54 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -265,7 +265,7 @@ impl AsLogicalPlan for LogicalPlanNode { Some(a) => match a { protobuf::projection_node::OptionalAlias::Alias(alias) => { Ok(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - Arc::new(new_proj), + Box::new(new_proj), alias.clone(), )?)) } @@ -590,7 +590,7 @@ impl AsLogicalPlan for LogicalPlanNode { create_view.name.as_ref(), "CreateView", )?, - input: Arc::new(plan), + input: Box::new(plan), or_replace: create_view.or_replace, definition, }))) @@ -853,7 +853,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(datafusion_expr::LogicalPlan::Copy( datafusion_expr::dml::CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: copy.output_url.clone(), partition_by: copy.partition_by.clone(), format_options: convert_required!(copy.format_options)?, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..2b85f5df26f6 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -319,7 +319,7 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { table_options.set("format.delimiter", ";")?; let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], format_options: FormatOptions::CSV(table_options.csv.clone()), @@ -353,7 +353,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.max_row_group_size = 555; let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: "test.parquet".to_string(), format_options: FormatOptions::PARQUET(parquet_format.clone()), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], @@ -385,7 +385,7 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let input = create_csv_scan(&ctx).await?; let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], format_options: FormatOptions::ARROW, @@ -426,7 +426,7 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.null_value = Some("NIL".to_string()); let plan = LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], format_options: FormatOptions::CSV(csv_format.clone()), diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index d34065d92fe5..5d642d3091f3 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -22,7 +22,6 @@ use datafusion_expr::expr::InSubquery; use datafusion_expr::{Expr, Subquery}; use sqlparser::ast::Expr as SQLExpr; use sqlparser::ast::Query; -use std::sync::Arc; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn parse_exists_subquery( @@ -39,7 +38,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context.set_outer_query_schema(old_outer_query_schema); Ok(Expr::Exists(Exists { subquery: Subquery { - subquery: Arc::new(sub_plan), + subquery: Box::new(sub_plan), outer_ref_columns, }, negated, @@ -63,7 +62,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(Expr::InSubquery(InSubquery::new( expr, Subquery { - subquery: Arc::new(sub_plan), + subquery: Box::new(sub_plan), outer_ref_columns, }, negated, @@ -82,7 +81,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(sub_plan), + subquery: Box::new(sub_plan), outer_ref_columns, })) } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea8edd0771c8..9337671cd05f 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -197,7 +197,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { name: self.object_name_to_table_reference(select_into.name)?, constraints: Constraints::empty(), - input: Arc::new(plan), + input: Box::new(plan), if_not_exists: false, or_replace: false, column_defaults: vec![], diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 1bfd60a8ce1a..96ca52beba11 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -16,7 +16,6 @@ // under the License. use std::collections::HashSet; -use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::{ @@ -373,7 +372,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Filter(Filter::try_new( filter_expr, - Arc::new(plan), + Box::new(plan), )?)) } None => Ok(plan), diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 7717f75d16b8..bfffb4f79d95 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -263,7 +263,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, - input: Arc::new(plan), + input: Box::new(plan), if_not_exists, or_replace, column_defaults, @@ -286,7 +286,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { CreateMemoryTable { name: self.object_name_to_table_reference(name)?, constraints, - input: Arc::new(plan), + input: Box::new(plan), if_not_exists, or_replace, column_defaults, @@ -322,7 +322,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { name: self.object_name_to_table_reference(name)?, - input: Arc::new(plan), + input: Box::new(plan), or_replace, definition: sql, }))) @@ -431,7 +431,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Prepare(Prepare { name: ident_to_string(&name), data_types, - input: Arc::new(plan), + input: Box::new(plan), })) } @@ -903,7 +903,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect(); Ok(LogicalPlan::Copy(CopyTo { - input: Arc::new(input), + input: Box::new(input), output_url: statement.target, format_options: file_type.into(), partition_by, @@ -1031,7 +1031,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if matches!(plan, LogicalPlan::Explain(_)) { return plan_err!("Nested EXPLAINs are not supported"); } - let plan = Arc::new(plan); + let plan = Box::new(plan); let schema = LogicalPlan::explain_schema(); let schema = schema.to_dfschema_ref()?; @@ -1187,7 +1187,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[&schema]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::Filter(Filter::try_new(filter_expr, Box::new(scan))?) } }; @@ -1195,7 +1195,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name: table_ref, table_schema: schema.into(), op: WriteOp::Delete, - input: Arc::new(source), + input: Box::new(source), }); Ok(plan) } @@ -1257,7 +1257,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[scan.schema()]], &[using_columns], )?; - LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(scan))?) + LogicalPlan::Filter(Filter::try_new(filter_expr, Box::new(scan))?) } }; @@ -1306,7 +1306,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name, table_schema, op: WriteOp::Update, - input: Arc::new(source), + input: Box::new(source), }); Ok(plan) } @@ -1430,7 +1430,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_name, table_schema: Arc::new(table_schema), op, - input: Arc::new(source), + input: Box::new(source), }); Ok(plan) } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index ed1e48ca71a6..d76bad1acc46 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -574,7 +574,7 @@ pub async fn from_substrait_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + let input = Box::new(from_substrait_rel(ctx, input, extensions).await?); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -1049,7 +1049,7 @@ pub async fn from_substrait_rex( .clone(), ), subquery: Subquery { - subquery: Arc::new(haystack_expr), + subquery: Box::new(haystack_expr), outer_ref_columns: outer_refs, }, negated: false, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a6a38ab6145c..624359abb7e4 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -372,7 +372,7 @@ pub fn to_substrait_rel( let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .map(|input| to_substrait_rel(input, ctx, extension_info)) .collect::>>()? .into_iter() .map(|ptr| *ptr) diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index bc9cc66b7626..3f1c61e3cb53 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -766,7 +766,7 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; let plan = LogicalPlan::Repartition(Repartition { - input: Arc::new(scan_plan), + input: Box::new(scan_plan), partitioning_scheme: Partitioning::RoundRobinBatch(8), }); @@ -783,7 +783,7 @@ async fn roundtrip_repartition_hash() -> Result<()> { let ctx = create_context().await?; let scan_plan = ctx.sql("SELECT * FROM data").await?.into_optimized_plan()?; let plan = LogicalPlan::Repartition(Repartition { - input: Arc::new(scan_plan), + input: Box::new(scan_plan), partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); diff --git a/docs/src/library_logical_plan.rs b/docs/src/library_logical_plan.rs index 355003941570..1c8f57aafd24 100644 --- a/docs/src/library_logical_plan.rs +++ b/docs/src/library_logical_plan.rs @@ -45,7 +45,7 @@ fn plan_1() -> Result<()> { // create a Filter plan that evaluates `id > 500` and wraps the TableScan let filter_expr = col("id").gt(lit(500)); - let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); + let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Box::new(table_scan))?); // print the plan println!("{}", plan.display_indent_schema());