Skip to content

Commit c902ed8

Browse files
yyy1000findepi
authored andcommitted
Move Covariance (Population) covar_pop to be a User Defined Aggregate Function (apache#10418)
* move covariance * add sqllogictest
1 parent 20e4682 commit c902ed8

File tree

15 files changed

+273
-424
lines changed

15 files changed

+273
-424
lines changed

datafusion/expr/src/aggregate_function.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ pub enum AggregateFunction {
6363
Stddev,
6464
/// Standard Deviation (Population)
6565
StddevPop,
66-
/// Covariance (Population)
67-
CovariancePop,
6866
/// Correlation
6967
Correlation,
7068
/// Slope from linear regression
@@ -126,7 +124,6 @@ impl AggregateFunction {
126124
VariancePop => "VAR_POP",
127125
Stddev => "STDDEV",
128126
StddevPop => "STDDEV_POP",
129-
CovariancePop => "COVAR_POP",
130127
Correlation => "CORR",
131128
RegrSlope => "REGR_SLOPE",
132129
RegrIntercept => "REGR_INTERCEPT",
@@ -181,7 +178,6 @@ impl FromStr for AggregateFunction {
181178
"string_agg" => AggregateFunction::StringAgg,
182179
// statistical
183180
"corr" => AggregateFunction::Correlation,
184-
"covar_pop" => AggregateFunction::CovariancePop,
185181
"stddev" => AggregateFunction::Stddev,
186182
"stddev_pop" => AggregateFunction::StddevPop,
187183
"stddev_samp" => AggregateFunction::Stddev,
@@ -255,9 +251,6 @@ impl AggregateFunction {
255251
AggregateFunction::VariancePop => {
256252
variance_return_type(&coerced_data_types[0])
257253
}
258-
AggregateFunction::CovariancePop => {
259-
covariance_return_type(&coerced_data_types[0])
260-
}
261254
AggregateFunction::Correlation => {
262255
correlation_return_type(&coerced_data_types[0])
263256
}
@@ -349,8 +342,7 @@ impl AggregateFunction {
349342
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
350343
}
351344
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
352-
AggregateFunction::CovariancePop
353-
| AggregateFunction::Correlation
345+
AggregateFunction::Correlation
354346
| AggregateFunction::RegrSlope
355347
| AggregateFunction::RegrIntercept
356348
| AggregateFunction::RegrCount

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,6 @@ pub fn coerce_types(
183183
}
184184
Ok(vec![Float64, Float64])
185185
}
186-
AggregateFunction::CovariancePop => {
187-
if !is_covariance_support_arg_type(&input_types[0]) {
188-
return plan_err!(
189-
"The function {:?} does not support inputs of type {:?}.",
190-
agg_fun,
191-
input_types[0]
192-
);
193-
}
194-
Ok(vec![Float64, Float64])
195-
}
196186
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
197187
if !is_stddev_support_arg_type(&input_types[0]) {
198188
return plan_err!(

datafusion/functions-aggregate/src/covariance.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ make_udaf_expr_and_func!(
4343
covar_samp_udaf
4444
);
4545

46+
make_udaf_expr_and_func!(
47+
CovariancePopulation,
48+
covar_pop,
49+
y x,
50+
"Computes the population covariance.",
51+
covar_pop_udaf
52+
);
53+
4654
pub struct CovarianceSample {
4755
signature: Signature,
4856
aliases: Vec<String>,
@@ -120,6 +128,79 @@ impl AggregateUDFImpl for CovarianceSample {
120128
}
121129
}
122130

131+
pub struct CovariancePopulation {
132+
signature: Signature,
133+
}
134+
135+
impl Debug for CovariancePopulation {
136+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
137+
f.debug_struct("CovariancePopulation")
138+
.field("name", &self.name())
139+
.field("signature", &self.signature)
140+
.finish()
141+
}
142+
}
143+
144+
impl Default for CovariancePopulation {
145+
fn default() -> Self {
146+
Self::new()
147+
}
148+
}
149+
150+
impl CovariancePopulation {
151+
pub fn new() -> Self {
152+
Self {
153+
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
154+
}
155+
}
156+
}
157+
158+
impl AggregateUDFImpl for CovariancePopulation {
159+
fn as_any(&self) -> &dyn std::any::Any {
160+
self
161+
}
162+
163+
fn name(&self) -> &str {
164+
"covar_pop"
165+
}
166+
167+
fn signature(&self) -> &Signature {
168+
&self.signature
169+
}
170+
171+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
172+
if !arg_types[0].is_numeric() {
173+
return plan_err!("Covariance requires numeric input types");
174+
}
175+
176+
Ok(DataType::Float64)
177+
}
178+
179+
fn state_fields(
180+
&self,
181+
name: &str,
182+
_value_type: DataType,
183+
_ordering_fields: Vec<Field>,
184+
) -> Result<Vec<Field>> {
185+
Ok(vec![
186+
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
187+
Field::new(format_state_name(name, "mean1"), DataType::Float64, true),
188+
Field::new(format_state_name(name, "mean2"), DataType::Float64, true),
189+
Field::new(
190+
format_state_name(name, "algo_const"),
191+
DataType::Float64,
192+
true,
193+
),
194+
])
195+
}
196+
197+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
198+
Ok(Box::new(CovarianceAccumulator::try_new(
199+
StatsType::Population,
200+
)?))
201+
}
202+
}
203+
123204
/// An accumulator to compute covariance
124205
/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper
125206
/// for calculating variance:

datafusion/functions-aggregate/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
7575
let functions: Vec<Arc<AggregateUDF>> = vec![
7676
first_last::first_value_udaf(),
7777
covariance::covar_samp_udaf(),
78+
covariance::covar_pop_udaf(),
7879
];
7980

8081
functions.into_iter().try_for_each(|udf| {

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,6 @@ pub fn create_aggregate_expr(
181181
(AggregateFunction::VariancePop, true) => {
182182
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
183183
}
184-
(AggregateFunction::CovariancePop, false) => {
185-
Arc::new(expressions::CovariancePop::new(
186-
input_phy_exprs[0].clone(),
187-
input_phy_exprs[1].clone(),
188-
name,
189-
data_type,
190-
))
191-
}
192-
(AggregateFunction::CovariancePop, true) => {
193-
return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available");
194-
}
195184
(AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new(
196185
input_phy_exprs[0].clone(),
197186
name,

0 commit comments

Comments
 (0)