diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9c29e6b40d10..819fcb7dd225 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// This is the description of the state. accumulator's state() must match the types here. fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", args.return_type.clone(), true), + Field::new("prod", args.return_field.data_type().clone(), true), Field::new("n", DataType::UInt32, true), ]) } diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 08a4b4fcb13b..3b665a337019 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::HashMap; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; +use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf; use datafusion_expr::{ WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; @@ -448,9 +448,9 @@ fn get_random_function( if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); - let dt = a.data_type(schema.as_ref()).unwrap(); - let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); - args[0] = cast(a, schema, coerced[0].clone()).unwrap(); + let dt = a.return_field(schema.as_ref()).unwrap(); + let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap(); + args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap(); } } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5cbb05f290a7..203fb6e85237 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,6 +18,8 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions +use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ @@ -26,10 +28,10 @@ use std::sync::{ }; use arrow::array::{ - types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray, + record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, + StringArray, StructArray, UInt64Array, }; use arrow::datatypes::{Fields, Schema}; - use datafusion::common::test_util::batches_to_string; use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; @@ -48,11 +50,12 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::assert_contains; +use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; +use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - LogicalPlanBuilder, SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -781,7 +784,7 @@ struct TestGroupsAccumulator { } impl AggregateUDFImpl for TestGroupsAccumulator { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -890,3 +893,263 @@ impl GroupsAccumulator for TestGroupsAccumulator { size_of::() } } + +#[derive(Debug)] +struct MetadataBasedAggregateUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedAggregateUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl AggregateUDFImpl for MetadataBasedAggregateUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("this should never be called since return_field is implemented"); + } + + fn return_field(&self, _arg_fields: &[Field]) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone())) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let input_expr = acc_args + .exprs + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + let input_field = input_expr.return_field(acc_args.schema)?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedAccumulator { + double_output, + curr_sum: 0, + })) + } +} + +#[derive(Debug)] +struct MetadataBasedAccumulator { + double_output: bool, + curr_sum: u64, +} + +impl Accumulator for MetadataBasedAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = values[0] + .as_any() + .downcast_ref::() + .ok_or(exec_datafusion_err!("Expected UInt64Array"))?; + + self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0)); + + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = match self.double_output { + true => self.curr_sum * 2, + false => self.curr_sum, + }; + + Ok(ScalarValue::from(v)) + } + + fn size(&self) -> usize { + 9 + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::from(self.curr_sum)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +#[tokio::test] +async fn test_metadata_based_aggregate() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = + AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new())); + let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.aggregate( + vec![], + vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ], + )?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50]), + ("meta_with_in_no_out", UInt64, [100]), + ("meta_no_in_with_out", UInt64, [50]), + ("meta_with_in_with_out", UInt64, [100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} + +#[tokio::test] +async fn test_metadata_based_aggregate_as_window() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = Arc::new(AggregateUDF::from( + MetadataBasedAggregateUdf::new(HashMap::new()), + )); + let with_output_meta_udf = + Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + ))); + + let df = df.select(vec![ + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_no_out"), + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(no_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_no_out"), + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)), + vec![col("no_metadata")], + )) + .alias("meta_no_in_with_out"), + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(with_output_meta_udf), + vec![col("with_metadata")], + )) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index b3a3900c5023..6798c0d308de 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -18,11 +18,15 @@ //! This module contains end to end tests of creating //! user defined window functions -use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray}; +use arrow::array::{ + record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray, + UInt64Array, +}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::common::test_util::batches_to_string; use datafusion::common::{Result, ScalarValue}; use datafusion::prelude::SessionContext; +use datafusion_common::exec_datafusion_err; use datafusion_expr::{ PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; @@ -34,6 +38,7 @@ use datafusion_physical_expr::{ expressions::{col, lit}, PhysicalExpr, }; +use std::collections::HashMap; use std::{ any::Any, ops::Range, @@ -723,11 +728,11 @@ fn test_default_expressions() -> Result<()> { ]; for input_exprs in &test_cases { - let input_types = input_exprs + let input_fields = input_exprs .iter() - .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .map(|expr: &Arc| expr.return_field(&schema).unwrap()) .collect::>(); - let expr_args = ExpressionArgs::new(input_exprs, &input_types); + let expr_args = ExpressionArgs::new(input_exprs, &input_fields); let ret_exprs = udwf.expressions(expr_args); @@ -751,3 +756,148 @@ fn test_default_expressions() -> Result<()> { } Ok(()) } + +#[derive(Debug)] +struct MetadataBasedWindowUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedWindowUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl WindowUDFImpl for MetadataBasedWindowUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let input_field = partition_evaluator_args + .input_fields() + .first() + .ok_or(exec_datafusion_err!("Expected one argument"))?; + + let double_output = input_field + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + + Ok(Box::new(MetadataBasedPartitionEvaluator { double_output })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone())) + } +} + +#[derive(Debug)] +struct MetadataBasedPartitionEvaluator { + double_output: bool, +} + +impl PartitionEvaluator for MetadataBasedPartitionEvaluator { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { + let values = values[0].as_any().downcast_ref::().unwrap(); + let sum = values.iter().fold(0_u64, |acc, v| acc + v.unwrap_or(0)); + + let result = if self.double_output { sum * 2 } else { sum }; + + Ok(Arc::new(UInt64Array::from_value(result, num_rows))) + } +} + +#[tokio::test] +async fn test_metadata_based_window_fn() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let df = ctx.table("t").await?; + + let no_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(HashMap::new())); + let with_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let df = df.select(vec![ + no_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![datafusion_expr::col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])?; + + let actual = df.collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]), + ("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]), + ("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index 44839378d52c..7da4e938f5dd 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -17,7 +17,7 @@ use crate::signature::TypeSignature; use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; @@ -82,48 +82,48 @@ pub static TIMES: &[DataType] = &[ DataType::Time64(TimeUnit::Nanosecond), ]; -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// Validate the length of `input_fields` matches the `signature` for `agg_fun`. /// -/// This method DOES NOT validate the argument types - only that (at least one, +/// This method DOES NOT validate the argument fields - only that (at least one, /// in the case of [`TypeSignature::OneOf`]) signature matches the desired /// number of input types. pub fn check_arg_count( func_name: &str, - input_types: &[DataType], + input_fields: &[Field], signature: &TypeSignature, ) -> Result<()> { match signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { + if input_fields.len() != *agg_count { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", agg_count, - input_types.len() + input_fields.len() ); } } TypeSignature::Exact(types) => { - if types.len() != input_types.len() { + if types.len() != input_fields.len() { return plan_err!( "The function {func_name} expects {:?} arguments, but {:?} were provided", types.len(), - input_types.len() + input_fields.len() ); } } TypeSignature::OneOf(variants) => { let ok = variants .iter() - .any(|v| check_arg_count(func_name, input_types, v).is_ok()); + .any(|v| check_arg_count(func_name, input_fields, v).is_ok()); if !ok { return plan_err!( "The function {func_name} does not accept {:?} function arguments.", - input_types.len() + input_fields.len() ); } } TypeSignature::VariadicAny => { - if input_types.is_empty() { + if input_fields.is_empty() { return plan_err!( "The function {func_name} expects at least one argument" ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 92d4497918fa..a081a5430d40 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,7 @@ use crate::logical_plan::Subquery; use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; -use arrow::datatypes::{DataType, FieldRef}; +use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, @@ -844,19 +844,19 @@ pub enum WindowFunctionDefinition { impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type( + pub fn return_field( &self, - input_expr_types: &[DataType], + input_expr_fields: &[Field], _input_expr_nullable: &[bool], display_name: &str, - ) -> Result { + ) -> Result { match self { WindowFunctionDefinition::AggregateUDF(fun) => { - fun.return_type(input_expr_types) + fun.return_field(input_expr_fields) + } + WindowFunctionDefinition::WindowUDF(fun) => { + fun.field(WindowUDFFieldArgs::new(input_expr_fields, display_name)) } - WindowFunctionDefinition::WindowUDF(fun) => fun - .field(WindowUDFFieldArgs::new(input_expr_types, display_name)) - .map(|field| field.data_type().clone()), } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 3786180e2cfa..6022182bfe67 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -22,7 +22,7 @@ use crate::expr::{ WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, }; use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; @@ -158,12 +158,16 @@ impl ExprSchemable for Expr { func, params: AggregateFunctionParams { args, .. }, }) => { - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(&data_types, func) - .map_err(|err| { + let new_fields = + fields_with_aggregate_udf(&fields, func).map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); plan_datafusion_err!( "{} {}", match err { @@ -177,7 +181,7 @@ impl ExprSchemable for Expr { ) ) })?; - Ok(func.return_type(&new_types)?) + Ok(func.return_field(&new_fields)?.data_type().clone()) } Expr::Not(_) | Expr::IsNull(_) @@ -452,6 +456,41 @@ impl ExprSchemable for Expr { )?; Ok(Field::new(&schema_name, dt, nullable)) } + Expr::AggregateFunction(aggregate_function) => { + let AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + .. + } = aggregate_function; + + let fields = args + .iter() + .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) + .collect::>>()?; + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_fields = + fields_with_aggregate_udf(&fields, func).map_err(|err| { + let arg_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + plan_datafusion_err!( + "{} {}", + match err { + DataFusionError::Plan(msg) => msg, + err => err.to_string(), + }, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_types, + ) + ) + })?; + + func.return_field(&new_fields) + } Expr::ScalarFunction(ScalarFunction { func, args }) => { let (arg_types, fields): (Vec, Vec>) = args .iter() @@ -506,7 +545,6 @@ impl ExprSchemable for Expr { | Expr::Between(_) | Expr::Case(_) | Expr::TryCast(_) - | Expr::AggregateFunction(_) | Expr::InList(_) | Expr::InSubquery(_) | Expr::Wildcard { .. } @@ -572,14 +610,19 @@ impl Expr { .. } = window_function; - let data_types = args + let fields = args .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udaf) => { - let new_types = data_types_with_aggregate_udf(&data_types, udaf) - .map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = + fields_with_aggregate_udf(&fields, udaf).map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -594,14 +637,18 @@ impl Expr { ) })?; - let return_type = udaf.return_type(&new_types)?; - let nullable = udaf.is_nullable(); + let return_field = udaf.return_field(&new_fields)?; - Ok((return_type, nullable)) + Ok((return_field.data_type().clone(), return_field.is_nullable())) } WindowFunctionDefinition::WindowUDF(udwf) => { - let new_types = - data_types_with_window_udf(&data_types, udwf).map_err(|err| { + let data_types = fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let new_fields = + fields_with_window_udf(&fields, udwf).map_err(|err| { plan_datafusion_err!( "{} {}", match err { @@ -616,7 +663,7 @@ impl Expr { ) })?; let (_, function_name) = self.qualified_name(); - let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); + let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name); udwf.field(field_args) .map(|field| (field.data_type().clone(), field.is_nullable())) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 57be9e3debe6..6d1ed238646d 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -17,6 +17,7 @@ use super::binary::binary_numeric_coercion; use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; +use arrow::datatypes::Field; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, @@ -77,19 +78,19 @@ pub fn data_types_with_scalar_udf( /// Performs type coercion for aggregate function arguments. /// -/// Returns the data types to which each argument must be coerced to +/// Returns the fields to which each argument must be coerced to /// match `signature`. /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_aggregate_udf( - current_types: &[DataType], +pub fn fields_with_aggregate_udf( + current_fields: &[Field], func: &AggregateUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -99,17 +100,29 @@ pub fn data_types_with_aggregate_udf( return plan_err!("'{}' does not support zero arguments", func.name()); } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_aggregate_udf(type_signature, current_types, func)?; + get_valid_types_with_aggregate_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| current_field.clone().with_data_type(new_type)) + .collect()) } /// Performs type coercion for window function arguments. @@ -119,14 +132,14 @@ pub fn data_types_with_aggregate_udf( /// /// For more details on coercion in general, please see the /// [`type_coercion`](crate::type_coercion) module. -pub fn data_types_with_window_udf( - current_types: &[DataType], +pub fn fields_with_window_udf( + current_fields: &[Field], func: &WindowUDF, -) -> Result> { +) -> Result> { let signature = func.signature(); let type_signature = &signature.type_signature; - if current_types.is_empty() && type_signature != &TypeSignature::UserDefined { + if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined { if type_signature.supports_zero_argument() { return Ok(vec![]); } else if type_signature.used_to_support_zero_arguments() { @@ -137,16 +150,28 @@ pub fn data_types_with_window_udf( } } + let current_types = current_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let valid_types = - get_valid_types_with_window_udf(type_signature, current_types, func)?; + get_valid_types_with_window_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_type == ¤t_types) { - return Ok(current_types.to_vec()); + return Ok(current_fields.to_vec()); } - try_coerce_types(func.name(), valid_types, current_types, type_signature) + let updated_types = + try_coerce_types(func.name(), valid_types, ¤t_types, type_signature)?; + + Ok(current_fields + .iter() + .zip(updated_types) + .map(|(current_field, new_type)| current_field.clone().with_data_type(new_type)) + .collect()) } /// Performs type coercion for function arguments. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9ab532240d76..3a8d0253a389 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -224,6 +224,13 @@ impl AggregateUDF { self.inner.return_type(args) } + /// Return the field of the function given its input fields + /// + /// See [`AggregateUDFImpl::return_field`] for more details. + pub fn return_field(&self, args: &[Field]) -> Result { + self.inner.return_field(args) + } + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { self.inner.accumulator(acc_args) @@ -403,7 +410,7 @@ where /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } /// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", args.return_type.clone(), true), +/// args.return_field.clone().with_name("value"), /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } @@ -674,6 +681,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What type will be returned by this function, given the arguments? + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// # Notes + /// + /// Most UDFs should implement [`Self::return_type`] and not this + /// function as the output type for most functions only depends on the types + /// of their inputs (e.g. `sum(f64)` is always `f64`). + /// + /// This function can be used for more advanced cases such as: + /// + /// 1. specifying nullability + /// 2. return types based on the **values** of the arguments (rather than + /// their **types**. + /// 3. return types based on metadata within the fields of the inputs + fn return_field(&self, arg_fields: &[Field]) -> Result { + let arg_types: Vec<_> = + arg_fields.iter().map(|f| f.data_type()).cloned().collect(); + let data_type = self.return_type(&arg_types)?; + + Ok(Field::new(self.name(), data_type, self.is_nullable())) + } + /// Whether the aggregate function is nullable. /// /// Nullable means that the function could return `null` for any inputs. @@ -713,11 +745,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let fields = vec![Field::new( - format_state_name(args.name, "value"), - args.return_type.clone(), - true, - )]; + let fields = vec![args + .return_field + .clone() + .with_name(format_state_name(args.name, "value"))]; Ok(fields .into_iter() diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4da63d7955f5..a52438fcc99c 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -280,7 +280,7 @@ where /// unimplemented!() /// } /// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { -/// if let Some(DataType::Int32) = field_args.get_input_type(0) { +/// if let Some(DataType::Int32) = field_args.get_input_field(0).map(|f| f.data_type().clone()) { /// Ok(Field::new(field_args.name(), DataType::Int32, false)) /// } else { /// plan_err!("smooth_it only accepts Int32 arguments") diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs index a230bb028909..f67e2f49dcbf 100644 --- a/datafusion/functions-aggregate-common/src/accumulator.rs +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{Field, Schema}; use datafusion_common::Result; use datafusion_expr_common::accumulator::Accumulator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -27,8 +27,8 @@ use std::sync::Arc; /// ordering expressions. #[derive(Debug)] pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return field of the aggregate function. + pub return_field: &'a Field, /// The schema of the input arguments pub schema: &'a Schema, @@ -81,11 +81,11 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input types of the aggregate function. - pub input_types: &'a [DataType], + /// The input fields of the aggregate function. + pub input_fields: &'a [Field], - /// The return type of the aggregate function. - pub return_type: &'a DataType, + /// The return fields of the aggregate function. + pub return_field: &'a Field, /// The ordering fields of the aggregate function. pub ordering_fields: &'a [Field], diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs index 8bde7d04c44d..fc7561dd8a56 100644 --- a/datafusion/functions-aggregate/benches/count.rs +++ b/datafusion/functions-aggregate/benches/count.rs @@ -28,7 +28,7 @@ use std::sync::Arc; fn prepare_accumulator() -> Box { let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); let accumulator_args = AccumulatorArgs { - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs index fab53ae94b25..d05d5c5676c5 100644 --- a/datafusion/functions-aggregate/benches/sum.rs +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -26,9 +26,10 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering; use std::sync::Arc; fn prepare_accumulator(data_type: &DataType) -> Box { - let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); + let field = Field::new("f", data_type.clone(), true); + let schema = Arc::new(Schema::new(vec![field.clone()])); let accumulator_args = AccumulatorArgs { - return_type: data_type, + return_field: &field, schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index ca3548d424a1..54479ee99fc3 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -114,7 +114,7 @@ impl AggregateUDFImpl for ArrayAgg { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, )]); } @@ -122,7 +122,7 @@ impl AggregateUDFImpl for ArrayAgg { let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), true, )]; @@ -984,7 +984,7 @@ mod tests { } struct ArrayAggAccumulatorBuilder { - data_type: DataType, + return_field: Field, distinct: bool, ordering: LexOrdering, schema: Schema, @@ -997,7 +997,7 @@ mod tests { fn new(data_type: DataType) -> Self { Self { - data_type: data_type.clone(), + return_field: Field::new("f", data_type.clone(), true), distinct: false, ordering: Default::default(), schema: Schema { @@ -1029,7 +1029,7 @@ mod tests { fn build(&self) -> Result> { ArrayAgg::default().accumulator(AccumulatorArgs { - return_type: &self.data_type, + return_field: &self.return_field, schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 798a039f50b1..15b5db2d72e0 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -121,7 +121,7 @@ impl AggregateUDFImpl for Avg { let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, acc_args.return_type) { + match (&data_type, acc_args.return_field.data_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -159,7 +159,7 @@ impl AggregateUDFImpl for Avg { _ => exec_err!( "AvgAccumulator for ({} --> {})", &data_type, - acc_args.return_type + acc_args.return_field.data_type() ), } } @@ -173,7 +173,7 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_types[0].clone(), + args.input_fields[0].data_type().clone(), true, ), ]) @@ -181,7 +181,7 @@ impl AggregateUDFImpl for Avg { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { matches!( - args.return_type, + args.return_field.data_type(), DataType::Float64 | DataType::Decimal128(_, _) ) } @@ -194,11 +194,11 @@ impl AggregateUDFImpl for Avg { let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&data_type, args.return_type) { + match (&data_type, args.return_field.data_type()) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -217,7 +217,7 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } @@ -238,7 +238,7 @@ impl AggregateUDFImpl for Avg { Ok(Box::new(AvgGroupsAccumulator::::new( &data_type, - args.return_type, + args.return_field.data_type(), avg_fn, ))) } @@ -246,7 +246,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", &data_type, - args.return_type + args.return_field.data_type() ), } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 1a9312ba1e92..5f51377484a0 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -87,7 +87,7 @@ macro_rules! accumulator_helper { /// `is_distinct` is boolean value indicating whether the operation is distinct or not. macro_rules! downcast_bitwise_accumulator { ($args:ident, $opr:expr, $is_distinct: expr) => { - match $args.return_type { + match $args.return_field.data_type() { DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), @@ -101,7 +101,7 @@ macro_rules! downcast_bitwise_accumulator { "{} not supported for {}: {}", stringify!($opr), $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -271,13 +271,13 @@ impl AggregateUDFImpl for BitwiseOperation { format!("{} distinct", self.name()).as_str(), ), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_field.data_type().clone(), true), false, )]) } else { Ok(vec![Field::new( format_state_name(args.name, self.name()), - args.return_type.clone(), + args.return_field.data_type().clone(), true, )]) } @@ -291,7 +291,7 @@ impl AggregateUDFImpl for BitwiseOperation { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = args.return_type; + let data_type = args.return_field.data_type(); let operation = &self.operation; downcast_integer! { data_type => (group_accumulator_helper, data_type, operation), diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs index 1b33a7900c00..034a28c27bb7 100644 --- a/datafusion/functions-aggregate/src/bool_and_or.rs +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -166,14 +166,14 @@ impl AggregateUDFImpl for BoolAnd { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => { Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y, true))) } _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } @@ -304,7 +304,7 @@ impl AggregateUDFImpl for BoolOr { &self, args: AccumulatorArgs, ) -> Result> { - match args.return_type { + match args.return_field.data_type() { DataType::Boolean => Ok(Box::new(BooleanGroupsAccumulator::new( |x, y| x || y, false, @@ -312,7 +312,7 @@ impl AggregateUDFImpl for BoolOr { _ => not_impl_err!( "GroupsAccumulator not supported for {} with {}", args.name, - args.return_type + args.return_field.data_type() ), } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 2d995b4a4179..42078c735578 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -206,7 +206,7 @@ impl AggregateUDFImpl for Count { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, )]) } else { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 530b4620809b..ea2ec63711b4 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -161,7 +161,7 @@ impl AggregateUDFImpl for FirstValue { acc_args.ordering_req.is_empty() || self.requirement_satisfied; FirstValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, acc_args.ordering_req.clone(), acc_args.ignore_nulls, @@ -172,7 +172,7 @@ impl AggregateUDFImpl for FirstValue { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( format_state_name(args.name, "first_value"), - args.return_type.clone(), + args.return_field.data_type().clone(), true, )]; fields.extend(args.ordering_fields.to_vec()); @@ -184,7 +184,7 @@ impl AggregateUDFImpl for FirstValue { // TODO: extract to function use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -225,13 +225,13 @@ impl AggregateUDFImpl for FirstValue { Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( args.ordering_req.clone(), args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, true, )?)) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -279,7 +279,7 @@ impl AggregateUDFImpl for FirstValue { _ => { internal_err!( "GroupsAccumulator not supported for first_value({})", - args.return_type + args.return_field.data_type() ) } } @@ -1038,7 +1038,7 @@ impl AggregateUDFImpl for LastValue { acc_args.ordering_req.is_empty() || self.requirement_satisfied; LastValueAccumulator::try_new( - acc_args.return_type, + acc_args.return_field.data_type(), &ordering_dtypes, acc_args.ordering_req.clone(), acc_args.ignore_nulls, @@ -1049,14 +1049,14 @@ impl AggregateUDFImpl for LastValue { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let StateFieldsArgs { name, - input_types, - return_type: _, + input_fields, + return_field: _, ordering_fields, is_distinct: _, } = args; let mut fields = vec![Field::new( format_state_name(name, "last_value"), - input_types[0].clone(), + input_fields[0].data_type().clone(), true, )]; fields.extend(ordering_fields.to_vec()); @@ -1092,7 +1092,7 @@ impl AggregateUDFImpl for LastValue { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1132,13 +1132,13 @@ impl AggregateUDFImpl for LastValue { Ok(Box::new(FirstPrimitiveGroupsAccumulator::::try_new( args.ordering_req.clone(), args.ignore_nulls, - args.return_type, + args.return_field.data_type(), &ordering_dtypes, false, )?)) } - match args.return_type { + match args.return_field.data_type() { DataType::Int8 => create_accumulator::(args), DataType::Int16 => create_accumulator::(args), DataType::Int32 => create_accumulator::(args), @@ -1186,7 +1186,7 @@ impl AggregateUDFImpl for LastValue { _ => { internal_err!( "GroupsAccumulator not supported for last_value({})", - args.return_type + args.return_field.data_type() ) } } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index ba6b63260e06..3d3f38503359 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -127,7 +127,7 @@ impl AggregateUDFImpl for Median { fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new_list_field(args.input_types[0].clone(), true); + let field = Field::new_list_field(args.input_fields[0].data_type().clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 2d51926a7bc6..f6b1589acaa0 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -234,7 +234,9 @@ impl AggregateUDFImpl for Max { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) + Ok(Box::new(MaxAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn aliases(&self) -> &[String] { @@ -244,7 +246,7 @@ impl AggregateUDFImpl for Max { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -279,7 +281,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), @@ -357,7 +359,9 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { @@ -1179,7 +1183,9 @@ impl AggregateUDFImpl for Min { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) + Ok(Box::new(MinAccumulator::try_new( + acc_args.return_field.data_type(), + )?)) } fn aliases(&self) -> &[String] { @@ -1189,7 +1195,7 @@ impl AggregateUDFImpl for Min { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; matches!( - args.return_type, + args.return_field.data_type(), Int8 | Int16 | Int32 | Int64 @@ -1224,7 +1230,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = args.return_type; + let data_type = args.return_field.data_type(); match data_type { Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), @@ -1302,7 +1308,9 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new( + args.return_field.data_type(), + )?)) } fn is_descending(&self) -> Option { diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index d84bd02a6baf..8a7c721dd472 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -168,7 +168,7 @@ impl AggregateUDFImpl for NthValueAgg { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.input_types[0].clone(), true), + Field::new_list_field(args.input_fields[0].data_type().clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index adf86a128cfb..5d3a6d5f70a7 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -436,7 +436,7 @@ mod tests { schema: &Schema, ) -> Result { let args1 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), @@ -447,7 +447,7 @@ mod tests { }; let args2 = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), schema, ignore_nulls: false, ordering_req: &LexOrdering::default(), diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7594b9ccb01..d59f8a576e78 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -154,7 +154,11 @@ impl AggregateUDFImpl for StringAgg { }; let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { - return_type: &DataType::new_list(acc_args.return_type.clone(), true), + return_field: &Field::new( + "f", + DataType::new_list(acc_args.return_field.data_type().clone(), true), + true, + ), exprs: &filter_index(acc_args.exprs, 1), ..acc_args })?; @@ -436,7 +440,7 @@ mod tests { fn build(&self) -> Result> { StringAgg::new().accumulator(AccumulatorArgs { - return_type: &DataType::LargeUtf8, + return_field: &Field::new("f", DataType::LargeUtf8, true), schema: &self.schema, ignore_nulls: false, ordering_req: &self.ordering, diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d88..a54d0af34693 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -63,17 +63,27 @@ make_udaf_expr_and_func!( /// `helper` is a macro accepting (ArrowPrimitiveType, DataType) macro_rules! downcast_sum { ($args:ident, $helper:ident) => { - match $args.return_type { - DataType::UInt64 => $helper!(UInt64Type, $args.return_type), - DataType::Int64 => $helper!(Int64Type, $args.return_type), - DataType::Float64 => $helper!(Float64Type, $args.return_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), + match $args.return_field.data_type().clone() { + DataType::UInt64 => { + $helper!(UInt64Type, $args.return_field.data_type().clone()) + } + DataType::Int64 => { + $helper!(Int64Type, $args.return_field.data_type().clone()) + } + DataType::Float64 => { + $helper!(Float64Type, $args.return_field.data_type().clone()) + } + DataType::Decimal128(_, _) => { + $helper!(Decimal128Type, $args.return_field.data_type().clone()) + } + DataType::Decimal256(_, _) => { + $helper!(Decimal256Type, $args.return_field.data_type().clone()) + } _ => { not_impl_err!( "Sum not supported for {}: {}", $args.name, - $args.return_type + $args.return_field.data_type() ) } } @@ -196,13 +206,13 @@ impl AggregateUDFImpl for Sum { Ok(vec![Field::new_list( format_state_name(args.name, "sum distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new_list_field(args.return_type.clone(), true), + Field::new_list_field(args.return_field.data_type().clone(), true), false, )]) } else { Ok(vec![Field::new( format_state_name(args.name, "sum"), - args.return_type.clone(), + args.return_field.data_type().clone(), true, )]) } diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs index 76e27b045b0a..eb2516a1e556 100644 --- a/datafusion/functions-window-common/src/expr.rs +++ b/datafusion/functions-window-common/src/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -25,9 +25,9 @@ pub struct ExpressionArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [Field], } impl<'a> ExpressionArgs<'a> { @@ -42,11 +42,11 @@ impl<'a> ExpressionArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [Field], ) -> Self { Self { input_exprs, - input_types, + input_fields, } } @@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`Field`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [Field] { + self.input_fields } } diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs index 03f88b0b95cc..9e1898908c95 100644 --- a/datafusion/functions-window-common/src/field.rs +++ b/datafusion/functions-window-common/src/field.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; /// Metadata for defining the result field from evaluating a /// user-defined window function. pub struct WindowUDFFieldArgs<'a> { - /// The data types corresponding to the arguments to the + /// The fields corresponding to the arguments to the /// user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [Field], /// The display name of the user-defined window function. display_name: &'a str, } @@ -32,22 +32,22 @@ impl<'a> WindowUDFFieldArgs<'a> { /// /// # Arguments /// - /// * `input_types` - The data types corresponding to the + /// * `input_fields` - The fields corresponding to the /// arguments to the user-defined window function. /// * `function_name` - The qualified schema name of the /// user-defined window function expression. /// - pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self { + pub fn new(input_fields: &'a [Field], display_name: &'a str) -> Self { WindowUDFFieldArgs { - input_types, + input_fields, display_name, } } - /// Returns the data type of input expressions passed as arguments + /// Returns the field of input expressions passed as arguments /// to the user-defined window function. - pub fn input_types(&self) -> &[DataType] { - self.input_types + pub fn input_fields(&self) -> &[Field] { + self.input_fields } /// Returns the name for the field of the final result of evaluating @@ -56,9 +56,9 @@ impl<'a> WindowUDFFieldArgs<'a> { self.display_name } - /// Returns `Some(DataType)` of input expression at index, otherwise + /// Returns `Some(Field)` of input expression at index, otherwise /// returns `None` if the index is out of bounds. - pub fn get_input_type(&self, index: usize) -> Option { - self.input_types.get(index).cloned() + pub fn get_input_field(&self, index: usize) -> Option { + self.input_fields.get(index).cloned() } } diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs index e853aa8fb05d..64c28e61a2cd 100644 --- a/datafusion/functions-window-common/src/partition.rs +++ b/datafusion/functions-window-common/src/partition.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use std::sync::Arc; @@ -26,9 +26,9 @@ pub struct PartitionEvaluatorArgs<'a> { /// The expressions passed as arguments to the user-defined window /// function. input_exprs: &'a [Arc], - /// The corresponding data types of expressions passed as arguments + /// The corresponding fields of expressions passed as arguments /// to the user-defined window function. - input_types: &'a [DataType], + input_fields: &'a [Field], /// Set to `true` if the user-defined window function is reversed. is_reversed: bool, /// Set to `true` if `IGNORE NULLS` is specified. @@ -51,13 +51,13 @@ impl<'a> PartitionEvaluatorArgs<'a> { /// pub fn new( input_exprs: &'a [Arc], - input_types: &'a [DataType], + input_fields: &'a [Field], is_reversed: bool, ignore_nulls: bool, ) -> Self { Self { input_exprs, - input_types, + input_fields, is_reversed, ignore_nulls, } @@ -69,10 +69,10 @@ impl<'a> PartitionEvaluatorArgs<'a> { self.input_exprs } - /// Returns the [`DataType`]s corresponding to the input expressions + /// Returns the [`Field`]s corresponding to the input expressions /// to the user-defined window function. - pub fn input_types(&self) -> &'a [DataType] { - self.input_types + pub fn input_fields(&self) -> &'a [Field] { + self.input_fields } /// Returns `true` when the user-defined window function is diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 84628a77a26c..6ebbceaced5e 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -240,7 +240,7 @@ impl WindowUDFImpl for WindowShift { /// /// For more details see: fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { - parse_expr(expr_args.input_exprs(), expr_args.input_types()) + parse_expr(expr_args.input_exprs(), expr_args.input_fields()) .into_iter() .collect::>() } @@ -263,7 +263,7 @@ impl WindowUDFImpl for WindowShift { })?; let default_value = parse_default_value( partition_evaluator_args.input_exprs(), - partition_evaluator_args.input_types(), + partition_evaluator_args.input_fields(), )?; Ok(Box::new(WindowShiftEvaluator { @@ -275,9 +275,9 @@ impl WindowUDFImpl for WindowShift { } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let return_type = parse_expr_type(field_args.input_types())?; + let return_field = parse_expr_field(field_args.input_fields())?; - Ok(Field::new(field_args.name(), return_type, true)) + Ok(return_field.with_name(field_args.name())) } fn reverse_expr(&self) -> ReversedUDWF { @@ -309,16 +309,16 @@ impl WindowUDFImpl for WindowShift { /// For more details see: fn parse_expr( input_exprs: &[Arc], - input_types: &[DataType], + input_fields: &[Field], ) -> Result> { assert!(!input_exprs.is_empty()); - assert!(!input_types.is_empty()); + assert!(!input_fields.is_empty()); let expr = Arc::clone(input_exprs.first().unwrap()); - let expr_type = input_types.first().unwrap(); + let expr_field = input_fields.first().unwrap(); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { + if !expr_field.data_type().is_null() { return Ok(expr); } @@ -331,36 +331,39 @@ fn parse_expr( }) } -/// Returns the data type of the default value(if provided) when the +static NULL_FIELD: LazyLock = + LazyLock::new(|| Field::new("value", DataType::Null, true)); + +/// Returns the field of the default value(if provided) when the /// expression is `NULL`. /// -/// Otherwise, returns the expression type unchanged. -fn parse_expr_type(input_types: &[DataType]) -> Result { - assert!(!input_types.is_empty()); - let expr_type = input_types.first().unwrap_or(&DataType::Null); +/// Otherwise, returns the expression field unchanged. +fn parse_expr_field(input_fields: &[Field]) -> Result { + assert!(!input_fields.is_empty()); + let expr_field = input_fields.first().unwrap_or(&NULL_FIELD); // Handles the most common case where NULL is unexpected - if !expr_type.is_null() { - return Ok(expr_type.clone()); + if !expr_field.data_type().is_null() { + return Ok(expr_field.clone().with_nullable(true)); } - let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); - Ok(default_value_type.clone()) + let default_value_field = input_fields.get(2).unwrap_or(&NULL_FIELD); + Ok(default_value_field.clone().with_nullable(true)) } /// Handles type coercion and null value refinement for default value /// argument depending on the data type of the input expression. fn parse_default_value( input_exprs: &[Arc], - input_types: &[DataType], + input_types: &[Field], ) -> Result { - let expr_type = parse_expr_type(input_types)?; + let expr_field = parse_expr_field(input_types)?; let unparsed = get_scalar_value_from_args(input_exprs, 2)?; unparsed .filter(|v| !v.data_type().is_null()) - .map(|v| v.cast_to(&expr_type)) - .unwrap_or_else(|| ScalarValue::try_from(expr_type)) + .map(|v| v.cast_to(expr_field.data_type())) + .unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type())) } #[derive(Debug)] @@ -705,7 +708,12 @@ mod tests { test_i32_result( WindowShift::lead(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true)], + false, + false, + ), [ Some(-2), Some(3), @@ -727,7 +735,12 @@ mod tests { test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true)], + false, + false, + ), [ None, Some(1), @@ -752,12 +765,14 @@ mod tests { as Arc; let input_exprs = &[expr, shift_offset, default_value]; - let input_types: &[DataType] = - &[DataType::Int32, DataType::Int32, DataType::Int32]; + let input_fields = [DataType::Int32, DataType::Int32, DataType::Int32] + .into_iter() + .map(|d| Field::new("f", d, true)) + .collect::>(); test_i32_result( WindowShift::lag(), - PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), + PartitionEvaluatorArgs::new(input_exprs, &input_fields, false, false), [ Some(100), Some(1), diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs index 2ef1eacba953..27799140931e 100644 --- a/datafusion/functions-window/src/macros.rs +++ b/datafusion/functions-window/src/macros.rs @@ -286,7 +286,7 @@ macro_rules! get_or_init_udwf { /// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, /// # )) /// # } @@ -557,7 +557,7 @@ macro_rules! create_udwf_expr { /// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, /// # )) /// # } @@ -646,7 +646,7 @@ macro_rules! create_udwf_expr { /// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { /// # Ok(Field::new( /// # field_args.name(), -/// # field_args.get_input_type(0).unwrap(), +/// # field_args.get_input_field(0).unwrap().data_type().clone(), /// # false, /// # )) /// # } diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 45c2ef243ab0..b2ecc87f4be8 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -310,10 +310,14 @@ impl WindowUDFImpl for NthValue { } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - let nullable = true; - let return_type = field_args.input_types().first().unwrap_or(&DataType::Null); - - Ok(Field::new(field_args.name(), return_type.clone(), nullable)) + let return_type = field_args + .input_fields() + .first() + .map(|f| f.data_type()) + .cloned() + .unwrap_or(DataType::Null); + + Ok(Field::new(field_args.name(), return_type, true)) } fn reverse_expr(&self) -> ReversedUDWF { @@ -551,7 +555,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::first(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true)], + false, + false, + ), Int32Array::from(vec![1; 8]).iter().collect::(), ) } @@ -561,7 +570,12 @@ mod tests { let expr = Arc::new(Column::new("c3", 0)) as Arc; test_i32_result( NthValue::last(), - PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + PartitionEvaluatorArgs::new( + &[expr], + &[Field::new("f", DataType::Int32, true)], + false, + false, + ), Int32Array::from(vec![ Some(1), Some(-2), @@ -585,7 +599,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true)], false, false, ), @@ -604,7 +618,7 @@ mod tests { NthValue::nth(), PartitionEvaluatorArgs::new( &[expr, n_value], - &[DataType::Int32], + &[Field::new("f", DataType::Int32, true)], false, false, ), diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 1c729fcbc2ca..e5779646b921 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -41,7 +41,7 @@ use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{comparison_coercion, like_coercion}; use datafusion_expr::type_coercion::functions::{ - data_types_with_aggregate_udf, data_types_with_scalar_udf, + data_types_with_scalar_udf, fields_with_aggregate_udf, }; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, @@ -809,12 +809,15 @@ fn coerce_arguments_for_signature_with_aggregate_udf( return Ok(expressions); } - let current_types = expressions + let current_fields = expressions .iter() - .map(|e| e.get_type(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f.as_ref().clone())) .collect::>>()?; - let new_types = data_types_with_aggregate_udf(¤t_types, func)?; + let new_types = fields_with_aggregate_udf(¤t_fields, func)? + .into_iter() + .map(|f| f.data_type().clone()) + .collect::>(); expressions .into_iter() diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 49912954ac81..867b4e0fc955 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -213,18 +213,18 @@ impl AggregateExprBuilder { utils::ordering_fields(ordering_req.as_ref(), &ordering_types); } - let input_exprs_types = args + let input_exprs_fields = args .iter() - .map(|arg| arg.data_type(&schema)) + .map(|arg| arg.return_field(&schema)) .collect::>>()?; check_arg_count( fun.name(), - &input_exprs_types, + &input_exprs_fields, &fun.signature().type_signature, )?; - let data_type = fun.return_type(&input_exprs_types)?; + let return_field = fun.return_field(&input_exprs_fields)?; let is_nullable = fun.is_nullable(); let name = match alias { None => { @@ -238,7 +238,7 @@ impl AggregateExprBuilder { Ok(AggregateFunctionExpr { fun: Arc::unwrap_or_clone(fun), args, - data_type, + return_field, name, human_display, schema: Arc::unwrap_or_clone(schema), @@ -246,7 +246,7 @@ impl AggregateExprBuilder { ignore_nulls, ordering_fields, is_distinct, - input_types: input_exprs_types, + input_fields: input_exprs_fields, is_reversed, is_nullable, }) @@ -310,8 +310,8 @@ impl AggregateExprBuilder { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, + /// Output / return field of this aggregate + return_field: Field, /// Output column name that this expression creates name: String, /// Simplified name for `tree` explain. @@ -325,7 +325,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_types: Vec, + input_fields: Vec, is_nullable: bool, } @@ -373,7 +373,7 @@ impl AggregateFunctionExpr { /// the field of the final result of this aggregation. pub fn field(&self) -> Field { - Field::new(&self.name, self.data_type.clone(), self.is_nullable) + self.return_field.clone().with_name(&self.name) } /// the accumulator used to accumulate values from the expressions. @@ -381,7 +381,7 @@ impl AggregateFunctionExpr { /// return states with the same description as `state_fields` pub fn create_accumulator(&self) -> Result> { let acc_args = AccumulatorArgs { - return_type: &self.data_type, + return_field: &self.return_field, schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -398,8 +398,8 @@ impl AggregateFunctionExpr { pub fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_types: &self.input_types, - return_type: &self.data_type, + input_fields: &self.input_fields, + return_field: &self.return_field, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, }; @@ -472,7 +472,7 @@ impl AggregateFunctionExpr { /// Creates accumulator implementation that supports retract pub fn create_sliding_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: &self.return_field, schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -541,7 +541,7 @@ impl AggregateFunctionExpr { /// `[Self::create_groups_accumulator`] will be called. pub fn groups_accumulator_supported(&self) -> bool { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: &self.return_field, schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -560,7 +560,7 @@ impl AggregateFunctionExpr { /// implemented in addition to [`Accumulator`]. pub fn create_groups_accumulator(&self) -> Result> { let args = AccumulatorArgs { - return_type: &self.data_type, + return_field: &self.return_field, schema: &self.schema, ignore_nulls: self.ignore_nulls, ordering_req: self.ordering_req.as_ref(), @@ -685,7 +685,7 @@ pub struct AggregatePhysicalExpressions { impl PartialEq for AggregateFunctionExpr { fn eq(&self, other: &Self) -> bool { self.name == other.name - && self.data_type == other.data_type + && self.return_field == other.return_field && self.fun == other.fun && self.args.len() == other.args.len() && self diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index d38bf2a186a8..f773391a6a70 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -30,7 +30,7 @@ use crate::{ InputOrderMode, PhysicalExpr, }; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow_schema::SortOptions; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ @@ -65,16 +65,16 @@ pub fn schema_add_window_field( window_fn: &WindowFunctionDefinition, fn_name: &str, ) -> Result> { - let data_types = args + let fields = args .iter() - .map(|e| Arc::clone(e).as_ref().data_type(schema)) + .map(|e| Arc::clone(e).as_ref().return_field(schema)) .collect::>>()?; let nullability = args .iter() .map(|e| Arc::clone(e).as_ref().nullable(schema)) .collect::>>()?; - let window_expr_return_type = - window_fn.return_type(&data_types, &nullability, fn_name)?; + let window_expr_return_field = + window_fn.return_field(&fields, &nullability, fn_name)?; let mut window_fields = schema .fields() .iter() @@ -84,11 +84,7 @@ pub fn schema_add_window_field( if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { Ok(Arc::new(Schema::new(window_fields))) } else { - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - false, - )]); + window_fields.extend_from_slice(&[window_expr_return_field.with_name(fn_name)]); Ok(Arc::new(Schema::new(window_fields))) } } @@ -165,15 +161,15 @@ pub fn create_udwf_window_expr( ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args + let input_fields: Vec<_> = args .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.return_field(input_schema)) .collect::>()?; let udwf_expr = Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), - input_types, + input_fields, name, is_reversed: false, ignore_nulls, @@ -202,8 +198,8 @@ pub struct WindowUDFExpr { args: Vec>, /// Display name name: String, - /// Types of input expressions - input_types: Vec, + /// Fields of input expressions + input_fields: Vec, /// This is set to `true` only if the user-defined window function /// expression supports evaluation in reverse order, and the /// evaluation order is reversed. @@ -225,19 +221,19 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { fn field(&self) -> Result { self.fun - .field(WindowUDFFieldArgs::new(&self.input_types, &self.name)) + .field(WindowUDFFieldArgs::new(&self.input_fields, &self.name)) } fn expressions(&self) -> Vec> { self.fun - .expressions(ExpressionArgs::new(&self.args, &self.input_types)) + .expressions(ExpressionArgs::new(&self.args, &self.input_fields)) } fn create_evaluator(&self) -> Result> { self.fun .partition_evaluator_factory(PartitionEvaluatorArgs::new( &self.args, - &self.input_types, + &self.input_fields, self.is_reversed, self.ignore_nulls, )) @@ -255,7 +251,7 @@ impl StandardWindowFunctionExpr for WindowUDFExpr { fun, args: self.args.clone(), name: self.name.clone(), - input_types: self.input_types.clone(), + input_fields: self.input_fields.clone(), is_reversed: !self.is_reversed, ignore_nulls: self.ignore_nulls, })), @@ -641,6 +637,7 @@ mod tests { use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use arrow::compute::SortOptions; + use arrow_schema::DataType; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::count::count_udaf; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ac3b14aeafc5..369700bded04 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2517,13 +2517,13 @@ fn roundtrip_window() { } fn field(&self, field_args: WindowUDFFieldArgs) -> Result { - if let Some(return_type) = field_args.get_input_type(0) { - Ok(Field::new(field_args.name(), return_type, true)) + if let Some(return_field) = field_args.get_input_field(0) { + Ok(return_field.with_name(field_args.name())) } else { plan_err!( "dummy_udwf expects 1 argument, got {}: {:?}", - field_args.input_types().len(), - field_args.input_types() + field_args.input_fields().len(), + field_args.input_fields() ) } } diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 0470bdbb27be..620be657b44b 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -38,6 +38,21 @@ has been removed, so you will need to remove all references to it. `ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this to access the metadata associated with the columnar values during invocation. +To upgrade user defined aggregate functions, there is now a function +`return_field` that will allow you to specify both metadata and nullability of +your function. You are not required to implement this if you do not need to +handle metatdata. + +The largest change to aggregate functions happens in the accumulator arguments. +Both the `AccumulatorArgs` and `StateFieldsArgs` now contain `Field` rather +than `DataType`. + +To upgrade window functions, `ExpressionArgs` now contains input fields instead +of input data types. When setting these fields, the name of the field is +not important since this gets overwritten during the planning stage. All you +should need to do is wrap your existing data types in fields with nullability +set depending on your use case. + ### Physical Expression return `Field` To support the changes to user defined functions processing metadata, the