Skip to content

Commit 3773fb7

Browse files
authored
Convert approx_median to UDAF (#10840)
* move tdigest to physical-expr-common * move approx_percentile_cont_accumulator to function-aggregate * implement approx_meidan udaf * remove approx_median aggregation function * fix sqllogictests * add removed type tests * cargo fmt and clippy * add logical roundtrip test * fix dataframe test * fix test and proto gen * update lock in datafusion-cli * fix typo * fix test and doc * fix sql_integration * cargo fmt * follow the checking style like other udaf * add comment and modified dependency * update lock and fmt * add missing test annotation
1 parent e8fdc09 commit 3773fb7

File tree

26 files changed

+471
-497
lines changed

26 files changed

+471
-497
lines changed

datafusion-cli/Cargo.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use datafusion::assert_batches_eq;
3333
use datafusion_common::{DFSchema, ScalarValue};
3434
use datafusion_expr::expr::Alias;
3535
use datafusion_expr::ExprSchemable;
36+
use datafusion_functions_aggregate::expr_fn::approx_median;
3637

3738
fn test_schema() -> SchemaRef {
3839
Arc::new(Schema::new(vec![
@@ -342,7 +343,7 @@ async fn test_fn_approx_median() -> Result<()> {
342343

343344
let expected = [
344345
"+-----------------------+",
345-
"| APPROX_MEDIAN(test.b) |",
346+
"| approx_median(test.b) |",
346347
"+-----------------------+",
347348
"| 10 |",
348349
"+-----------------------+",

datafusion/expr/src/aggregate_function.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ pub enum AggregateFunction {
7171
ApproxPercentileCont,
7272
/// Approximate continuous percentile function with weight
7373
ApproxPercentileContWithWeight,
74-
/// ApproxMedian
75-
ApproxMedian,
7674
/// Grouping
7775
Grouping,
7876
/// Bit And
@@ -112,7 +110,6 @@ impl AggregateFunction {
112110
RegrSXY => "REGR_SXY",
113111
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
114112
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
115-
ApproxMedian => "APPROX_MEDIAN",
116113
Grouping => "GROUPING",
117114
BitAnd => "BIT_AND",
118115
BitOr => "BIT_OR",
@@ -161,7 +158,6 @@ impl FromStr for AggregateFunction {
161158
"regr_sxy" => AggregateFunction::RegrSXY,
162159
// approximate
163160
"approx_distinct" => AggregateFunction::ApproxDistinct,
164-
"approx_median" => AggregateFunction::ApproxMedian,
165161
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
166162
"approx_percentile_cont_with_weight" => {
167163
AggregateFunction::ApproxPercentileContWithWeight
@@ -234,7 +230,6 @@ impl AggregateFunction {
234230
AggregateFunction::ApproxPercentileContWithWeight => {
235231
Ok(coerced_data_types[0].clone())
236232
}
237-
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
238233
AggregateFunction::Grouping => Ok(DataType::Int32),
239234
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
240235
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
@@ -284,7 +279,8 @@ impl AggregateFunction {
284279
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
285280
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
286281
}
287-
AggregateFunction::Avg | AggregateFunction::ApproxMedian => {
282+
283+
AggregateFunction::Avg => {
288284
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
289285
}
290286
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
284284
))
285285
}
286286

287-
/// Calculate an approximation of the median for `expr`.
288-
pub fn approx_median(expr: Expr) -> Expr {
289-
Expr::AggregateFunction(AggregateFunction::new(
290-
aggregate_function::AggregateFunction::ApproxMedian,
291-
vec![expr],
292-
false,
293-
None,
294-
None,
295-
None,
296-
))
297-
}
298-
299287
/// Calculate an approximation of the specified `percentile` for `expr`.
300288
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
301289
Expr::AggregateFunction(AggregateFunction::new(

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,16 +231,6 @@ pub fn coerce_types(
231231
}
232232
Ok(input_types.to_vec())
233233
}
234-
AggregateFunction::ApproxMedian => {
235-
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
236-
return plan_err!(
237-
"The function {:?} does not support inputs of type {:?}.",
238-
agg_fun,
239-
input_types[0]
240-
);
241-
}
242-
Ok(input_types.to_vec())
243-
}
244234
AggregateFunction::NthValue => Ok(input_types.to_vec()),
245235
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
246236
AggregateFunction::StringAgg => {
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution
19+
20+
use std::any::Any;
21+
use std::fmt::Debug;
22+
23+
use arrow::{datatypes::DataType, datatypes::Field};
24+
use arrow_schema::DataType::{Float64, UInt64};
25+
26+
use datafusion_common::{not_impl_err, plan_err, Result};
27+
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
28+
use datafusion_expr::type_coercion::aggregates::NUMERICS;
29+
use datafusion_expr::utils::format_state_name;
30+
use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
31+
use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref;
32+
33+
use crate::approx_percentile_cont::ApproxPercentileAccumulator;
34+
35+
make_udaf_expr_and_func!(
36+
ApproxMedian,
37+
approx_median,
38+
expression,
39+
"Computes the approximate median of a set of numbers",
40+
approx_median_udaf
41+
);
42+
43+
/// APPROX_MEDIAN aggregate expression
44+
pub struct ApproxMedian {
45+
signature: Signature,
46+
}
47+
48+
impl Debug for ApproxMedian {
49+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
50+
f.debug_struct("ApproxMedian")
51+
.field("name", &self.name())
52+
.field("signature", &self.signature)
53+
.finish()
54+
}
55+
}
56+
57+
impl Default for ApproxMedian {
58+
fn default() -> Self {
59+
Self::new()
60+
}
61+
}
62+
63+
impl ApproxMedian {
64+
/// Create a new APPROX_MEDIAN aggregate function
65+
pub fn new() -> Self {
66+
Self {
67+
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
68+
}
69+
}
70+
}
71+
72+
impl AggregateUDFImpl for ApproxMedian {
73+
/// Return a reference to Any that can be used for downcasting
74+
fn as_any(&self) -> &dyn Any {
75+
self
76+
}
77+
78+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
79+
Ok(vec![
80+
Field::new(format_state_name(args.name, "max_size"), UInt64, false),
81+
Field::new(format_state_name(args.name, "sum"), Float64, false),
82+
Field::new(format_state_name(args.name, "count"), Float64, false),
83+
Field::new(format_state_name(args.name, "max"), Float64, false),
84+
Field::new(format_state_name(args.name, "min"), Float64, false),
85+
Field::new_list(
86+
format_state_name(args.name, "centroids"),
87+
Field::new("item", Float64, true),
88+
false,
89+
),
90+
])
91+
}
92+
93+
fn name(&self) -> &str {
94+
"approx_median"
95+
}
96+
97+
fn signature(&self) -> &Signature {
98+
&self.signature
99+
}
100+
101+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
102+
if !arg_types[0].is_numeric() {
103+
return plan_err!("ApproxMedian requires numeric input types");
104+
}
105+
Ok(arg_types[0].clone())
106+
}
107+
108+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
109+
if acc_args.is_distinct {
110+
return not_impl_err!(
111+
"APPROX_MEDIAN(DISTINCT) aggregations are not available"
112+
);
113+
}
114+
115+
Ok(Box::new(ApproxPercentileAccumulator::new(
116+
0.5_f64,
117+
acc_args.input_type.clone(),
118+
)))
119+
}
120+
}
121+
122+
impl PartialEq<dyn Any> for ApproxMedian {
123+
fn eq(&self, other: &dyn Any) -> bool {
124+
down_cast_any_ref(other)
125+
.downcast_ref::<Self>()
126+
.map(|x| self.signature == x.signature)
127+
.unwrap_or(false)
128+
}
129+
}

0 commit comments

Comments
 (0)