Skip to content

Add input_nullable for UDAF args StateField and Accumulator #11063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ pub struct AccumulatorArgs<'a> {
/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// If the input type is nullable.
pub input_nullable: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only need nullable for state_field 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is not used in the PR. I just thought it made sense to make the API more similar but will revert.

But I also noticed this is not enough to resolved the array_agg regression. There are two more limitiations in the current UDAF API. Firstly here
https://github.com/eejbyfeldt/datafusion/blob/18042fd69138e19613844580408a71a200ea6caa/datafusion/physical-expr-common/src/aggregate/mod.rs#L287-L289
the nullability of the returned field is hardcoded to true and it not controllable AggregateUDFImpl. What is the desired way to fix this?

Should be api be changed to instead implement a method fn field?

Or should we add a method return_nullable method with a default false implementation?

I also noticed that the current implementation for array_agg does not propagate the nullability of the input to the field in the returned array. This is probably because the return_type method does not have access to nullability. But probably something we want to be able to resolve in the long run.

Copy link
Contributor

@jayzhan211 jayzhan211 Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the nullability of the returned field is hardcoded to true

I think we can use nullable for field().

    fn field(&self) -> Result<Field> {
        Ok(Field::new(&self.name, self.data_type.clone(), self.nullable))
    }

You can get input_nullable in create_aggregate_expr

    let expr = input_phy_exprs[0].clone();
    let input_nullable = expr.nullable(schema)?;
    Ok(Arc::new(AggregateFunctionExpr {
        fun: fun.clone(),
        args: input_phy_exprs.to_vec(),
        logical_args: input_exprs.to_vec(),
        data_type: fun.return_type(&input_exprs_types)?,
        name: name.into(),
        schema: schema.clone(),
        sort_exprs: sort_exprs.to_vec(),
        ordering_req: ordering_req.to_vec(),
        ignore_nulls,
        ordering_fields,
        is_distinct,
        input_type: input_exprs_types[0].clone(),
        input_nullable,
    }))

I also noticed that the current implementation for array_agg does not propagate the nullability of the input to the field in the returned array. This is probably because the return_type method does not have access to nullability. But probably something we want to be able to resolve in the long run.

I think nullable is both set in state_field and field, so the returned array should match the schema of them. 🤔

Copy link
Contributor Author

@eejbyfeldt eejbyfeldt Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use nullable for field().

Looking closer at the code I see that this will maintain the behavior of the old code. But it seems wrong to me that we in general assume that the aggregate maintains the nullability of the input type. If we consider the aggregate array_agg. Then there are two "nullable" fields in the return value the "top level" value and the "field inside" the returned array. I think our array_agg (or at least a possible array_agg) will return an empty array when there are no values. This means that the nullability of the "top level" field should always be false regardless of input nullability and the nullabillity that depends on the input is the "field inside" the array. Note that I think the existing code also does not implement this correctly.

I tried out the suggested fix and that will break existing code. Probably because it wrong for some existing aggregtes like sum that might return null even if the input is not nullable. So that is further indication that is not the correct way to go.

This comment was marked as outdated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nullability is introduced in #8055
There might be other way to fix #8055 🤔

Copy link
Contributor

@jayzhan211 jayzhan211 Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I test the code in #8032, and there is no error after I change the "top level null" back to false 🤔

Copy link
Contributor

@jayzhan211 jayzhan211 Jun 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the field for array agg should be, the nullable is the nullability of element, not the nullability of the List

 fn field(&self) -> Result<Field> {
        Ok(Field::new_list(
            &self.name,
            // This should be the same as return type of AggregateFunction::ArrayAgg
            Field::new("item", self.input_data_type.clone(), self.nullable),
            false,
        ))
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was what I was trying to explain. I created this PR that fixes that #11093 it required some other changes to make that change possible.


/// The logical expression of arguments the aggregate function takes.
pub input_exprs: &'a [Expr],
}
Expand All @@ -98,6 +101,9 @@ pub struct StateFieldsArgs<'a> {
/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// If the input type is nullable.
pub input_nullable: bool,

/// The return type of the aggregate function.
pub return_type: &'a DataType,

Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ impl AggregateUDFImpl for LastValue {
let StateFieldsArgs {
name,
input_type,
input_nullable: _,
return_type: _,
ordering_fields,
is_distinct: _,
Expand Down
7 changes: 7 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub fn create_aggregate_expr(
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
input_nullable: input_phy_exprs[0].nullable(&schema)?,
}))
}

Expand Down Expand Up @@ -248,6 +249,7 @@ pub struct AggregateFunctionExpr {
ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
input_nullable: bool,
}

impl AggregateFunctionExpr {
Expand Down Expand Up @@ -276,6 +278,7 @@ impl AggregateExpr for AggregateFunctionExpr {
let args = StateFieldsArgs {
name: &self.name,
input_type: &self.input_type,
input_nullable: self.input_nullable,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
Expand All @@ -296,6 +299,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_nullable: self.input_nullable,
input_exprs: &self.logical_args,
name: &self.name,
};
Expand All @@ -311,6 +315,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_nullable: self.input_nullable,
input_exprs: &self.logical_args,
name: &self.name,
};
Expand Down Expand Up @@ -381,6 +386,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_nullable: self.input_nullable,
input_exprs: &self.logical_args,
name: &self.name,
};
Expand All @@ -395,6 +401,7 @@ impl AggregateExpr for AggregateFunctionExpr {
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_nullable: self.input_nullable,
input_exprs: &self.logical_args,
name: &self.name,
};
Expand Down
Loading