diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index 034d6fa23d9c..ec086bcc50c7 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -19,6 +19,23 @@ use datafusion::prelude::*; use tempfile::TempDir; +#[tokio::test] +async fn test_window_function() { + let ctx = SessionContext::new(); + let df = ctx + .sql( + r#"SELECT + t1.v1, + SUM(t1.v1) OVER w + 1 + FROM + generate_series(1, 10000) AS t1(v1) + WINDOW + w AS (ORDER BY t1.v1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW);"#, + ) + .await; + assert!(df.is_ok()); +} + #[tokio::test] async fn unsupported_ddl_returns_error() { // Verify SessionContext::with_sql_options errors appropriately diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 2a2d0b3b3eb8..85dba0f43081 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -16,6 +16,7 @@ // under the License. use std::collections::HashSet; +use std::ops::ControlFlow; use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -28,7 +29,7 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -45,8 +46,8 @@ use datafusion_expr::{ use indexmap::IndexMap; use sqlparser::ast::{ - Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderBy, - SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, + visit_expressions_mut, Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, + OrderBy, SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, }; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; @@ -891,29 +892,42 @@ fn match_window_definitions( named_windows: &[NamedWindowDefinition], ) -> Result<()> { for proj in projection.iter_mut() { - if let SelectItem::ExprWithAlias { - expr: SQLExpr::Function(f), - alias: _, - } - | SelectItem::UnnamedExpr(SQLExpr::Function(f)) = proj + if let SelectItem::ExprWithAlias { expr, alias: _ } + | SelectItem::UnnamedExpr(expr) = proj { - for NamedWindowDefinition(window_ident, window_expr) in named_windows.iter() { - if let Some(WindowType::NamedWindow(ident)) = &f.over { - if ident.eq(window_ident) { - f.over = Some(match window_expr { - NamedWindowExpr::NamedWindow(ident) => { - WindowType::NamedWindow(ident.clone()) - } - NamedWindowExpr::WindowSpec(spec) => { - WindowType::WindowSpec(spec.clone()) + let mut err = None; + visit_expressions_mut(expr, |expr| { + if let SQLExpr::Function(f) = expr { + if let Some(WindowType::NamedWindow(_)) = &f.over { + for NamedWindowDefinition(window_ident, window_expr) in + named_windows + { + if let Some(WindowType::NamedWindow(ident)) = &f.over { + if ident.eq(window_ident) { + f.over = Some(match window_expr { + NamedWindowExpr::NamedWindow(ident) => { + WindowType::NamedWindow(ident.clone()) + } + NamedWindowExpr::WindowSpec(spec) => { + WindowType::WindowSpec(spec.clone()) + } + }) + } } - }) + } + // All named windows must be defined with a WindowSpec. + if let Some(WindowType::NamedWindow(ident)) = &f.over { + err = Some(DataFusionError::Plan(format!( + "The window {ident} is not defined!" + ))); + return ControlFlow::Break(()); + } } } - } - // All named windows must be defined with a WindowSpec. - if let Some(WindowType::NamedWindow(ident)) = &f.over { - return plan_err!("The window {ident} is not defined!"); + ControlFlow::Continue(()) + }); + if let Some(err) = err { + return Err(err); } } } diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 76e3751e4b8e..c5c094cad3da 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5537,6 +5537,21 @@ physical_plan 02)--WindowAggExec: wdw=[max(aggregate_test_100_ordered.c5) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "max(aggregate_test_100_ordered.c5) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] 03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5], file_type=csv, has_header=true +query II rowsort +SELECT + t1.v1, + SUM(t1.v1) OVER w + 1 +FROM + generate_series(1, 5) AS t1(v1) +WINDOW + w AS (ORDER BY t1.v1); +---- +1 2 +2 4 +3 7 +4 11 +5 16 + # Testing Utf8View with window statement ok CREATE TABLE aggregate_test_100_utf8view AS SELECT