Skip to content

Commit 7db4213

Browse files
authored
Refactor/simplify window frame utils (#11648)
* Simplify window frame utils * Remove unwrap calls * Fix format * Incorporate review feedback
1 parent 49d9d45 commit 7db4213

File tree

4 files changed

+105
-105
lines changed

4 files changed

+105
-105
lines changed

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::{ArrayRef, Int32Array};
20+
use arrow::array::{ArrayRef, Int32Array, StringArray};
2121
use arrow::compute::{concat_batches, SortOptions};
2222
use arrow::datatypes::SchemaRef;
2323
use arrow::record_batch::RecordBatch;
@@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
4545
use test_utils::add_empty_batches;
4646

4747
use hashbrown::HashMap;
48+
use rand::distributions::Alphanumeric;
4849
use rand::rngs::StdRng;
4950
use rand::{Rng, SeedableRng};
5051

@@ -607,25 +608,6 @@ fn convert_bound_to_current_row_if_applicable(
607608
}
608609
}
609610

610-
/// This utility determines whether a given window frame can be executed with
611-
/// multiple ORDER BY expressions. As an example, range frames with offset (such
612-
/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses
613-
/// of the form `\[ORDER BY a ASC, b ASC, ...]`
614-
fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool {
615-
match window_frame.units {
616-
WindowFrameUnits::Rows => true,
617-
WindowFrameUnits::Range => {
618-
// Range can only accept multi ORDER BY clauses when bounds are
619-
// CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING:
620-
(window_frame.start_bound.is_unbounded()
621-
|| window_frame.start_bound == WindowFrameBound::CurrentRow)
622-
&& (window_frame.end_bound.is_unbounded()
623-
|| window_frame.end_bound == WindowFrameBound::CurrentRow)
624-
}
625-
WindowFrameUnits::Groups => true,
626-
}
627-
}
628-
629611
/// Perform batch and running window same input
630612
/// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal
631613
async fn run_window_test(
@@ -649,7 +631,7 @@ async fn run_window_test(
649631
options: SortOptions::default(),
650632
})
651633
}
652-
if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) {
634+
if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() {
653635
orderby_exprs = orderby_exprs[0..1].to_vec();
654636
}
655637
let mut partitionby_exprs = vec![];
@@ -733,11 +715,30 @@ async fn run_window_test(
733715
)?) as _;
734716
let task_ctx = ctx.task_ctx();
735717
let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?;
736-
let collected_running = collect(running_window_exec, task_ctx).await?;
718+
let collected_running = collect(running_window_exec, task_ctx)
719+
.await?
720+
.into_iter()
721+
.filter(|b| b.num_rows() > 0)
722+
.collect::<Vec<_>>();
737723

738724
// BoundedWindowAggExec should produce more chunk than the usual WindowAggExec.
739725
// Otherwise it means that we cannot generate result in running mode.
740-
assert!(collected_running.len() > collected_usual.len());
726+
let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}");
727+
// Below check makes sure that, streaming execution generates more chunks than the bulk execution.
728+
// Since algorithms and operators works on sliding windows in the streaming execution.
729+
// However, in the current test setup for some random generated window frame clauses: It is not guaranteed
730+
// for streaming execution to generate more chunk than its non-streaming counter part in the Linear mode.
731+
// As an example window frame `OVER(PARTITION BY d ORDER BY a RANGE BETWEEN CURRENT ROW AND 9 FOLLOWING)`
732+
// needs to receive a=10 to generate result for the rows where a=0. If the input data generated is between the range [0, 9].
733+
// even in streaming mode, generated result will be single bulk as in the non-streaming version.
734+
if search_mode != Linear {
735+
assert!(
736+
collected_running.len() > collected_usual.len(),
737+
"{}",
738+
err_msg
739+
);
740+
}
741+
741742
// compare
742743
let usual_formatted = pretty_format_batches(&collected_usual)?.to_string();
743744
let running_formatted = pretty_format_batches(&collected_running)?.to_string();
@@ -767,10 +768,17 @@ async fn run_window_test(
767768
Ok(())
768769
}
769770

771+
fn generate_random_string(rng: &mut StdRng, length: usize) -> String {
772+
rng.sample_iter(&Alphanumeric)
773+
.take(length)
774+
.map(char::from)
775+
.collect()
776+
}
777+
770778
/// Return randomly sized record batches with:
771779
/// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns
772780
/// one random int32 column x
773-
fn make_staggered_batches<const STREAM: bool>(
781+
pub(crate) fn make_staggered_batches<const STREAM: bool>(
774782
len: usize,
775783
n_distinct: usize,
776784
random_seed: u64,
@@ -779,6 +787,7 @@ fn make_staggered_batches<const STREAM: bool>(
779787
let mut rng = StdRng::seed_from_u64(random_seed);
780788
let mut input123: Vec<(i32, i32, i32)> = vec![(0, 0, 0); len];
781789
let mut input4: Vec<i32> = vec![0; len];
790+
let mut input5: Vec<String> = vec!["".to_string(); len];
782791
input123.iter_mut().for_each(|v| {
783792
*v = (
784793
rng.gen_range(0..n_distinct) as i32,
@@ -788,17 +797,23 @@ fn make_staggered_batches<const STREAM: bool>(
788797
});
789798
input123.sort();
790799
rng.fill(&mut input4[..]);
800+
input5.iter_mut().for_each(|v| {
801+
*v = generate_random_string(&mut rng, 1);
802+
});
803+
input5.sort();
791804
let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0));
792805
let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1));
793806
let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2));
794807
let input4 = Int32Array::from_iter_values(input4);
808+
let input5 = StringArray::from_iter_values(input5);
795809

796810
// split into several record batches
797811
let mut remainder = RecordBatch::try_from_iter(vec![
798812
("a", Arc::new(input1) as ArrayRef),
799813
("b", Arc::new(input2) as ArrayRef),
800814
("c", Arc::new(input3) as ArrayRef),
801815
("x", Arc::new(input4) as ArrayRef),
816+
("string_field", Arc::new(input5) as ArrayRef),
802817
])
803818
.unwrap();
804819

@@ -807,6 +822,7 @@ fn make_staggered_batches<const STREAM: bool>(
807822
while remainder.num_rows() > 0 {
808823
let batch_size = rng.gen_range(0..50);
809824
if remainder.num_rows() < batch_size {
825+
batches.push(remainder);
810826
break;
811827
}
812828
batches.push(remainder.slice(0, batch_size));

datafusion/expr/src/window_frame.rs

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
use std::fmt::{self, Formatter};
2727
use std::hash::Hash;
2828

29-
use crate::expr::Sort;
30-
use crate::Expr;
29+
use crate::{lit, Expr};
3130

3231
use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue};
3332
use sqlparser::ast;
@@ -246,59 +245,51 @@ impl WindowFrame {
246245
causal,
247246
}
248247
}
249-
}
250248

251-
/// Regularizes ORDER BY clause for window definition for implicit corner cases.
252-
pub fn regularize_window_order_by(
253-
frame: &WindowFrame,
254-
order_by: &mut Vec<Expr>,
255-
) -> Result<()> {
256-
if frame.units == WindowFrameUnits::Range && order_by.len() != 1 {
257-
// Normally, RANGE frames require an ORDER BY clause with exactly one
258-
// column. However, an ORDER BY clause may be absent or present but with
259-
// more than one column in two edge cases:
260-
// 1. start bound is UNBOUNDED or CURRENT ROW
261-
// 2. end bound is CURRENT ROW or UNBOUNDED.
262-
// In these cases, we regularize the ORDER BY clause if the ORDER BY clause
263-
// is absent. If an ORDER BY clause is present but has more than one column,
264-
// the ORDER BY clause is unchanged. Note that this follows Postgres behavior.
265-
if (frame.start_bound.is_unbounded()
266-
|| frame.start_bound == WindowFrameBound::CurrentRow)
267-
&& (frame.end_bound == WindowFrameBound::CurrentRow
268-
|| frame.end_bound.is_unbounded())
269-
{
270-
// If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause
271-
// with constant value as sort key.
272-
// If an ORDER BY clause is present but has more than one column, it is
273-
// unchanged.
274-
if order_by.is_empty() {
275-
order_by.push(Expr::Sort(Sort::new(
276-
Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))),
277-
true,
278-
false,
279-
)));
249+
/// Regularizes the ORDER BY clause of the window frame.
250+
pub fn regularize_order_bys(&self, order_by: &mut Vec<Expr>) -> Result<()> {
251+
match self.units {
252+
// Normally, RANGE frames require an ORDER BY clause with exactly
253+
// one column. However, an ORDER BY clause may be absent or have
254+
// more than one column when the start/end bounds are UNBOUNDED or
255+
// CURRENT ROW.
256+
WindowFrameUnits::Range if self.free_range() => {
257+
// If an ORDER BY clause is absent, it is equivalent to an
258+
// ORDER BY clause with constant value as sort key. If an
259+
// ORDER BY clause is present but has more than one column,
260+
// it is unchanged. Note that this follows PostgreSQL behavior.
261+
if order_by.is_empty() {
262+
order_by.push(lit(1u64).sort(true, false));
263+
}
264+
}
265+
WindowFrameUnits::Range if order_by.len() != 1 => {
266+
return plan_err!("RANGE requires exactly one ORDER BY column");
280267
}
268+
WindowFrameUnits::Groups if order_by.is_empty() => {
269+
return plan_err!("GROUPS requires an ORDER BY clause");
270+
}
271+
_ => {}
281272
}
273+
Ok(())
282274
}
283-
Ok(())
284-
}
285275

286-
/// Checks if given window frame is valid. In particular, if the frame is RANGE
287-
/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column.
288-
pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> {
289-
if frame.units == WindowFrameUnits::Range && order_bys != 1 {
290-
// See `regularize_window_order_by`.
291-
if !(frame.start_bound.is_unbounded()
292-
|| frame.start_bound == WindowFrameBound::CurrentRow)
293-
|| !(frame.end_bound == WindowFrameBound::CurrentRow
294-
|| frame.end_bound.is_unbounded())
295-
{
296-
plan_err!("RANGE requires exactly one ORDER BY column")?
276+
/// Returns whether the window frame can accept multiple ORDER BY expressons.
277+
pub fn can_accept_multi_orderby(&self) -> bool {
278+
match self.units {
279+
WindowFrameUnits::Rows => true,
280+
WindowFrameUnits::Range => self.free_range(),
281+
WindowFrameUnits::Groups => true,
297282
}
298-
} else if frame.units == WindowFrameUnits::Groups && order_bys == 0 {
299-
plan_err!("GROUPS requires an ORDER BY clause")?
300-
};
301-
Ok(())
283+
}
284+
285+
/// Returns whether the window frame is "free range"; i.e. its start/end
286+
/// bounds are UNBOUNDED or CURRENT ROW.
287+
fn free_range(&self) -> bool {
288+
(self.start_bound.is_unbounded()
289+
|| self.start_bound == WindowFrameBound::CurrentRow)
290+
&& (self.end_bound.is_unbounded()
291+
|| self.end_bound == WindowFrameBound::CurrentRow)
292+
}
302293
}
303294

304295
/// There are five ways to describe starting and ending frame boundaries:

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@ use std::sync::Arc;
1919

2020
use datafusion::execution::registry::FunctionRegistry;
2121
use datafusion_common::{
22-
internal_err, plan_datafusion_err, DataFusionError, Result, ScalarValue,
22+
exec_datafusion_err, internal_err, plan_datafusion_err, Result, ScalarValue,
2323
TableReference, UnnestOptions,
2424
};
25-
use datafusion_expr::expr::Unnest;
26-
use datafusion_expr::expr::{Alias, Placeholder};
27-
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
28-
use datafusion_expr::ExprFunctionExt;
2925
use datafusion_expr::{
30-
expr::{self, InList, Sort, WindowFunction},
26+
expr::{self, Alias, InList, Placeholder, Sort, Unnest, WindowFunction},
3127
logical_plan::{PlanType, StringifiedPlan},
3228
AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr,
33-
GroupingSet,
29+
ExprFunctionExt, GroupingSet,
3430
GroupingSet::GroupingSets,
3531
JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound,
3632
WindowFrameUnits,
@@ -289,24 +285,22 @@ pub fn parse_expr(
289285
.window_frame
290286
.as_ref()
291287
.map::<Result<WindowFrame, _>, _>(|window_frame| {
292-
let window_frame = window_frame.clone().try_into()?;
293-
check_window_frame(&window_frame, order_by.len())
288+
let window_frame: WindowFrame = window_frame.clone().try_into()?;
289+
window_frame
290+
.regularize_order_bys(&mut order_by)
294291
.map(|_| window_frame)
295292
})
296293
.transpose()?
297294
.ok_or_else(|| {
298-
DataFusionError::Execution(
299-
"missing window frame during deserialization".to_string(),
300-
)
295+
exec_datafusion_err!("missing window frame during deserialization")
301296
})?;
302-
// TODO: support proto for null treatment
303-
regularize_window_order_by(&window_frame, &mut order_by)?;
304297

298+
// TODO: support proto for null treatment
305299
match window_function {
306300
window_expr_node::WindowFunction::AggrFunction(i) => {
307301
let aggr_function = parse_i32_to_aggregate_function(i)?;
308302

309-
Ok(Expr::WindowFunction(WindowFunction::new(
303+
Expr::WindowFunction(WindowFunction::new(
310304
expr::WindowFunctionDefinition::AggregateFunction(aggr_function),
311305
vec![parse_required_expr(
312306
expr.expr.as_deref(),
@@ -319,7 +313,7 @@ pub fn parse_expr(
319313
.order_by(order_by)
320314
.window_frame(window_frame)
321315
.build()
322-
.unwrap())
316+
.map_err(Error::DataFusionError)
323317
}
324318
window_expr_node::WindowFunction::BuiltInFunction(i) => {
325319
let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i)
@@ -331,7 +325,7 @@ pub fn parse_expr(
331325
.map(|e| vec![e])
332326
.unwrap_or_else(Vec::new);
333327

334-
Ok(Expr::WindowFunction(WindowFunction::new(
328+
Expr::WindowFunction(WindowFunction::new(
335329
expr::WindowFunctionDefinition::BuiltInWindowFunction(
336330
built_in_function,
337331
),
@@ -341,7 +335,7 @@ pub fn parse_expr(
341335
.order_by(order_by)
342336
.window_frame(window_frame)
343337
.build()
344-
.unwrap())
338+
.map_err(Error::DataFusionError)
345339
}
346340
window_expr_node::WindowFunction::Udaf(udaf_name) => {
347341
let udaf_function = match &expr.fun_definition {
@@ -353,15 +347,15 @@ pub fn parse_expr(
353347
parse_optional_expr(expr.expr.as_deref(), registry, codec)?
354348
.map(|e| vec![e])
355349
.unwrap_or_else(Vec::new);
356-
Ok(Expr::WindowFunction(WindowFunction::new(
350+
Expr::WindowFunction(WindowFunction::new(
357351
expr::WindowFunctionDefinition::AggregateUDF(udaf_function),
358352
args,
359353
))
360354
.partition_by(partition_by)
361355
.order_by(order_by)
362356
.window_frame(window_frame)
363357
.build()
364-
.unwrap())
358+
.map_err(Error::DataFusionError)
365359
}
366360
window_expr_node::WindowFunction::Udwf(udwf_name) => {
367361
let udwf_function = match &expr.fun_definition {
@@ -373,15 +367,15 @@ pub fn parse_expr(
373367
parse_optional_expr(expr.expr.as_deref(), registry, codec)?
374368
.map(|e| vec![e])
375369
.unwrap_or_else(Vec::new);
376-
Ok(Expr::WindowFunction(WindowFunction::new(
370+
Expr::WindowFunction(WindowFunction::new(
377371
expr::WindowFunctionDefinition::WindowUDF(udwf_function),
378372
args,
379373
))
380374
.partition_by(partition_by)
381375
.order_by(order_by)
382376
.window_frame(window_frame)
383377
.build()
384-
.unwrap())
378+
.map_err(Error::DataFusionError)
385379
}
386380
}
387381
}

0 commit comments

Comments
 (0)