Skip to content

Commit 742e3c5

Browse files
lewiszlwalamb
andauthored
Remove ScalarFunctionDefinition (#10325)
* Remove ScalarFunctionDefinition * Fix test * rename func_def to func --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 5146f44 commit 742e3c5

File tree

22 files changed

+124
-231
lines changed

22 files changed

+124
-231
lines changed

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ use log::{debug, trace};
3838

3939
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
4040
use datafusion_common::{Column, DFSchema, DataFusionError};
41-
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
41+
use datafusion_expr::{Expr, Volatility};
4242
use datafusion_physical_expr::create_physical_expr;
4343
use object_store::path::Path;
4444
use object_store::{ObjectMeta, ObjectStore};
@@ -89,16 +89,12 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
8989
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),
9090

9191
Expr::ScalarFunction(scalar_function) => {
92-
match &scalar_function.func_def {
93-
ScalarFunctionDefinition::UDF(fun) => {
94-
match fun.signature().volatility {
95-
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
96-
// TODO: Stable functions could be `applicable`, but that would require access to the context
97-
Volatility::Stable | Volatility::Volatile => {
98-
is_applicable = false;
99-
Ok(TreeNodeRecursion::Stop)
100-
}
101-
}
92+
match scalar_function.func.signature().volatility {
93+
Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
94+
// TODO: Stable functions could be `applicable`, but that would require access to the context
95+
Volatility::Stable | Volatility::Volatile => {
96+
is_applicable = false;
97+
Ok(TreeNodeRecursion::Stop)
10298
}
10399
}
104100
}

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,8 +1301,7 @@ mod tests {
13011301
use datafusion_execution::object_store::ObjectStoreUrl;
13021302
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
13031303
use datafusion_expr::{
1304-
ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl,
1305-
Signature, Volatility,
1304+
ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
13061305
};
13071306
use datafusion_physical_expr::expressions::{
13081307
BinaryExpr, CaseExpr, CastExpr, NegativeExpr,
@@ -1363,9 +1362,7 @@ mod tests {
13631362
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
13641363
Arc::new(ScalarFunctionExpr::new(
13651364
"scalar_expr",
1366-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
1367-
DummyUDF::new(),
1368-
))),
1365+
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
13691366
vec![
13701367
Arc::new(BinaryExpr::new(
13711368
Arc::new(Column::new("b", 1)),
@@ -1431,9 +1428,7 @@ mod tests {
14311428
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
14321429
Arc::new(ScalarFunctionExpr::new(
14331430
"scalar_expr",
1434-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
1435-
DummyUDF::new(),
1436-
))),
1431+
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
14371432
vec![
14381433
Arc::new(BinaryExpr::new(
14391434
Arc::new(Column::new("b", 1)),
@@ -1502,9 +1497,7 @@ mod tests {
15021497
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
15031498
Arc::new(ScalarFunctionExpr::new(
15041499
"scalar_expr",
1505-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
1506-
DummyUDF::new(),
1507-
))),
1500+
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
15081501
vec![
15091502
Arc::new(BinaryExpr::new(
15101503
Arc::new(Column::new("b", 1)),
@@ -1570,9 +1563,7 @@ mod tests {
15701563
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
15711564
Arc::new(ScalarFunctionExpr::new(
15721565
"scalar_expr",
1573-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
1574-
DummyUDF::new(),
1575-
))),
1566+
Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
15761567
vec![
15771568
Arc::new(BinaryExpr::new(
15781569
Arc::new(Column::new("b_new", 1)),

datafusion/expr/src/expr.rs

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ use std::sync::Arc;
2626
use crate::expr_fn::binary_expr;
2727
use crate::logical_plan::Subquery;
2828
use crate::utils::expr_to_columns;
29-
use crate::window_frame;
3029
use crate::{
3130
aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator,
3231
Signature,
3332
};
33+
use crate::{window_frame, Volatility};
3434

3535
use arrow::datatypes::{DataType, FieldRef};
3636
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -399,60 +399,26 @@ impl Between {
399399
}
400400
}
401401

402-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
403-
/// Defines which implementation of a function for DataFusion to call.
404-
pub enum ScalarFunctionDefinition {
405-
/// Resolved to a user defined function
406-
UDF(Arc<crate::ScalarUDF>),
407-
}
408-
409402
/// ScalarFunction expression invokes a built-in scalar function
410403
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
411404
pub struct ScalarFunction {
412405
/// The function
413-
pub func_def: ScalarFunctionDefinition,
406+
pub func: Arc<crate::ScalarUDF>,
414407
/// List of expressions to feed to the functions as arguments
415408
pub args: Vec<Expr>,
416409
}
417410

418411
impl ScalarFunction {
419412
// return the Function's name
420413
pub fn name(&self) -> &str {
421-
self.func_def.name()
422-
}
423-
}
424-
425-
impl ScalarFunctionDefinition {
426-
/// Function's name for display
427-
pub fn name(&self) -> &str {
428-
match self {
429-
ScalarFunctionDefinition::UDF(udf) => udf.name(),
430-
}
431-
}
432-
433-
/// Whether this function is volatile, i.e. whether it can return different results
434-
/// when evaluated multiple times with the same input.
435-
pub fn is_volatile(&self) -> Result<bool> {
436-
match self {
437-
ScalarFunctionDefinition::UDF(udf) => {
438-
Ok(udf.signature().volatility == crate::Volatility::Volatile)
439-
}
440-
}
414+
self.func.name()
441415
}
442416
}
443417

444418
impl ScalarFunction {
445419
/// Create a new ScalarFunction expression with a user-defined function (UDF)
446420
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
447-
Self {
448-
func_def: ScalarFunctionDefinition::UDF(udf),
449-
args,
450-
}
451-
}
452-
453-
/// Create a new ScalarFunction expression with a user-defined function (UDF)
454-
pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec<Expr>) -> Self {
455-
Self { func_def, args }
421+
Self { func: udf, args }
456422
}
457423
}
458424

@@ -1299,7 +1265,7 @@ impl Expr {
12991265
/// results when evaluated multiple times with the same input.
13001266
pub fn is_volatile(&self) -> Result<bool> {
13011267
self.exists(|expr| {
1302-
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func_def.is_volatile()?))
1268+
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile ))
13031269
})
13041270
}
13051271

@@ -1334,9 +1300,7 @@ impl Expr {
13341300
/// and thus any side effects (like divide by zero) may not be encountered
13351301
pub fn short_circuits(&self) -> bool {
13361302
match self {
1337-
Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
1338-
matches!(func_def, ScalarFunctionDefinition::UDF(fun) if fun.short_circuits())
1339-
}
1303+
Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(),
13401304
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
13411305
matches!(op, Operator::And | Operator::Or)
13421306
}
@@ -2071,7 +2035,7 @@ mod test {
20712035
}
20722036

20732037
#[test]
2074-
fn test_is_volatile_scalar_func_definition() {
2038+
fn test_is_volatile_scalar_func() {
20752039
// UDF
20762040
#[derive(Debug)]
20772041
struct TestScalarUDF {
@@ -2100,7 +2064,7 @@ mod test {
21002064
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
21012065
signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
21022066
}));
2103-
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
2067+
assert_ne!(udf.signature().volatility, Volatility::Volatile);
21042068

21052069
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
21062070
signature: Signature::uniform(
@@ -2109,7 +2073,7 @@ mod test {
21092073
Volatility::Volatile,
21102074
),
21112075
}));
2112-
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
2076+
assert_eq!(udf.signature().volatility, Volatility::Volatile);
21132077
}
21142078

21152079
use super::*;

datafusion/expr/src/expr_schema.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use super::{Between, Expr, Like};
1919
use crate::expr::{
2020
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
2121
GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
22-
ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
22+
Sort, TryCast, Unnest, WindowFunction,
2323
};
2424
use crate::field_util::GetFieldAccessSchema;
2525
use crate::type_coercion::binary::get_result_type;
@@ -133,30 +133,26 @@ impl ExprSchemable for Expr {
133133
}
134134
}
135135
}
136-
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
136+
Expr::ScalarFunction(ScalarFunction { func, args }) => {
137137
let arg_data_types = args
138138
.iter()
139139
.map(|e| e.get_type(schema))
140140
.collect::<Result<Vec<_>>>()?;
141-
match func_def {
142-
ScalarFunctionDefinition::UDF(fun) => {
143141
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
144-
data_types(&arg_data_types, fun.signature()).map_err(|_| {
142+
data_types(&arg_data_types, func.signature()).map_err(|_| {
145143
plan_datafusion_err!(
146144
"{}",
147145
utils::generate_signature_error_msg(
148-
fun.name(),
149-
fun.signature().clone(),
146+
func.name(),
147+
func.signature().clone(),
150148
&arg_data_types,
151149
)
152150
)
153151
})?;
154152

155153
// perform additional function arguments validation (due to limited
156154
// expressiveness of `TypeSignature`), then infer return type
157-
Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?)
158-
}
159-
}
155+
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
160156
}
161157
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
162158
let data_types = args

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
6363
pub use columnar_value::ColumnarValue;
6464
pub use expr::{
6565
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet,
66-
Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition,
66+
Like, TryCast, WindowFunctionDefinition,
6767
};
6868
pub use expr_fn::*;
6969
pub use expr_schema::ExprSchemable;

datafusion/expr/src/tree_node.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use crate::expr::{
2121
AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case,
2222
Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
23-
ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
23+
ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
2424
};
2525
use crate::{Expr, GetFieldAccess};
2626

@@ -281,11 +281,11 @@ impl TreeNode for Expr {
281281
nulls_first,
282282
}) => transform_box(expr, &mut f)?
283283
.update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))),
284-
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
285-
transform_vec(args, &mut f)?.map_data(|new_args| match func_def {
286-
ScalarFunctionDefinition::UDF(fun) => {
287-
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args)))
288-
}
284+
Expr::ScalarFunction(ScalarFunction { func, args }) => {
285+
transform_vec(args, &mut f)?.map_data(|new_args| {
286+
Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
287+
func, new_args,
288+
)))
289289
})?
290290
}
291291
Expr::WindowFunction(WindowFunction {

datafusion/functions-array/src/rewrite.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,20 @@ impl FunctionRewrite for ArrayFunctionRewriter {
182182
/// Returns true if expr is a function call to the specified named function.
183183
/// Returns false otherwise.
184184
fn is_func(expr: &Expr, func_name: &str) -> bool {
185-
let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else {
185+
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
186186
return false;
187187
};
188188

189-
func_def.name() == func_name
189+
func.name() == func_name
190190
}
191191

192192
/// Returns true if expr is a function call with one of the specified names
193193
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
194-
let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else {
194+
let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
195195
return false;
196196
};
197197

198-
func_names.contains(&func_def.name())
198+
func_names.contains(&func.name())
199199
}
200200

201201
/// returns Some(col) if this is Expr::Column

datafusion/functions/src/math/log.rs

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ use datafusion_common::{
2424
};
2525
use datafusion_expr::expr::ScalarFunction;
2626
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
27-
use datafusion_expr::{
28-
lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
29-
};
27+
use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF};
3028

3129
use arrow::array::{ArrayRef, Float32Array, Float64Array};
3230
use datafusion_expr::TypeSignature::*;
@@ -178,8 +176,8 @@ impl ScalarUDFImpl for LogFunc {
178176
&info.get_data_type(&base)?,
179177
)?)))
180178
}
181-
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
182-
if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
179+
Expr::ScalarFunction(ScalarFunction { func, mut args })
180+
if is_pow(&func) && args.len() == 2 && base == args[0] =>
183181
{
184182
let b = args.pop().unwrap(); // length checked above
185183
Ok(ExprSimplifyResult::Simplified(b))
@@ -207,15 +205,8 @@ impl ScalarUDFImpl for LogFunc {
207205
}
208206

209207
/// Returns true if the function is `PowerFunc`
210-
fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
211-
match func_def {
212-
ScalarFunctionDefinition::UDF(fun) => fun
213-
.as_ref()
214-
.inner()
215-
.as_any()
216-
.downcast_ref::<PowerFunc>()
217-
.is_some(),
218-
}
208+
fn is_pow(func: &ScalarUDF) -> bool {
209+
func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
219210
}
220211

221212
#[cfg(test)]

datafusion/functions/src/math/power.rs

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use datafusion_common::{
2323
};
2424
use datafusion_expr::expr::ScalarFunction;
2525
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
26-
use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
26+
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};
2727

2828
use arrow::array::{ArrayRef, Float64Array, Int64Array};
2929
use datafusion_expr::TypeSignature::*;
@@ -140,8 +140,8 @@ impl ScalarUDFImpl for PowerFunc {
140140
Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => {
141141
Ok(ExprSimplifyResult::Simplified(base))
142142
}
143-
Expr::ScalarFunction(ScalarFunction { func_def, mut args })
144-
if is_log(&func_def) && args.len() == 2 && base == args[0] =>
143+
Expr::ScalarFunction(ScalarFunction { func, mut args })
144+
if is_log(&func) && args.len() == 2 && base == args[0] =>
145145
{
146146
let b = args.pop().unwrap(); // length checked above
147147
Ok(ExprSimplifyResult::Simplified(b))
@@ -152,15 +152,8 @@ impl ScalarUDFImpl for PowerFunc {
152152
}
153153

154154
/// Return true if this function call is a call to `Log`
155-
fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
156-
match func_def {
157-
ScalarFunctionDefinition::UDF(fun) => fun
158-
.as_ref()
159-
.inner()
160-
.as_any()
161-
.downcast_ref::<LogFunc>()
162-
.is_some(),
163-
}
155+
fn is_log(func: &ScalarUDF) -> bool {
156+
func.inner().as_any().downcast_ref::<LogFunc>().is_some()
164157
}
165158

166159
#[cfg(test)]

0 commit comments

Comments
 (0)