Skip to content

Add customizable equality and hash functions to UDFs #11392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
//! This module contains end to end demonstrations of creating
//! user defined aggregate functions

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray};
use arrow_schema::Schema;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};

use arrow::{array::AsArray, datatypes::Fields};
use arrow_array::{
types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray,
};
use arrow_schema::Schema;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::test_util::plan_and_collect;
use datafusion::{
Expand All @@ -45,8 +50,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
LogicalPlanBuilder, SimpleAggregateUDF,
};
use datafusion_functions_aggregate::average::AvgAccumulator;

Expand Down Expand Up @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_parameterized_aggregate_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable);
let udf1 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 1,
});
let udf2 = AggregateUDF::from(TestGroupsAccumulator {
signature: signature.clone(),
result: 2,
});

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.aggregate(
[col("text")],
[
udf1.call(vec![col("text")]).alias("a"),
udf2.call(vec![col("text")]).alias("b"),
],
)?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]"
);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+------+---+---+",
"| text | a | b |",
"+------+---+---+",
"| foo | 1 | 2 |",
"+------+---+---+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
Expand Down Expand Up @@ -735,6 +789,21 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
) -> Result<Box<dyn GroupsAccumulator>> {
Ok(Box::new(self.clone()))
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() {
self.result == other.result && self.signature == other.signature
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.signature.hash(hasher);
self.result.hash(hasher);
hasher.finish()
}
}

impl Accumulator for TestGroupsAccumulator {
Expand Down
128 changes: 125 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
// under the License.

use std::any::Any;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_array::builder::BooleanBuilder;
use arrow_array::cast::AsArray;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray,
};
use arrow_schema::{DataType, Field, Schema};
use parking_lot::Mutex;
use regex::Regex;
use sqlparser::ast::Ident;

use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
Expand All @@ -37,8 +46,6 @@ use datafusion_expr::{
Volatility,
};
use datafusion_functions_array::range::range_udf;
use parking_lot::Mutex;
use sqlparser::ast::Ident;

/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and
Expand Down Expand Up @@ -1021,6 +1028,121 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<(
Ok(())
}

#[derive(Debug)]
struct MyRegexUdf {
signature: Signature,
regex: Regex,
}

impl MyRegexUdf {
fn new(pattern: &str) -> Self {
Self {
signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
regex: Regex::new(pattern).expect("regex"),
}
}

fn matches(&self, value: Option<&str>) -> Option<bool> {
Some(self.regex.is_match(value?))
}
}

impl ScalarUDFImpl for MyRegexUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"regex_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, args: &[DataType]) -> Result<DataType> {
if matches!(args, [DataType::Utf8]) {
Ok(DataType::Boolean)
} else {
plan_err!("regex_udf only accepts a Utf8 argument")
}
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args {
[ColumnarValue::Scalar(ScalarValue::Utf8(value))] => {
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(
self.matches(value.as_deref()),
)))
}
[ColumnarValue::Array(values)] => {
let mut builder = BooleanBuilder::with_capacity(values.len());
for value in values.as_string::<i32>() {
builder.append_option(self.matches(value))
}
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
self.regex.as_str() == other.regex.as_str()
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.regex.as_str().hash(hasher);
hasher.finish()
}
}

#[tokio::test]
async fn test_parameterized_scalar_udf() -> Result<()> {
let batch = RecordBatch::try_from_iter([(
"text",
Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef,
)])?;

let ctx = SessionContext::new();
ctx.register_batch("t", batch)?;
let t = ctx.table("t").await?;
let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}"));
let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar"));

let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?)
.filter(
foo_udf
.call(vec![col("text")])
.and(bar_udf.call(vec![col("text")])),
)?
.filter(col("text").is_not_null())?
.build()?;

assert_eq!(
format!("{plan:?}"),
"Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the changes in this PR are the expressions combined by CSE or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular case is deduplicated in PushDownFilter:

            LogicalPlan::Filter(child_filter) => {
                let parents_predicates = split_conjunction_owned(filter.predicate);

                // remove duplicated filters
                let child_predicates = split_conjunction_owned(child_filter.predicate);
                let new_predicates = parents_predicates
                    .into_iter()
                    .chain(child_predicates)
                    // use IndexSet to remove dupes while preserving predicate order
                    .collect::<IndexSet<_>>()
                    .into_iter()
                    .collect::<Vec<_>>();

);

let actual = DataFrame::new(ctx.state(), plan).collect().await?;
let expected = [
"+--------+",
"| text |",
"+--------+",
"| foobar |",
"| barfoo |",
"+--------+",
];
assert_batches_eq!(expected, &actual);

ctx.deregister_table("t")?;
Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
73 changes: 59 additions & 14 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;

use arrow::datatypes::{DataType, Field};
use sqlparser::ast::NullTreatment;

use datafusion_common::{exec_err, not_impl_err, plan_err, Result};

use crate::expr::AggregateFunction;
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
Expand All @@ -26,13 +37,6 @@ use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -72,20 +76,19 @@ pub struct AggregateUDF {

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref())
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

impl std::fmt::Display for AggregateUDF {
impl fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
Expand Down Expand Up @@ -280,7 +283,7 @@ where
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
/// };
/// }
///
/// impl GeoMeanUdf {
/// fn new() -> Self {
Expand Down Expand Up @@ -507,6 +510,33 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}

/// Return true if this aggregate UDF is equal to the other.
///
/// Allows customizing the equality of aggregate UDFs.
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
///
/// - reflexive: `a.equals(a)`;
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
///
/// By default, compares [`Self::name`] and [`Self::signature`].
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Returns a hash value for this aggregate UDF.
///
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
///
/// By default, hashes [`Self::name`] and [`Self::signature`].
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

pub enum ReversedUDAF {
Expand Down Expand Up @@ -562,6 +592,21 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense to me as the name and signature are the same as the inner

if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
} else {
false
}
}

fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.inner.hash_value().hash(hasher);
self.aliases.hash(hasher);
hasher.finish()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
Loading