From fbff8822da565518181209543fffbdb0667ce1a1 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 10 Aug 2024 11:53:18 +0800 Subject: [PATCH 1/2] fix: support ordering and pencentile function ser/der --- .../core/src/physical_optimizer/test_utils.rs | 1 - datafusion/core/src/physical_planner.rs | 1 - .../core/tests/fuzz_cases/window_fuzz.rs | 3 -- datafusion/physical-plan/src/windows/mod.rs | 6 ++-- .../proto/src/physical_plan/from_proto.rs | 3 -- datafusion/proto/src/physical_plan/mod.rs | 6 ++-- .../tests/cases/roundtrip_physical_plan.rs | 29 +++++++++++++++++++ 7 files changed, 33 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 55a0fa814552..90853c347672 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -251,7 +251,6 @@ pub fn bounded_window_exec( "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], - &[], &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 7eb468f56eeb..9cc2f253f8da 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1510,7 +1510,6 @@ pub fn create_window_expr_with_name( fun, name, &physical_args, - args, &partition_by, &order_by, window_frame, diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 813862c4cc2f..dac9d0e67b7c 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -284,7 +284,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let window_expr = create_window_expr( &window_fn, fn_name.to_string(), - &args, &logical_exprs, &partitionby_exprs, &orderby_exprs, @@ -674,7 +673,6 @@ async fn run_window_test( &window_fn, fn_name.clone(), &args, - &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -693,7 +691,6 @@ async fn run_window_test( &window_fn, fn_name, &args, - &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 2e6ad4e1a14f..1fd0ca36b1eb 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -32,8 +32,8 @@ use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{ - BuiltInWindowFunction, Expr, PartitionEvaluator, WindowFrame, - WindowFunctionDefinition, WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, + WindowUDF, }; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ @@ -94,7 +94,6 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], - _logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -746,7 +745,6 @@ mod tests { &[col("a", &schema)?], &[], &[], - &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index bc0a19336bae..b2f92f4b2ee4 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -169,13 +169,10 @@ pub fn parse_physical_window_expr( // TODO: Remove extended_schema if functions are all UDAF let extended_schema = schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; - // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. - let logical_exprs = &[]; create_window_expr( &fun, name, &window_node_expr, - logical_exprs, &partition_by, &order_by, Arc::new(window_frame), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index b5d28f40a68f..0f6722dd375b 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -477,7 +477,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; - let _ordering_req: Vec = agg_node.ordering_req.iter() + let ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; agg_node.aggregate_function.as_ref().map(|func| { match func { @@ -487,14 +487,12 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { None => registry.udaf(udaf_name)? }; - // TODO: approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. - // TODO: `order by` is not supported for UDAF yet - // https://github.com/apache/datafusion/issues/11804 AggregateExprBuilder::new(agg_udf, input_phy_expr) .schema(Arc::clone(&physical_schema)) .alias(name) .with_ignore_nulls(agg_node.ignore_nulls) .with_distinct(agg_node.distinct) + .order_by(ordering_req) .build() } } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 1a9c6d40ebe6..f66160df904d 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -25,6 +25,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; +use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; @@ -412,6 +413,34 @@ fn rountrip_aggregate_with_limit() -> Result<()> { roundtrip_test(Arc::new(agg)) } +#[test] +fn rountrip_aggregate_with_sort() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates: Vec> = vec![AggregateExprBuilder::new( + approx_percentile_cont_udaf(), + vec![col("b", &schema)?, lit(0.5)], + ) + .schema(Arc::clone(&schema)) + .alias("APPROX_PERCENTILE_CONT(b, 0.5)") + .build()?]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + roundtrip_test(Arc::new(agg)) +} + #[test] fn roundtrip_aggregate_udaf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From 89dbe8c858c034170bf7a080ae1ddfe2f42182e5 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 10 Aug 2024 12:20:59 +0800 Subject: [PATCH 2/2] add more test case --- .../core/tests/fuzz_cases/window_fuzz.rs | 3 +- .../src/windows/bounded_window_agg_exec.rs | 6 +-- .../tests/cases/roundtrip_physical_plan.rs | 39 ++++++++++++++++++- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index dac9d0e67b7c..d75d8e43370d 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -253,7 +253,6 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let partitionby_exprs = vec![]; let orderby_exprs = vec![]; - let logical_exprs = vec![]; // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -284,7 +283,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let window_expr = create_window_expr( &window_fn, fn_name.to_string(), - &logical_exprs, + &args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 6311107f7b58..29ead35895fe 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1196,7 +1196,7 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; @@ -1303,10 +1303,7 @@ mod tests { let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; - let log_expr = - Expr::Column(datafusion_common::Column::from(schema.fields[0].name())); let args = vec![col_expr]; - let log_args = vec![log_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, @@ -1327,7 +1324,6 @@ mod tests { &window_fn, fn_name, &args, - &log_args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index f66160df904d..6766468ef443 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -26,6 +26,7 @@ use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use datafusion::physical_expr_functions_aggregate::aggregate::AggregateExprBuilder; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; @@ -414,7 +415,7 @@ fn rountrip_aggregate_with_limit() -> Result<()> { } #[test] -fn rountrip_aggregate_with_sort() -> Result<()> { +fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -441,6 +442,42 @@ fn rountrip_aggregate_with_sort() -> Result<()> { roundtrip_test(Arc::new(agg)) } +#[test] +fn rountrip_aggregate_with_sort() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + + let aggregates: Vec> = + vec![ + AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("ARRAY_AGG(b)") + .order_by(sort_exprs) + .build()?, + ]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + roundtrip_test(Arc::new(agg)) +} + #[test] fn roundtrip_aggregate_udaf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false);