diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 487063642345..fae4b2cd82ab 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -59,6 +59,7 @@ //! use std::fmt::Debug; +use std::hash::Hash; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; @@ -93,7 +94,7 @@ use datafusion::{ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; -use datafusion_expr::{FetchType, Projection, SortExpr}; +use datafusion_expr::{FetchType, InvariantLevel, Projection, SortExpr}; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -295,7 +296,119 @@ async fn topk_plan() -> Result<()> { Ok(()) } +#[tokio::test] +/// Run invariant checks on the logical plan extension [`TopKPlanNode`]. +async fn topk_invariants() -> Result<()> { + // Test: pass an InvariantLevel::Always + let pass = InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test: fail an InvariantLevel::Always + let fail = InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + // Test: pass an InvariantLevel::Executable + let pass = InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Executable, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(pass))).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test: fail an InvariantLevel::Executable + let fail = InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Executable, + }; + let ctx = setup_table(make_topk_context_with_invariants(Some(fail))).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + Ok(()) +} + +#[tokio::test] +async fn topk_invariants_after_invalid_mutation() -> Result<()> { + // CONTROL + // Build a valid topK plan. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + run_and_compare_query(ctx, "Topk context").await?; + + // Test + // Build a valid topK plan. + // Then have an invalid mutation in an optimizer run. + let config = SessionConfig::new().with_target_partitions(48); + let runtime = Arc::new(RuntimeEnv::default()); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(TopKQueryPlanner {})) + // 1. adds a valid TopKPlanNode + .with_optimizer_rule(Arc::new(TopKOptimizerRule { + invariant_mock: Some(InvariantMock { + should_fail_invariant: false, + kind: InvariantLevel::Always, + }), + })) + // 2. break the TopKPlanNode + .with_optimizer_rule(Arc::new(OptimizerMakeExtensionNodeInvalid {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); + let ctx = setup_table(SessionContext::new_with_state(state)).await?; + matches!( + &*run_and_compare_query(ctx, "Topk context") + .await + .unwrap_err() + .message(), + "node fails check, such as improper inputs" + ); + + Ok(()) +} + fn make_topk_context() -> SessionContext { + make_topk_context_with_invariants(None) +} + +fn make_topk_context_with_invariants( + invariant_mock: Option, +) -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); let state = SessionStateBuilder::new() @@ -303,12 +416,55 @@ fn make_topk_context() -> SessionContext { .with_runtime_env(runtime) .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_optimizer_rule(Arc::new(TopKOptimizerRule { invariant_mock })) .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) .build(); SessionContext::new_with_state(state) } +#[derive(Debug)] +struct OptimizerMakeExtensionNodeInvalid; + +impl OptimizerRule for OptimizerMakeExtensionNodeInvalid { + fn name(&self) -> &str { + "OptimizerMakeExtensionNodeInvalid" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + // Example rewrite pass which impacts validity of the extension node. + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + if let LogicalPlan::Extension(Extension { node }) = &plan { + if let Some(prev) = node.as_any().downcast_ref::() { + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: prev.k, + input: prev.input.clone(), + expr: prev.expr.clone(), + // In a real use case, this rewriter could have change the number of inputs, etc + invariant_mock: Some(InvariantMock { + should_fail_invariant: true, + kind: InvariantLevel::Always, + }), + }), + }))); + } + }; + + Ok(Transformed::no(plan)) + } +} + // ------ The implementation of the TopK code follows ----- #[derive(Debug)] @@ -336,7 +492,10 @@ impl QueryPlanner for TopKQueryPlanner { } #[derive(Default, Debug)] -struct TopKOptimizerRule {} +struct TopKOptimizerRule { + /// A testing-only hashable fixture. + invariant_mock: Option, +} impl OptimizerRule for TopKOptimizerRule { fn name(&self) -> &str { @@ -380,6 +539,7 @@ impl OptimizerRule for TopKOptimizerRule { k: fetch, input: input.as_ref().clone(), expr: expr[0].clone(), + invariant_mock: self.invariant_mock.clone(), }), }))); } @@ -396,6 +556,10 @@ struct TopKPlanNode { /// The sort expression (this example only supports a single sort /// expr) expr: SortExpr, + + /// A testing-only hashable fixture. + /// For actual use, define the [`Invariant`] in the [`UserDefinedLogicalNodeCore::invariants`]. + invariant_mock: Option, } impl Debug for TopKPlanNode { @@ -406,6 +570,12 @@ impl Debug for TopKPlanNode { } } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +struct InvariantMock { + should_fail_invariant: bool, + kind: InvariantLevel, +} + impl UserDefinedLogicalNodeCore for TopKPlanNode { fn name(&self) -> &str { "TopK" @@ -420,6 +590,19 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { self.input.schema() } + fn check_invariants(&self, check: InvariantLevel, _plan: &LogicalPlan) -> Result<()> { + if let Some(InvariantMock { + should_fail_invariant, + kind, + }) = self.invariant_mock.clone() + { + if should_fail_invariant && check == kind { + return internal_err!("node fails check, such as improper inputs"); + } + } + Ok(()) + } + fn expressions(&self) -> Vec { vec![self.expr.expr.clone()] } @@ -440,6 +623,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { k: self.k, input: inputs.swap_remove(0), expr: self.expr.with_expr(exprs.swap_remove(0)), + invariant_mock: self.invariant_mock.clone(), }) } diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 19d4cb3db9ce..be7153cc4eaa 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -22,6 +22,8 @@ use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; +use super::InvariantLevel; + /// This defines the interface for [`LogicalPlan`] nodes that can be /// used to extend DataFusion with custom relational operators. /// @@ -54,6 +56,9 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; + /// Perform check of invariants for the extension node. + fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>; + /// Returns all expressions in the current logical plan node. This should /// not include expressions of any inputs (aka non-recursively). /// @@ -244,6 +249,17 @@ pub trait UserDefinedLogicalNodeCore: /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; + /// Perform check of invariants for the extension node. + /// + /// This is the default implementation for extension nodes. + fn check_invariants( + &self, + _check: InvariantLevel, + _plan: &LogicalPlan, + ) -> Result<()> { + Ok(()) + } + /// Returns all expressions in the current logical plan node. This /// should not include expressions of any inputs (aka /// non-recursively). These expressions are used for optimizer @@ -336,6 +352,10 @@ impl UserDefinedLogicalNode for T { self.schema() } + fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> { + self.check_invariants(check, plan) + } + fn expressions(&self) -> Vec { self.expressions() } diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index bde4acaae562..c8f1fcd2d90b 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -28,6 +28,9 @@ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, }; +use super::Extension; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum InvariantLevel { /// Invariants that are always true in DataFusion `LogicalPlan`s /// such as the number of expected children and no duplicated output fields @@ -41,19 +44,56 @@ pub enum InvariantLevel { Executable, } -pub fn assert_always_invariants(plan: &LogicalPlan) -> Result<()> { +/// Apply the [`InvariantLevel::Always`] check at the current plan node only. +/// +/// This does not recurs to any child nodes. +pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> { // Refer to assert_unique_field_names(plan)?; Ok(()) } +/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`] +/// as well as the less stringent [`InvariantLevel::Always`] checks. pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { - assert_always_invariants(plan)?; + // Always invariants + assert_always_invariants_at_current_node(plan)?; + assert_valid_extension_nodes(plan, InvariantLevel::Always)?; + + // Executable invariants + assert_valid_extension_nodes(plan, InvariantLevel::Executable)?; assert_valid_semantic_plan(plan)?; Ok(()) } +/// Asserts that the query plan, and subplan, extension nodes have valid invariants. +/// +/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode) +/// for more details of user-provided extension node invariants. +fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> { + plan.apply_with_subqueries(|plan: &LogicalPlan| { + if let LogicalPlan::Extension(Extension { node }) = plan { + node.check_invariants(check, plan)?; + } + plan.apply_expressions(|expr| { + // recursively look for subqueries + expr.apply(|expr| { + match expr { + Expr::Exists(Exists { subquery, .. }) + | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::ScalarSubquery(subquery) => { + assert_valid_extension_nodes(&subquery.subquery, check)?; + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) + }) + }) + }) + .map(|_| ()) +} + /// Returns an error if plan, and subplans, do not have unique fields. /// /// This invariant is subject to change. @@ -87,7 +127,7 @@ pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Resul /// Asserts that the subqueries are structured properly with valid node placement. /// -/// Refer to [`check_subquery_expr`] for more details. +/// Refer to [`check_subquery_expr`] for more details of the internal invariants. fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { plan.apply_with_subqueries(|plan: &LogicalPlan| { plan.apply_expressions(|expr| { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 446ae94108b1..c0a580d89f3e 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -25,7 +25,8 @@ use std::sync::{Arc, LazyLock}; use super::dml::CopyTo; use super::invariants::{ - assert_always_invariants, assert_executable_invariants, InvariantLevel, + assert_always_invariants_at_current_node, assert_executable_invariants, + InvariantLevel, }; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; @@ -1135,7 +1136,7 @@ impl LogicalPlan { /// checks that the plan conforms to the listed invariant level, returning an Error if not pub fn check_invariants(&self, check: InvariantLevel) -> Result<()> { match check { - InvariantLevel::Always => assert_always_invariants(self), + InvariantLevel::Always => assert_always_invariants_at_current_node(self), InvariantLevel::Executable => assert_executable_invariants(self), } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b1..5fb357dfcd23 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -32,8 +32,8 @@ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; use datafusion::logical_expr::{ - Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, - Values, Volatility, + Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, Repartition, + UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -111,6 +111,14 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { &self.empty_schema } + fn check_invariants( + &self, + _check: InvariantLevel, + _plan: &LogicalPlan, + ) -> Result<()> { + Ok(()) + } + fn expressions(&self) -> Vec { vec![] }