Skip to content

Commit fbc371e

Browse files
alambappletreeisyellow
authored andcommitted
Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629)
1 parent 7ad104e commit fbc371e

File tree

2 files changed

+73
-21
lines changed

2 files changed

+73
-21
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,8 +1819,8 @@ mod tests {
18191819
let df_results = collect(physical_plan, ctx.task_ctx()).await?;
18201820

18211821
#[rustfmt::skip]
1822-
assert_batches_sorted_eq!(
1823-
[ "+----+",
1822+
assert_batches_sorted_eq!([
1823+
"+----+",
18241824
"| id |",
18251825
"+----+",
18261826
"| 1 |",
@@ -1831,6 +1831,38 @@ mod tests {
18311831
Ok(())
18321832
}
18331833

1834+
#[tokio::test]
1835+
async fn test_aggregate_alias() -> Result<()> {
1836+
let df = test_table().await?;
1837+
1838+
let df = df
1839+
// GROUP BY `c2 + 1`
1840+
.aggregate(vec![col("c2") + lit(1)], vec![])?
1841+
// SELECT `c2 + 1` as c2
1842+
.select(vec![(col("c2") + lit(1)).alias("c2")])?
1843+
// GROUP BY c2 as "c2" (alias in expr is not supported by SQL)
1844+
.aggregate(vec![col("c2").alias("c2")], vec![])?;
1845+
1846+
let df_results = df.collect().await?;
1847+
1848+
#[rustfmt::skip]
1849+
assert_batches_sorted_eq!([
1850+
"+----+",
1851+
"| c2 |",
1852+
"+----+",
1853+
"| 2 |",
1854+
"| 3 |",
1855+
"| 4 |",
1856+
"| 5 |",
1857+
"| 6 |",
1858+
"+----+",
1859+
],
1860+
&df_results
1861+
);
1862+
1863+
Ok(())
1864+
}
1865+
18341866
#[tokio::test]
18351867
async fn test_distinct() -> Result<()> {
18361868
let t = test_table().await?;

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -904,27 +904,11 @@ impl LogicalPlanBuilder {
904904
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
905905
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
906906
) -> Result<Self> {
907-
let mut group_expr = normalize_cols(group_expr, &self.plan)?;
907+
let group_expr = normalize_cols(group_expr, &self.plan)?;
908908
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
909909

910-
// Rewrite groupby exprs according to functional dependencies
911-
let group_by_expr_names = group_expr
912-
.iter()
913-
.map(|group_by_expr| group_by_expr.display_name())
914-
.collect::<Result<Vec<_>>>()?;
915-
let schema = self.plan.schema();
916-
if let Some(target_indices) =
917-
get_target_functional_dependencies(schema, &group_by_expr_names)
918-
{
919-
for idx in target_indices {
920-
let field = schema.field(idx);
921-
let expr =
922-
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
923-
if !group_expr.contains(&expr) {
924-
group_expr.push(expr);
925-
}
926-
}
927-
}
910+
let group_expr =
911+
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
928912
Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr)
929913
.map(LogicalPlan::Aggregate)
930914
.map(Self::from)
@@ -1189,6 +1173,42 @@ pub fn build_join_schema(
11891173
schema.with_functional_dependencies(func_dependencies)
11901174
}
11911175

1176+
/// Add additional "synthetic" group by expressions based on functional
1177+
/// dependencies.
1178+
///
1179+
/// For example, if we are grouping on `[c1]`, and we know from
1180+
/// functional dependencies that column `c1` determines `c2`, this function
1181+
/// adds `c2` to the group by list.
1182+
///
1183+
/// This allows MySQL style selects like
1184+
/// `SELECT col FROM t WHERE pk = 5` if col is unique
1185+
fn add_group_by_exprs_from_dependencies(
1186+
mut group_expr: Vec<Expr>,
1187+
schema: &DFSchemaRef,
1188+
) -> Result<Vec<Expr>> {
1189+
// Names of the fields produced by the GROUP BY exprs for example, `GROUP BY
1190+
// c1 + 1` produces an output field named `"c1 + 1"`
1191+
let mut group_by_field_names = group_expr
1192+
.iter()
1193+
.map(|e| e.display_name())
1194+
.collect::<Result<Vec<_>>>()?;
1195+
1196+
if let Some(target_indices) =
1197+
get_target_functional_dependencies(schema, &group_by_field_names)
1198+
{
1199+
for idx in target_indices {
1200+
let field = schema.field(idx);
1201+
let expr =
1202+
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
1203+
let expr_name = expr.display_name()?;
1204+
if !group_by_field_names.contains(&expr_name) {
1205+
group_by_field_names.push(expr_name);
1206+
group_expr.push(expr);
1207+
}
1208+
}
1209+
}
1210+
Ok(group_expr)
1211+
}
11921212
/// Errors if one or more expressions have equal names.
11931213
pub(crate) fn validate_unique_names<'a>(
11941214
node_name: &str,

0 commit comments

Comments
 (0)